Skip to content

Commit

Permalink
Add helper for specifying model to bind column-like to.
Browse files Browse the repository at this point in the history
  • Loading branch information
coleifer committed Aug 11, 2022
1 parent 67f061b commit eeff376
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
18 changes: 18 additions & 0 deletions peewee.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,6 +1128,9 @@ def alias(self, alias):
def unalias(self):
return self

def bind_to(self, dest):
return BindTo(self, dest)

def cast(self, as_type):
return Cast(self, as_type)

Expand Down Expand Up @@ -1338,6 +1341,15 @@ def __sql__(self, ctx):
return ctx.sql(Entity(self._alias))


class BindTo(WrappedNode):
def __init__(self, node, dest):
super(BindTo, self).__init__(node)
self.dest = dest

def __sql__(self, ctx):
return ctx.sql(self.node)


class Negated(WrappedNode):
def __invert__(self):
return self.node
Expand Down Expand Up @@ -7721,6 +7733,12 @@ def initialize(self):
key = field.source
else:
key = field.model
elif isinstance(node, BindTo):
if node.dest not in self.key_to_constructor:
raise ValueError('%s specifies bind-to %s, but %s is not '
'among the selected sources.' %
(node.unwrap(), node.dest, node.dest))
key = node.dest
else:
if isinstance(node, Node):
node = node.unwrap()
Expand Down
49 changes: 49 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4868,3 +4868,52 @@ def test_data_modifying_cte_insert(self):
self.assertEqual(sorted([(p.name, p.price) for p in C_Product]), [
('p0', 0), ('p1', 1), ('p2', 2), ('p3', 3), ('p4', 4), ('p5', 5),
('p6', 6)])


class TestBindTo(ModelTestCase):
requires = [User, Tweet]

def test_bind_to(self):
for i in (1, 2, 3):
user = User.create(username='u%s' % i)
Tweet.create(user=user, content='t%s' % i)

# Alias to a particular field-name.
name = Case(User.username, [
('u1', 'user 1'),
('u2', 'user 2')], 'someone else')
q = (Tweet
.select(Tweet.content, name.alias('username').bind_to(User))
.join(User)
.order_by(Tweet.content))
with self.assertQueryCount(1):
self.assertEqual([(t.content, t.user.username) for t in q], [
('t1', 'user 1'),
('t2', 'user 2'),
('t3', 'someone else')])

# Use a different alias.
q = (Tweet
.select(Tweet.content, name.alias('display').bind_to(User))
.join(User)
.order_by(Tweet.content))
with self.assertQueryCount(1):
self.assertEqual([(t.content, t.user.display) for t in q], [
('t1', 'user 1'),
('t2', 'user 2'),
('t3', 'someone else')])

# Ensure works with model and field aliases.
TA, UA = Tweet.alias(), User.alias()
name = Case(UA.username, [
('u1', 'user 1'),
('u2', 'user 2')], 'someone else')
q = (TA
.select(TA.content, name.alias('display').bind_to(UA))
.join(UA, on=(UA.id == TA.user))
.order_by(TA.content))
with self.assertQueryCount(1):
self.assertEqual([(t.content, t.user.display) for t in q], [
('t1', 'user 1'),
('t2', 'user 2'),
('t3', 'someone else')])

0 comments on commit eeff376

Please sign in to comment.