diff --git a/django/db/backends/base/schema.py b/django/db/backends/base/schema.py index 116c98f432cee..7633445d3c974 100644 --- a/django/db/backends/base/schema.py +++ b/django/db/backends/base/schema.py @@ -105,6 +105,7 @@ class BaseDatabaseSchemaEditor: sql_check_constraint = "CHECK (%(check)s)" sql_delete_constraint = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s" sql_constraint = "CONSTRAINT %(name)s %(constraint)s" + sql_pk_constraint = "PRIMARY KEY (%(columns)s)" sql_create_check = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s CHECK (%(check)s)" sql_delete_check = sql_delete_constraint @@ -268,6 +269,13 @@ def table_sql(self, model): constraint.constraint_sql(model, self) for constraint in model._meta.constraints ] + + # If the model defines Meta.primary_key, add the primary key constraint + # to the table definition. + # It's expected primary_key=True isn't set on any fields (see E042). + if model._meta.primary_key: + constraints.append(self._pk_constraint_sql(model._meta.primary_key)) + sql = self.sql_create_table % { "table": self.quote_name(model._meta.db_table), "definition": ", ".join( @@ -1967,6 +1975,11 @@ def _constraint_names( result.append(name) return result + def _pk_constraint_sql(self, fields): + return self.sql_pk_constraint % { + "columns": ", ".join(self.quote_name(field) for field in fields) + } + def _delete_primary_key(self, model, strict=False): constraint_names = self._constraint_names(model, primary_key=True) if strict and len(constraint_names) != 1: diff --git a/django/db/models/base.py b/django/db/models/base.py index e68baf4e57632..bf693ffc64754 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -30,6 +30,7 @@ from django.db.models.constants import LOOKUP_SEP from django.db.models.deletion import CASCADE, Collector from django.db.models.expressions import DatabaseDefault +from django.db.models.fields.composite import is_pk_set from django.db.models.fields.related import ( ForeignObjectRel, OneToOneField, @@ -1080,7 +1081,7 @@ def _save_table( if pk_val is None: pk_val = meta.pk.get_pk_value_on_save(self) setattr(self, meta.pk.attname, pk_val) - pk_set = pk_val is not None + pk_set = is_pk_set(pk_val) if not pk_set and (force_update or update_fields): raise ValueError("Cannot force an update in save() with no primary key.") updated = False @@ -1686,6 +1687,7 @@ def check(cls, **kwargs): *cls._check_constraints(databases), *cls._check_default_pk(), *cls._check_db_table_comment(databases), + *cls._check_composite_pk(), ] return errors @@ -1694,6 +1696,9 @@ def check(cls, **kwargs): def _check_default_pk(cls): if ( not cls._meta.abstract + # If the model defines Meta.primary_key, the check should be skipped, + # since there's no default primary key. + and not cls._meta.primary_key and cls._meta.pk.auto_created and # Inherited PKs are checked in parents models. @@ -1722,6 +1727,24 @@ def _check_default_pk(cls): ] return [] + @classmethod + def _check_composite_pk(cls): + errors = [] + + if cls._meta.primary_key and any( + field for field in cls._meta.fields if field.primary_key + ): + errors.append( + checks.Error( + "primary_key=True must not be set if Meta.primary_key " + "is defined.", + obj=cls, + id="models.E042", + ) + ) + + return errors + @classmethod def _check_db_table_comment(cls, databases): if not cls._meta.db_table_comment: @@ -1842,6 +1865,11 @@ def _check_m2m_through_same_relationship(cls): @classmethod def _check_id_field(cls): """Check if `id` field is a primary key.""" + # If the model defines Meta.primary_key, the check should be skipped, + # since primary_key=True can't be set on any fields (including `id`). + if cls._meta.primary_key: + return [] + fields = [ f for f in cls._meta.local_fields if f.name == "id" and f != cls._meta.pk ] diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index 796c4d23c458b..cd62e53c68ab9 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -2794,6 +2794,11 @@ def check(self, **kwargs): ] def _check_primary_key(self): + # If the model defines Meta.primary_key, primary_key=True can't be set on + # any field (including AutoFields). + if self.model._meta.primary_key: + return [] + if not self.primary_key: return [ checks.Error( @@ -2808,7 +2813,7 @@ def _check_primary_key(self): def deconstruct(self): name, path, args, kwargs = super().deconstruct() del kwargs["blank"] - kwargs["primary_key"] = True + kwargs["primary_key"] = self.primary_key return name, path, args, kwargs def validate(self, value, model_instance): diff --git a/django/db/models/fields/composite.py b/django/db/models/fields/composite.py new file mode 100644 index 0000000000000..5941107376497 --- /dev/null +++ b/django/db/models/fields/composite.py @@ -0,0 +1,176 @@ +from collections.abc import Iterable + +from django.core.exceptions import FieldDoesNotExist +from django.db.models import Field +from django.db.models.expressions import Col, Expression +from django.db.models.lookups import Exact, In +from django.db.models.signals import class_prepared +from django.utils.functional import cached_property + + +class TupleExact(Exact): + def get_prep_lookup(self): + if not isinstance(self.lhs, Cols): + raise ValueError( + "The left-hand side of the 'exact' lookup must be an instance of Cols" + ) + if not isinstance(self.rhs, Iterable): + raise ValueError( + "The right-hand side of the 'exact' lookup must be an iterable" + ) + if len(list(self.lhs)) != len(list(self.rhs)): + raise ValueError( + "The left-hand side and right-hand side of the 'exact' lookup must " + "have the same number of elements" + ) + + return super().get_prep_lookup() + + def as_sql(self, compiler, connection): + from django.db.models.sql.where import AND, WhereNode + + cols = self.lhs.get_source_expressions() + exprs = [Exact(col, val) for col, val in zip(cols, self.rhs)] + + return compiler.compile(WhereNode(exprs, connector=AND)) + + +class TupleIn(In): + def get_prep_lookup(self): + if not isinstance(self.lhs, Cols): + raise ValueError( + "The left-hand side of the 'in' lookup must be an instance of Cols" + ) + if not isinstance(self.rhs, Iterable): + raise ValueError( + "The right-hand side of the 'in' lookup must be an iterable" + ) + if not all(isinstance(vals, Iterable) for vals in self.rhs): + raise ValueError( + "The right-hand side of the 'in' lookup must be an iterable of " + "iterables" + ) + lhs_len = len(tuple(self.lhs)) + if not all(lhs_len == len(tuple(vals)) for vals in self.rhs): + raise ValueError( + "The left-hand side and right-hand side of the 'in' lookup must " + "have the same number of elements" + ) + + return super().get_prep_lookup() + + def as_sql(self, compiler, connection): + from django.db.models.sql.where import AND, OR, WhereNode + + exprs = [] + cols = self.lhs.get_source_expressions() + + for vals in self.rhs: + exprs.append( + WhereNode( + [Exact(col, val) for col, val in zip(cols, vals)], connector=AND + ) + ) + + return compiler.compile(WhereNode(exprs, connector=OR)) + + +class Cols(Expression): + def __init__(self, alias, targets, output_field): + super().__init__(output_field=output_field) + self.alias, self.targets = alias, targets + + def get_source_expressions(self): + return [Col(self.alias, target) for target in self.targets] + + def set_source_expressions(self, exprs): + assert all(isinstance(expr, Col) for expr in exprs) + assert len(exprs) == len(self.targets) + + def as_sql(self, compiler, connection): + sqls = [] + cols = self.get_source_expressions() + + for col in cols: + sql, _ = col.as_sql(compiler, connection) + sqls.append(sql) + + return ", ".join(sqls), [] + + def __iter__(self): + return iter(self.get_source_expressions()) + + +def is_pk_not_set(pk): + return pk is None or (isinstance(pk, tuple) and any(f is None for f in pk)) + + +def is_pk_set(pk): + return not is_pk_not_set(pk) + + +class CompositeAttribute: + def __init__(self, field): + self.field = field + + def __get__(self, instance, cls=None): + return tuple( + getattr(instance, field_name) for field_name in self.field.field_names + ) + + def __set__(self, instance, values): + if values is None: + values = (None,) * len(self.field.field_names) + + for field_name, value in zip(self.field.field_names, values): + setattr(instance, field_name, value) + + +class CompositeField(Field): + descriptor_class = CompositeAttribute + + def __init__(self, *args, **kwargs): + kwargs["db_column"] = None + kwargs["editable"] = False + super().__init__(**kwargs) + self.field_names = args + self.fields = None + + def contribute_to_class(self, cls, name, **_): + super().contribute_to_class(cls, name, private_only=True) + cls._meta.pk = self + setattr(cls, self.attname, self.descriptor_class(self)) + + def get_attname_column(self): + return self.get_attname(), self.db_column + + def __iter__(self): + return iter(self.fields) + + @cached_property + def cached_col(self): + return Cols(self.model._meta.db_table, self.fields, self) + + def get_col(self, alias, output_field=None): + return self.cached_col + + def get_lookup(self, lookup_name): + if lookup_name == "exact": + return TupleExact + elif lookup_name == "in": + return TupleIn + + return super().get_lookup(lookup_name) + + +def resolve_fields(*args, **kwargs): + meta = kwargs["sender"]._meta + for field in meta.private_fields: + if isinstance(field, CompositeField) and field.fields is None: + try: + field.fields = tuple(meta.get_field(name) for name in field.field_names) + except FieldDoesNotExist: + continue + + +class_prepared.connect(resolve_fields) diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 3e4bfe34c1b17..8855a7b0d2671 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -615,6 +615,16 @@ def _check_unique_target(self): if not self.foreign_related_fields: return [] + # If a model defines Meta.primary_key and a foreign key refers to it, + # the check should be skipped (since primary keys are unique). + pk = self.remote_field.model._meta.primary_key + if pk: + pk = set(pk) + if pk == {f.attname for f in self.foreign_related_fields}: + return [] + elif pk == {f.name for f in self.foreign_related_fields}: + return [] + has_unique_constraint = any( rel_field.unique for rel_field in self.foreign_related_fields ) diff --git a/django/db/models/options.py b/django/db/models/options.py index ed7be7dd7a84e..5998b38bbedda 100644 --- a/django/db/models/options.py +++ b/django/db/models/options.py @@ -7,6 +7,7 @@ from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured from django.db import connections from django.db.models import AutoField, Manager, OrderWrt, UniqueConstraint +from django.db.models.fields.composite import CompositeField from django.db.models.query_utils import PathInfo from django.utils.datastructures import ImmutableList, OrderedSet from django.utils.functional import cached_property @@ -24,6 +25,7 @@ ) DEFAULT_NAMES = ( + "primary_key", "verbose_name", "verbose_name_plural", "db_table", @@ -106,6 +108,7 @@ def __init__(self, meta, app_label=None): self.base_manager_name = None self.default_manager_name = None self.model_name = None + self.primary_key = None self.verbose_name = None self.verbose_name_plural = None self.db_table = "" @@ -296,7 +299,10 @@ def _prepare(self, model): self.order_with_respect_to = None if self.pk is None: - if self.parents: + if self.primary_key: + pk = CompositeField(*self.primary_key) + model.add_to_class("primary_key", pk) + elif self.parents: # Promote the first parent link in lieu of adding yet another # field. field = next(iter(self.parents.values())) diff --git a/django/db/models/query.py b/django/db/models/query.py index cb5c63c0d17c2..8fed5f2cee6ba 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -24,6 +24,7 @@ from django.db.models.constants import LOOKUP_SEP, OnConflict from django.db.models.deletion import Collector from django.db.models.expressions import Case, F, Value, When +from django.db.models.fields.composite import is_pk_not_set from django.db.models.functions import Cast, Trunc from django.db.models.query_utils import FilteredRelation, Q from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE @@ -813,7 +814,9 @@ def bulk_create( objs = list(objs) self._prepare_for_bulk_create(objs) with transaction.atomic(using=self.db, savepoint=False): - objs_with_pk, objs_without_pk = partition(lambda o: o.pk is None, objs) + objs_with_pk, objs_without_pk = partition( + lambda o: is_pk_not_set(o.pk), objs + ) if objs_with_pk: returned_columns = self._batched_insert( objs_with_pk, diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 676625df6fe7a..bf154823e7586 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -8,6 +8,7 @@ from django.db import DatabaseError, NotSupportedError from django.db.models.constants import LOOKUP_SEP from django.db.models.expressions import F, OrderBy, RawSQL, Ref, Value +from django.db.models.fields.composite import CompositeField from django.db.models.functions import Cast, Random from django.db.models.lookups import Lookup from django.db.models.query_utils import select_related_descend @@ -979,6 +980,12 @@ def get_default_columns( # be used by local fields. seen_models = {None: start_alias} + select_mask_fields = set() + for field in select_mask: + select_mask_fields.update( + field.fields if isinstance(field, CompositeField) else [field] + ) + for field in opts.concrete_fields: model = field.model._meta.concrete_model # A proxy model will have a different model and concrete_model. We @@ -998,7 +1005,7 @@ def get_default_columns( # parent model data is already present in the SELECT clause, # and we want to avoid reloading the same data again. continue - if select_mask and field not in select_mask: + if select_mask and field not in select_mask_fields: continue alias = self.query.join_parent_model(opts, model, start_alias, seen_models) column = field.get_col(alias) @@ -2051,7 +2058,7 @@ def pre_sql_setup(self): must_pre_select = ( count > 1 and not self.connection.features.update_can_self_select - ) + ) or meta.primary_key # Now we adjust the current query: reset the where clause and get rid # of all the tables we don't need (since they're in the sub-select). @@ -2063,7 +2070,10 @@ def pre_sql_setup(self): idents = [] related_ids = collections.defaultdict(list) for rows in query.get_compiler(self.using).execute_sql(MULTI): - idents.extend(r[0] for r in rows) + if meta.primary_key: + idents.extend(rows) + else: + idents.extend(r[0] for r in rows) for parent, index in related_ids_index: related_ids[parent].extend(r[index] for r in rows) self.query.add_filter("pk__in", idents) diff --git a/docs/ref/checks.txt b/docs/ref/checks.txt index efc8cf666a235..13be27aa4a5ac 100644 --- a/docs/ref/checks.txt +++ b/docs/ref/checks.txt @@ -416,6 +416,8 @@ Models * **models.W040**: ```` does not support indexes with non-key columns. * **models.E041**: ``constraints`` refers to the joined field ````. +* **models.E042**: primary_key=True must not be set if Meta.primary_key + is defined. * **models.W042**: Auto-created primary key used when not defining a primary key type, by default ``django.db.models.AutoField``. * **models.W043**: ```` does not support indexes on expressions. diff --git a/tests/composite_pk/__init__.py b/tests/composite_pk/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/composite_pk/models/__init__.py b/tests/composite_pk/models/__init__.py new file mode 100644 index 0000000000000..beb137af7639c --- /dev/null +++ b/tests/composite_pk/models/__init__.py @@ -0,0 +1,7 @@ +from .tenant import Comment, Tenant, User + +__all__ = [ + "Tenant", + "User", + "Comment", +] diff --git a/tests/composite_pk/models/tenant.py b/tests/composite_pk/models/tenant.py new file mode 100644 index 0000000000000..191f8ccd60d68 --- /dev/null +++ b/tests/composite_pk/models/tenant.py @@ -0,0 +1,35 @@ +from django.db import connection, models + +# SQLite doesn't support non-primary auto fields. +ID = ( + models.SmallIntegerField if connection.vendor == "sqlite" else models.SmallAutoField +) + + +class Tenant(models.Model): + pass + + +class User(models.Model): + tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE) + id = ID(unique=True) + email = models.EmailField() + + class Meta: + primary_key = ("tenant_id", "id") + + +class Comment(models.Model): + tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE) + id = ID(unique=True) + user_id = models.SmallIntegerField() + user = models.ForeignObject( + User, + on_delete=models.CASCADE, + from_fields=("tenant_id", "user_id"), + to_fields=("tenant_id", "id"), + related_name="+", + ) + + class Meta: + primary_key = ("tenant_id", "id") diff --git a/tests/composite_pk/test_create.py b/tests/composite_pk/test_create.py new file mode 100644 index 0000000000000..a8987adce69fe --- /dev/null +++ b/tests/composite_pk/test_create.py @@ -0,0 +1,214 @@ +import unittest + +from django.db import connection +from django.test import TestCase +from django.test.utils import CaptureQueriesContext + +from .models import Tenant, User + + +class CompositePKCreateTests(TestCase): + """ + Test the .create(), .save(), .bulk_create(), .get_or_create(), .update_or_create() + methods of composite_pk models. + """ + + maxDiff = None + + @classmethod + def setUpTestData(cls): + cls.tenant = Tenant.objects.create() + + @unittest.skipUnless(connection.vendor == "sqlite", "SQLite specific test") + def test_create_user_in_sqlite(self): + test_cases = [ + {"tenant": self.tenant, "id": 2412, "email": "user2412@example.com"}, + {"tenant_id": self.tenant.id, "id": 5316, "email": "user5316@example.com"}, + {"pk": (self.tenant.id, 7424), "email": "user7424@example.com"}, + ] + + for fields in test_cases: + user = User(**fields) + self.assertIsNotNone(user.id) + self.assertIsNotNone(user.email) + + with self.subTest(fields=fields): + with CaptureQueriesContext(connection) as context: + obj = User.objects.create(**fields) + + self.assertEqual(obj.tenant_id, self.tenant.id) + self.assertEqual(obj.id, user.id) + self.assertEqual(obj.pk, (self.tenant.id, user.id)) + self.assertEqual(obj.email, user.email) + self.assertEqual(len(context.captured_queries), 1) + u = User._meta.db_table + self.assertEqual( + context.captured_queries[0]["sql"], + f'INSERT INTO "{u}" ("tenant_id", "id", "email") ' + f"VALUES ({self.tenant.id}, {user.id}, '{user.email}')", + ) + + @unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific test") + def test_create_user_in_postgresql(self): + test_cases = [ + {"tenant": self.tenant, "id": 5231, "email": "user5231@example.com"}, + {"tenant_id": self.tenant.id, "id": 6123, "email": "user6123@example.com"}, + {"pk": (self.tenant.id, 3513), "email": "user3513@example.com"}, + ] + + for fields in test_cases: + user = User(**fields) + self.assertIsNotNone(user.id) + self.assertIsNotNone(user.email) + + with self.subTest(fields=fields): + with CaptureQueriesContext(connection) as context: + obj = User.objects.create(**fields) + + self.assertEqual(obj.tenant_id, self.tenant.id) + self.assertEqual(obj.id, user.id) + self.assertEqual(obj.pk, (self.tenant.id, user.id)) + self.assertEqual(obj.email, user.email) + self.assertEqual(len(context.captured_queries), 1) + u = User._meta.db_table + self.assertEqual( + context.captured_queries[0]["sql"], + f'INSERT INTO "{u}" ("tenant_id", "id", "email") ' + f"VALUES ({self.tenant.id}, {user.id}, '{user.email}') " + f'RETURNING "{u}"."id"', + ) + + @unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific test") + def test_create_user_with_autofield_in_postgresql(self): + test_cases = [ + {"tenant": self.tenant, "email": "user1111@example.com"}, + {"tenant_id": self.tenant.id, "email": "user2222@example.com"}, + ] + + for fields in test_cases: + user = User(**fields) + self.assertIsNotNone(user.email) + + with CaptureQueriesContext(connection) as context: + obj = User.objects.create(**fields) + + self.assertEqual(obj.tenant_id, self.tenant.id) + self.assertIsInstance(obj.id, int) + self.assertGreater(obj.id, 0) + self.assertEqual(obj.pk, (self.tenant.id, obj.id)) + self.assertEqual(obj.email, user.email) + self.assertEqual(len(context.captured_queries), 1) + u = User._meta.db_table + self.assertEqual( + context.captured_queries[0]["sql"], + f'INSERT INTO "{u}" ("tenant_id", "email") ' + f"VALUES ({self.tenant.id}, '{user.email}') " + f'RETURNING "{u}"."id"', + ) + + def test_save_user(self): + user = User(tenant=self.tenant, id=9241, email="user9241@example.com") + user.save() + self.assertEqual(user.tenant_id, self.tenant.id) + self.assertEqual(user.tenant, self.tenant) + self.assertEqual(user.id, 9241) + self.assertEqual(user.pk, (self.tenant.id, 9241)) + self.assertEqual(user.email, "user9241@example.com") + + @unittest.skipUnless(connection.vendor == "sqlite", "SQLite specific test") + def test_bulk_create_users_in_sqlite(self): + objs = [ + User(tenant=self.tenant, id=8291, email="user8291@example.com"), + User(tenant_id=self.tenant.id, id=4021, email="user4021@example.com"), + User(pk=(self.tenant.id, 8214), email="user8214@example.com"), + ] + + with CaptureQueriesContext(connection) as context: + result = User.objects.bulk_create(objs) + + obj_1, obj_2, obj_3 = result + self.assertEqual(obj_1.tenant_id, self.tenant.id) + self.assertEqual(obj_1.id, 8291) + self.assertEqual(obj_1.pk, (obj_1.tenant_id, obj_1.id)) + self.assertEqual(obj_2.tenant_id, self.tenant.id) + self.assertEqual(obj_2.id, 4021) + self.assertEqual(obj_2.pk, (obj_2.tenant_id, obj_2.id)) + self.assertEqual(obj_3.tenant_id, self.tenant.id) + self.assertEqual(obj_3.id, 8214) + self.assertEqual(obj_3.pk, (obj_3.tenant_id, obj_3.id)) + self.assertEqual(len(context.captured_queries), 1) + u = User._meta.db_table + self.assertEqual( + context.captured_queries[0]["sql"], + f'INSERT INTO "{u}" ("tenant_id", "id", "email") ' + f"VALUES ({self.tenant.id}, 8291, 'user8291@example.com'), " + f"({self.tenant.id}, 4021, 'user4021@example.com'), " + f"({self.tenant.id}, 8214, 'user8214@example.com')", + ) + + @unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific test") + def test_bulk_create_users_in_postgresql(self): + objs = [ + User(tenant=self.tenant, id=8361, email="user8361@example.com"), + User(tenant_id=self.tenant.id, id=2819, email="user2819@example.com"), + User(pk=(self.tenant.id, 9136), email="user9136@example.com"), + User(tenant=self.tenant, email="user1111@example.com"), + User(tenant_id=self.tenant.id, email="user2222@example.com"), + ] + + with CaptureQueriesContext(connection) as context: + result = User.objects.bulk_create(objs) + + obj_1, obj_2, obj_3, obj_4, obj_5 = result + self.assertEqual(obj_1.tenant_id, self.tenant.id) + self.assertEqual(obj_1.id, 8361) + self.assertEqual(obj_1.pk, (obj_1.tenant_id, obj_1.id)) + self.assertEqual(obj_2.tenant_id, self.tenant.id) + self.assertEqual(obj_2.id, 2819) + self.assertEqual(obj_2.pk, (obj_2.tenant_id, obj_2.id)) + self.assertEqual(obj_3.tenant_id, self.tenant.id) + self.assertEqual(obj_3.id, 9136) + self.assertEqual(obj_3.pk, (obj_3.tenant_id, obj_3.id)) + self.assertEqual(obj_4.tenant_id, self.tenant.id) + self.assertIsInstance(obj_4.id, int) + self.assertGreater(obj_4.id, 0) + self.assertEqual(obj_4.pk, (obj_4.tenant_id, obj_4.id)) + self.assertEqual(obj_5.tenant_id, self.tenant.id) + self.assertIsInstance(obj_5.id, int) + self.assertGreater(obj_5.id, obj_4.id) + self.assertEqual(obj_5.pk, (obj_5.tenant_id, obj_5.id)) + self.assertEqual(len(context.captured_queries), 2) + u = User._meta.db_table + self.assertEqual( + context.captured_queries[0]["sql"], + f'INSERT INTO "{u}" ("tenant_id", "id", "email") ' + f"VALUES ({self.tenant.id}, 8361, 'user8361@example.com'), " + f"({self.tenant.id}, 2819, 'user2819@example.com'), " + f"({self.tenant.id}, 9136, 'user9136@example.com') " + f'RETURNING "{u}"."id"', + ) + self.assertEqual( + context.captured_queries[1]["sql"], + f'INSERT INTO "{u}" ("tenant_id", "email") ' + f"VALUES ({self.tenant.id}, 'user1111@example.com'), " + f"({self.tenant.id}, 'user2222@example.com') " + f'RETURNING "{u}"."id"', + ) + + def test_get_or_create_user_by_pk(self): + user, created = User.objects.get_or_create(pk=(self.tenant.id, 8314)) + + self.assertTrue(created) + self.assertEqual(1, User.objects.all().count()) + self.assertEqual(user.pk, (self.tenant.id, 8314)) + self.assertEqual(user.tenant_id, self.tenant.id) + self.assertEqual(user.id, 8314) + + def test_update_or_create_user_by_pk(self): + user, created = User.objects.update_or_create(pk=(self.tenant.id, 2931)) + + self.assertTrue(created) + self.assertEqual(1, User.objects.all().count()) + self.assertEqual(user.pk, (self.tenant.id, 2931)) + self.assertEqual(user.tenant_id, self.tenant.id) + self.assertEqual(user.id, 2931) diff --git a/tests/composite_pk/test_delete.py b/tests/composite_pk/test_delete.py new file mode 100644 index 0000000000000..4cd38e9d68898 --- /dev/null +++ b/tests/composite_pk/test_delete.py @@ -0,0 +1,171 @@ +from django.db import connection +from django.test import TestCase +from django.test.utils import CaptureQueriesContext + +from .models import Comment, Tenant, User + + +class CompositePKDeleteTests(TestCase): + """ + Test the .delete(), .exists() methods of composite_pk models. + """ + + maxDiff = None + + @classmethod + def setUpTestData(cls): + cls.tenant = Tenant.objects.create() + cls.user = User.objects.create(tenant=cls.tenant, id=1) + cls.comment = Comment.objects.create(tenant=cls.tenant, id=1, user=cls.user) + + def test_delete_tenant_by_pk(self): + with CaptureQueriesContext(connection) as context: + result = Tenant.objects.filter(pk=self.tenant.pk).delete() + + self.assertEqual( + result, + ( + 3, + { + "composite_pk.Comment": 1, + "composite_pk.User": 1, + "composite_pk.Tenant": 1, + }, + ), + ) + + self.assertFalse(Tenant.objects.filter(id=self.tenant.id).exists()) + self.assertFalse(User.objects.filter(id=self.user.id).exists()) + self.assertFalse(Comment.objects.filter(id=self.comment.id).exists()) + + self.assertEqual(len(context.captured_queries), 6) + if connection.vendor in ("sqlite", "postgresql"): + t = Tenant._meta.db_table + u = User._meta.db_table + c = Comment._meta.db_table + self.assertEqual( + context.captured_queries[0]["sql"], + f'SELECT "{t}"."id" FROM "{t}" WHERE "{t}"."id" = {self.tenant.id}', + ) + self.assertEqual( + context.captured_queries[1]["sql"], + f'SELECT "{u}"."tenant_id", "{u}"."id" ' + f'FROM "{u}" ' + f'WHERE "{u}"."tenant_id" IN ({self.tenant.id})', + ) + self.assertEqual( + context.captured_queries[2]["sql"], + f'DELETE FROM "{c}" ' + f'WHERE ("{c}"."tenant_id" = {self.tenant.id} ' + f'AND "{c}"."user_id" = {self.user.id})', + ) + self.assertEqual( + context.captured_queries[3]["sql"], + f'DELETE FROM "{c}" WHERE "{c}"."tenant_id" IN ({self.tenant.id})', + ) + self.assertEqual( + context.captured_queries[4]["sql"], + f'DELETE FROM "{u}" ' + f'WHERE ("{u}"."tenant_id" = {self.tenant.id} ' + f'AND "{u}"."id" = {self.user.id})', + ) + self.assertEqual( + context.captured_queries[5]["sql"], + f'DELETE FROM "{t}" WHERE "{t}"."id" IN ({self.tenant.id})', + ) + + def test_delete_user_by_id(self): + with CaptureQueriesContext(connection) as context: + result = User.objects.only("pk").filter(id=self.user.id).delete() + + self.assertEqual( + result, (2, {"composite_pk.User": 1, "composite_pk.Comment": 1}) + ) + + self.assertTrue(Tenant.objects.filter(id=self.tenant.id).exists()) + self.assertFalse(User.objects.filter(id=self.user.id).exists()) + self.assertFalse(Comment.objects.filter(id=self.comment.id).exists()) + + self.assertEqual(len(context.captured_queries), 3) + if connection.vendor in ("sqlite", "postgresql"): + u = User._meta.db_table + c = Comment._meta.db_table + self.assertEqual( + context.captured_queries[0]["sql"], + f'SELECT "{u}"."tenant_id", "{u}"."id" ' + f'FROM "{u}" ' + f'WHERE "{u}"."id" = {self.user.id}', + ) + self.assertEqual( + context.captured_queries[1]["sql"], + f'DELETE FROM "{c}" ' + f'WHERE ("{c}"."tenant_id" = {self.tenant.id} ' + f'AND "{c}"."user_id" = {self.user.id})', + ) + self.assertEqual( + context.captured_queries[2]["sql"], + f'DELETE FROM "{u}" ' + f'WHERE ("{u}"."tenant_id" = {self.tenant.id} ' + f'AND "{u}"."id" = {self.user.id})', + ) + + def test_delete_user_by_pk(self): + with CaptureQueriesContext(connection) as context: + result = User.objects.only("pk").filter(pk=self.user.pk).delete() + + self.assertEqual( + result, (2, {"composite_pk.User": 1, "composite_pk.Comment": 1}) + ) + + self.assertTrue(Tenant.objects.filter(id=self.tenant.id).exists()) + self.assertFalse(User.objects.filter(id=self.user.id).exists()) + self.assertFalse(Comment.objects.filter(id=self.comment.id).exists()) + + self.assertEqual(len(context.captured_queries), 3) + if connection.vendor in ("sqlite", "postgresql"): + u = User._meta.db_table + c = Comment._meta.db_table + self.assertEqual( + context.captured_queries[0]["sql"], + f'SELECT "{u}"."tenant_id", "{u}"."id" ' + f'FROM "{u}" ' + f'WHERE ("{u}"."tenant_id" = {self.tenant.id} ' + f'AND "{u}"."id" = {self.user.id})', + ) + self.assertEqual( + context.captured_queries[1]["sql"], + f'DELETE FROM "{c}" ' + f'WHERE ("{c}"."tenant_id" = {self.tenant.id} ' + f'AND "{c}"."user_id" = {self.user.id})', + ) + self.assertEqual( + context.captured_queries[2]["sql"], + f'DELETE FROM "{u}" ' + f'WHERE ("{u}"."tenant_id" = {self.tenant.id} ' + f'AND "{u}"."id" = {self.user.id})', + ) + + def test_delete_comments_by_user(self): + user = User.objects.create(pk=(self.tenant.id, 8259)) + comment_1 = Comment.objects.create(pk=(self.tenant.id, 1923), user=user) + comment_2 = Comment.objects.create(pk=(self.tenant.id, 8123), user=user) + comment_3 = Comment.objects.create(pk=(self.tenant.id, 8219), user=user) + + with CaptureQueriesContext(connection) as context: + result = Comment.objects.filter(user=user).delete() + + self.assertEqual(result, (3, {"composite_pk.Comment": 3})) + + self.assertFalse(Comment.objects.filter(id=comment_1.id).exists()) + self.assertFalse(Comment.objects.filter(id=comment_2.id).exists()) + self.assertFalse(Comment.objects.filter(id=comment_3.id).exists()) + + self.assertEqual(len(context.captured_queries), 1) + if connection.vendor in ("sqlite", "postgresql"): + c = Comment._meta.db_table + self.assertEqual( + context.captured_queries[0]["sql"], + f'DELETE FROM "{c}" ' + f'WHERE ("{c}"."tenant_id" = {self.tenant.id} ' + f'AND "{c}"."user_id" = 8259)', + ) diff --git a/tests/composite_pk/test_filter.py b/tests/composite_pk/test_filter.py new file mode 100644 index 0000000000000..94a83a20935bd --- /dev/null +++ b/tests/composite_pk/test_filter.py @@ -0,0 +1,156 @@ +from django.db import connection +from django.test import TestCase +from django.test.utils import CaptureQueriesContext + +from .models import Comment, Tenant, User + + +class CompositePKFilterTests(TestCase): + """ + Test the .filter(), .order_by(), .first(), .last(), .latest(), .earliest(), + .exclude() methods of composite_pk models. + """ + + maxDiff = None + + @classmethod + def setUpTestData(cls): + cls.tenant = Tenant.objects.create() + cls.user = User.objects.create(tenant=cls.tenant, id=1) + cls.comment = Comment.objects.create(tenant=cls.tenant, id=1, user=cls.user) + + def test_filter_and_count_user_by_pk(self): + test_cases = [ + {"pk": self.user.pk}, + {"pk": (self.tenant.id, self.user.id)}, + ] + + for lookup in test_cases: + with self.subTest(lookup=lookup): + with CaptureQueriesContext(connection) as context: + result = User.objects.filter(**lookup).count() + + self.assertEqual(result, 1) + self.assertEqual(len(context.captured_queries), 1) + if connection.vendor in ("sqlite", "postgresql"): + u = User._meta.db_table + self.assertEqual( + context.captured_queries[0]["sql"], + 'SELECT COUNT(*) AS "__count" ' + f'FROM "{u}" ' + f'WHERE ("{u}"."tenant_id" = {self.tenant.id} ' + f'AND "{u}"."id" = {self.user.id})', + ) + + def test_filter_comments_by_user_and_order_by_pk_asc(self): + user = User.objects.create(pk=(self.tenant.id, 2491)) + comment_1 = Comment.objects.create(pk=(self.tenant.id, 9471), user=user) + comment_2 = Comment.objects.create(pk=(self.tenant.id, 5128), user=user) + comment_3 = Comment.objects.create(pk=(self.tenant.id, 4823), user=user) + + with CaptureQueriesContext(connection) as context: + result = list(Comment.objects.filter(user=user).order_by("pk")) + + self.assertEqual(result, [comment_3, comment_2, comment_1]) + self.assertEqual(len(context.captured_queries), 1) + if connection.vendor in ("sqlite", "postgresql"): + c = Comment._meta.db_table + self.assertEqual( + context.captured_queries[0]["sql"], + f'SELECT "{c}"."tenant_id", "{c}"."id", "{c}"."user_id" ' + f'FROM "{c}" ' + f'WHERE ("{c}"."tenant_id" = {self.tenant.id} ' + f'AND "{c}"."user_id" = 2491) ' + f'ORDER BY "{c}"."tenant_id", "{c}"."id" ASC', + ) + + def test_filter_comments_by_user_and_order_by_pk_desc(self): + user = User.objects.create(pk=(self.tenant.id, 8316)) + comment_1 = Comment.objects.create(pk=(self.tenant.id, 3571), user=user) + comment_2 = Comment.objects.create(pk=(self.tenant.id, 7234), user=user) + comment_3 = Comment.objects.create(pk=(self.tenant.id, 1035), user=user) + + with CaptureQueriesContext(connection) as context: + result = list(Comment.objects.filter(user=user).order_by("-pk")) + + self.assertEqual(result, [comment_2, comment_1, comment_3]) + self.assertEqual(len(context.captured_queries), 1) + if connection.vendor in ("sqlite", "postgresql"): + c = Comment._meta.db_table + self.assertEqual( + context.captured_queries[0]["sql"], + f'SELECT "{c}"."tenant_id", "{c}"."id", "{c}"."user_id" ' + f'FROM "{c}" ' + f'WHERE ("{c}"."tenant_id" = {self.tenant.id} ' + f'AND "{c}"."user_id" = 8316) ' + f'ORDER BY "{c}"."tenant_id", "{c}"."id" DESC', + ) + + def test_filter_comments_by_user_and_order_by_pk(self): + user = User.objects.create(pk=(self.tenant.id, 9314)) + objs = [ + Comment.objects.create(pk=(self.tenant.id, 3931), user=user), + Comment.objects.create(pk=(self.tenant.id, 2912), user=user), + Comment.objects.create(pk=(self.tenant.id, 5312), user=user), + ] + + qs = Comment.objects.filter(user=user) + self.assertEqual(qs.latest("pk"), objs[2]) + self.assertEqual(qs.earliest("pk"), objs[1]) + self.assertEqual(qs.latest("-pk"), objs[1]) + self.assertEqual(qs.earliest("-pk"), objs[2]) + self.assertEqual(qs.order_by("pk").first(), objs[1]) + self.assertEqual(qs.order_by("pk").last(), objs[2]) + self.assertEqual(qs.order_by("-pk").first(), objs[2]) + self.assertEqual(qs.order_by("-pk").last(), objs[1]) + self.assertEqual(qs.latest("pk", "user"), objs[2]) + self.assertEqual(qs.earliest("pk", "user"), objs[1]) + self.assertEqual(qs.latest("-pk", "user"), objs[1]) + self.assertEqual(qs.earliest("-pk", "user"), objs[2]) + self.assertEqual(qs.order_by("pk", "user").first(), objs[1]) + self.assertEqual(qs.order_by("pk", "user").last(), objs[2]) + self.assertEqual(qs.order_by("-pk", "user").first(), objs[2]) + self.assertEqual(qs.order_by("-pk", "user").last(), objs[1]) + + def test_filter_comments_by_user_and_exclude_by_pk(self): + user = User.objects.create(pk=(self.tenant.id, 2491)) + comment_1 = Comment.objects.create(pk=(self.tenant.id, 9214), user=user) + comment_2 = Comment.objects.create(pk=(self.tenant.id, 3512), user=user) + comment_3 = Comment.objects.create(pk=(self.tenant.id, 7313), user=user) + + with CaptureQueriesContext(connection) as context: + comments = list( + Comment.objects.filter(user=user) + .exclude(pk=comment_2.pk) + .order_by("-pk") + ) + + self.assertEqual(comments, [comment_1, comment_3]) + self.assertEqual(len(context.captured_queries), 1) + if connection.vendor in ("sqlite", "postgresql"): + c = Comment._meta.db_table + self.assertEqual( + context.captured_queries[0]["sql"], + f'SELECT "{c}"."tenant_id", "{c}"."id", "{c}"."user_id" ' + f'FROM "{c}" WHERE (' + f'("{c}"."tenant_id" = {self.tenant.id} AND "{c}"."user_id" = 2491) ' + f"AND NOT (" + f'("{c}"."tenant_id" = {self.tenant.id} AND "{c}"."id" = 3512))) ' + f'ORDER BY "{c}"."tenant_id", "{c}"."id" DESC', + ) + + def test_contains_comment(self): + with CaptureQueriesContext(connection) as context: + result = Comment.objects.contains(self.comment) + + self.assertTrue(result) + self.assertEqual(len(context.captured_queries), 1) + if connection.vendor in ("sqlite", "postgresql"): + c = Comment._meta.db_table + self.assertEqual( + context.captured_queries[0]["sql"], + f'SELECT 1 AS "a" ' + f'FROM "{c}" ' + f'WHERE ("{c}"."tenant_id" = {self.tenant.id} AND "{c}"."id" = 1) ' + f"LIMIT 1", + ) diff --git a/tests/composite_pk/test_get.py b/tests/composite_pk/test_get.py new file mode 100644 index 0000000000000..38c99958ef2c9 --- /dev/null +++ b/tests/composite_pk/test_get.py @@ -0,0 +1,227 @@ +from django.db import connection +from django.db.models.query import MAX_GET_RESULTS +from django.test import TestCase +from django.test.utils import CaptureQueriesContext + +from .models import Comment, Tenant, User + + +class CompositePKGetTests(TestCase): + """ + Test the .get(), .get_or_create() methods of composite_pk models. + """ + + maxDiff = None + + @classmethod + def setUpTestData(cls): + cls.tenant = Tenant.objects.create() + cls.user = User.objects.create(tenant=cls.tenant, id=1) + cls.comment = Comment.objects.create(tenant=cls.tenant, id=1, user=cls.user) + + def test_get_tenant_by_pk(self): + test_cases = [ + {"id": self.tenant.id}, + {"pk": self.tenant.pk}, + ] + + for lookup in test_cases: + with self.subTest(lookup=lookup): + with CaptureQueriesContext(connection) as context: + obj = Tenant.objects.get(**lookup) + + self.assertEqual(obj, self.tenant) + self.assertEqual(len(context.captured_queries), 1) + if connection.vendor in ("sqlite", "postgresql"): + t = Tenant._meta.db_table + self.assertEqual( + context.captured_queries[0]["sql"], + f'SELECT "{t}"."id" ' + f'FROM "{t}" ' + f'WHERE "{t}"."id" = {self.tenant.id} ' + f"LIMIT {MAX_GET_RESULTS}", + ) + + def test_get_user_by_pk(self): + test_cases = [ + {"pk": (self.tenant.id, self.user.id)}, + {"pk": self.user.pk}, + ] + + for lookup in test_cases: + with self.subTest(lookup=lookup): + with CaptureQueriesContext(connection) as context: + obj = User.objects.only("pk").get(**lookup) + + self.assertEqual(obj, self.user) + self.assertEqual(len(context.captured_queries), 1) + if connection.vendor in ("sqlite", "postgresql"): + u = User._meta.db_table + self.assertEqual( + context.captured_queries[0]["sql"], + f'SELECT "{u}"."tenant_id", "{u}"."id" ' + f'FROM "{u}" ' + f'WHERE ("{u}"."tenant_id" = {self.tenant.id} ' + f'AND "{u}"."id" = {self.user.id}) ' + f"LIMIT {MAX_GET_RESULTS}", + ) + + def test_get_user_by_field(self): + test_cases = [ + ({"id": self.user.id}, "id", self.user.id), + ({"tenant": self.tenant}, "tenant_id", self.tenant.id), + ({"tenant_id": self.tenant.id}, "tenant_id", self.tenant.id), + ({"tenant__id": self.tenant.id}, "tenant_id", self.tenant.id), + ({"tenant__pk": self.tenant.id}, "tenant_id", self.tenant.id), + ] + + for lookup, column, value in test_cases: + with self.subTest(lookup=lookup, column=column, value=value): + with CaptureQueriesContext(connection) as context: + obj = User.objects.only("pk").get(**lookup) + + self.assertEqual(obj, self.user) + self.assertEqual(len(context.captured_queries), 1) + if connection.vendor in ("sqlite", "postgresql"): + u = User._meta.db_table + self.assertEqual( + context.captured_queries[0]["sql"], + f'SELECT "{u}"."tenant_id", "{u}"."id" ' + f'FROM "{u}" ' + f'WHERE "{u}"."{column}" = {value} ' + f"LIMIT {MAX_GET_RESULTS}", + ) + + def test_get_comment_by_pk(self): + with CaptureQueriesContext(connection) as context: + obj = Comment.objects.get(pk=(self.tenant.id, self.comment.id)) + + self.assertEqual(obj, self.comment) + self.assertEqual(len(context.captured_queries), 1) + if connection.vendor in ("sqlite", "postgresql"): + c = Comment._meta.db_table + self.assertEqual( + context.captured_queries[0]["sql"], + f'SELECT "{c}"."tenant_id", "{c}"."id", "{c}"."user_id" ' + f'FROM "{c}" ' + f'WHERE ("{c}"."tenant_id" = {self.tenant.id} ' + f'AND "{c}"."id" = {self.comment.id}) ' + f"LIMIT {MAX_GET_RESULTS}", + ) + + def test_get_comment_by_field(self): + test_cases = [ + ({"id": self.comment.id}, "id", self.comment.id), + ({"user_id": self.user.id}, "user_id", self.user.id), + ({"user__id": self.user.id}, "user_id", self.user.id), + ({"tenant": self.tenant}, "tenant_id", self.tenant.id), + ({"tenant_id": self.tenant.id}, "tenant_id", self.tenant.id), + ({"tenant__id": self.tenant.id}, "tenant_id", self.tenant.id), + ({"tenant__pk": self.tenant.id}, "tenant_id", self.tenant.id), + ] + + for lookup, column, value in test_cases: + with self.subTest(lookup=lookup, column=column, value=value): + with CaptureQueriesContext(connection) as context: + obj = Comment.objects.get(**lookup) + + self.assertEqual(obj, self.comment) + self.assertEqual(len(context.captured_queries), 1) + if connection.vendor in ("sqlite", "postgresql"): + c = Comment._meta.db_table + self.assertEqual( + context.captured_queries[0]["sql"], + f'SELECT "{c}"."tenant_id", "{c}"."id", "{c}"."user_id" ' + f'FROM "{c}" ' + f'WHERE "{c}"."{column}" = {value} ' + f"LIMIT {MAX_GET_RESULTS}", + ) + + def test_get_comment_by_user(self): + with CaptureQueriesContext(connection) as context: + obj = Comment.objects.get(user=self.user) + + self.assertEqual(obj, self.comment) + self.assertEqual(len(context.captured_queries), 1) + if connection.vendor in ("sqlite", "postgresql"): + c = Comment._meta.db_table + self.assertEqual( + context.captured_queries[0]["sql"], + f'SELECT "{c}"."tenant_id", "{c}"."id", "{c}"."user_id" ' + f'FROM "{c}" ' + f'WHERE ("{c}"."tenant_id" = {self.tenant.id} ' + f'AND "{c}"."user_id" = {self.user.id}) ' + f"LIMIT {MAX_GET_RESULTS}", + ) + + def test_get_comment_by_user_pk(self): + with CaptureQueriesContext(connection) as context: + obj = Comment.objects.get(user__pk=self.user.pk) + + self.assertEqual(obj, self.comment) + self.assertEqual(len(context.captured_queries), 1) + if connection.vendor in ("sqlite", "postgresql"): + c = Comment._meta.db_table + u = User._meta.db_table + self.assertEqual( + context.captured_queries[0]["sql"], + f'SELECT "{c}"."tenant_id", "{c}"."id", "{c}"."user_id" ' + f'FROM "{c}" ' + f'INNER JOIN "{u}" ON ("{c}"."tenant_id" = "{u}"."tenant_id" ' + f'AND "{c}"."user_id" = "{u}"."id") ' + f'WHERE ("{u}"."tenant_id" = {self.tenant.id} ' + f'AND "{u}"."id" = {self.user.id}) ' + f"LIMIT {MAX_GET_RESULTS}", + ) + + def test_get_comment_by_pk_only_pk(self): + with CaptureQueriesContext(connection) as context: + obj = Comment.objects.only("pk").get(pk=self.comment.pk) + + self.assertEqual(obj, self.comment) + self.assertEqual(len(context.captured_queries), 1) + if connection.vendor in ("sqlite", "postgresql"): + c = Comment._meta.db_table + self.assertEqual( + context.captured_queries[0]["sql"], + f'SELECT "{c}"."tenant_id", "{c}"."id" ' + f'FROM "{c}" ' + f'WHERE ("{c}"."tenant_id" = {self.tenant.id} ' + f'AND "{c}"."id" = {self.user.id}) ' + f"LIMIT {MAX_GET_RESULTS}", + ) + + def test_get_or_create_user_by_pk(self): + user, created = User.objects.get_or_create(pk=self.user.pk) + + self.assertFalse(created) + self.assertEqual(1, User.objects.all().count()) + self.assertEqual(user, self.user) + + def test_lookup_errors(self): + with self.assertRaisesMessage( + ValueError, "The right-hand side of the 'exact' lookup must be an iterable" + ): + Comment.objects.get(pk=1) + with self.assertRaisesMessage( + ValueError, + "The left-hand side and right-hand side of the 'exact' " + "lookup must have the same number of elements", + ): + Comment.objects.get(pk=(1, 2, 3)) + with self.assertRaisesMessage( + ValueError, "The right-hand side of the 'in' lookup must be an iterable" + ): + Comment.objects.get(pk__in=1) + with self.assertRaisesMessage( + ValueError, + "The right-hand side of the 'in' lookup must be an iterable " + "of iterables", + ): + Comment.objects.get(pk__in=(1, 2, 3)) + with self.assertRaisesMessage( + ValueError, + "The left-hand side and right-hand side of the 'in' lookup must " + "have the same number of elements", + ): + Comment.objects.get(pk__in=((1, 2, 3),)) diff --git a/tests/composite_pk/test_update.py b/tests/composite_pk/test_update.py new file mode 100644 index 0000000000000..7220b57e6b267 --- /dev/null +++ b/tests/composite_pk/test_update.py @@ -0,0 +1,135 @@ +import unittest + +from django.db import connection +from django.test import TestCase +from django.test.utils import CaptureQueriesContext + +from .models import Comment, Tenant, User + + +class CompositePKUpdateTests(TestCase): + """ + Test the .update(), .save(), .bulk_update(), .update_or_create() methods of + composite_pk models. + """ + + maxDiff = None + + @classmethod + def setUpTestData(cls): + cls.tenant = Tenant.objects.create() + cls.user = User.objects.create( + tenant=cls.tenant, id=1, email="user1@example.com" + ) + cls.comment = Comment.objects.create(tenant=cls.tenant, id=1, user=cls.user) + + def test_update_user(self): + with CaptureQueriesContext(connection) as context: + result = User.objects.filter(pk=self.user.pk).update(id=8341) + + self.assertEqual(result, 1) + self.assertFalse(User.objects.filter(pk=self.user.pk).exists()) + self.assertEqual(User.objects.all().count(), 1) + user = User.objects.get(pk=(self.tenant.id, 8341)) + self.assertEqual(user.tenant, self.tenant) + self.assertEqual(user.tenant_id, self.tenant.id) + self.assertEqual(user.id, 8341) + self.assertEqual(len(context.captured_queries), 1) + if connection.vendor in ("sqlite", "postgresql"): + u = User._meta.db_table + self.assertEqual( + context.captured_queries[0]["sql"], + f'UPDATE "{u}" ' + 'SET "id" = 8341 ' + f'WHERE ("{u}"."tenant_id" = {self.tenant.id} ' + f'AND "{u}"."id" = {self.user.id})', + ) + + def test_save_comment(self): + comment = Comment.objects.get(pk=self.comment.pk) + comment.user = User.objects.create(tenant=self.tenant, id=8214) + + with CaptureQueriesContext(connection) as context: + comment.save() + + self.assertEqual(Comment.objects.all().count(), 1) + self.assertEqual(len(context.captured_queries), 1) + if connection.vendor in ("sqlite", "postgresql"): + c = Comment._meta.db_table + self.assertEqual( + context.captured_queries[0]["sql"], + f'UPDATE "{c}" ' + f'SET "tenant_id" = {self.tenant.id}, "id" = {self.comment.id}, ' + f'"user_id" = 8214 ' + f'WHERE ("{c}"."tenant_id" = {self.tenant.id} ' + f'AND "{c}"."id" = {self.comment.id})', + ) + + @unittest.skipUnless(connection.vendor == "sqlite", "SQLite specific test") + def test_bulk_update_comments_in_sqlite(self): + user_1 = User.objects.create(pk=(self.tenant.id, 1352)) + user_2 = User.objects.create(pk=(self.tenant.id, 9314)) + comment_1 = Comment.objects.create(pk=(self.tenant.id, 1934), user=user_1) + comment_2 = Comment.objects.create(pk=(self.tenant.id, 8314), user=user_1) + comment_3 = Comment.objects.create(pk=(self.tenant.id, 9214), user=user_1) + comment_1.user = user_2 + comment_2.user = user_2 + comment_3.user = user_2 + + with CaptureQueriesContext(connection) as context: + result = Comment.objects.bulk_update( + [comment_1, comment_2, comment_3], ["user_id"] + ) + + self.assertEqual(result, 3) + self.assertEqual(len(context.captured_queries), 1) + c = Comment._meta.db_table + self.assertEqual( + context.captured_queries[0]["sql"], + f'UPDATE "{c}" ' + f'SET "user_id" = CASE ' + f'WHEN (("{c}"."tenant_id" = {self.tenant.id} AND "{c}"."id" = 1934)) ' + f"THEN 9314 " + f'WHEN (("{c}"."tenant_id" = {self.tenant.id} AND "{c}"."id" = 8314)) ' + f"THEN 9314 " + f'WHEN (("{c}"."tenant_id" = {self.tenant.id} AND "{c}"."id" = 9214)) ' + f"THEN 9314 ELSE NULL END " + f'WHERE (("{c}"."tenant_id" = {self.tenant.id} AND "{c}"."id" = 1934) ' + f'OR ("{c}"."tenant_id" = {self.tenant.id} AND "{c}"."id" = 8314) ' + f'OR ("{c}"."tenant_id" = {self.tenant.id} AND "{c}"."id" = 9214))', + ) + + def test_update_or_create_user_by_pk(self): + user, created = User.objects.update_or_create(pk=self.user.pk) + + self.assertFalse(created) + self.assertEqual(1, User.objects.all().count()) + self.assertEqual(user.pk, self.user.pk) + self.assertEqual(user.tenant_id, self.tenant.id) + self.assertEqual(user.id, self.user.id) + + def test_update_comment(self): + with CaptureQueriesContext(connection) as context: + result = Comment.objects.filter(user__email=self.user.email).update(id=2914) + + self.assertEqual(result, 1) + self.assertEqual(len(context.captured_queries), 2) + if connection.vendor in ("sqlite", "postgresql"): + u = User._meta.db_table + c = Comment._meta.db_table + self.assertEqual( + context.captured_queries[0]["sql"], + f'SELECT "{c}"."tenant_id", "{c}"."id" ' + f'FROM "{c}" ' + f'INNER JOIN "{u}" ' + f'ON ("{c}"."tenant_id" = "{u}"."tenant_id" ' + f'AND "{c}"."user_id" = "{u}"."id") ' + f'WHERE "{u}"."email" = \'{self.user.email}\'', + ) + self.assertEqual( + context.captured_queries[1]["sql"], + f'UPDATE "{c}" ' + f'SET "id" = 2914 ' + f'WHERE ("{c}"."tenant_id" = {self.tenant.id} ' + f'AND "{c}"."id" = {self.comment.id})', + ) diff --git a/tests/composite_pk/tests.py b/tests/composite_pk/tests.py new file mode 100644 index 0000000000000..ef4082595b149 --- /dev/null +++ b/tests/composite_pk/tests.py @@ -0,0 +1,215 @@ +import unittest + +from django.db import connection +from django.db.models.query_utils import PathInfo +from django.db.models.sql import Query +from django.test import TestCase + +from .models import Comment, Tenant, User + + +def get_constraints(table): + with connection.cursor() as cursor: + return connection.introspection.get_constraints(cursor, table) + + +class CompositePKTests(TestCase): + maxDiff = None + + @classmethod + def setUpTestData(cls): + cls.tenant = Tenant.objects.create() + cls.user = User.objects.create(tenant=cls.tenant, id=1) + cls.comment = Comment.objects.create(tenant=cls.tenant, id=1, user=cls.user) + + def test_fields(self): + self.assertIsInstance(self.tenant.pk, int) + self.assertGreater(self.tenant.id, 0) + self.assertEqual(self.tenant.pk, self.tenant.id) + + self.assertIsInstance(self.user.id, int) + self.assertGreater(self.user.id, 0) + self.assertEqual(self.user.tenant_id, self.tenant.id) + self.assertEqual(self.user.pk, (self.user.tenant_id, self.user.id)) + self.assertEqual(self.user.primary_key, self.user.pk) + + self.assertIsInstance(self.comment.id, int) + self.assertGreater(self.comment.id, 0) + self.assertEqual(self.comment.user_id, self.user.id) + self.assertEqual(self.comment.tenant_id, self.tenant.id) + self.assertEqual(self.comment.pk, (self.comment.tenant_id, self.comment.id)) + self.assertEqual(self.comment.primary_key, self.comment.pk) + + def test_pk_updated_if_field_updated(self): + user = User.objects.get(pk=self.user.pk) + self.assertEqual(user.pk, (self.tenant.id, self.user.id)) + user.tenant_id = 9831 + self.assertEqual(user.pk, (9831, self.user.id)) + user.id = 4321 + self.assertEqual(user.pk, (9831, 4321)) + user.pk = (9132, 3521) + self.assertEqual(user.tenant_id, 9132) + self.assertEqual(user.id, 3521) + + def test_composite_pk_in_fields(self): + user_fields = {f.name for f in User._meta.get_fields()} + self.assertEqual(user_fields, {"id", "tenant", "primary_key", "email"}) + + comment_fields = {f.name for f in Comment._meta.get_fields()} + self.assertEqual( + comment_fields, {"id", "tenant", "user_id", "user", "primary_key"} + ) + + def test_error_on_pk_conflict(self): + with self.assertRaises(Exception): + User.objects.create(tenant=self.tenant, id=self.user.id) + with self.assertRaises(Exception): + Comment.objects.create(tenant=self.tenant, id=self.comment.id) + + @unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific test") + def test_pk_constraints_in_postgresql(self): + user_constraints = get_constraints(User._meta.db_table) + user_pk = user_constraints["composite_pk_user_pkey"] + self.assertEqual(user_pk["columns"], ["tenant_id", "id"]) + self.assertTrue(user_pk["primary_key"]) + + comment_constraints = get_constraints(Comment._meta.db_table) + comment_pk = comment_constraints["composite_pk_comment_pkey"] + self.assertEqual(comment_pk["columns"], ["tenant_id", "id"]) + self.assertTrue(comment_pk["primary_key"]) + + @unittest.skipUnless(connection.vendor == "sqlite", "SQLite specific test") + def test_pk_constraints_in_sqlite(self): + user_constraints = get_constraints(User._meta.db_table) + user_pk = user_constraints["__primary__"] + self.assertEqual(user_pk["columns"], ["tenant_id", "id"]) + self.assertTrue(user_pk["primary_key"]) + + comment_constraints = get_constraints(Comment._meta.db_table) + comment_pk = comment_constraints["__primary__"] + self.assertEqual(comment_pk["columns"], ["tenant_id", "id"]) + self.assertTrue(comment_pk["primary_key"]) + + def test_in_bulk(self): + """ + Test the .in_bulk() method of composite_pk models. + """ + result = Comment.objects.in_bulk() + self.assertEqual(result, {self.comment.pk: self.comment}) + + result = Comment.objects.in_bulk([self.comment.pk]) + self.assertEqual(result, {self.comment.pk: self.comment}) + + def test_iterator(self): + """ + Test the .iterator() method of composite_pk models. + """ + result = list(Comment.objects.iterator()) + self.assertEqual(result, [self.comment]) + + +class NamesToPathTests(TestCase): + def test_id(self): + query = Query(User) + path, final_field, targets, rest = query.names_to_path(["id"], User._meta) + + self.assertEqual(path, []) + self.assertEqual(final_field, User._meta.get_field("id")) + self.assertEqual(targets, (User._meta.get_field("id"),)) + self.assertEqual(rest, []) + + def test_pk(self): + query = Query(User) + path, final_field, targets, rest = query.names_to_path(["pk"], User._meta) + + self.assertEqual(path, []) + self.assertEqual(final_field, User._meta.get_field("primary_key")) + self.assertEqual(targets, (User._meta.get_field("primary_key"),)) + self.assertEqual(rest, []) + + def test_tenant_id(self): + query = Query(User) + path, final_field, targets, rest = query.names_to_path( + ["tenant", "id"], User._meta + ) + + self.assertEqual( + path, + [ + PathInfo( + from_opts=User._meta, + to_opts=Tenant._meta, + target_fields=(Tenant._meta.get_field("id"),), + join_field=User._meta.get_field("tenant"), + m2m=False, + direct=True, + filtered_relation=None, + ), + ], + ) + self.assertEqual(final_field, Tenant._meta.get_field("id")) + self.assertEqual(targets, (Tenant._meta.get_field("id"),)) + self.assertEqual(rest, []) + + def test_user_id(self): + query = Query(Comment) + path, final_field, targets, rest = query.names_to_path( + ["user", "id"], Comment._meta + ) + + self.assertEqual( + path, + [ + PathInfo( + from_opts=Comment._meta, + to_opts=User._meta, + target_fields=( + User._meta.get_field("tenant"), + User._meta.get_field("id"), + ), + join_field=Comment._meta.get_field("user"), + m2m=False, + direct=True, + filtered_relation=None, + ), + ], + ) + self.assertEqual(final_field, User._meta.get_field("id")) + self.assertEqual(targets, (User._meta.get_field("id"),)) + self.assertEqual(rest, []) + + def test_user_tenant_id(self): + query = Query(Comment) + path, final_field, targets, rest = query.names_to_path( + ["user", "tenant", "id"], Comment._meta + ) + + self.assertEqual( + path, + [ + PathInfo( + from_opts=Comment._meta, + to_opts=User._meta, + target_fields=( + User._meta.get_field("tenant"), + User._meta.get_field("id"), + ), + join_field=Comment._meta.get_field("user"), + m2m=False, + direct=True, + filtered_relation=None, + ), + PathInfo( + from_opts=User._meta, + to_opts=Tenant._meta, + target_fields=(Tenant._meta.get_field("id"),), + join_field=User._meta.get_field("tenant"), + m2m=False, + direct=True, + filtered_relation=None, + ), + ], + ) + self.assertEqual(final_field, Tenant._meta.get_field("id")) + self.assertEqual(targets, (Tenant._meta.get_field("id"),)) + self.assertEqual(rest, []) diff --git a/tests/invalid_models_tests/test_models.py b/tests/invalid_models_tests/test_models.py index 8b6d705acb697..4d139539b4013 100644 --- a/tests/invalid_models_tests/test_models.py +++ b/tests/invalid_models_tests/test_models.py @@ -2815,3 +2815,22 @@ class Meta: ] self.assertEqual(Bar.check(databases=self.databases), []) + + def test_composite_pk_with_primary_key_set_to_true(self): + class Model(models.Model): + id_1 = models.IntegerField(primary_key=True) + id_2 = models.IntegerField() + + class Meta: + primary_key = ("id_1", "id_2") + + self.assertEqual( + Model.check(databases=self.databases), + [ + Error( + "primary_key=True must not be set if Meta.primary_key is defined.", + obj=Model, + id="models.E042", + ), + ], + )