Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 48 additions & 11 deletions netfields/managers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from netaddr import IPNetwork

from django import VERSION
from django.db import models, connection
from django.db.backends.postgresql_psycopg2.base import DatabaseWrapper
from django.db.models.fields import DateTimeField
from django.db.models import sql, query
from django.db.models.query_utils import QueryWrapper
from django.utils import tree
Expand Down Expand Up @@ -30,27 +32,29 @@ class NetQuery(sql.Query):


class NetWhere(sql.where.WhereNode):
def add(self, data, connector):


def _prepare_data(self, data):
"""
Special form of WhereNode.add() that does not automatically consume the
__iter__ method of IPNetwork objects.
Special form of WhereNode._prepare_data() that does not automatically consume the
__iter__ method of IPNetwork objects. This is used in Django >= 1.6
"""
if not isinstance(data, (list, tuple)):
# Need to bypass WhereNode
tree.Node.add(self, data, connector)
return

if not isinstance(data, (list, tuple)):
return data
obj, lookup_type, value = data
if not isinstance(value, IPNetwork) and hasattr(value, '__iter__') and hasattr(value, 'next'):
# Consume any generators immediately, so that we can determine
# emptiness and transform any non-empty values correctly.
value = list(value)


# The "value_annotation" parameter is used to pass auxilliary information
# about the value(s) to the query construction. Specifically, datetime
# and empty values need special handling. Other types could be used
# here in the future (using Python types is suggested for consistency).
if isinstance(value, datetime.datetime):
if (isinstance(value, datetime.datetime)
or (isinstance(obj.field, DateTimeField) and lookup_type != 'isnull')):
value_annotation = datetime.datetime
elif hasattr(value, 'value_annotation'):
value_annotation = value.value_annotation
Expand All @@ -59,10 +63,43 @@ def add(self, data, connector):

if hasattr(obj, "prepare"):
value = obj.prepare(lookup_type, value)
return (obj, lookup_type, value_annotation, value)


if VERSION[:2] < (1, 6):
def add(self, data, connector):
"""
Special form of WhereNode.add() that does not automatically consume the
__iter__ method of IPNetwork objects.
"""
if not isinstance(data, (list, tuple)):
# Need to bypass WhereNode
tree.Node.add(self, data, connector)
return

obj, lookup_type, value = data
if not isinstance(value, IPNetwork) and hasattr(value, '__iter__') and hasattr(value, 'next'):
# Consume any generators immediately, so that we can determine
# emptiness and transform any non-empty values correctly.
value = list(value)

# The "value_annotation" parameter is used to pass auxilliary information
# about the value(s) to the query construction. Specifically, datetime
# and empty values need special handling. Other types could be used
# here in the future (using Python types is suggested for consistency).
if isinstance(value, datetime.datetime):
value_annotation = datetime.datetime
elif hasattr(value, 'value_annotation'):
value_annotation = value.value_annotation
else:
value_annotation = bool(value)

# Need to bypass WhereNode
tree.Node.add(self,
(obj, lookup_type, value_annotation, value), connector)
if hasattr(obj, "prepare"):
value = obj.prepare(lookup_type, value)

# Need to bypass WhereNode
tree.Node.add(self,
(obj, lookup_type, value_annotation, value), connector)

def make_atom(self, child, qn, conn):
lvalue, lookup_type, value_annot, params_or_value = child
Expand Down
17 changes: 13 additions & 4 deletions netfields/tests.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from netaddr import IPAddress, IPNetwork, EUI, AddrFormatError

from django import VERSION
from django.db import IntegrityError
from django.forms import ModelForm
from django.test import TestCase
Expand Down Expand Up @@ -355,12 +356,20 @@ def test_iendswith_lookup(self):
self.select + 'WHERE UPPER("table"."field"::text) LIKE UPPER(%s) ')

def test_regex_lookup(self):
self.assertSqlEquals(self.qs.filter(field__regex='00'),
self.select + 'WHERE "table"."field" ~ %s ')
if VERSION[:2] < (1, 6):
self.assertSqlEquals(self.qs.filter(field__regex='00'),
self.select + 'WHERE "table"."field" ~ %s ')
else:
self.assertSqlEquals(self.qs.filter(field__regex='00'),
self.select + 'WHERE "table"."field"::text ~ %s ')

def test_iregex_lookup(self):
self.assertSqlEquals(self.qs.filter(field__iregex='00'),
self.select + 'WHERE "table"."field" ~* %s ')
if VERSION[:2] < (1, 6):
self.assertSqlEquals(self.qs.filter(field__iregex='00'),
self.select + 'WHERE "table"."field" ~* %s ')
else:
self.assertSqlEquals(self.qs.filter(field__iregex='00'),
self.select + 'WHERE "table"."field"::text ~* %s ')


class TestMacAddressField(BaseMacTestCase, TestCase):
Expand Down