Skip to content

Commit

Permalink
Fixed #373 -- Added CompositeField-based Meta.primary_key.
Browse files Browse the repository at this point in the history
  • Loading branch information
csirmazbendeguz committed Apr 20, 2024
1 parent 53719d6 commit 4be1c68
Show file tree
Hide file tree
Showing 19 changed files with 1,439 additions and 7 deletions.
13 changes: 13 additions & 0 deletions django/db/backends/base/schema.py
Expand Up @@ -105,6 +105,7 @@ class BaseDatabaseSchemaEditor:
sql_check_constraint = "CHECK (%(check)s)"
sql_delete_constraint = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
sql_constraint = "CONSTRAINT %(name)s %(constraint)s"
sql_pk_constraint = "PRIMARY KEY (%(columns)s)"

sql_create_check = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s CHECK (%(check)s)"
sql_delete_check = sql_delete_constraint
Expand Down Expand Up @@ -268,6 +269,13 @@ def table_sql(self, model):
constraint.constraint_sql(model, self)
for constraint in model._meta.constraints
]

# If the model defines Meta.primary_key, add the primary key constraint
# to the table definition.
# It's expected primary_key=True isn't set on any fields (see E042).
if model._meta.primary_key:
constraints.append(self._pk_constraint_sql(model._meta.primary_key))

sql = self.sql_create_table % {
"table": self.quote_name(model._meta.db_table),
"definition": ", ".join(
Expand Down Expand Up @@ -1967,6 +1975,11 @@ def _constraint_names(
result.append(name)
return result

def _pk_constraint_sql(self, fields):
return self.sql_pk_constraint % {
"columns": ", ".join(self.quote_name(field) for field in fields)
}

def _delete_primary_key(self, model, strict=False):
constraint_names = self._constraint_names(model, primary_key=True)
if strict and len(constraint_names) != 1:
Expand Down
30 changes: 29 additions & 1 deletion django/db/models/base.py
Expand Up @@ -30,6 +30,7 @@
from django.db.models.constants import LOOKUP_SEP
from django.db.models.deletion import CASCADE, Collector
from django.db.models.expressions import DatabaseDefault
from django.db.models.fields.composite import is_pk_set
from django.db.models.fields.related import (
ForeignObjectRel,
OneToOneField,
Expand Down Expand Up @@ -1080,7 +1081,7 @@ def _save_table(
if pk_val is None:
pk_val = meta.pk.get_pk_value_on_save(self)
setattr(self, meta.pk.attname, pk_val)
pk_set = pk_val is not None
pk_set = is_pk_set(pk_val)
if not pk_set and (force_update or update_fields):
raise ValueError("Cannot force an update in save() with no primary key.")
updated = False
Expand Down Expand Up @@ -1686,6 +1687,7 @@ def check(cls, **kwargs):
*cls._check_constraints(databases),
*cls._check_default_pk(),
*cls._check_db_table_comment(databases),
*cls._check_composite_pk(),
]

return errors
Expand All @@ -1694,6 +1696,9 @@ def check(cls, **kwargs):
def _check_default_pk(cls):
if (
not cls._meta.abstract
# If the model defines Meta.primary_key, the check should be skipped,
# since there's no default primary key.
and not cls._meta.primary_key
and cls._meta.pk.auto_created
and
# Inherited PKs are checked in parents models.
Expand Down Expand Up @@ -1722,6 +1727,24 @@ def _check_default_pk(cls):
]
return []

@classmethod
def _check_composite_pk(cls):
errors = []

if cls._meta.primary_key and any(
field for field in cls._meta.fields if field.primary_key
):
errors.append(
checks.Error(
"primary_key=True must not be set if Meta.primary_key "
"is defined.",
obj=cls,
id="models.E042",
)
)

return errors

@classmethod
def _check_db_table_comment(cls, databases):
if not cls._meta.db_table_comment:
Expand Down Expand Up @@ -1842,6 +1865,11 @@ def _check_m2m_through_same_relationship(cls):
@classmethod
def _check_id_field(cls):
"""Check if `id` field is a primary key."""
# If the model defines Meta.primary_key, the check should be skipped,
# since primary_key=True can't be set on any fields (including `id`).
if cls._meta.primary_key:
return []

fields = [
f for f in cls._meta.local_fields if f.name == "id" and f != cls._meta.pk
]
Expand Down
7 changes: 6 additions & 1 deletion django/db/models/fields/__init__.py
Expand Up @@ -2794,6 +2794,11 @@ def check(self, **kwargs):
]

def _check_primary_key(self):
# If the model defines Meta.primary_key, primary_key=True can't be set on
# any field (including AutoFields).
if self.model._meta.primary_key:
return []

if not self.primary_key:
return [
checks.Error(
Expand All @@ -2808,7 +2813,7 @@ def _check_primary_key(self):
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
del kwargs["blank"]
kwargs["primary_key"] = True
kwargs["primary_key"] = self.primary_key
return name, path, args, kwargs

def validate(self, value, model_instance):
Expand Down
176 changes: 176 additions & 0 deletions django/db/models/fields/composite.py
@@ -0,0 +1,176 @@
from collections.abc import Iterable

from django.core.exceptions import FieldDoesNotExist
from django.db.models import Field
from django.db.models.expressions import Col, Expression
from django.db.models.lookups import Exact, In
from django.db.models.signals import class_prepared
from django.utils.functional import cached_property


class TupleExact(Exact):
def get_prep_lookup(self):
if not isinstance(self.lhs, Cols):
raise ValueError(
"The left-hand side of the 'exact' lookup must be an instance of Cols"
)
if not isinstance(self.rhs, Iterable):
raise ValueError(
"The right-hand side of the 'exact' lookup must be an iterable"
)
if len(list(self.lhs)) != len(list(self.rhs)):
raise ValueError(
"The left-hand side and right-hand side of the 'exact' lookup must "
"have the same number of elements"
)

return super().get_prep_lookup()

def as_sql(self, compiler, connection):
from django.db.models.sql.where import AND, WhereNode

cols = self.lhs.get_source_expressions()
exprs = [Exact(col, val) for col, val in zip(cols, self.rhs)]

return compiler.compile(WhereNode(exprs, connector=AND))


class TupleIn(In):
def get_prep_lookup(self):
if not isinstance(self.lhs, Cols):
raise ValueError(
"The left-hand side of the 'in' lookup must be an instance of Cols"
)
if not isinstance(self.rhs, Iterable):
raise ValueError(
"The right-hand side of the 'in' lookup must be an iterable"
)
if not all(isinstance(vals, Iterable) for vals in self.rhs):
raise ValueError(
"The right-hand side of the 'in' lookup must be an iterable of "
"iterables"
)
lhs_len = len(tuple(self.lhs))
if not all(lhs_len == len(tuple(vals)) for vals in self.rhs):
raise ValueError(
"The left-hand side and right-hand side of the 'in' lookup must "
"have the same number of elements"
)

return super().get_prep_lookup()

def as_sql(self, compiler, connection):
from django.db.models.sql.where import AND, OR, WhereNode

exprs = []
cols = self.lhs.get_source_expressions()

for vals in self.rhs:
exprs.append(
WhereNode(
[Exact(col, val) for col, val in zip(cols, vals)], connector=AND
)
)

return compiler.compile(WhereNode(exprs, connector=OR))


class Cols(Expression):
def __init__(self, alias, targets, output_field):
super().__init__(output_field=output_field)
self.alias, self.targets = alias, targets

def get_source_expressions(self):
return [Col(self.alias, target) for target in self.targets]

def set_source_expressions(self, exprs):
assert all(isinstance(expr, Col) for expr in exprs)
assert len(exprs) == len(self.targets)

def as_sql(self, compiler, connection):
sqls = []
cols = self.get_source_expressions()

for col in cols:
sql, _ = col.as_sql(compiler, connection)
sqls.append(sql)

return ", ".join(sqls), []

def __iter__(self):
return iter(self.get_source_expressions())


def is_pk_not_set(pk):
return pk is None or (isinstance(pk, tuple) and any(f is None for f in pk))


def is_pk_set(pk):
return not is_pk_not_set(pk)


class CompositeAttribute:
def __init__(self, field):
self.field = field

def __get__(self, instance, cls=None):
return tuple(
getattr(instance, field_name) for field_name in self.field.field_names
)

def __set__(self, instance, values):
if values is None:
values = (None,) * len(self.field.field_names)

for field_name, value in zip(self.field.field_names, values):
setattr(instance, field_name, value)


class CompositeField(Field):
descriptor_class = CompositeAttribute

def __init__(self, *args, **kwargs):
kwargs["db_column"] = None
kwargs["editable"] = False
super().__init__(**kwargs)
self.field_names = args
self.fields = None

def contribute_to_class(self, cls, name, **_):
super().contribute_to_class(cls, name, private_only=True)
cls._meta.pk = self
setattr(cls, self.attname, self.descriptor_class(self))

def get_attname_column(self):
return self.get_attname(), self.db_column

def __iter__(self):
return iter(self.fields)

@cached_property
def cached_col(self):
return Cols(self.model._meta.db_table, self.fields, self)

def get_col(self, alias, output_field=None):
return self.cached_col

def get_lookup(self, lookup_name):
if lookup_name == "exact":
return TupleExact
elif lookup_name == "in":
return TupleIn

return super().get_lookup(lookup_name)


def resolve_fields(*args, **kwargs):
meta = kwargs["sender"]._meta
for field in meta.private_fields:
if isinstance(field, CompositeField) and field.fields is None:
try:
field.fields = tuple(meta.get_field(name) for name in field.field_names)
except FieldDoesNotExist:
continue


class_prepared.connect(resolve_fields)
10 changes: 10 additions & 0 deletions django/db/models/fields/related.py
Expand Up @@ -615,6 +615,16 @@ def _check_unique_target(self):
if not self.foreign_related_fields:
return []

# If a model defines Meta.primary_key and a foreign key refers to it,
# the check should be skipped (since primary keys are unique).
pk = self.remote_field.model._meta.primary_key
if pk:
pk = set(pk)
if pk == {f.attname for f in self.foreign_related_fields}:
return []
elif pk == {f.name for f in self.foreign_related_fields}:
return []

has_unique_constraint = any(
rel_field.unique for rel_field in self.foreign_related_fields
)
Expand Down
8 changes: 7 additions & 1 deletion django/db/models/options.py
Expand Up @@ -7,6 +7,7 @@
from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured
from django.db import connections
from django.db.models import AutoField, Manager, OrderWrt, UniqueConstraint
from django.db.models.fields.composite import CompositeField
from django.db.models.query_utils import PathInfo
from django.utils.datastructures import ImmutableList, OrderedSet
from django.utils.functional import cached_property
Expand All @@ -24,6 +25,7 @@
)

DEFAULT_NAMES = (
"primary_key",
"verbose_name",
"verbose_name_plural",
"db_table",
Expand Down Expand Up @@ -106,6 +108,7 @@ def __init__(self, meta, app_label=None):
self.base_manager_name = None
self.default_manager_name = None
self.model_name = None
self.primary_key = None
self.verbose_name = None
self.verbose_name_plural = None
self.db_table = ""
Expand Down Expand Up @@ -296,7 +299,10 @@ def _prepare(self, model):
self.order_with_respect_to = None

if self.pk is None:
if self.parents:
if self.primary_key:
pk = CompositeField(*self.primary_key)
model.add_to_class("primary_key", pk)
elif self.parents:
# Promote the first parent link in lieu of adding yet another
# field.
field = next(iter(self.parents.values()))
Expand Down
5 changes: 4 additions & 1 deletion django/db/models/query.py
Expand Up @@ -24,6 +24,7 @@
from django.db.models.constants import LOOKUP_SEP, OnConflict
from django.db.models.deletion import Collector
from django.db.models.expressions import Case, F, Value, When
from django.db.models.fields.composite import is_pk_not_set
from django.db.models.functions import Cast, Trunc
from django.db.models.query_utils import FilteredRelation, Q
from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE
Expand Down Expand Up @@ -813,7 +814,9 @@ def bulk_create(
objs = list(objs)
self._prepare_for_bulk_create(objs)
with transaction.atomic(using=self.db, savepoint=False):
objs_with_pk, objs_without_pk = partition(lambda o: o.pk is None, objs)
objs_with_pk, objs_without_pk = partition(
lambda o: is_pk_not_set(o.pk), objs
)
if objs_with_pk:
returned_columns = self._batched_insert(
objs_with_pk,
Expand Down

0 comments on commit 4be1c68

Please sign in to comment.