Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions django_mongodb_backend/fields/embedded_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def get_transform(self, name):
if transform:
return transform
field = self.embedded_model._meta.get_field(name)
return KeyTransformFactory(name, field)
return EmbeddedModelTransformFactory(field)

def validate(self, value, model_instance):
super().validate(value, model_instance)
Expand All @@ -156,39 +156,40 @@ def formfield(self, **kwargs):
)


class KeyTransform(Transform):
def __init__(self, key_name, ref_field, *args, **kwargs):
class EmbeddedModelTransform(Transform):
def __init__(self, field, *args, **kwargs):
super().__init__(*args, **kwargs)
self.key_name = str(key_name)
self.ref_field = ref_field
# self.field aliases self._field via BaseExpression.field returning
# self.output_field.
self._field = field

def get_lookup(self, name):
return self.ref_field.get_lookup(name)
return self.field.get_lookup(name)

def get_transform(self, name):
"""
Validate that `name` is either a field of an embedded model or a
lookup on an embedded model's field.
"""
if transform := self.ref_field.get_transform(name):
if transform := self.field.get_transform(name):
return transform
suggested_lookups = difflib.get_close_matches(name, self.ref_field.get_lookups())
suggested_lookups = difflib.get_close_matches(name, self.field.get_lookups())
if suggested_lookups:
suggested_lookups = " or ".join(suggested_lookups)
suggestion = f", perhaps you meant {suggested_lookups}?"
else:
suggestion = "."
raise FieldDoesNotExist(
f"Unsupported lookup '{name}' for "
f"{self.ref_field.__class__.__name__} '{self.ref_field.name}'"
f"{self.field.__class__.__name__} '{self.field.name}'"
f"{suggestion}"
)

def as_mql(self, compiler, connection, as_path=False):
previous = self
columns = []
while isinstance(previous, KeyTransform):
columns.insert(0, previous.ref_field.column)
while isinstance(previous, EmbeddedModelTransform):
columns.insert(0, previous.field.column)
previous = previous.lhs
if as_path:
mql = previous.as_mql(compiler, connection, as_path=True)
Expand All @@ -201,13 +202,12 @@ def as_mql(self, compiler, connection, as_path=False):

@property
def output_field(self):
return self.ref_field
return self._field


class KeyTransformFactory:
def __init__(self, key_name, ref_field):
self.key_name = key_name
self.ref_field = ref_field
class EmbeddedModelTransformFactory:
def __init__(self, field):
self.field = field

def __call__(self, *args, **kwargs):
return KeyTransform(self.key_name, self.ref_field, *args, **kwargs)
return EmbeddedModelTransform(self.field, *args, **kwargs)
24 changes: 11 additions & 13 deletions django_mongodb_backend/fields/embedded_model_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def get_transform(self, name):
transform = super().get_transform(name)
if transform:
return transform
return KeyTransformFactory(name, self)
field = self.base_field.embedded_model._meta.get_field(name)
return EmbeddedModelArrayFieldTransformFactory(field)

def _get_lookup(self, lookup_name):
lookup = super()._get_lookup(lookup_name)
Expand Down Expand Up @@ -223,17 +224,15 @@ class EmbeddedModelArrayFieldLessThanOrEqual(
pass


class KeyTransform(Transform):
class EmbeddedModelArrayFieldTransform(Transform):
field_class_name = "EmbeddedModelArrayField"

def __init__(self, key_name, array_field, *args, **kwargs):
def __init__(self, field, *args, **kwargs):
super().__init__(*args, **kwargs)
self.array_field = array_field
self.key_name = key_name
# Lookups iterate over the array of embedded models. A virtual column
# of the queried field's type represents each element.
column_target = array_field.base_field.embedded_model._meta.get_field(key_name).clone()
column_name = f"$item.{key_name}"
column_target = field.clone()
column_name = f"$item.{field.column}"
column_target.db_column = column_name
column_target.set_attributes_from_name(column_name)
self._lhs = Col(None, column_target)
Expand All @@ -254,7 +253,7 @@ def get_transform(self, name):
# Once the sub-lhs is a transform, all the filters are applied over it.
# Otherwise get the transform from the nested embedded model field.
if transform := self._lhs.get_transform(name):
if isinstance(transform, KeyTransformFactory):
if isinstance(transform, EmbeddedModelArrayFieldTransformFactory):
raise ValueError("Cannot perform multiple levels of array traversal in a query.")
self._sub_transform = transform
return self
Expand Down Expand Up @@ -296,10 +295,9 @@ def output_field(self):
return _EmbeddedModelArrayOutputField(self._lhs.output_field)


class KeyTransformFactory:
def __init__(self, key_name, base_field):
self.key_name = key_name
self.base_field = base_field
class EmbeddedModelArrayFieldTransformFactory:
def __init__(self, field):
self.field = field

def __call__(self, *args, **kwargs):
return KeyTransform(self.key_name, self.base_field, *args, **kwargs)
return EmbeddedModelArrayFieldTransform(self.field, *args, **kwargs)
4 changes: 2 additions & 2 deletions django_mongodb_backend/fields/polymorphic_embedded_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from django.db import models
from django.db.models.fields.related import lazy_related_operation

from .embedded_model import KeyTransformFactory
from .embedded_model import EmbeddedModelTransformFactory
from .utils import get_mongodb_connection


Expand Down Expand Up @@ -170,7 +170,7 @@ def get_transform(self, name):
raise FieldDoesNotExist(
f"The models of field '{self.name}' have no field named '{name}'."
)
return KeyTransformFactory(name, field)
return EmbeddedModelTransformFactory(field)

def validate(self, value, model_instance):
super().validate(value, model_instance)
Expand Down
39 changes: 20 additions & 19 deletions django_mongodb_backend/fields/polymorphic_embedded_model_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@

from . import PolymorphicEmbeddedModelField
from .array import ArrayField, ArrayLenTransform
from .embedded_model_array import KeyTransform as ArrayFieldKeyTransform
from .embedded_model_array import KeyTransformFactory as ArrayFieldKeyTransformFactory
from .embedded_model_array import (
EmbeddedModelArrayFieldTransform,
EmbeddedModelArrayFieldTransformFactory,
)


class PolymorphicEmbeddedModelArrayField(ArrayField):
Expand Down Expand Up @@ -62,7 +64,15 @@ def get_transform(self, name):
transform = super().get_transform(name)
if transform:
return transform
return KeyTransformFactory(name, self)
for model in self.base_field.embedded_models:
with contextlib.suppress(FieldDoesNotExist):
field = model._meta.get_field(name)
break
else:
raise FieldDoesNotExist(
f"The models of field '{self.name}' have no field named '{name}'."
)
Comment on lines +67 to +74
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a test failure caused by moving this logic here:

======================================================================
ERROR: test_nested_lookup (model_fields_.test_polymorphic_embedded_model_array.QueryingTests.test_nested_lookup)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/tim/code/django-mongodb/tests/model_fields_/test_polymorphic_embedded_model_array.py", line 190, in test_nested_lookup
    Owner.objects.filter(pets__toys__name="")
  File "/home/tim/code/django/django/db/models/manager.py", line 87, in manager_method
    return getattr(self.get_queryset(), name)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tim/code/django/django/db/models/query.py", line 1493, in filter
    return self._filter_or_exclude(False, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tim/code/django/django/db/models/query.py", line 1511, in _filter_or_exclude
    clone._filter_or_exclude_inplace(negate, args, kwargs)
  File "/home/tim/code/django/django/db/models/query.py", line 1518, in _filter_or_exclude_inplace
    self._query.add_q(Q(*args, **kwargs))
  File "/home/tim/code/django/django/db/models/sql/query.py", line 1646, in add_q
    clause, _ = self._add_q(q_object, can_reuse)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tim/code/django/django/db/models/sql/query.py", line 1678, in _add_q
    child_clause, needed_inner = self.build_filter(
                                 ^^^^^^^^^^^^^^^^^^
  File "/home/tim/code/django/django/db/models/sql/query.py", line 1588, in build_filter
    condition = self.build_lookup(lookups, col, value)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tim/code/django/django/db/models/sql/query.py", line 1409, in build_lookup
    lhs = self.try_transform(lhs, lookup_name)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tim/code/django/django/db/models/sql/query.py", line 1441, in try_transform
    transform_class = lhs.get_transform(name)
                      ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tim/code/django-mongodb/django_mongodb_backend/fields/embedded_model_array.py", line 255, in get_transform
    if transform := self._lhs.get_transform(name):
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tim/code/django/django/db/models/expressions.py", line 406, in get_transform
    return self.output_field.get_transform(name)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tim/code/django-mongodb/django_mongodb_backend/fields/polymorphic_embedded_model_array.py", line 70, in get_transform
    raise FieldDoesNotExist(
django.core.exceptions.FieldDoesNotExist: The models of field '$item.toys' have no field named 'name'.

It preempts the correct exception from being raised here:

if transform := self._lhs.get_transform(name):
if isinstance(transform, KeyTransformFactory):
raise ValueError("Cannot perform multiple levels of array traversal in a query.")

Copy link
Collaborator

@WaVEV WaVEV Sep 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I think it is a feature, haha 😄 . The field name does not exist in Dog's toy. So it is okey if the test raises.
if we define the Bone as follows

class Bone(EmbeddedModel):
    name = models.CharField(max_length=100, null=True)
    brand = models.CharField(max_length=100)

    def __str__(self):
        return self.brand

the test pass

Or the idea was: I will try to filter whatever it is. If the field isn't there it will match None?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. The behavior changed of those error messages changed slightly, but I think it's okay. I added additional tests to demonstrate.

return PolymorphicArrayFieldTransformFactory(field)

def _get_lookup(self, lookup_name):
lookup = super()._get_lookup(lookup_name)
Expand All @@ -79,32 +89,23 @@ def as_mql(self, compiler, connection):
return EmbeddedModelArrayFieldLookups


class KeyTransform(ArrayFieldKeyTransform):
class PolymorphicArrayFieldTransform(EmbeddedModelArrayFieldTransform):
field_class_name = "PolymorphicEmbeddedModelArrayField"

def __init__(self, key_name, array_field, *args, **kwargs):
# Skip ArrayFieldKeyTransform.__init__()
def __init__(self, field, *args, **kwargs):
# Skip EmbeddedModelArrayFieldTransform.__init__()
Transform.__init__(self, *args, **kwargs)
self.array_field = array_field
self.key_name = key_name
for model in array_field.base_field.embedded_models:
with contextlib.suppress(FieldDoesNotExist):
field = model._meta.get_field(key_name)
break
else:
raise FieldDoesNotExist(
f"The models of field '{array_field.name}' have no field named '{key_name}'."
)
# Lookups iterate over the array of embedded models. A virtual column
# of the queried field's type represents each element.
column_target = field.clone()
column_name = f"$item.{key_name}"
column_name = f"$item.{field.column}"
column_target.name = f"{field.name}"
column_target.db_column = column_name
column_target.set_attributes_from_name(column_name)
self._lhs = Col(None, column_target)
self._sub_transform = None


class KeyTransformFactory(ArrayFieldKeyTransformFactory):
class PolymorphicArrayFieldTransformFactory(EmbeddedModelArrayFieldTransformFactory):
def __call__(self, *args, **kwargs):
return KeyTransform(self.key_name, self.base_field, *args, **kwargs)
return PolymorphicArrayFieldTransform(self.field, *args, **kwargs)
5 changes: 5 additions & 0 deletions tests/model_fields_/test_embedded_model_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,11 @@ def test_invalid_field(self):
with self.assertRaisesMessage(FieldDoesNotExist, msg):
Exhibit.objects.filter(sections__section__in=[10]).first()

def test_invalid_nested_field(self):
msg = "Artifact has no field named 'xx'"
with self.assertRaisesMessage(FieldDoesNotExist, msg):
Exhibit.objects.filter(sections__artifacts__xx=10).first()

def test_invalid_lookup(self):
msg = "Unsupported lookup 'return' for EmbeddedModelArrayField of 'IntegerField'"
with self.assertRaisesMessage(FieldDoesNotExist, msg):
Expand Down
7 changes: 6 additions & 1 deletion tests/model_fields_/test_polymorphic_embedded_model_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,11 @@ def test_invalid_field(self):
with self.assertRaisesMessage(FieldDoesNotExist, msg):
Owner.objects.filter(pets__xxx=10).first()

def test_invalid_nested_field(self):
msg = "The models of field 'toys' have no field named 'xxx'."
with self.assertRaisesMessage(FieldDoesNotExist, msg):
Owner.objects.filter(pets__toys__xxx=10).first()

def test_invalid_lookup(self):
msg = "Unsupported lookup 'return' for PolymorphicEmbeddedModelArrayField of 'CharField'"
with self.assertRaisesMessage(FieldDoesNotExist, msg):
Expand All @@ -197,7 +202,7 @@ def test_missing_lookup_suggestions(self):
def test_nested_lookup(self):
msg = "Cannot perform multiple levels of array traversal in a query."
with self.assertRaisesMessage(ValueError, msg):
Owner.objects.filter(pets__toys__name="")
Owner.objects.filter(pets__toys__brand="")


@isolate_apps("model_fields_")
Expand Down
Loading