From 7a72922f00e5ac90b2ec1aec65cacb4fac0de694 Mon Sep 17 00:00:00 2001 From: Tim Graham Date: Thu, 25 Sep 2025 16:00:25 -0400 Subject: [PATCH 1/3] Simplify and rename EmbeddedModelField's transform classes --- .../fields/embedded_model.py | 34 +++++++++---------- .../fields/polymorphic_embedded_model.py | 4 +-- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/django_mongodb_backend/fields/embedded_model.py b/django_mongodb_backend/fields/embedded_model.py index 582a0703f..fbc1d53a1 100644 --- a/django_mongodb_backend/fields/embedded_model.py +++ b/django_mongodb_backend/fields/embedded_model.py @@ -132,7 +132,7 @@ def get_transform(self, name): if transform: return transform field = self.embedded_model._meta.get_field(name) - return KeyTransformFactory(name, field) + return EmbeddedModelTransformFactory(field) def validate(self, value, model_instance): super().validate(value, model_instance) @@ -156,23 +156,24 @@ def formfield(self, **kwargs): ) -class KeyTransform(Transform): - def __init__(self, key_name, ref_field, *args, **kwargs): +class EmbeddedModelTransform(Transform): + def __init__(self, field, *args, **kwargs): super().__init__(*args, **kwargs) - self.key_name = str(key_name) - self.ref_field = ref_field + # self.field aliases self._field via BaseExpression.field returning + # self.output_field. + self._field = field def get_lookup(self, name): - return self.ref_field.get_lookup(name) + return self.field.get_lookup(name) def get_transform(self, name): """ Validate that `name` is either a field of an embedded model or a lookup on an embedded model's field. """ - if transform := self.ref_field.get_transform(name): + if transform := self.field.get_transform(name): return transform - suggested_lookups = difflib.get_close_matches(name, self.ref_field.get_lookups()) + suggested_lookups = difflib.get_close_matches(name, self.field.get_lookups()) if suggested_lookups: suggested_lookups = " or ".join(suggested_lookups) suggestion = f", perhaps you meant {suggested_lookups}?" @@ -180,15 +181,15 @@ def get_transform(self, name): suggestion = "." raise FieldDoesNotExist( f"Unsupported lookup '{name}' for " - f"{self.ref_field.__class__.__name__} '{self.ref_field.name}'" + f"{self.field.__class__.__name__} '{self.field.name}'" f"{suggestion}" ) def as_mql(self, compiler, connection, as_path=False): previous = self columns = [] - while isinstance(previous, KeyTransform): - columns.insert(0, previous.ref_field.column) + while isinstance(previous, EmbeddedModelTransform): + columns.insert(0, previous.field.column) previous = previous.lhs if as_path: mql = previous.as_mql(compiler, connection, as_path=True) @@ -201,13 +202,12 @@ def as_mql(self, compiler, connection, as_path=False): @property def output_field(self): - return self.ref_field + return self._field -class KeyTransformFactory: - def __init__(self, key_name, ref_field): - self.key_name = key_name - self.ref_field = ref_field +class EmbeddedModelTransformFactory: + def __init__(self, field): + self.field = field def __call__(self, *args, **kwargs): - return KeyTransform(self.key_name, self.ref_field, *args, **kwargs) + return EmbeddedModelTransform(self.field, *args, **kwargs) diff --git a/django_mongodb_backend/fields/polymorphic_embedded_model.py b/django_mongodb_backend/fields/polymorphic_embedded_model.py index 98a33368e..d584cd7c1 100644 --- a/django_mongodb_backend/fields/polymorphic_embedded_model.py +++ b/django_mongodb_backend/fields/polymorphic_embedded_model.py @@ -5,7 +5,7 @@ from django.db import models from django.db.models.fields.related import lazy_related_operation -from .embedded_model import KeyTransformFactory +from .embedded_model import EmbeddedModelTransformFactory from .utils import get_mongodb_connection @@ -170,7 +170,7 @@ def get_transform(self, name): raise FieldDoesNotExist( f"The models of field '{self.name}' have no field named '{name}'." ) - return KeyTransformFactory(name, field) + return EmbeddedModelTransformFactory(field) def validate(self, value, model_instance): super().validate(value, model_instance) From e3a0464644791d4f8235bdc701c808c632cd688e Mon Sep 17 00:00:00 2001 From: Tim Graham Date: Tue, 30 Sep 2025 11:54:14 -0400 Subject: [PATCH 2/3] Add invalid nested field tests for Polymorphic/EmbeddedModelArrayField --- tests/model_fields_/test_embedded_model_array.py | 5 +++++ .../model_fields_/test_polymorphic_embedded_model_array.py | 7 ++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/model_fields_/test_embedded_model_array.py b/tests/model_fields_/test_embedded_model_array.py index 291afdf7c..d8883af3d 100644 --- a/tests/model_fields_/test_embedded_model_array.py +++ b/tests/model_fields_/test_embedded_model_array.py @@ -282,6 +282,11 @@ def test_invalid_field(self): with self.assertRaisesMessage(FieldDoesNotExist, msg): Exhibit.objects.filter(sections__section__in=[10]).first() + def test_invalid_nested_field(self): + msg = "Cannot perform multiple levels of array traversal in a query." + with self.assertRaisesMessage(ValueError, msg): + Exhibit.objects.filter(sections__artifacts__xx=10).first() + def test_invalid_lookup(self): msg = "Unsupported lookup 'return' for EmbeddedModelArrayField of 'IntegerField'" with self.assertRaisesMessage(FieldDoesNotExist, msg): diff --git a/tests/model_fields_/test_polymorphic_embedded_model_array.py b/tests/model_fields_/test_polymorphic_embedded_model_array.py index 403decec1..bf1e96806 100644 --- a/tests/model_fields_/test_polymorphic_embedded_model_array.py +++ b/tests/model_fields_/test_polymorphic_embedded_model_array.py @@ -176,6 +176,11 @@ def test_invalid_field(self): with self.assertRaisesMessage(FieldDoesNotExist, msg): Owner.objects.filter(pets__xxx=10).first() + def test_invalid_nested_field(self): + msg = "Cannot perform multiple levels of array traversal in a query." + with self.assertRaisesMessage(ValueError, msg): + Owner.objects.filter(pets__toys__xxx=10).first() + def test_invalid_lookup(self): msg = "Unsupported lookup 'return' for PolymorphicEmbeddedModelArrayField of 'CharField'" with self.assertRaisesMessage(FieldDoesNotExist, msg): @@ -197,7 +202,7 @@ def test_missing_lookup_suggestions(self): def test_nested_lookup(self): msg = "Cannot perform multiple levels of array traversal in a query." with self.assertRaisesMessage(ValueError, msg): - Owner.objects.filter(pets__toys__name="") + Owner.objects.filter(pets__toys__brand="") @isolate_apps("model_fields_") From d9dabd69e71dca9f41d9dcc59ab2eeac46c6b016 Mon Sep 17 00:00:00 2001 From: Tim Graham Date: Thu, 25 Sep 2025 19:31:28 -0400 Subject: [PATCH 3/3] Simplify and rename EmbeddedModelArrayField's transform classes --- .../fields/embedded_model_array.py | 24 ++++++------ .../polymorphic_embedded_model_array.py | 39 ++++++++++--------- .../test_embedded_model_array.py | 4 +- .../test_polymorphic_embedded_model_array.py | 4 +- 4 files changed, 35 insertions(+), 36 deletions(-) diff --git a/django_mongodb_backend/fields/embedded_model_array.py b/django_mongodb_backend/fields/embedded_model_array.py index d04b99db1..e880931c9 100644 --- a/django_mongodb_backend/fields/embedded_model_array.py +++ b/django_mongodb_backend/fields/embedded_model_array.py @@ -67,7 +67,8 @@ def get_transform(self, name): transform = super().get_transform(name) if transform: return transform - return KeyTransformFactory(name, self) + field = self.base_field.embedded_model._meta.get_field(name) + return EmbeddedModelArrayFieldTransformFactory(field) def _get_lookup(self, lookup_name): lookup = super()._get_lookup(lookup_name) @@ -223,17 +224,15 @@ class EmbeddedModelArrayFieldLessThanOrEqual( pass -class KeyTransform(Transform): +class EmbeddedModelArrayFieldTransform(Transform): field_class_name = "EmbeddedModelArrayField" - def __init__(self, key_name, array_field, *args, **kwargs): + def __init__(self, field, *args, **kwargs): super().__init__(*args, **kwargs) - self.array_field = array_field - self.key_name = key_name # Lookups iterate over the array of embedded models. A virtual column # of the queried field's type represents each element. - column_target = array_field.base_field.embedded_model._meta.get_field(key_name).clone() - column_name = f"$item.{key_name}" + column_target = field.clone() + column_name = f"$item.{field.column}" column_target.db_column = column_name column_target.set_attributes_from_name(column_name) self._lhs = Col(None, column_target) @@ -254,7 +253,7 @@ def get_transform(self, name): # Once the sub-lhs is a transform, all the filters are applied over it. # Otherwise get the transform from the nested embedded model field. if transform := self._lhs.get_transform(name): - if isinstance(transform, KeyTransformFactory): + if isinstance(transform, EmbeddedModelArrayFieldTransformFactory): raise ValueError("Cannot perform multiple levels of array traversal in a query.") self._sub_transform = transform return self @@ -296,10 +295,9 @@ def output_field(self): return _EmbeddedModelArrayOutputField(self._lhs.output_field) -class KeyTransformFactory: - def __init__(self, key_name, base_field): - self.key_name = key_name - self.base_field = base_field +class EmbeddedModelArrayFieldTransformFactory: + def __init__(self, field): + self.field = field def __call__(self, *args, **kwargs): - return KeyTransform(self.key_name, self.base_field, *args, **kwargs) + return EmbeddedModelArrayFieldTransform(self.field, *args, **kwargs) diff --git a/django_mongodb_backend/fields/polymorphic_embedded_model_array.py b/django_mongodb_backend/fields/polymorphic_embedded_model_array.py index e625d2ea8..6325ca4fc 100644 --- a/django_mongodb_backend/fields/polymorphic_embedded_model_array.py +++ b/django_mongodb_backend/fields/polymorphic_embedded_model_array.py @@ -7,8 +7,10 @@ from . import PolymorphicEmbeddedModelField from .array import ArrayField, ArrayLenTransform -from .embedded_model_array import KeyTransform as ArrayFieldKeyTransform -from .embedded_model_array import KeyTransformFactory as ArrayFieldKeyTransformFactory +from .embedded_model_array import ( + EmbeddedModelArrayFieldTransform, + EmbeddedModelArrayFieldTransformFactory, +) class PolymorphicEmbeddedModelArrayField(ArrayField): @@ -62,7 +64,15 @@ def get_transform(self, name): transform = super().get_transform(name) if transform: return transform - return KeyTransformFactory(name, self) + for model in self.base_field.embedded_models: + with contextlib.suppress(FieldDoesNotExist): + field = model._meta.get_field(name) + break + else: + raise FieldDoesNotExist( + f"The models of field '{self.name}' have no field named '{name}'." + ) + return PolymorphicArrayFieldTransformFactory(field) def _get_lookup(self, lookup_name): lookup = super()._get_lookup(lookup_name) @@ -79,32 +89,23 @@ def as_mql(self, compiler, connection): return EmbeddedModelArrayFieldLookups -class KeyTransform(ArrayFieldKeyTransform): +class PolymorphicArrayFieldTransform(EmbeddedModelArrayFieldTransform): field_class_name = "PolymorphicEmbeddedModelArrayField" - def __init__(self, key_name, array_field, *args, **kwargs): - # Skip ArrayFieldKeyTransform.__init__() + def __init__(self, field, *args, **kwargs): + # Skip EmbeddedModelArrayFieldTransform.__init__() Transform.__init__(self, *args, **kwargs) - self.array_field = array_field - self.key_name = key_name - for model in array_field.base_field.embedded_models: - with contextlib.suppress(FieldDoesNotExist): - field = model._meta.get_field(key_name) - break - else: - raise FieldDoesNotExist( - f"The models of field '{array_field.name}' have no field named '{key_name}'." - ) # Lookups iterate over the array of embedded models. A virtual column # of the queried field's type represents each element. column_target = field.clone() - column_name = f"$item.{key_name}" + column_name = f"$item.{field.column}" + column_target.name = f"{field.name}" column_target.db_column = column_name column_target.set_attributes_from_name(column_name) self._lhs = Col(None, column_target) self._sub_transform = None -class KeyTransformFactory(ArrayFieldKeyTransformFactory): +class PolymorphicArrayFieldTransformFactory(EmbeddedModelArrayFieldTransformFactory): def __call__(self, *args, **kwargs): - return KeyTransform(self.key_name, self.base_field, *args, **kwargs) + return PolymorphicArrayFieldTransform(self.field, *args, **kwargs) diff --git a/tests/model_fields_/test_embedded_model_array.py b/tests/model_fields_/test_embedded_model_array.py index d8883af3d..5ae396e2a 100644 --- a/tests/model_fields_/test_embedded_model_array.py +++ b/tests/model_fields_/test_embedded_model_array.py @@ -283,8 +283,8 @@ def test_invalid_field(self): Exhibit.objects.filter(sections__section__in=[10]).first() def test_invalid_nested_field(self): - msg = "Cannot perform multiple levels of array traversal in a query." - with self.assertRaisesMessage(ValueError, msg): + msg = "Artifact has no field named 'xx'" + with self.assertRaisesMessage(FieldDoesNotExist, msg): Exhibit.objects.filter(sections__artifacts__xx=10).first() def test_invalid_lookup(self): diff --git a/tests/model_fields_/test_polymorphic_embedded_model_array.py b/tests/model_fields_/test_polymorphic_embedded_model_array.py index bf1e96806..453f4ad50 100644 --- a/tests/model_fields_/test_polymorphic_embedded_model_array.py +++ b/tests/model_fields_/test_polymorphic_embedded_model_array.py @@ -177,8 +177,8 @@ def test_invalid_field(self): Owner.objects.filter(pets__xxx=10).first() def test_invalid_nested_field(self): - msg = "Cannot perform multiple levels of array traversal in a query." - with self.assertRaisesMessage(ValueError, msg): + msg = "The models of field 'toys' have no field named 'xxx'." + with self.assertRaisesMessage(FieldDoesNotExist, msg): Owner.objects.filter(pets__toys__xxx=10).first() def test_invalid_lookup(self):