diff --git a/django_mongodb_backend/fields/embedded_model.py b/django_mongodb_backend/fields/embedded_model.py index 951632363..582a0703f 100644 --- a/django_mongodb_backend/fields/embedded_model.py +++ b/django_mongodb_backend/fields/embedded_model.py @@ -93,9 +93,9 @@ def to_python(self, value): return value instance = self.embedded_model( **{ - field.attname: field.to_python(value[field.attname]) + field.attname: field.to_python(value[field.column]) for field in self.embedded_model._meta.fields - if field.attname in value + if field.column in value } ) instance._state.adding = False @@ -122,7 +122,7 @@ def get_db_prep_save(self, embedded_instance, connection): # Exclude unset primary keys (e.g. {'id': None}). if field.primary_key and value is None: continue - field_values[field.attname] = value + field_values[field.column] = value # This instance will exist in the database soon. embedded_instance._state.adding = False return field_values @@ -186,17 +186,17 @@ def get_transform(self, name): def as_mql(self, compiler, connection, as_path=False): previous = self - key_transforms = [] + columns = [] while isinstance(previous, KeyTransform): - key_transforms.insert(0, previous.key_name) + columns.insert(0, previous.ref_field.column) previous = previous.lhs if as_path: mql = previous.as_mql(compiler, connection, as_path=True) - mql_path = ".".join(key_transforms) + mql_path = ".".join(columns) return f"{mql}.{mql_path}" mql = previous.as_mql(compiler, connection) - for key in key_transforms: - mql = {"$getField": {"input": mql, "field": key}} + for column in columns: + mql = {"$getField": {"input": mql, "field": column}} return mql @property diff --git a/django_mongodb_backend/fields/polymorphic_embedded_model.py b/django_mongodb_backend/fields/polymorphic_embedded_model.py index e41038443..98a33368e 100644 --- a/django_mongodb_backend/fields/polymorphic_embedded_model.py +++ b/django_mongodb_backend/fields/polymorphic_embedded_model.py @@ -121,9 +121,9 @@ def to_python(self, value): model_class = self._get_model_from_label(value.pop("_label")) instance = model_class( **{ - field.attname: field.to_python(value[field.attname]) + field.attname: field.to_python(value[field.column]) for field in model_class._meta.fields - if field.attname in value + if field.column in value } ) instance._state.adding = False @@ -150,7 +150,7 @@ def get_db_prep_save(self, embedded_instance, connection): # Exclude unset primary keys (e.g. {'id': None}). if field.primary_key and value is None: continue - field_values[field.attname] = value + field_values[field.column] = value # Store the model's label to know the class to use for initializing # upon retrieval. field_values["_label"] = embedded_instance._meta.label diff --git a/django_mongodb_backend/operations.py b/django_mongodb_backend/operations.py index 79ac030da..5106678cc 100644 --- a/django_mongodb_backend/operations.py +++ b/django_mongodb_backend/operations.py @@ -184,14 +184,14 @@ def convert_embeddedmodelfield_value(self, value, expression, connection): if value is not None: # Apply database converters to each field of the embedded model. for field in expression.output_field.embedded_model._meta.fields: - if field.attname not in value: + if field.column not in value: continue field_expr = Expression(output_field=field) converters = connection.ops.get_db_converters( field_expr ) + field_expr.get_db_converters(connection) for converter in converters: - value[field.attname] = converter(value[field.attname], field_expr, connection) + value[field.column] = converter(value[field.column], field_expr, connection) return value def convert_jsonfield_value(self, value, expression, connection): @@ -206,14 +206,14 @@ def convert_polymorphicembeddedmodelfield_value(self, value, expression, connect model_class = expression.output_field._get_model_from_label(value["_label"]) # Apply database converters to each field of the embedded model. for field in model_class._meta.fields: - if field.attname not in value: + if field.column not in value: continue field_expr = Expression(output_field=field) converters = connection.ops.get_db_converters( field_expr ) + field_expr.get_db_converters(connection) for converter in converters: - value[field.attname] = converter(value[field.attname], field_expr, connection) + value[field.column] = converter(value[field.column], field_expr, connection) return value def convert_timefield_value(self, value, expression, connection): diff --git a/docs/releases/5.2.x.rst b/docs/releases/5.2.x.rst index c49d87b7f..47fb16704 100644 --- a/docs/releases/5.2.x.rst +++ b/docs/releases/5.2.x.rst @@ -19,6 +19,8 @@ Bug fixes that use a database converter, if the field isn't present in the data (e.g. data not written by Django, or after a field was added to an existing ``EmbeddedModel``). +- Made ``EmbeddedModel`` fields respect + :attr:`~django.db.models.Field.db_column`. Deprecated features ------------------- diff --git a/tests/model_fields_/models.py b/tests/model_fields_/models.py index 21e0af249..ab8731e02 100644 --- a/tests/model_fields_/models.py +++ b/tests/model_fields_/models.py @@ -125,7 +125,7 @@ class Holder(models.Model): class Data(EmbeddedModel): - integer = models.IntegerField(db_column="custom_column") + integer = models.IntegerField(db_column="integer_") auto_now = models.DateTimeField(auto_now=True) auto_now_add = models.DateTimeField(auto_now_add=True) json_value = models.JSONField() @@ -175,7 +175,7 @@ def __str__(self): class Review(EmbeddedModel): - title = models.CharField(max_length=255) + title = models.CharField(max_length=255, db_column="title_") rating = models.DecimalField(max_digits=6, decimal_places=1) def __str__(self): @@ -261,7 +261,7 @@ def __str__(self): class Cat(EmbeddedModel): - name = models.CharField(max_length=100) + name = models.CharField(max_length=100, db_column="name_") purs = models.BooleanField(default=True) weight = models.DecimalField(max_digits=4, decimal_places=2, blank=True, null=True) favorite_toy = PolymorphicEmbeddedModelField(["Mouse"], blank=True, null=True) diff --git a/tests/model_fields_/test_embedded_model.py b/tests/model_fields_/test_embedded_model.py index 199fde1b2..1a219613f 100644 --- a/tests/model_fields_/test_embedded_model.py +++ b/tests/model_fields_/test_embedded_model.py @@ -114,9 +114,18 @@ def test_missing_field_in_data(self): this case, integer is an IntegerField) doesn't crash. """ Holder.objects.create(data=Data(integer=5)) - connection.database.model_fields__holder.update_many({}, {"$unset": {"data.integer": ""}}) + connection.database.model_fields__holder.update_many({}, {"$unset": {"data.integer_": ""}}) self.assertIsNone(Holder.objects.first().data.integer) + def test_embedded_model_field_respects_db_column(self): + """ + EmbeddedModel data respects Field.db_column. In this case, Data.integer + has db_column="integer_". + """ + obj = Holder.objects.create(data=Data(integer=5)) + query = connection.database.model_fields__holder.find({"_id": obj.pk}) + self.assertEqual(query[0]["data"]["integer_"], 5) + class QueryingTests(TestCase): @classmethod diff --git a/tests/model_fields_/test_embedded_model_array.py b/tests/model_fields_/test_embedded_model_array.py index 381ffd5e9..291afdf7c 100644 --- a/tests/model_fields_/test_embedded_model_array.py +++ b/tests/model_fields_/test_embedded_model_array.py @@ -74,6 +74,15 @@ def test_missing_field_in_data(self): ) self.assertIsNone(Movie.objects.first().reviews[0].rating) + def test_embedded_model_field_respects_db_column(self): + """ + EmbeddedModel data respects Field.db_column. In this case, + Review.title has db_column="title_". + """ + obj = Movie.objects.create(title="Lion King", reviews=[Review(title="Awesome", rating=10)]) + query = connection.database.model_fields__movie.find({"_id": obj.pk}) + self.assertEqual(query[0]["reviews"][0]["title_"], "Awesome") + class QueryingTests(TestCase): @classmethod diff --git a/tests/model_fields_/test_polymorphic_embedded_model.py b/tests/model_fields_/test_polymorphic_embedded_model.py index 5a34d25f3..bf743b22c 100644 --- a/tests/model_fields_/test_polymorphic_embedded_model.py +++ b/tests/model_fields_/test_polymorphic_embedded_model.py @@ -101,6 +101,15 @@ def test_missing_field_in_data(self): connection.database.model_fields__person.update_many({}, {"$unset": {"pet.weight": ""}}) self.assertIsNone(Person.objects.first().pet.weight) + def test_embedded_model_field_respects_db_column(self): + """ + EmbeddedModel data respects Field.db_column. In this case, Cat.name + has db_column="name_". + """ + obj = Person.objects.create(pet=Cat(name="Phoebe")) + query = connection.database.model_fields__person.find({"_id": obj.pk}) + self.assertEqual(query[0]["pet"]["name_"], "Phoebe") + class QueryingTests(TestCase): @classmethod diff --git a/tests/model_fields_/test_polymorphic_embedded_model_array.py b/tests/model_fields_/test_polymorphic_embedded_model_array.py index ca357d9fc..403decec1 100644 --- a/tests/model_fields_/test_polymorphic_embedded_model_array.py +++ b/tests/model_fields_/test_polymorphic_embedded_model_array.py @@ -72,6 +72,15 @@ def test_missing_field_in_data(self): connection.database.model_fields__owner.update_many({}, {"$unset": {"pets.$[].weight": ""}}) self.assertIsNone(Owner.objects.first().pets[0].weight) + def test_embedded_model_field_respects_db_column(self): + """ + EmbeddedModel data respects Field.db_column. In this case, Cat.name + has db_column="name_". + """ + obj = Owner.objects.create(name="Bob", pets=[Cat(name="Phoebe", weight="3.5")]) + query = connection.database.model_fields__owner.find({"_id": obj.pk}) + self.assertEqual(query[0]["pets"][0]["name_"], "Phoebe") + class QueryingTests(TestCase): @classmethod