diff --git a/.travis.yml b/.travis.yml index 380ff79..7797281 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,6 +8,8 @@ env: - DJANGO='Django<=1.4' PYSCOPG2='psycopg2==2.4.1' # see django issue #16250 - DJANGO='Django<=1.5' PYSCOPG2='psycopg2' - DJANGO='Django<=1.6' PYSCOPG2='psycopg2' + - DJANGO='Django<=1.7' PYSCOPG2='psycopg2' + - DJANGO='Django<=1.8' PYSCOPG2='psycopg2' install: - pip install -q $PYSCOPG2 --use-mirrors @@ -19,3 +21,10 @@ before_script: - psql -c 'create database netfields;' -U postgres script: "./manage.py test" + +matrix: + exclude: + - python: "2.6" + env: DJANGO='Django<=1.7' PYSCOPG2='psycopg2' + - python: "2.6" + env: DJANGO='Django<=1.8' PYSCOPG2='psycopg2' diff --git a/netfields/apps.py b/netfields/apps.py index 67d032e..c58add0 100644 --- a/netfields/apps.py +++ b/netfields/apps.py @@ -1,17 +1,39 @@ from django.apps import AppConfig +from django.db.models.lookups import default_lookups from netfields.fields import CidrAddressField, InetAddressField -from netfields.lookups import NetContained, NetContains, NetContainedOrEqual, NetContainsOrEquals +from netfields.lookups import NetContained, NetContains, NetContainedOrEqual, NetContainsOrEquals, InvalidLookup +from netfields.lookups import EndsWith, IEndsWith, StartsWith, IStartsWith, Regex, IRegex class NetfieldsConfig(AppConfig): name = 'netfields' + for lookup in default_lookups.keys(): + if lookup not in ['contains', 'startswith', 'endswith', 'icontains', 'istartswith', 'iendswith', 'isnull', 'in', + 'exact', 'iexact', 'regex', 'iregex', 'lt', 'lte', 'gt', 'gte', 'equals', 'iequals', 'range']: + invalid_lookup = InvalidLookup + invalid_lookup.lookup_name = lookup + CidrAddressField.register_lookup(invalid_lookup) + InetAddressField.register_lookup(invalid_lookup) + + CidrAddressField.register_lookup(EndsWith) + CidrAddressField.register_lookup(IEndsWith) + CidrAddressField.register_lookup(StartsWith) + CidrAddressField.register_lookup(IStartsWith) + CidrAddressField.register_lookup(Regex) + CidrAddressField.register_lookup(IRegex) CidrAddressField.register_lookup(NetContained) CidrAddressField.register_lookup(NetContains) CidrAddressField.register_lookup(NetContainedOrEqual) CidrAddressField.register_lookup(NetContainsOrEquals) + InetAddressField.register_lookup(EndsWith) + InetAddressField.register_lookup(IEndsWith) + InetAddressField.register_lookup(StartsWith) + InetAddressField.register_lookup(IStartsWith) + InetAddressField.register_lookup(Regex) + InetAddressField.register_lookup(IRegex) InetAddressField.register_lookup(NetContained) InetAddressField.register_lookup(NetContains) InetAddressField.register_lookup(NetContainedOrEqual) diff --git a/netfields/fields.py b/netfields/fields.py index 277f805..cfd32fb 100644 --- a/netfields/fields.py +++ b/netfields/fields.py @@ -1,6 +1,7 @@ from netaddr import IPAddress, IPNetwork, EUI from netaddr.core import AddrFormatError +from django import VERSION from django.db import models from django.core.exceptions import ValidationError @@ -35,7 +36,7 @@ def get_prep_lookup(self, lookup_type, value): NET_OPERATORS[lookup_type] not in NET_TEXT_OPERATORS): if lookup_type.startswith('net_contained') and value is not None: # Argument will be CIDR - return unicode(value) + return str(value) return self.get_prep_value(value) return super(_NetAddressField, self).get_prep_lookup( @@ -45,7 +46,7 @@ def get_prep_value(self, value): if not value: return None - return unicode(self.to_python(value)) + return str(self.to_python(value)) def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False): @@ -64,6 +65,12 @@ def formfield(self, **kwargs): defaults.update(kwargs) return super(_NetAddressField, self).formfield(**defaults) + if VERSION[:2] >= (1, 7): + def deconstruct(self): + name, path, args, kwargs = super(_NetAddressField, self).deconstruct() + if self.max_length is not None: + kwargs['max_length'] = self.max_length + return name, path, args, kwargs class InetAddressField(_NetAddressField): @@ -115,19 +122,20 @@ def get_prep_value(self, value): if not value: return None - return unicode(self.to_python(value)) + return str(self.to_python(value)) def formfield(self, **kwargs): defaults = {'form_class': MACAddressFormField} defaults.update(kwargs) return super(MACAddressField, self).formfield(**defaults) -try: - from south.modelsinspector import add_introspection_rules - add_introspection_rules([], [ - "^netfields\.fields\.InetAddressField", - "^netfields\.fields\.CidrAddressField", - "^netfields\.fields\.MACAddressField", - ]) -except ImportError: - pass +if VERSION[:2] < (1, 7): + try: + from south.modelsinspector import add_introspection_rules + add_introspection_rules([], [ + "^netfields\.fields\.InetAddressField", + "^netfields\.fields\.CidrAddressField", + "^netfields\.fields\.MACAddressField", + ]) + except ImportError: + pass diff --git a/netfields/forms.py b/netfields/forms.py index fdeacd8..7143ac0 100644 --- a/netfields/forms.py +++ b/netfields/forms.py @@ -1,7 +1,6 @@ from netaddr import IPAddress, IPNetwork, EUI, AddrFormatError from django import forms -from django.utils.encoding import force_unicode from django.utils.safestring import mark_safe from django.core.exceptions import ValidationError @@ -17,7 +16,7 @@ def render(self, name, value, attrs=None): value = '' final_attrs = self.build_attrs(attrs, type=self.input_type, name=name) if value: - final_attrs['value'] = force_unicode(value) + final_attrs['value'] = value return mark_safe(u'' % forms.util.flatatt(final_attrs)) @@ -39,7 +38,7 @@ def to_python(self, value): try: return IPAddress(value) - except (AddrFormatError, TypeError), e: + except (AddrFormatError, TypeError) as e: raise ValidationError(str(e)) @@ -61,7 +60,7 @@ def to_python(self, value): try: return IPNetwork(value) - except (AddrFormatError, TypeError), e: + except (AddrFormatError, TypeError) as e: raise ValidationError(str(e)) diff --git a/netfields/lookups.py b/netfields/lookups.py index 5ec547e..c8b9ad2 100644 --- a/netfields/lookups.py +++ b/netfields/lookups.py @@ -1,5 +1,47 @@ from django.db.models import Lookup -from django.db.models.lookups import IExact, IContains +from django.db.models.lookups import BuiltinLookup +from netfields.fields import InetAddressField, CidrAddressField + + +class InvalidLookup(BuiltinLookup): + def as_sql(self, qn, connection): + raise ValueError('Invalid lookup type "%s"' % self.lookup_name) + + +class NetFieldDecoratorMixin(object): + def process_lhs(self, qn, connection, lhs=None): + lhs = lhs or self.lhs + lhs_string, lhs_params = qn.compile(lhs) + if isinstance(lhs.source if hasattr(lhs, 'source') else lhs.output_field, InetAddressField): + lhs_string = 'HOST(%s)' % lhs_string + elif isinstance(lhs.source if hasattr(lhs, 'source') else lhs.output_field, CidrAddressField): + lhs_string = 'TEXT(%s)' % lhs_string + return lhs_string, lhs_params + + +class EndsWith(NetFieldDecoratorMixin, BuiltinLookup): + lookup_name = 'endswith' + + +class IEndsWith(NetFieldDecoratorMixin, BuiltinLookup): + lookup_name = 'iendswith' + + +class StartsWith(NetFieldDecoratorMixin, BuiltinLookup): + lookup_name = 'startswith' + + +class IStartsWith(NetFieldDecoratorMixin, BuiltinLookup): + lookup_name = 'istartswith' + + +class Regex(NetFieldDecoratorMixin, BuiltinLookup): + lookup_name = 'regex' + + +class IRegex(NetFieldDecoratorMixin, BuiltinLookup): + lookup_name = 'iregex' + class NetContains(Lookup): lookup_name = 'net_contains' @@ -38,4 +80,3 @@ def as_sql(self, qn, connection): rhs, rhs_params = self.process_rhs(qn, connection) params = lhs_params + rhs_params return '%s <<= %s' % (lhs, rhs), params - diff --git a/netfields/managers.py b/netfields/managers.py index f76ace0..af6dbe6 100644 --- a/netfields/managers.py +++ b/netfields/managers.py @@ -48,7 +48,6 @@ def _prepare_data(self, data): # emptiness and transform any non-empty values correctly. value = list(value) - # The "value_annotation" parameter is used to pass auxilliary information # about the value(s) to the query construction. Specifically, datetime # and empty values need special handling. Other types could be used @@ -101,64 +100,70 @@ def add(self, data, connector): tree.Node.add(self, (obj, lookup_type, value_annotation, value), connector) - def make_atom(self, child, qn, conn): - lvalue, lookup_type, value_annot, params_or_value = child - - if hasattr(lvalue, 'process'): - try: - lvalue, params = lvalue.process(lookup_type, params_or_value, - connection) - except sql.where.EmptyShortCircuit: - raise query.EmptyResultSet - else: - return super(NetWhere, self).make_atom(child, qn, conn) + if VERSION[:2] < (1, 7): + def make_atom(self, child, qn, conn): + lvalue, lookup_type, value_annot, params_or_value = child - table_alias, name, db_type = lvalue + if hasattr(lvalue, 'process'): + try: + lvalue, params = lvalue.process(lookup_type, params_or_value, + connection) + except sql.where.EmptyShortCircuit: + raise query.EmptyResultSet + else: + return super(NetWhere, self).make_atom(child, qn, conn) - if db_type not in ['inet', 'cidr']: - return super(NetWhere, self).make_atom(child, qn, conn) + table_alias, name, db_type = lvalue - if table_alias: - field_sql = '%s.%s' % (qn(table_alias), qn(name)) - else: - field_sql = qn(name) + if db_type not in ['inet', 'cidr']: + return super(NetWhere, self).make_atom(child, qn, conn) - if NET_OPERATORS.get(lookup_type, '') in NET_TEXT_OPERATORS: - if db_type == 'inet': - field_sql = 'HOST(%s)' % field_sql + if table_alias: + field_sql = '%s.%s' % (qn(table_alias), qn(name)) else: - field_sql = 'TEXT(%s)' % field_sql + field_sql = qn(name) - if isinstance(params, QueryWrapper): - extra, params = params.data - else: - extra = '' - - if isinstance(params, basestring): - params = (params,) - - if lookup_type in NET_OPERATORS: - return (' '.join([field_sql, NET_OPERATORS[lookup_type], extra]), - params) - elif lookup_type == 'in': - if not value_annot: - raise sql.datastructures.EmptyResultSet - if extra: - return ('%s IN %s' % (field_sql, extra), params) - return ('%s IN (%s)' % (field_sql, ', '.join(['%s'] * - len(params))), params) - elif lookup_type == 'range': - return ('%s BETWEEN %%s and %%s' % field_sql, params) - elif lookup_type == 'isnull': - return ('%s IS %sNULL' % (field_sql, (not value_annot and 'NOT ' or - '')), params) - - raise ValueError('Invalid lookup type "%s"' % lookup_type) + if NET_OPERATORS.get(lookup_type, '') in NET_TEXT_OPERATORS: + if db_type == 'inet': + field_sql = 'HOST(%s)' % field_sql + else: + field_sql = 'TEXT(%s)' % field_sql + + if isinstance(params, QueryWrapper): + extra, params = params.data + else: + extra = '' + + if isinstance(params, basestring): + params = (params,) + + if lookup_type in NET_OPERATORS: + return (' '.join([field_sql, NET_OPERATORS[lookup_type], extra]), + params) + elif lookup_type == 'in': + if not value_annot: + raise sql.datastructures.EmptyResultSet + if extra: + return ('%s IN %s' % (field_sql, extra), params) + return ('%s IN (%s)' % (field_sql, ', '.join(['%s'] * + len(params))), params) + elif lookup_type == 'range': + return ('%s BETWEEN %%s and %%s' % field_sql, params) + elif lookup_type == 'isnull': + return ('%s IS %sNULL' % (field_sql, (not value_annot and 'NOT ' or + '')), params) + + raise ValueError('Invalid lookup type "%s"' % lookup_type) class NetManager(models.Manager): use_for_related_fields = True - def get_query_set(self): + def get_queryset(self): q = NetQuery(self.model, NetWhere) return query.QuerySet(self.model, q) + + if VERSION[:2] < (1, 6): + def get_query_set(self): + q = NetQuery(self.model, NetWhere) + return query.QuerySet(self.model, q) diff --git a/netfields/tests.py b/netfields/tests.py index 7a395fe..b6cba75 100644 --- a/netfields/tests.py +++ b/netfields/tests.py @@ -1,5 +1,5 @@ from django.core.exceptions import ValidationError -from netaddr import IPAddress, IPNetwork, EUI, AddrFormatError +from netaddr import IPAddress, IPNetwork, EUI from django import VERSION from django.db import IntegrityError @@ -37,36 +37,64 @@ def test_save(self): self.model(field=self.value1).save() def test_equals_lookup(self): - self.assertSqlEquals(self.qs.filter(field=self.value1), - self.select + 'WHERE "table"."field" = %s ') + if VERSION[:2] < (1, 7): + self.assertSqlEquals(self.qs.filter(field=self.value1), + self.select + 'WHERE "table"."field" = %s ') + else: + self.assertSqlEquals(self.qs.filter(field=self.value1), + self.select + 'WHERE "table"."field" = %s') def test_exact_lookup(self): - self.assertSqlEquals(self.qs.filter(field__exact=self.value1), - self.select + 'WHERE "table"."field" = %s ') + if VERSION[:2] < (1, 7): + self.assertSqlEquals(self.qs.filter(field__exact=self.value1), + self.select + 'WHERE "table"."field" = %s ') + else: + self.assertSqlEquals(self.qs.filter(field__exact=self.value1), + self.select + 'WHERE "table"."field" = %s') def test_in_lookup(self): self.assertSqlEquals(self.qs.filter(field__in=[self.value1, self.value2]), self.select + 'WHERE "table"."field" IN (%s, %s)') def test_gt_lookup(self): - self.assertSqlEquals(self.qs.filter(field__gt=self.value1), - self.select + 'WHERE "table"."field" > %s ') + if VERSION[:2] < (1, 7): + self.assertSqlEquals(self.qs.filter(field__gt=self.value1), + self.select + 'WHERE "table"."field" > %s ') + else: + self.assertSqlEquals(self.qs.filter(field__gt=self.value1), + self.select + 'WHERE "table"."field" > %s') def test_gte_lookup(self): - self.assertSqlEquals(self.qs.filter(field__gte=self.value1), - self.select + 'WHERE "table"."field" >= %s ') + if VERSION[:2] < (1, 7): + self.assertSqlEquals(self.qs.filter(field__gte=self.value1), + self.select + 'WHERE "table"."field" >= %s ') + else: + self.assertSqlEquals(self.qs.filter(field__gte=self.value1), + self.select + 'WHERE "table"."field" >= %s') def test_lt_lookup(self): - self.assertSqlEquals(self.qs.filter(field__lt=self.value1), - self.select + 'WHERE "table"."field" < %s ') + if VERSION[:2] < (1, 7): + self.assertSqlEquals(self.qs.filter(field__lt=self.value1), + self.select + 'WHERE "table"."field" < %s ') + else: + self.assertSqlEquals(self.qs.filter(field__lt=self.value1), + self.select + 'WHERE "table"."field" < %s') def test_lte_lookup(self): - self.assertSqlEquals(self.qs.filter(field__lte=self.value1), - self.select + 'WHERE "table"."field" <= %s ') + if VERSION[:2] < (1, 7): + self.assertSqlEquals(self.qs.filter(field__lte=self.value1), + self.select + 'WHERE "table"."field" <= %s ') + else: + self.assertSqlEquals(self.qs.filter(field__lte=self.value1), + self.select + 'WHERE "table"."field" <= %s') def test_range_lookup(self): - self.assertSqlEquals(self.qs.filter(field__range=(self.value1, self.value3)), - self.select + 'WHERE "table"."field" BETWEEN %s and %s') + if VERSION[:2] < (1, 7): + self.assertSqlEquals(self.qs.filter(field__range=(self.value1, self.value3)), + self.select + 'WHERE "table"."field" BETWEEN %s and %s') + else: + self.assertSqlEquals(self.qs.filter(field__range=(self.value1, self.value3)), + self.select + 'WHERE "table"."field" BETWEEN %s AND %s') @@ -78,8 +106,12 @@ def test_init_with_text_fails(self): self.assertRaises(ValidationError, self.model, field='abc') def test_iexact_lookup(self): - self.assertSqlEquals(self.qs.filter(field__iexact=self.value1), - self.select + 'WHERE "table"."field" = %s ') + if VERSION[:2] < (1, 7): + self.assertSqlEquals(self.qs.filter(field__iexact=self.value1), + self.select + 'WHERE "table"."field" = %s ') + else: + self.assertSqlEquals(self.qs.filter(field__iexact=self.value1), + self.select + 'WHERE UPPER("table"."field"::text) = UPPER(%s)') def test_search_lookup_fails(self): self.assertSqlRaises(self.qs.filter(field__search='10'), ValueError) @@ -94,12 +126,20 @@ def test_day_lookup_fails(self): self.assertSqlRaises(self.qs.filter(field__day=1), ValueError) def test_net_contained(self): - self.assertSqlEquals(self.qs.filter(field__net_contained='10.0.0.1/24'), - self.select + 'WHERE "table"."field" << %s ') + if VERSION[:2] < (1, 7): + self.assertSqlEquals(self.qs.filter(field__net_contained='10.0.0.1/24'), + self.select + 'WHERE "table"."field" << %s ') + else: + self.assertSqlEquals(self.qs.filter(field__net_contained='10.0.0.1/24'), + self.select + 'WHERE "table"."field" << %s') def test_net_contained_or_equals(self): - self.assertSqlEquals(self.qs.filter(field__net_contained_or_equal='10.0.0.1/24'), - self.select + 'WHERE "table"."field" <<= %s ') + if VERSION[:2] < (1, 7): + self.assertSqlEquals(self.qs.filter(field__net_contained_or_equal='10.0.0.1/24'), + self.select + 'WHERE "table"."field" <<= %s ') + else: + self.assertSqlEquals(self.qs.filter(field__net_contained_or_equal='10.0.0.1/24'), + self.select + 'WHERE "table"."field" <<= %s') class BaseInetFieldTestCase(BaseInetTestCase): @@ -108,28 +148,52 @@ class BaseInetFieldTestCase(BaseInetTestCase): value3 = '10.0.0.10' def test_startswith_lookup(self): - self.assertSqlEquals(self.qs.filter(field__startswith='10.'), - self.select + 'WHERE HOST("table"."field") ILIKE %s ') + if VERSION[:2] < (1, 7): + self.assertSqlEquals(self.qs.filter(field__startswith='10.'), + self.select + 'WHERE HOST("table"."field") ILIKE %s ') + else: + self.assertSqlEquals(self.qs.filter(field__startswith='10.'), + self.select + 'WHERE HOST("table"."field") LIKE %s') def test_istartswith_lookup(self): - self.assertSqlEquals(self.qs.filter(field__istartswith='10.'), - self.select + 'WHERE HOST("table"."field") ILIKE %s ') + if VERSION[:2] < (1, 7): + self.assertSqlEquals(self.qs.filter(field__istartswith='10.'), + self.select + 'WHERE HOST("table"."field") ILIKE %s ') + else: + self.assertSqlEquals(self.qs.filter(field__istartswith='10.'), + self.select + 'WHERE HOST("table"."field") LIKE UPPER(%s)') def test_endswith_lookup(self): - self.assertSqlEquals(self.qs.filter(field__endswith='.1'), - self.select + 'WHERE HOST("table"."field") ILIKE %s ') + if VERSION[:2] < (1, 7): + self.assertSqlEquals(self.qs.filter(field__endswith='.1'), + self.select + 'WHERE HOST("table"."field") ILIKE %s ') + else: + self.assertSqlEquals(self.qs.filter(field__endswith='.1'), + self.select + 'WHERE HOST("table"."field") LIKE %s') def test_iendswith_lookup(self): - self.assertSqlEquals(self.qs.filter(field__iendswith='.1'), - self.select + 'WHERE HOST("table"."field") ILIKE %s ') + if VERSION[:2] < (1, 7): + self.assertSqlEquals(self.qs.filter(field__iendswith='.1'), + self.select + 'WHERE HOST("table"."field") ILIKE %s ') + else: + self.assertSqlEquals(self.qs.filter(field__iendswith='.1'), + self.select + 'WHERE HOST("table"."field") LIKE UPPER(%s)') def test_regex_lookup(self): - self.assertSqlEquals(self.qs.filter(field__regex='10'), - self.select + 'WHERE HOST("table"."field") ~* %s ') + if VERSION[:2] < (1, 7): + self.assertSqlEquals(self.qs.filter(field__regex='10'), + self.select + 'WHERE HOST("table"."field") ~* %s ') + else: + self.assertSqlEquals(self.qs.filter(field__regex='10'), + self.select + 'WHERE HOST("table"."field") ~ %s') def test_iregex_lookup(self): - self.assertSqlEquals(self.qs.filter(field__iregex='10'), - self.select + 'WHERE HOST("table"."field") ~* %s ') + if VERSION[:2] < (1, 7): + self.assertSqlEquals(self.qs.filter(field__iregex='10'), + self.select + 'WHERE HOST("table"."field") ~* %s ') + else: + self.assertSqlEquals(self.qs.filter(field__iregex='10'), + self.select + 'WHERE HOST("table"."field") ~* %s') class BaseCidrFieldTestCase(BaseInetTestCase): @@ -138,36 +202,68 @@ class BaseCidrFieldTestCase(BaseInetTestCase): value3 = '10.0.0.10/16' def test_startswith_lookup(self): - self.assertSqlEquals(self.qs.filter(field__startswith='10.'), - self.select + 'WHERE TEXT("table"."field") ILIKE %s ') + if VERSION[:2] < (1, 7): + self.assertSqlEquals(self.qs.filter(field__startswith='10.'), + self.select + 'WHERE TEXT("table"."field") ILIKE %s ') + else: + self.assertSqlEquals(self.qs.filter(field__startswith='10.'), + self.select + 'WHERE TEXT("table"."field") LIKE %s') def test_istartswith_lookup(self): - self.assertSqlEquals(self.qs.filter(field__istartswith='10.'), - self.select + 'WHERE TEXT("table"."field") ILIKE %s ') + if VERSION[:2] < (1, 7): + self.assertSqlEquals(self.qs.filter(field__istartswith='10.'), + self.select + 'WHERE TEXT("table"."field") ILIKE %s ') + else: + self.assertSqlEquals(self.qs.filter(field__istartswith='10.'), + self.select + 'WHERE TEXT("table"."field") LIKE UPPER(%s)') def test_endswith_lookup(self): - self.assertSqlEquals(self.qs.filter(field__endswith='.1'), - self.select + 'WHERE TEXT("table"."field") ILIKE %s ') + if VERSION[:2] < (1, 7): + self.assertSqlEquals(self.qs.filter(field__endswith='.1'), + self.select + 'WHERE TEXT("table"."field") ILIKE %s ') + else: + self.assertSqlEquals(self.qs.filter(field__endswith='.1'), + self.select + 'WHERE TEXT("table"."field") LIKE %s') def test_iendswith_lookup(self): - self.assertSqlEquals(self.qs.filter(field__iendswith='.1'), - self.select + 'WHERE TEXT("table"."field") ILIKE %s ') + if VERSION[:2] < (1, 7): + self.assertSqlEquals(self.qs.filter(field__iendswith='.1'), + self.select + 'WHERE TEXT("table"."field") ILIKE %s ') + else: + self.assertSqlEquals(self.qs.filter(field__iendswith='.1'), + self.select + 'WHERE TEXT("table"."field") LIKE UPPER(%s)') def test_regex_lookup(self): - self.assertSqlEquals(self.qs.filter(field__regex='10'), - self.select + 'WHERE TEXT("table"."field") ~* %s ') + if VERSION[:2] < (1, 7): + self.assertSqlEquals(self.qs.filter(field__regex='10'), + self.select + 'WHERE TEXT("table"."field") ~* %s ') + else: + self.assertSqlEquals(self.qs.filter(field__regex='10'), + self.select + 'WHERE TEXT("table"."field") ~ %s') def test_iregex_lookup(self): - self.assertSqlEquals(self.qs.filter(field__iregex='10'), - self.select + 'WHERE TEXT("table"."field") ~* %s ') + if VERSION[:2] < (1, 7): + self.assertSqlEquals(self.qs.filter(field__iregex='10'), + self.select + 'WHERE TEXT("table"."field") ~* %s ') + else: + self.assertSqlEquals(self.qs.filter(field__iregex='10'), + self.select + 'WHERE TEXT("table"."field") ~* %s') def test_net_contains_lookup(self): - self.assertSqlEquals(self.qs.filter(field__net_contains='10.0.0.1'), - self.select + 'WHERE "table"."field" >> %s ') + if VERSION[:2] < (1, 7): + self.assertSqlEquals(self.qs.filter(field__net_contains='10.0.0.1'), + self.select + 'WHERE "table"."field" >> %s ') + else: + self.assertSqlEquals(self.qs.filter(field__net_contains='10.0.0.1'), + self.select + 'WHERE "table"."field" >> %s') def test_net_contains_or_equals(self): - self.assertSqlEquals(self.qs.filter(field__net_contains_or_equals='10.0.0.1'), - self.select + 'WHERE "table"."field" >>= %s ') + if VERSION[:2] < (1, 7): + self.assertSqlEquals(self.qs.filter(field__net_contains_or_equals='10.0.0.1'), + self.select + 'WHERE "table"."field" >>= %s ') + else: + self.assertSqlEquals(self.qs.filter(field__net_contains_or_equals='10.0.0.1'), + self.select + 'WHERE "table"."field" >>= %s') class TestInetField(BaseInetFieldTestCase, TestCase): @@ -259,6 +355,7 @@ def test_save_nonunique(self): class InetAddressTestModelForm(ModelForm): class Meta: model = InetTestModel + exclude = [] class TestInetAddressFormField(TestCase): @@ -286,6 +383,7 @@ def test_form_ipv6_invalid(self): class UniqueInetAddressTestModelForm(ModelForm): class Meta: model = UniqueInetTestModel + exclude = [] class TestUniqueInetAddressFormField(TestInetAddressFormField): @@ -295,6 +393,7 @@ class TestUniqueInetAddressFormField(TestInetAddressFormField): class CidrAddressTestModelForm(ModelForm): class Meta: model = CidrTestModel + exclude = [] class TestCidrAddressFormField(TestCase): @@ -322,6 +421,7 @@ def test_form_ipv6_invalid(self): class UniqueCidrAddressTestModelForm(ModelForm): class Meta: model = UniqueCidrTestModel + exclude = [] class TestUniqueCidrAddressFormField(TestCidrAddressFormField): @@ -337,40 +437,66 @@ def test_save_object(self): self.model(field=EUI(self.value1)).save() def test_iexact_lookup(self): - self.assertSqlEquals(self.qs.filter(field__iexact=self.value1), - self.select + 'WHERE UPPER("table"."field"::text) = UPPER(%s) ') + if VERSION[:2] < (1, 7): + self.assertSqlEquals(self.qs.filter(field__iexact=self.value1), + self.select + 'WHERE UPPER("table"."field"::text) = UPPER(%s) ') + else: + self.assertSqlEquals(self.qs.filter(field__iexact=self.value1), + self.select + 'WHERE UPPER("table"."field"::text) = UPPER(%s)') def test_startswith_lookup(self): - self.assertSqlEquals(self.qs.filter(field__startswith='00:'), - self.select + 'WHERE "table"."field"::text LIKE %s ') + if VERSION[:2] < (1, 7): + self.assertSqlEquals(self.qs.filter(field__startswith='00:'), + self.select + 'WHERE "table"."field"::text LIKE %s ') + else: + self.assertSqlEquals(self.qs.filter(field__startswith='00:'), + self.select + 'WHERE "table"."field"::text LIKE %s') def test_istartswith_lookup(self): - self.assertSqlEquals(self.qs.filter(field__istartswith='00:'), - self.select + 'WHERE UPPER("table"."field"::text) LIKE UPPER(%s) ') + if VERSION[:2] < (1, 7): + self.assertSqlEquals(self.qs.filter(field__istartswith='00:'), + self.select + 'WHERE UPPER("table"."field"::text) LIKE UPPER(%s) ') + else: + self.assertSqlEquals(self.qs.filter(field__istartswith='00:'), + self.select + 'WHERE UPPER("table"."field"::text) LIKE UPPER(%s)') def test_endswith_lookup(self): - self.assertSqlEquals(self.qs.filter(field__endswith=':ff'), - self.select + 'WHERE "table"."field"::text LIKE %s ') + if VERSION[:2] < (1, 7): + self.assertSqlEquals(self.qs.filter(field__endswith=':ff'), + self.select + 'WHERE "table"."field"::text LIKE %s ') + else: + self.assertSqlEquals(self.qs.filter(field__endswith=':ff'), + self.select + 'WHERE "table"."field"::text LIKE %s') def test_iendswith_lookup(self): - self.assertSqlEquals(self.qs.filter(field__iendswith=':ff'), - self.select + 'WHERE UPPER("table"."field"::text) LIKE UPPER(%s) ') + if VERSION[:2] < (1, 7): + self.assertSqlEquals(self.qs.filter(field__iendswith=':ff'), + self.select + 'WHERE UPPER("table"."field"::text) LIKE UPPER(%s) ') + else: + self.assertSqlEquals(self.qs.filter(field__iendswith=':ff'), + self.select + 'WHERE UPPER("table"."field"::text) LIKE UPPER(%s)') def test_regex_lookup(self): if VERSION[:2] < (1, 6): self.assertSqlEquals(self.qs.filter(field__regex='00'), self.select + 'WHERE "table"."field" ~ %s ') - else: + elif VERSION[:2] < (1, 7): self.assertSqlEquals(self.qs.filter(field__regex='00'), self.select + 'WHERE "table"."field"::text ~ %s ') + else: + self.assertSqlEquals(self.qs.filter(field__regex='00'), + self.select + 'WHERE "table"."field"::text ~ %s') def test_iregex_lookup(self): if VERSION[:2] < (1, 6): self.assertSqlEquals(self.qs.filter(field__iregex='00'), self.select + 'WHERE "table"."field" ~* %s ') - else: + elif VERSION[:2] < (1, 7): self.assertSqlEquals(self.qs.filter(field__iregex='00'), self.select + 'WHERE "table"."field"::text ~* %s ') + else: + self.assertSqlEquals(self.qs.filter(field__iregex='00'), + self.select + 'WHERE "table"."field"::text ~* %s') class TestMacAddressField(BaseMacTestCase, TestCase): @@ -395,6 +521,7 @@ def test_invalid_fails(self): class MacAddressTestModelForm(ModelForm): class Meta: model = MACTestModel + exclude = [] class TestMacAddressFormField(TestCase): diff --git a/testsettings.py b/testsettings.py index f481075..3d9afc5 100644 --- a/testsettings.py +++ b/testsettings.py @@ -9,4 +9,10 @@ 'netfields', ) +MIDDLEWARE_CLASSES = ( + 'django.contrib.sessions.middleware.SessionMiddleware', + 'django.contrib.auth.middleware.AuthenticationMiddleware', + 'django.contrib.messages.middleware.MessageMiddleware', +) + SECRET_KEY = "notimportant"