Skip to content

Commit

Permalink
feat: extend update_fields with translation fields in Model.save() (#687
Browse files Browse the repository at this point in the history
)
  • Loading branch information
andrei-shabanski committed Jul 16, 2023
1 parent c68104c commit d86c6de
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 54 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ jobs:
if [[ $DB == postgres ]]; then
pip install -q psycopg2-binary
fi
pip install typing-extensions coverage pytest pytest-django pytest-cov $(./get-django-version.py ${{ matrix.django }})
pip install typing-extensions coverage pytest pytest-django pytest-cov parameterized $(./get-django-version.py ${{ matrix.django }})
- name: Run tests
run: |
pytest --cov-report term
38 changes: 23 additions & 15 deletions modeltranslation/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
"""
import itertools
from functools import reduce
from typing import List, Tuple, Type, Any, Optional

from django import VERSION
from django.contrib.admin.utils import get_model_from_relation
from django.core.exceptions import FieldDoesNotExist
from django.db import models
from django.db.models import Field, Model
from django.db.models.expressions import Col
from django.db.models.lookups import Lookup
from django.db.models.query import QuerySet, ValuesIterable
Expand Down Expand Up @@ -243,21 +245,6 @@ def select_related(self, *fields, **kwargs):
new_args.append(rewrite_lookup_key(self.model, key))
return super().select_related(*new_args, **kwargs)

def update_or_create(self, defaults=None, **kwargs):
"""
Updates or creates a database record with the specified kwargs. The method first
rewrites the keys in the defaults dictionary using a custom function named
`rewrite_lookup_key`. This ensures that the keys are valid for the current model
before calling the inherited update_or_create() method from the super class.
Returns the updated or created model instance.
"""
if defaults is not None:
rewritten_defaults = {}
for key, value in defaults.items():
rewritten_defaults[rewrite_lookup_key(self.model, key)] = value
defaults = rewritten_defaults
return super().update_or_create(defaults=defaults, **kwargs)

# This method was not present in django-linguo
def _rewrite_col(self, col):
"""Django >= 1.7 column name rewriting"""
Expand Down Expand Up @@ -386,6 +373,27 @@ def update(self, **kwargs):

update.alters_data = True

def _update(self, values: List[Tuple[Field, Optional[Type[Model]], Any]]):
"""
This method is called in .save() method to update an existing record.
Here we force to update translation fields as well if the original
field only is passed in `save()` in argument `update_fields`.
"""
# TODO: Should the original field (field without lang code suffix) be updated
# when only the default translation field (`field_<DEFAULT_LANG_CODE>`) is passed in `update_fields`?
# Currently, we don't synchronize values of the original and default translation fields in that case.
field_names_to_update = {field.name for field, *_ in values}

translation_values = []
for field, model, value in values:
translation_field_name = rewrite_lookup_key(self.model, field.name)
if translation_field_name not in field_names_to_update:
translatable_field = self.model._meta.get_field(translation_field_name)
translation_values.append((translatable_field, model, value))

values += translation_values
return super()._update(values)

# This method was not present in django-linguo
@property
def _populate_mode(self):
Expand Down
2 changes: 1 addition & 1 deletion modeltranslation/tests/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _get_database_config():
{
'ENGINE': 'django.db.backends.postgresql',
'USER': os.getenv('POSTGRES_USER', 'postgres'),
'PASSWORD': os.getenv('POSTGRES_DB', 'postgres'),
'PASSWORD': os.getenv('POSTGRES_PASSWORD', 'postgres'),
'NAME': os.getenv('POSTGRES_DB', 'modeltranslation'),
'HOST': host,
}
Expand Down
142 changes: 142 additions & 0 deletions modeltranslation/tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from django.test import TestCase, TransactionTestCase
from django.test.utils import override_settings
from django.utils.translation import get_language, override, trans_real
from parameterized import parameterized

from modeltranslation import admin
from modeltranslation import settings as mt_settings
Expand Down Expand Up @@ -79,6 +80,20 @@ def get_field_names(model):
return names


def assert_db_record(instance, **expected_fields):
"""
Compares field values stored in the db.
"""
actual = (
type(instance)
.objects.rewrite(False)
.filter(pk=instance.pk)
.values(*expected_fields.keys())
.first()
)
assert actual == expected_fields


class ModeltranslationTransactionTestBase(TransactionTestCase):
cache = django_apps

Expand Down Expand Up @@ -358,6 +373,7 @@ def test_set_translation(self):
assert n.title == title_de
assert n.title_en == title_en
assert n.title_de == title_de
assert_db_record(n, title=title_de, title_de=title_de, title_en=title_en)

# Queries are also language-aware:
assert 1 == models.TestModel.objects.filter(title=title_de).count()
Expand Down Expand Up @@ -463,6 +479,89 @@ def test_constructor(self):
)
self._test_constructor(keywords)

@parameterized.expand(
[
({'title': 'DE'}, ['title'], {'title': 'DE', 'title_de': 'DE', 'title_en': None}),
({'title_de': 'DE'}, ['title'], {'title': 'DE', 'title_de': 'DE', 'title_en': None}),
({'title': 'DE'}, ['title_de'], {'title': 'old', 'title_de': 'DE', 'title_en': None}),
(
{'title_de': 'DE'},
['title_de'],
{'title': 'old', 'title_de': 'DE', 'title_en': None},
),
(
{'title': 'DE', 'title_en': 'EN'},
['title', 'title_en'],
{'title': 'DE', 'title_de': 'DE', 'title_en': 'EN'},
),
(
{'title_de': 'DE', 'title_en': 'EN'},
['title_de', 'title_en'],
{'title': 'old', 'title_de': 'DE', 'title_en': 'EN'},
),
(
{'title_de': 'DE', 'title_en': 'EN'},
['title', 'title_de', 'title_en'],
{'title': 'DE', 'title_de': 'DE', 'title_en': 'EN'},
),
]
)
def test_save_original_translation_field(self, field_values, update_fields, expected_db_values):
obj = models.TestModel.objects.create(title='old')

for field, value in field_values.items():
setattr(obj, field, value)

obj.save(update_fields=update_fields)
assert_db_record(obj, **expected_db_values)

@parameterized.expand(
[
({'title': 'EN'}, ['title'], {'title': 'EN', 'title_de': None, 'title_en': 'EN'}),
({'title_en': 'EN'}, ['title'], {'title': 'EN', 'title_de': None, 'title_en': 'EN'}),
({'title': 'EN'}, ['title_en'], {'title': 'old', 'title_de': None, 'title_en': 'EN'}),
(
{'title_en': 'EN'},
['title_en'],
{'title': 'old', 'title_de': None, 'title_en': 'EN'},
),
(
{'title': 'EN', 'title_de': 'DE'},
['title', 'title_de'],
{'title': 'EN', 'title_de': 'DE', 'title_en': 'EN'},
),
(
{'title_de': 'DE', 'title_en': 'EN'},
['title_de', 'title_en'],
{'title': 'old', 'title_de': 'DE', 'title_en': 'EN'},
),
(
{'title_de': 'DE', 'title_en': 'EN'},
['title', 'title_de', 'title_en'],
{'title': 'EN', 'title_de': 'DE', 'title_en': 'EN'},
),
]
)
def test_save_active_translation_field(self, field_values, update_fields, expected_db_values):
with override('en'):
obj = models.TestModel.objects.create(title='old')

for field, value in field_values.items():
setattr(obj, field, value)

obj.save(update_fields=update_fields)
assert_db_record(obj, **expected_db_values)

def test_save_non_original_translation_field(self):
obj = models.TestModel.objects.create(title='old')

obj.title_en = 'en value'
obj.save(update_fields=['title'])
assert_db_record(obj, title='old', title_de='old', title_en=None)

obj.save(update_fields=['title_en'])
assert_db_record(obj, title='old', title_de='old', title_en='en value')

def test_update_or_create_existing(self):
"""
Test that update_or_create works as expected
Expand All @@ -477,6 +576,43 @@ def test_update_or_create_existing(self):
assert instance.title == 'NEW DE TITLE'
assert instance.title_en == 'old en'
assert instance.title_de == 'NEW DE TITLE'
assert_db_record(
instance,
title='NEW DE TITLE',
title_en='old en',
title_de='NEW DE TITLE',
)

instance, created = models.TestModel.objects.update_or_create(
pk=obj.pk, defaults={'title_de': 'NEW DE TITLE 2'}
)

assert created is False
assert instance.title == 'NEW DE TITLE 2'
assert instance.title_en == 'old en'
assert instance.title_de == 'NEW DE TITLE 2'
assert_db_record(
instance,
# title='NEW DE TITLE', # TODO: django < 4.2 doesn't pass `"title"` into `.save(update_fields)`
title_en='old en',
title_de='NEW DE TITLE 2',
)

with override('en'):
instance, created = models.TestModel.objects.update_or_create(
pk=obj.pk, defaults={'title': 'NEW EN TITLE'}
)

assert created is False
assert instance.title == 'NEW EN TITLE'
assert instance.title_en == 'NEW EN TITLE'
assert instance.title_de == 'NEW DE TITLE 2'
assert_db_record(
instance,
title='NEW EN TITLE',
title_en='NEW EN TITLE',
title_de='NEW DE TITLE 2',
)

def test_update_or_create_new(self):
instance, created = models.TestModel.objects.update_or_create(
Expand All @@ -488,6 +624,12 @@ def test_update_or_create_new(self):
assert instance.title == 'old de'
assert instance.title_en == 'old en'
assert instance.title_de == 'old de'
assert_db_record(
instance,
title='old de',
title_en='old en',
title_de='old de',
)


class ModeltranslationTransactionTest(ModeltranslationTransactionTestBase):
Expand Down
Loading

0 comments on commit d86c6de

Please sign in to comment.