From d86c6defc864b3493955a41f95a85fc5aa8d5649 Mon Sep 17 00:00:00 2001 From: Andrei Shabanski Date: Sun, 16 Jul 2023 11:04:27 +0300 Subject: [PATCH] feat: extend update_fields with translation fields in Model.save() (#687) --- .github/workflows/test.yml | 2 +- modeltranslation/manager.py | 38 +++++--- modeltranslation/tests/settings.py | 2 +- modeltranslation/tests/tests.py | 142 +++++++++++++++++++++++++++++ poetry.lock | 53 ++++------- pyproject.toml | 1 + 6 files changed, 184 insertions(+), 54 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 185b2f11..91402195 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -88,7 +88,7 @@ jobs: if [[ $DB == postgres ]]; then pip install -q psycopg2-binary fi - pip install typing-extensions coverage pytest pytest-django pytest-cov $(./get-django-version.py ${{ matrix.django }}) + pip install typing-extensions coverage pytest pytest-django pytest-cov parameterized $(./get-django-version.py ${{ matrix.django }}) - name: Run tests run: | pytest --cov-report term diff --git a/modeltranslation/manager.py b/modeltranslation/manager.py index d2faf4f7..d7452e88 100644 --- a/modeltranslation/manager.py +++ b/modeltranslation/manager.py @@ -6,11 +6,13 @@ """ import itertools from functools import reduce +from typing import List, Tuple, Type, Any, Optional from django import VERSION from django.contrib.admin.utils import get_model_from_relation from django.core.exceptions import FieldDoesNotExist from django.db import models +from django.db.models import Field, Model from django.db.models.expressions import Col from django.db.models.lookups import Lookup from django.db.models.query import QuerySet, ValuesIterable @@ -243,21 +245,6 @@ def select_related(self, *fields, **kwargs): new_args.append(rewrite_lookup_key(self.model, key)) return super().select_related(*new_args, **kwargs) - def update_or_create(self, defaults=None, **kwargs): - """ - Updates or creates a database record with the specified kwargs. The method first - rewrites the keys in the defaults dictionary using a custom function named - `rewrite_lookup_key`. This ensures that the keys are valid for the current model - before calling the inherited update_or_create() method from the super class. - Returns the updated or created model instance. - """ - if defaults is not None: - rewritten_defaults = {} - for key, value in defaults.items(): - rewritten_defaults[rewrite_lookup_key(self.model, key)] = value - defaults = rewritten_defaults - return super().update_or_create(defaults=defaults, **kwargs) - # This method was not present in django-linguo def _rewrite_col(self, col): """Django >= 1.7 column name rewriting""" @@ -386,6 +373,27 @@ def update(self, **kwargs): update.alters_data = True + def _update(self, values: List[Tuple[Field, Optional[Type[Model]], Any]]): + """ + This method is called in .save() method to update an existing record. + Here we force to update translation fields as well if the original + field only is passed in `save()` in argument `update_fields`. + """ + # TODO: Should the original field (field without lang code suffix) be updated + # when only the default translation field (`field_`) is passed in `update_fields`? + # Currently, we don't synchronize values of the original and default translation fields in that case. + field_names_to_update = {field.name for field, *_ in values} + + translation_values = [] + for field, model, value in values: + translation_field_name = rewrite_lookup_key(self.model, field.name) + if translation_field_name not in field_names_to_update: + translatable_field = self.model._meta.get_field(translation_field_name) + translation_values.append((translatable_field, model, value)) + + values += translation_values + return super()._update(values) + # This method was not present in django-linguo @property def _populate_mode(self): diff --git a/modeltranslation/tests/settings.py b/modeltranslation/tests/settings.py index 8d075d18..f7f693db 100644 --- a/modeltranslation/tests/settings.py +++ b/modeltranslation/tests/settings.py @@ -29,7 +29,7 @@ def _get_database_config(): { 'ENGINE': 'django.db.backends.postgresql', 'USER': os.getenv('POSTGRES_USER', 'postgres'), - 'PASSWORD': os.getenv('POSTGRES_DB', 'postgres'), + 'PASSWORD': os.getenv('POSTGRES_PASSWORD', 'postgres'), 'NAME': os.getenv('POSTGRES_DB', 'modeltranslation'), 'HOST': host, } diff --git a/modeltranslation/tests/tests.py b/modeltranslation/tests/tests.py index 8d43966d..3544b9b9 100644 --- a/modeltranslation/tests/tests.py +++ b/modeltranslation/tests/tests.py @@ -22,6 +22,7 @@ from django.test import TestCase, TransactionTestCase from django.test.utils import override_settings from django.utils.translation import get_language, override, trans_real +from parameterized import parameterized from modeltranslation import admin from modeltranslation import settings as mt_settings @@ -79,6 +80,20 @@ def get_field_names(model): return names +def assert_db_record(instance, **expected_fields): + """ + Compares field values stored in the db. + """ + actual = ( + type(instance) + .objects.rewrite(False) + .filter(pk=instance.pk) + .values(*expected_fields.keys()) + .first() + ) + assert actual == expected_fields + + class ModeltranslationTransactionTestBase(TransactionTestCase): cache = django_apps @@ -358,6 +373,7 @@ def test_set_translation(self): assert n.title == title_de assert n.title_en == title_en assert n.title_de == title_de + assert_db_record(n, title=title_de, title_de=title_de, title_en=title_en) # Queries are also language-aware: assert 1 == models.TestModel.objects.filter(title=title_de).count() @@ -463,6 +479,89 @@ def test_constructor(self): ) self._test_constructor(keywords) + @parameterized.expand( + [ + ({'title': 'DE'}, ['title'], {'title': 'DE', 'title_de': 'DE', 'title_en': None}), + ({'title_de': 'DE'}, ['title'], {'title': 'DE', 'title_de': 'DE', 'title_en': None}), + ({'title': 'DE'}, ['title_de'], {'title': 'old', 'title_de': 'DE', 'title_en': None}), + ( + {'title_de': 'DE'}, + ['title_de'], + {'title': 'old', 'title_de': 'DE', 'title_en': None}, + ), + ( + {'title': 'DE', 'title_en': 'EN'}, + ['title', 'title_en'], + {'title': 'DE', 'title_de': 'DE', 'title_en': 'EN'}, + ), + ( + {'title_de': 'DE', 'title_en': 'EN'}, + ['title_de', 'title_en'], + {'title': 'old', 'title_de': 'DE', 'title_en': 'EN'}, + ), + ( + {'title_de': 'DE', 'title_en': 'EN'}, + ['title', 'title_de', 'title_en'], + {'title': 'DE', 'title_de': 'DE', 'title_en': 'EN'}, + ), + ] + ) + def test_save_original_translation_field(self, field_values, update_fields, expected_db_values): + obj = models.TestModel.objects.create(title='old') + + for field, value in field_values.items(): + setattr(obj, field, value) + + obj.save(update_fields=update_fields) + assert_db_record(obj, **expected_db_values) + + @parameterized.expand( + [ + ({'title': 'EN'}, ['title'], {'title': 'EN', 'title_de': None, 'title_en': 'EN'}), + ({'title_en': 'EN'}, ['title'], {'title': 'EN', 'title_de': None, 'title_en': 'EN'}), + ({'title': 'EN'}, ['title_en'], {'title': 'old', 'title_de': None, 'title_en': 'EN'}), + ( + {'title_en': 'EN'}, + ['title_en'], + {'title': 'old', 'title_de': None, 'title_en': 'EN'}, + ), + ( + {'title': 'EN', 'title_de': 'DE'}, + ['title', 'title_de'], + {'title': 'EN', 'title_de': 'DE', 'title_en': 'EN'}, + ), + ( + {'title_de': 'DE', 'title_en': 'EN'}, + ['title_de', 'title_en'], + {'title': 'old', 'title_de': 'DE', 'title_en': 'EN'}, + ), + ( + {'title_de': 'DE', 'title_en': 'EN'}, + ['title', 'title_de', 'title_en'], + {'title': 'EN', 'title_de': 'DE', 'title_en': 'EN'}, + ), + ] + ) + def test_save_active_translation_field(self, field_values, update_fields, expected_db_values): + with override('en'): + obj = models.TestModel.objects.create(title='old') + + for field, value in field_values.items(): + setattr(obj, field, value) + + obj.save(update_fields=update_fields) + assert_db_record(obj, **expected_db_values) + + def test_save_non_original_translation_field(self): + obj = models.TestModel.objects.create(title='old') + + obj.title_en = 'en value' + obj.save(update_fields=['title']) + assert_db_record(obj, title='old', title_de='old', title_en=None) + + obj.save(update_fields=['title_en']) + assert_db_record(obj, title='old', title_de='old', title_en='en value') + def test_update_or_create_existing(self): """ Test that update_or_create works as expected @@ -477,6 +576,43 @@ def test_update_or_create_existing(self): assert instance.title == 'NEW DE TITLE' assert instance.title_en == 'old en' assert instance.title_de == 'NEW DE TITLE' + assert_db_record( + instance, + title='NEW DE TITLE', + title_en='old en', + title_de='NEW DE TITLE', + ) + + instance, created = models.TestModel.objects.update_or_create( + pk=obj.pk, defaults={'title_de': 'NEW DE TITLE 2'} + ) + + assert created is False + assert instance.title == 'NEW DE TITLE 2' + assert instance.title_en == 'old en' + assert instance.title_de == 'NEW DE TITLE 2' + assert_db_record( + instance, + # title='NEW DE TITLE', # TODO: django < 4.2 doesn't pass `"title"` into `.save(update_fields)` + title_en='old en', + title_de='NEW DE TITLE 2', + ) + + with override('en'): + instance, created = models.TestModel.objects.update_or_create( + pk=obj.pk, defaults={'title': 'NEW EN TITLE'} + ) + + assert created is False + assert instance.title == 'NEW EN TITLE' + assert instance.title_en == 'NEW EN TITLE' + assert instance.title_de == 'NEW DE TITLE 2' + assert_db_record( + instance, + title='NEW EN TITLE', + title_en='NEW EN TITLE', + title_de='NEW DE TITLE 2', + ) def test_update_or_create_new(self): instance, created = models.TestModel.objects.update_or_create( @@ -488,6 +624,12 @@ def test_update_or_create_new(self): assert instance.title == 'old de' assert instance.title_en == 'old en' assert instance.title_de == 'old de' + assert_db_record( + instance, + title='old de', + title_en='old en', + title_de='old de', + ) class ModeltranslationTransactionTest(ModeltranslationTransactionTestBase): diff --git a/poetry.lock b/poetry.lock index 32cda3ae..f49dcd75 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,10 +1,9 @@ -# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.5.0 and should not be changed by hand. [[package]] name = "asgiref" version = "3.6.0" description = "ASGI specs, helper code, and adapters" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -19,7 +18,6 @@ tests = ["mypy (>=0.800)", "pytest", "pytest-asyncio"] name = "attrs" version = "22.2.0" description = "Classes Without Boilerplate" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -38,7 +36,6 @@ tests-no-zope = ["cloudpickle", "cloudpickle", "hypothesis", "hypothesis", "mypy name = "backports-zoneinfo" version = "0.2.1" description = "Backport of the standard library zoneinfo module" -category = "main" optional = false python-versions = ">=3.6" files = [ @@ -67,7 +64,6 @@ tzdata = ["tzdata"] name = "black" version = "23.1.0" description = "The uncompromising code formatter." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -117,7 +113,6 @@ uvloop = ["uvloop (>=0.15.2)"] name = "click" version = "8.1.3" description = "Composable command line interface toolkit" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -132,7 +127,6 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""} name = "colorama" version = "0.4.6" description = "Cross-platform colored terminal text." -category = "dev" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" files = [ @@ -144,7 +138,6 @@ files = [ name = "coverage" version = "7.1.0" description = "Code coverage measurement for Python" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -211,7 +204,6 @@ toml = ["tomli"] name = "django" version = "4.2.1" description = "A high-level Python web framework that encourages rapid development and clean, pragmatic design." -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -233,7 +225,6 @@ bcrypt = ["bcrypt"] name = "django-types" version = "0.16.0" description = "Type stubs for Django" -category = "dev" optional = false python-versions = ">=3.7,<4.0" files = [ @@ -245,7 +236,6 @@ files = [ name = "exceptiongroup" version = "1.1.0" description = "Backport of PEP 654 (exception groups)" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -260,7 +250,6 @@ test = ["pytest (>=6)"] name = "fancycompleter" version = "0.9.1" description = "colorful TAB completion for Python prompt" -category = "dev" optional = false python-versions = "*" files = [ @@ -276,7 +265,6 @@ pyrepl = ">=0.8.2" name = "flake8" version = "5.0.4" description = "the modular source code checker: pep8 pyflakes and co" -category = "dev" optional = false python-versions = ">=3.6.1" files = [ @@ -293,7 +281,6 @@ pyflakes = ">=2.5.0,<2.6.0" name = "iniconfig" version = "2.0.0" description = "brain-dead simple config-ini parsing" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -305,7 +292,6 @@ files = [ name = "mccabe" version = "0.7.0" description = "McCabe checker, plugin for flake8" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -317,7 +303,6 @@ files = [ name = "mypy-extensions" version = "1.0.0" description = "Type system extensions for programs checked with the mypy type checker." -category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -329,7 +314,6 @@ files = [ name = "packaging" version = "23.0" description = "Core utilities for Python packages" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -337,11 +321,24 @@ files = [ {file = "packaging-23.0.tar.gz", hash = "sha256:b6ad297f8907de0fa2fe1ccbd26fdaf387f5f47c7275fedf8cce89f99446cf97"}, ] +[[package]] +name = "parameterized" +version = "0.9.0" +description = "Parameterized testing with any Python test framework" +optional = false +python-versions = ">=3.7" +files = [ + {file = "parameterized-0.9.0-py2.py3-none-any.whl", hash = "sha256:4e0758e3d41bea3bbd05ec14fc2c24736723f243b28d702081aef438c9372b1b"}, + {file = "parameterized-0.9.0.tar.gz", hash = "sha256:7fc905272cefa4f364c1a3429cbbe9c0f98b793988efb5bf90aac80f08db09b1"}, +] + +[package.extras] +dev = ["jinja2"] + [[package]] name = "pathspec" version = "0.11.0" description = "Utility library for gitignore style pattern matching of file paths." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -353,7 +350,6 @@ files = [ name = "pdbpp" version = "0.10.3" description = "pdb++, a drop-in replacement for pdb" -category = "dev" optional = false python-versions = "*" files = [ @@ -374,7 +370,6 @@ testing = ["funcsigs", "pytest"] name = "platformdirs" version = "3.0.0" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -390,7 +385,6 @@ test = ["appdirs (==1.4.4)", "covdefaults (>=2.2.2)", "pytest (>=7.2.1)", "pytes name = "pluggy" version = "1.0.0" description = "plugin and hook calling mechanisms for python" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -406,7 +400,6 @@ testing = ["pytest", "pytest-benchmark"] name = "pycodestyle" version = "2.9.1" description = "Python style guide checker" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -418,7 +411,6 @@ files = [ name = "pyflakes" version = "2.5.0" description = "passive checker of Python programs" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -430,7 +422,6 @@ files = [ name = "pygments" version = "2.14.0" description = "Pygments is a syntax highlighting package written in Python." -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -445,7 +436,6 @@ plugins = ["importlib-metadata"] name = "pyreadline" version = "2.1" description = "A python implmementation of GNU readline." -category = "dev" optional = false python-versions = "*" files = [ @@ -456,7 +446,6 @@ files = [ name = "pyrepl" version = "0.9.0" description = "A library for building flexible command line interfaces" -category = "dev" optional = false python-versions = "*" files = [ @@ -467,7 +456,6 @@ files = [ name = "pytest" version = "7.2.1" description = "pytest: simple powerful testing with Python" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -491,7 +479,6 @@ testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2. name = "pytest-cov" version = "4.0.0" description = "Pytest plugin for measuring coverage." -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -510,7 +497,6 @@ testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtuale name = "pytest-django" version = "4.5.2" description = "A Django plugin for pytest." -category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -529,7 +515,6 @@ testing = ["Django", "django-configurations (>=2.0)"] name = "pytest-sugar" version = "0.9.6" description = "pytest-sugar is a plugin for pytest that changes the default look and feel of pytest (e.g. progressbar, show tests that fail instantly)." -category = "dev" optional = false python-versions = "*" files = [ @@ -546,7 +531,6 @@ termcolor = ">=1.1.0" name = "sqlparse" version = "0.4.4" description = "A non-validating SQL parser." -category = "main" optional = false python-versions = ">=3.5" files = [ @@ -563,7 +547,6 @@ test = ["pytest", "pytest-cov"] name = "termcolor" version = "2.2.0" description = "ANSI color formatting for output in terminal" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -578,7 +561,6 @@ tests = ["pytest", "pytest-cov"] name = "tomli" version = "2.0.1" description = "A lil' TOML parser" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -590,7 +572,6 @@ files = [ name = "typing-extensions" version = "4.4.0" description = "Backported and Experimental Type Hints for Python 3.7+" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -602,7 +583,6 @@ files = [ name = "tzdata" version = "2023.3" description = "Provider of IANA time zone data" -category = "main" optional = false python-versions = ">=2" files = [ @@ -614,7 +594,6 @@ files = [ name = "wmctrl" version = "0.4" description = "A tool to programmatically control windows inside X" -category = "dev" optional = false python-versions = "*" files = [ @@ -624,4 +603,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = ">=3.8,<4" -content-hash = "9e9b9964450d62b14b6f80348b2ec2f8d61347fc81c112489636280912bd7737" +content-hash = "5d767bc3b0567ffe33222e6d66a9204a5898dc62cf1dc33c80d2d0a40f1ef5b4" diff --git a/pyproject.toml b/pyproject.toml index 7c7dc230..e45684b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ typing-extensions = "*" pdbpp = "*" flake8 = "*" black = "*" +parameterized = "*" pytest-cov = "*" pytest = "*" pytest-sugar = "*"