diff --git a/peewee.py b/peewee.py index efecd23ca..2279feb6c 100644 --- a/peewee.py +++ b/peewee.py @@ -1128,6 +1128,9 @@ def alias(self, alias): def unalias(self): return self + def bind_to(self, dest): + return BindTo(self, dest) + def cast(self, as_type): return Cast(self, as_type) @@ -1338,6 +1341,15 @@ def __sql__(self, ctx): return ctx.sql(Entity(self._alias)) +class BindTo(WrappedNode): + def __init__(self, node, dest): + super(BindTo, self).__init__(node) + self.dest = dest + + def __sql__(self, ctx): + return ctx.sql(self.node) + + class Negated(WrappedNode): def __invert__(self): return self.node @@ -7721,6 +7733,12 @@ def initialize(self): key = field.source else: key = field.model + elif isinstance(node, BindTo): + if node.dest not in self.key_to_constructor: + raise ValueError('%s specifies bind-to %s, but %s is not ' + 'among the selected sources.' % + (node.unwrap(), node.dest, node.dest)) + key = node.dest else: if isinstance(node, Node): node = node.unwrap() diff --git a/tests/models.py b/tests/models.py index 3bed3d6d0..040f8a8f9 100644 --- a/tests/models.py +++ b/tests/models.py @@ -4868,3 +4868,52 @@ def test_data_modifying_cte_insert(self): self.assertEqual(sorted([(p.name, p.price) for p in C_Product]), [ ('p0', 0), ('p1', 1), ('p2', 2), ('p3', 3), ('p4', 4), ('p5', 5), ('p6', 6)]) + + +class TestBindTo(ModelTestCase): + requires = [User, Tweet] + + def test_bind_to(self): + for i in (1, 2, 3): + user = User.create(username='u%s' % i) + Tweet.create(user=user, content='t%s' % i) + + # Alias to a particular field-name. + name = Case(User.username, [ + ('u1', 'user 1'), + ('u2', 'user 2')], 'someone else') + q = (Tweet + .select(Tweet.content, name.alias('username').bind_to(User)) + .join(User) + .order_by(Tweet.content)) + with self.assertQueryCount(1): + self.assertEqual([(t.content, t.user.username) for t in q], [ + ('t1', 'user 1'), + ('t2', 'user 2'), + ('t3', 'someone else')]) + + # Use a different alias. + q = (Tweet + .select(Tweet.content, name.alias('display').bind_to(User)) + .join(User) + .order_by(Tweet.content)) + with self.assertQueryCount(1): + self.assertEqual([(t.content, t.user.display) for t in q], [ + ('t1', 'user 1'), + ('t2', 'user 2'), + ('t3', 'someone else')]) + + # Ensure works with model and field aliases. + TA, UA = Tweet.alias(), User.alias() + name = Case(UA.username, [ + ('u1', 'user 1'), + ('u2', 'user 2')], 'someone else') + q = (TA + .select(TA.content, name.alias('display').bind_to(UA)) + .join(UA, on=(UA.id == TA.user)) + .order_by(TA.content)) + with self.assertQueryCount(1): + self.assertEqual([(t.content, t.user.display) for t in q], [ + ('t1', 'user 1'), + ('t2', 'user 2'), + ('t3', 'someone else')])