From 66c9a95b70342eeb40fdb26a33ec3aa2aa9ed490 Mon Sep 17 00:00:00 2001 From: Guilherme Martins Crocetti Date: Thu, 17 Jun 2021 18:13:49 -0300 Subject: [PATCH] Fixed #27147: Add 'default_bounds' argument for fields subclassing RangeField. --- django/contrib/postgres/fields/ranges.py | 42 ++++++++- django/contrib/postgres/forms/ranges.py | 5 +- docs/ref/contrib/postgres/fields.txt | 28 +++++- docs/releases/4.1.txt | 6 ++ tests/postgres_tests/fields.py | 13 ++- .../migrations/0002_create_test_models.py | 1 + tests/postgres_tests/models.py | 1 + tests/postgres_tests/test_apps.py | 39 ++++++++- tests/postgres_tests/test_ranges.py | 85 ++++++++++++++++++- 9 files changed, 205 insertions(+), 15 deletions(-) diff --git a/django/contrib/postgres/fields/ranges.py b/django/contrib/postgres/fields/ranges.py index 8eab2cd2d997a..fcd60d174aafa 100644 --- a/django/contrib/postgres/fields/ranges.py +++ b/django/contrib/postgres/fields/ranges.py @@ -12,10 +12,13 @@ __all__ = [ 'RangeField', 'IntegerRangeField', 'BigIntegerRangeField', 'DecimalRangeField', 'DateTimeRangeField', 'DateRangeField', - 'RangeBoundary', 'RangeOperators', + 'RangeBoundary', 'RangeOperators', 'ContinuousRangeField', ] +CANONICAL_RANGE_BOUNDS = '[)' + + class RangeBoundary(models.Expression): """A class that represents range boundaries.""" def __init__(self, inclusive_lower=True, inclusive_upper=False): @@ -45,6 +48,8 @@ class RangeField(models.Field): def __init__(self, *args, **kwargs): # Initializing base_field here ensures that its model matches the model for self. + if 'default_bounds' in kwargs: + raise TypeError(f"Cannot use 'default_bounds' with {self.__class__.__name__}.") if hasattr(self, 'base_field'): self.base_field = self.base_field() super().__init__(*args, **kwargs) @@ -112,6 +117,37 @@ def formfield(self, **kwargs): return super().formfield(**kwargs) +class ContinuousRangeField(RangeField): + """ + This class specializes RangeField allowing user to choose a default bounds value + for the field. This assumption is conservative in the sense that we enforce the + chosen 'default_bounds' exclusively for inputs of the types (list, tuple). + We're not replacing the bounds attribute by 'default_bounds' for objects of the type + `psycopg2.extras.Range`. + """ + + def __init__(self, *args, default_bounds=CANONICAL_RANGE_BOUNDS, **kwargs): + if default_bounds not in ('[)', '(]', '()', '[]'): + raise ValueError(f'Invalid bound flags: {default_bounds}.') + self.default_bounds = default_bounds + super().__init__(*args, **kwargs) + + def get_prep_value(self, value): + if isinstance(value, (list, tuple)): + return self.range_type(value[0], value[1], self.default_bounds) + return super().get_prep_value(value) + + def formfield(self, **kwargs): + kwargs.setdefault('default_bounds', self.default_bounds) + return super().formfield(**kwargs) + + def deconstruct(self): + name, path, args, kwargs = super().deconstruct() + if self.default_bounds and self.default_bounds != CANONICAL_RANGE_BOUNDS: + kwargs['default_bounds'] = self.default_bounds + return name, path, args, kwargs + + class IntegerRangeField(RangeField): base_field = models.IntegerField range_type = NumericRange @@ -130,7 +166,7 @@ def db_type(self, connection): return 'int8range' -class DecimalRangeField(RangeField): +class DecimalRangeField(ContinuousRangeField): base_field = models.DecimalField range_type = NumericRange form_field = forms.DecimalRangeField @@ -139,7 +175,7 @@ def db_type(self, connection): return 'numrange' -class DateTimeRangeField(RangeField): +class DateTimeRangeField(ContinuousRangeField): base_field = models.DateTimeField range_type = DateTimeTZRange form_field = forms.DateTimeRangeField diff --git a/django/contrib/postgres/forms/ranges.py b/django/contrib/postgres/forms/ranges.py index 5a20975eb4e16..cba39259e0424 100644 --- a/django/contrib/postgres/forms/ranges.py +++ b/django/contrib/postgres/forms/ranges.py @@ -42,6 +42,9 @@ def __init__(self, **kwargs): kwargs['fields'] = [self.base_field(required=False), self.base_field(required=False)] kwargs.setdefault('required', False) kwargs.setdefault('require_all_fields', False) + self.range_kwargs = {} + if default_bounds := kwargs.pop("default_bounds", None): + self.range_kwargs = {"bounds": default_bounds} super().__init__(**kwargs) def prepare_value(self, value): @@ -68,7 +71,7 @@ def compress(self, values): code='bound_ordering', ) try: - range_value = self.range_type(lower, upper) + range_value = self.range_type(lower, upper, **self.range_kwargs) except TypeError: raise exceptions.ValidationError( self.error_messages['invalid'], diff --git a/docs/ref/contrib/postgres/fields.txt b/docs/ref/contrib/postgres/fields.txt index c439a70b89617..32e754b774f22 100644 --- a/docs/ref/contrib/postgres/fields.txt +++ b/docs/ref/contrib/postgres/fields.txt @@ -503,7 +503,9 @@ All of the range fields translate to :ref:`psycopg2 Range objects ` in Python, but also accept tuples as input if no bounds information is necessary. The default is lower bound included, upper bound excluded, that is ``[)`` (see the PostgreSQL documentation for details about -`different bounds`_). +`different bounds`_). The default value for bounds can be changed by passing a +`default_bounds` parameter for the ``DecimalRangeField`` and ``DateTimeRangeField``. + .. _different bounds: https://www.postgresql.org/docs/current/rangetypes.html#RANGETYPES-IO @@ -538,23 +540,43 @@ excluded, that is ``[)`` (see the PostgreSQL documentation for details about ``DecimalRangeField`` --------------------- -.. class:: DecimalRangeField(**options) +.. class:: DecimalRangeField(default_bounds='[)', **options) Stores a range of floating point values. Based on a :class:`~django.db.models.DecimalField`. Represented by a ``numrange`` in the database and a :class:`~psycopg2:psycopg2.extras.NumericRange` in Python. + .. attribute:: DecimalRangeField.default_bounds + + Optional. The value for ``bounds`` at tuple inputs. `default_bounds` + is not used at :class:`~psycopg2:psycopg2.extras.NumericRange` input. + The default is lower bound included, upper bound excluded, that is + ``[)`` (see the PostgreSQL documentation for details about `different + bounds`_). + + .. versionadded:: 4.1 + ``DateTimeRangeField`` ---------------------- -.. class:: DateTimeRangeField(**options) +.. class:: DateTimeRangeField(default_bounds='[)', **options) Stores a range of timestamps. Based on a :class:`~django.db.models.DateTimeField`. Represented by a ``tstzrange`` in the database and a :class:`~psycopg2:psycopg2.extras.DateTimeTZRange` in Python. + .. attribute:: DateTimeRangeField.default_bounds + + Optional. The value for ``bounds`` at tuple inputs. `default_bounds` + is not used at :class:`~psycopg2:psycopg2.extras.DateTimeTZRange` input. + The default is lower bound included, upper bound excluded, that is + ``[)`` (see the PostgreSQL documentation for details about `different + bounds`_). + + .. versionadded:: 4.1 + ``DateRangeField`` ------------------ diff --git a/docs/releases/4.1.txt b/docs/releases/4.1.txt index a8223b7ef8cb8..4459c8238ae7b 100644 --- a/docs/releases/4.1.txt +++ b/docs/releases/4.1.txt @@ -76,6 +76,12 @@ Minor features supports covering exclusion constraints using SP-GiST indexes on PostgreSQL 14+. +* The new :class:`~django.contrib.postgres.fields.ranges.ContinuousRangeField` + specializes the existing + :class:`~django.contrib.postgres.fields.ranges.RangeField` adding a + `default_bounds` parameter, this parameter allows users to choose the default + value for `bounds` in case of list or tuple inputs. + :mod:`django.contrib.redirects` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/postgres_tests/fields.py b/tests/postgres_tests/fields.py index b1bb6668d66fe..1c0cb05d4664a 100644 --- a/tests/postgres_tests/fields.py +++ b/tests/postgres_tests/fields.py @@ -26,14 +26,23 @@ def deconstruct(self): }) return name, path, args, kwargs + class DummyContinuousRangeField(models.Field): + def __init__(self, *args, default_bounds='[)', **kwargs): + super().__init__(**kwargs) + + def deconstruct(self): + name, path, args, kwargs = super().deconstruct() + kwargs['default_bounds'] = '[)' + return name, path, args, kwargs + ArrayField = DummyArrayField BigIntegerRangeField = models.Field CICharField = models.Field CIEmailField = models.Field CITextField = models.Field DateRangeField = models.Field - DateTimeRangeField = models.Field - DecimalRangeField = models.Field + DateTimeRangeField = DummyContinuousRangeField + DecimalRangeField = DummyContinuousRangeField HStoreField = models.Field IntegerRangeField = models.Field SearchVector = models.Expression diff --git a/tests/postgres_tests/migrations/0002_create_test_models.py b/tests/postgres_tests/migrations/0002_create_test_models.py index 377e220db18ad..258a5580921e7 100644 --- a/tests/postgres_tests/migrations/0002_create_test_models.py +++ b/tests/postgres_tests/migrations/0002_create_test_models.py @@ -249,6 +249,7 @@ class Migration(migrations.Migration): ('decimals', DecimalRangeField(null=True, blank=True)), ('timestamps', DateTimeRangeField(null=True, blank=True)), ('timestamps_inner', DateTimeRangeField(null=True, blank=True)), + ('timestamps_with_closed_bounds', DateTimeRangeField(null=True, blank=True, default_bounds='[]')), ('dates', DateRangeField(null=True, blank=True)), ('dates_inner', DateRangeField(null=True, blank=True)), ], diff --git a/tests/postgres_tests/models.py b/tests/postgres_tests/models.py index adb2e89201644..21cf99c2d8c2b 100644 --- a/tests/postgres_tests/models.py +++ b/tests/postgres_tests/models.py @@ -135,6 +135,7 @@ class RangesModel(PostgreSQLModel): decimals = DecimalRangeField(blank=True, null=True) timestamps = DateTimeRangeField(blank=True, null=True) timestamps_inner = DateTimeRangeField(blank=True, null=True) + timestamps_with_closed_bounds = DateTimeRangeField(blank=True, null=True, default_bounds='[]') dates = DateRangeField(blank=True, null=True) dates_inner = DateRangeField(blank=True, null=True) diff --git a/tests/postgres_tests/test_apps.py b/tests/postgres_tests/test_apps.py index bfb7568d50bb9..6c6c06252b10d 100644 --- a/tests/postgres_tests/test_apps.py +++ b/tests/postgres_tests/test_apps.py @@ -1,3 +1,5 @@ +from decimal import Decimal + from django.db.backends.signals import connection_created from django.db.migrations.writer import MigrationWriter from django.test.utils import modify_settings @@ -10,7 +12,8 @@ ) from django.contrib.postgres.fields import ( - DateRangeField, DateTimeRangeField, IntegerRangeField, + DateRangeField, DateTimeRangeField, DecimalRangeField, + IntegerRangeField, ) except ImportError: pass @@ -24,11 +27,10 @@ def test_register_type_handlers_connection(self): self.assertIn(register_type_handlers, connection_created._live_receivers(None)) self.assertNotIn(register_type_handlers, connection_created._live_receivers(None)) - def test_register_serializer_for_migrations(self): + def test_register_serializer_for_migrations_discrete_ranges(self): tests = ( (DateRange(empty=True), DateRangeField), (DateTimeRange(empty=True), DateRangeField), - (DateTimeTZRange(None, None, '[]'), DateTimeRangeField), (NumericRange(1, 10), IntegerRangeField), ) @@ -58,3 +60,34 @@ def assertNotSerializable(): serialized_field ) assertNotSerializable() + + def test_register_serializer_for_migrations_continuous_ranges(self): + tests = ( + (DateTimeTZRange(None, None, '[]'), DateTimeRangeField), + (NumericRange(Decimal('1.0'), Decimal('5.0'), '(]'), DecimalRangeField), + ) + + def assertNotSerializable(): + for default, test_field in tests: + with self.subTest(default=default): + field = test_field(default=default) + with self.assertRaisesMessage(ValueError, 'Cannot serialize: %s' % default.__class__.__name__): + MigrationWriter.serialize(field) + + assertNotSerializable() + with self.modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'}): + for default, test_field in tests: + with self.subTest(default=default): + field = test_field(default=default) + field_string = ( + f'{field.__module__}.{field.__class__.__name__}' + f"(default=psycopg2.extras.{default!r}, default_bounds={field.default_bounds!r})" + ) + serialized_field, imports = MigrationWriter.serialize(field) + self.assertEqual(imports, { + 'import django.contrib.postgres.fields.ranges', + 'import psycopg2.extras', + }) + self.assertIn(field_string, serialized_field) + + assertNotSerializable() diff --git a/tests/postgres_tests/test_ranges.py b/tests/postgres_tests/test_ranges.py index 180678578ed8c..d282bb358d4df 100644 --- a/tests/postgres_tests/test_ranges.py +++ b/tests/postgres_tests/test_ranges.py @@ -50,6 +50,45 @@ class Model(PostgreSQLModel): instance = Model(field=value) self.assertEqual(instance.get_field_display(), display) + def test_discrete_range_fields_invalid_default_bounds_argument(self): + discrete_range_fields = [ + pg_fields.IntegerRangeField, + pg_fields.BigIntegerRangeField, + pg_fields.DateRangeField + ] + + for discrete_range_field in discrete_range_fields: + msg = f"Cannot use 'default_bounds' with {discrete_range_field.__class__.__name__}" + with self.assertRaises(TypeError, msg=msg): + discrete_range_field( + choices=[ + ['1-50', [((1, 25), '1-25'), ([26, 50], '26-50')]], + ((51, 100), '51-100'), + ], + default_bounds='[]' + ) + + def test_continuous_range_fields_with_default_bounds_argument(self): + continuous_range_fields = [ + pg_fields.DecimalRangeField, + pg_fields.DateTimeRangeField, + ] + + for continuous_range_field in continuous_range_fields: + assert continuous_range_field( + choices=[ + ['1-50', [((1, 25), '1-25'), ([26, 50], '26-50')]], + ((51, 100), '51-100'), + ], + default_bounds='[]' + ) + + def test_default_bounds_with_invalid_bound_flags(self): + for invalid_bounds in (')]', ')[', '](', '])', '([', '[('): + msg = f'Invalid bound flags: {invalid_bounds}.' + with self.assertRaises(ValueError, msg=msg): + pg_fields.ContinuousRangeField(default_bounds=invalid_bounds) + class TestSaveLoad(PostgreSQLTestCase): @@ -117,6 +156,22 @@ def test_model_set_on_base_field(self): self.assertEqual(field.model, RangesModel) self.assertEqual(field.base_field.model, RangesModel) + def test_range_object_with_default_bounds(self): + r = (datetime.datetime(year=2015, month=1, day=1), datetime.datetime(year=2015, month=2, day=1)) + instance = RangesModel(timestamps_with_closed_bounds=r, timestamps=r) + instance.save() + loaded = RangesModel.objects.get() + self.assertEqual(loaded.timestamps_with_closed_bounds, DateTimeTZRange(r[0], r[1], '[]')) + self.assertEqual(loaded.timestamps, DateTimeTZRange(r[0], r[1], '[)')) + + def test_range_object_with_default_bounds_not_being_replaced(self): + interval = (datetime.datetime(year=2015, month=1, day=1), datetime.datetime(year=2015, month=2, day=1)) + r = DateTimeTZRange(*interval, bounds='()') + instance = RangesModel(timestamps_with_closed_bounds=r) + instance.save() + loaded = RangesModel.objects.get() + self.assertEqual(loaded.timestamps_with_closed_bounds, DateTimeTZRange(interval[0], interval[1], '()')) + class TestRangeContainsLookup(PostgreSQLTestCase): @@ -478,6 +533,8 @@ class TestSerialization(PostgreSQLSimpleTestCase): '"bigints": null, "timestamps": "{\\"upper\\": \\"2014-02-02T12:12:12+00:00\\", ' '\\"lower\\": \\"2014-01-01T00:00:00+00:00\\", \\"bounds\\": \\"[)\\"}", ' '"timestamps_inner": null, ' + '"timestamps_with_closed_bounds": "{\\"upper\\": \\"2014-02-02T12:12:12+00:00\\", ' + '\\"lower\\": \\"2014-01-01T00:00:00+00:00\\", \\"bounds\\": \\"[]\\"}", ' '"dates": "{\\"upper\\": \\"2014-02-02\\", \\"lower\\": \\"2014-01-01\\", \\"bounds\\": \\"[)\\"}", ' '"dates_inner": null }, ' '"model": "postgres_tests.rangesmodel", "pk": null}]' @@ -492,15 +549,17 @@ def test_dumping(self): instance = RangesModel( ints=NumericRange(0, 10), decimals=NumericRange(empty=True), timestamps=DateTimeTZRange(self.lower_dt, self.upper_dt), + timestamps_with_closed_bounds=DateTimeTZRange(self.lower_dt, self.upper_dt, bounds='[]'), dates=DateRange(self.lower_date, self.upper_date), ) data = serializers.serialize('json', [instance]) dumped = json.loads(data) - for field in ('ints', 'dates', 'timestamps'): + for field in ('ints', 'dates', 'timestamps', 'timestamps_with_closed_bounds'): dumped[0]['fields'][field] = json.loads(dumped[0]['fields'][field]) check = json.loads(self.test_data) - for field in ('ints', 'dates', 'timestamps'): + for field in ('ints', 'dates', 'timestamps', 'timestamps_with_closed_bounds'): check[0]['fields'][field] = json.loads(check[0]['fields'][field]) + self.assertEqual(dumped, check) def test_loading(self): @@ -886,26 +945,46 @@ def test_model_field_formfield_integer(self): model_field = pg_fields.IntegerRangeField() form_field = model_field.formfield() self.assertIsInstance(form_field, pg_forms.IntegerRangeField) + self.assertEqual(form_field.range_kwargs, {}) def test_model_field_formfield_biginteger(self): model_field = pg_fields.BigIntegerRangeField() form_field = model_field.formfield() self.assertIsInstance(form_field, pg_forms.IntegerRangeField) + self.assertEqual(form_field.range_kwargs, {}) def test_model_field_formfield_float(self): - model_field = pg_fields.DecimalRangeField() + expected_bounds = '()' + model_field = pg_fields.DecimalRangeField(default_bounds=expected_bounds) form_field = model_field.formfield() self.assertIsInstance(form_field, pg_forms.DecimalRangeField) + self.assertEqual(form_field.range_kwargs, {'bounds': expected_bounds}) def test_model_field_formfield_date(self): model_field = pg_fields.DateRangeField() form_field = model_field.formfield() self.assertIsInstance(form_field, pg_forms.DateRangeField) + self.assertEqual(form_field.range_kwargs, {}) def test_model_field_formfield_datetime(self): model_field = pg_fields.DateTimeRangeField() form_field = model_field.formfield() self.assertIsInstance(form_field, pg_forms.DateTimeRangeField) + self.assertEqual(form_field.range_kwargs, {"bounds": pg_fields.ranges.CANONICAL_RANGE_BOUNDS}) + + def test_model_field_formfield_date_inclusive(self): + expected_bounds = '[]' + model_field = pg_fields.DateTimeRangeField(default_bounds=expected_bounds) + form_field = model_field.formfield() + self.assertIsInstance(form_field, pg_forms.DateTimeRangeField) + self.assertEqual(form_field.range_kwargs, {'bounds': expected_bounds}) + + def test_model_field_with_default_bounds(self): + field = pg_forms.DateTimeRangeField(default_bounds='[]') + value = field.clean(['01/01/2014 00:00:00', '02/02/2014 12:12:12']) + lower = datetime.datetime(2014, 1, 1, 0, 0, 0) + upper = datetime.datetime(2014, 2, 2, 12, 12, 12) + self.assertEqual(value, DateTimeTZRange(lower, upper, '[]')) def test_has_changed(self): for field, value in (