Skip to content

Commit

Permalink
Fixed #27147: Add 'default_bounds' argument for fields subclassing Ra…
Browse files Browse the repository at this point in the history
…ngeField.
  • Loading branch information
gmcrocetti committed Nov 3, 2021
1 parent 073b7b5 commit 66c9a95
Show file tree
Hide file tree
Showing 9 changed files with 205 additions and 15 deletions.
42 changes: 39 additions & 3 deletions django/contrib/postgres/fields/ranges.py
Expand Up @@ -12,10 +12,13 @@
__all__ = [
'RangeField', 'IntegerRangeField', 'BigIntegerRangeField',
'DecimalRangeField', 'DateTimeRangeField', 'DateRangeField',
'RangeBoundary', 'RangeOperators',
'RangeBoundary', 'RangeOperators', 'ContinuousRangeField',
]


CANONICAL_RANGE_BOUNDS = '[)'


class RangeBoundary(models.Expression):
"""A class that represents range boundaries."""
def __init__(self, inclusive_lower=True, inclusive_upper=False):
Expand Down Expand Up @@ -45,6 +48,8 @@ class RangeField(models.Field):

def __init__(self, *args, **kwargs):
# Initializing base_field here ensures that its model matches the model for self.
if 'default_bounds' in kwargs:
raise TypeError(f"Cannot use 'default_bounds' with {self.__class__.__name__}.")
if hasattr(self, 'base_field'):
self.base_field = self.base_field()
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -112,6 +117,37 @@ def formfield(self, **kwargs):
return super().formfield(**kwargs)


class ContinuousRangeField(RangeField):
"""
This class specializes RangeField allowing user to choose a default bounds value
for the field. This assumption is conservative in the sense that we enforce the
chosen 'default_bounds' exclusively for inputs of the types (list, tuple).
We're not replacing the bounds attribute by 'default_bounds' for objects of the type
`psycopg2.extras.Range`.
"""

def __init__(self, *args, default_bounds=CANONICAL_RANGE_BOUNDS, **kwargs):
if default_bounds not in ('[)', '(]', '()', '[]'):
raise ValueError(f'Invalid bound flags: {default_bounds}.')
self.default_bounds = default_bounds
super().__init__(*args, **kwargs)

def get_prep_value(self, value):
if isinstance(value, (list, tuple)):
return self.range_type(value[0], value[1], self.default_bounds)
return super().get_prep_value(value)

def formfield(self, **kwargs):
kwargs.setdefault('default_bounds', self.default_bounds)
return super().formfield(**kwargs)

def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
if self.default_bounds and self.default_bounds != CANONICAL_RANGE_BOUNDS:
kwargs['default_bounds'] = self.default_bounds
return name, path, args, kwargs


class IntegerRangeField(RangeField):
base_field = models.IntegerField
range_type = NumericRange
Expand All @@ -130,7 +166,7 @@ def db_type(self, connection):
return 'int8range'


class DecimalRangeField(RangeField):
class DecimalRangeField(ContinuousRangeField):
base_field = models.DecimalField
range_type = NumericRange
form_field = forms.DecimalRangeField
Expand All @@ -139,7 +175,7 @@ def db_type(self, connection):
return 'numrange'


class DateTimeRangeField(RangeField):
class DateTimeRangeField(ContinuousRangeField):
base_field = models.DateTimeField
range_type = DateTimeTZRange
form_field = forms.DateTimeRangeField
Expand Down
5 changes: 4 additions & 1 deletion django/contrib/postgres/forms/ranges.py
Expand Up @@ -42,6 +42,9 @@ def __init__(self, **kwargs):
kwargs['fields'] = [self.base_field(required=False), self.base_field(required=False)]
kwargs.setdefault('required', False)
kwargs.setdefault('require_all_fields', False)
self.range_kwargs = {}
if default_bounds := kwargs.pop("default_bounds", None):
self.range_kwargs = {"bounds": default_bounds}
super().__init__(**kwargs)

def prepare_value(self, value):
Expand All @@ -68,7 +71,7 @@ def compress(self, values):
code='bound_ordering',
)
try:
range_value = self.range_type(lower, upper)
range_value = self.range_type(lower, upper, **self.range_kwargs)
except TypeError:
raise exceptions.ValidationError(
self.error_messages['invalid'],
Expand Down
28 changes: 25 additions & 3 deletions docs/ref/contrib/postgres/fields.txt
Expand Up @@ -503,7 +503,9 @@ All of the range fields translate to :ref:`psycopg2 Range objects
<psycopg2:adapt-range>` in Python, but also accept tuples as input if no bounds
information is necessary. The default is lower bound included, upper bound
excluded, that is ``[)`` (see the PostgreSQL documentation for details about
`different bounds`_).
`different bounds`_). The default value for bounds can be changed by passing a
`default_bounds` parameter for the ``DecimalRangeField`` and ``DateTimeRangeField``.


.. _different bounds: https://www.postgresql.org/docs/current/rangetypes.html#RANGETYPES-IO

Expand Down Expand Up @@ -538,23 +540,43 @@ excluded, that is ``[)`` (see the PostgreSQL documentation for details about
``DecimalRangeField``
---------------------

.. class:: DecimalRangeField(**options)
.. class:: DecimalRangeField(default_bounds='[)', **options)

Stores a range of floating point values. Based on a
:class:`~django.db.models.DecimalField`. Represented by a ``numrange`` in
the database and a :class:`~psycopg2:psycopg2.extras.NumericRange` in
Python.

.. attribute:: DecimalRangeField.default_bounds

Optional. The value for ``bounds`` at tuple inputs. `default_bounds`
is not used at :class:`~psycopg2:psycopg2.extras.NumericRange` input.
The default is lower bound included, upper bound excluded, that is
``[)`` (see the PostgreSQL documentation for details about `different
bounds`_).

.. versionadded:: 4.1

``DateTimeRangeField``
----------------------

.. class:: DateTimeRangeField(**options)
.. class:: DateTimeRangeField(default_bounds='[)', **options)

Stores a range of timestamps. Based on a
:class:`~django.db.models.DateTimeField`. Represented by a ``tstzrange`` in
the database and a :class:`~psycopg2:psycopg2.extras.DateTimeTZRange` in
Python.

.. attribute:: DateTimeRangeField.default_bounds

Optional. The value for ``bounds`` at tuple inputs. `default_bounds`
is not used at :class:`~psycopg2:psycopg2.extras.DateTimeTZRange` input.
The default is lower bound included, upper bound excluded, that is
``[)`` (see the PostgreSQL documentation for details about `different
bounds`_).

.. versionadded:: 4.1

``DateRangeField``
------------------

Expand Down
6 changes: 6 additions & 0 deletions docs/releases/4.1.txt
Expand Up @@ -76,6 +76,12 @@ Minor features
supports covering exclusion constraints using SP-GiST indexes on PostgreSQL
14+.

* The new :class:`~django.contrib.postgres.fields.ranges.ContinuousRangeField`
specializes the existing
:class:`~django.contrib.postgres.fields.ranges.RangeField` adding a
`default_bounds` parameter, this parameter allows users to choose the default
value for `bounds` in case of list or tuple inputs.

:mod:`django.contrib.redirects`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
13 changes: 11 additions & 2 deletions tests/postgres_tests/fields.py
Expand Up @@ -26,14 +26,23 @@ def deconstruct(self):
})
return name, path, args, kwargs

class DummyContinuousRangeField(models.Field):
def __init__(self, *args, default_bounds='[)', **kwargs):
super().__init__(**kwargs)

def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
kwargs['default_bounds'] = '[)'
return name, path, args, kwargs

ArrayField = DummyArrayField
BigIntegerRangeField = models.Field
CICharField = models.Field
CIEmailField = models.Field
CITextField = models.Field
DateRangeField = models.Field
DateTimeRangeField = models.Field
DecimalRangeField = models.Field
DateTimeRangeField = DummyContinuousRangeField
DecimalRangeField = DummyContinuousRangeField
HStoreField = models.Field
IntegerRangeField = models.Field
SearchVector = models.Expression
Expand Down
1 change: 1 addition & 0 deletions tests/postgres_tests/migrations/0002_create_test_models.py
Expand Up @@ -249,6 +249,7 @@ class Migration(migrations.Migration):
('decimals', DecimalRangeField(null=True, blank=True)),
('timestamps', DateTimeRangeField(null=True, blank=True)),
('timestamps_inner', DateTimeRangeField(null=True, blank=True)),
('timestamps_with_closed_bounds', DateTimeRangeField(null=True, blank=True, default_bounds='[]')),
('dates', DateRangeField(null=True, blank=True)),
('dates_inner', DateRangeField(null=True, blank=True)),
],
Expand Down
1 change: 1 addition & 0 deletions tests/postgres_tests/models.py
Expand Up @@ -135,6 +135,7 @@ class RangesModel(PostgreSQLModel):
decimals = DecimalRangeField(blank=True, null=True)
timestamps = DateTimeRangeField(blank=True, null=True)
timestamps_inner = DateTimeRangeField(blank=True, null=True)
timestamps_with_closed_bounds = DateTimeRangeField(blank=True, null=True, default_bounds='[]')
dates = DateRangeField(blank=True, null=True)
dates_inner = DateRangeField(blank=True, null=True)

Expand Down
39 changes: 36 additions & 3 deletions tests/postgres_tests/test_apps.py
@@ -1,3 +1,5 @@
from decimal import Decimal

from django.db.backends.signals import connection_created
from django.db.migrations.writer import MigrationWriter
from django.test.utils import modify_settings
Expand All @@ -10,7 +12,8 @@
)

from django.contrib.postgres.fields import (
DateRangeField, DateTimeRangeField, IntegerRangeField,
DateRangeField, DateTimeRangeField, DecimalRangeField,
IntegerRangeField,
)
except ImportError:
pass
Expand All @@ -24,11 +27,10 @@ def test_register_type_handlers_connection(self):
self.assertIn(register_type_handlers, connection_created._live_receivers(None))
self.assertNotIn(register_type_handlers, connection_created._live_receivers(None))

def test_register_serializer_for_migrations(self):
def test_register_serializer_for_migrations_discrete_ranges(self):
tests = (
(DateRange(empty=True), DateRangeField),
(DateTimeRange(empty=True), DateRangeField),
(DateTimeTZRange(None, None, '[]'), DateTimeRangeField),
(NumericRange(1, 10), IntegerRangeField),
)

Expand Down Expand Up @@ -58,3 +60,34 @@ def assertNotSerializable():
serialized_field
)
assertNotSerializable()

def test_register_serializer_for_migrations_continuous_ranges(self):
tests = (
(DateTimeTZRange(None, None, '[]'), DateTimeRangeField),
(NumericRange(Decimal('1.0'), Decimal('5.0'), '(]'), DecimalRangeField),
)

def assertNotSerializable():
for default, test_field in tests:
with self.subTest(default=default):
field = test_field(default=default)
with self.assertRaisesMessage(ValueError, 'Cannot serialize: %s' % default.__class__.__name__):
MigrationWriter.serialize(field)

assertNotSerializable()
with self.modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'}):
for default, test_field in tests:
with self.subTest(default=default):
field = test_field(default=default)
field_string = (
f'{field.__module__}.{field.__class__.__name__}'
f"(default=psycopg2.extras.{default!r}, default_bounds={field.default_bounds!r})"
)
serialized_field, imports = MigrationWriter.serialize(field)
self.assertEqual(imports, {
'import django.contrib.postgres.fields.ranges',
'import psycopg2.extras',
})
self.assertIn(field_string, serialized_field)

assertNotSerializable()

0 comments on commit 66c9a95

Please sign in to comment.