Skip to content

Commit

Permalink
Fixed #22648 -- Transform.output_type should respect overridden custo…
Browse files Browse the repository at this point in the history
…m_lookup and custom_transform.

Previously, class lookups from the output_type would be used, but any
changes to custom_lookup or custom_transform would be ignored.
  • Loading branch information
mjtamlyn committed May 17, 2014
1 parent effa9da commit 253ef84
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 2 deletions.
6 changes: 4 additions & 2 deletions django/db/models/lookups.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,20 @@ def _get_lookup(self, lookup_name):
except AttributeError:
# This class didn't have any class_lookups
pass
if hasattr(self, 'output_type'):
return self.output_type.get_lookup(lookup_name)
return None

def get_lookup(self, lookup_name):
found = self._get_lookup(lookup_name)
if found is None and hasattr(self, 'output_type'):
return self.output_type.get_lookup(lookup_name)
if found is not None and not issubclass(found, Lookup):
return None
return found

def get_transform(self, lookup_name):
found = self._get_lookup(lookup_name)
if found is None and hasattr(self, 'output_type'):
return self.output_type.get_transform(lookup_name)
if found is not None and not issubclass(found, Transform):
return None
return found
Expand Down
60 changes: 60 additions & 0 deletions tests/custom_lookups/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,47 @@ def as_sql(self, qn, connection):
YearTransform.register_lookup(YearLte)


class SQLFunc(models.Lookup):
def __init__(self, name, *args, **kwargs):
super(SQLFunc, self).__init__(*args, **kwargs)
self.name = name

def as_sql(self, qn, connection):
return '%s()', [self.name]

@property
def output_type(self):
return CustomField()


class SQLFuncFactory(object):

def __init__(self, name):
self.name = name

def __call__(self, *args, **kwargs):
return SQLFunc(self.name, *args, **kwargs)


class CustomField(models.Field):

def get_lookup(self, lookup_name):
if lookup_name.startswith('lookupfunc_'):
key, name = lookup_name.split('_', 1)
return SQLFuncFactory(name)
return super(CustomField, self).get_lookup(lookup_name)

def get_transform(self, lookup_name):
if lookup_name.startswith('transformfunc_'):
key, name = lookup_name.split('_', 1)
return SQLFuncFactory(name)
return super(CustomField, self).get_transform(lookup_name)


class CustomModel(models.Model):
field = CustomField()


# We will register this class temporarily in the test method.


Expand Down Expand Up @@ -341,3 +382,22 @@ def test_call_order(self):

finally:
models.DateField._unregister_lookup(TrackCallsYearTransform)


class CustomisedMethodsTests(TestCase):

def test_overridden_get_lookup(self):
q = CustomModel.objects.filter(field__lookupfunc_monkeys=3)
self.assertIn('monkeys()', str(q.query))

def test_overridden_get_transform(self):
q = CustomModel.objects.filter(field__transformfunc_banana=3)
self.assertIn('banana()', str(q.query))

def test_overridden_get_lookup_chain(self):
q = CustomModel.objects.filter(field__transformfunc_banana__lookupfunc_elephants=3)
self.assertIn('elephants()', str(q.query))

def test_overridden_get_transform_chain(self):
q = CustomModel.objects.filter(field__transformfunc_banana__transformfunc_pear=3)
self.assertIn('pear()', str(q.query))

0 comments on commit 253ef84

Please sign in to comment.