Skip to content

Commit

Permalink
Fixed #27452 -- Added serial fields to django.contrib.postgres
Browse files Browse the repository at this point in the history
  • Loading branch information
csirmazbendeguz committed May 6, 2024
1 parent 9a27c76 commit 4c086eb
Show file tree
Hide file tree
Showing 14 changed files with 1,307 additions and 57 deletions.
1 change: 1 addition & 0 deletions django/contrib/postgres/fields/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .hstore import * # NOQA
from .jsonb import * # NOQA
from .ranges import * # NOQA
from .serial import * # NOQA
86 changes: 86 additions & 0 deletions django/contrib/postgres/fields/serial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from django.core import checks
from django.db import models
from django.db.models import NOT_PROVIDED
from django.db.models.expressions import DatabaseDefault
from django.utils.translation import gettext_lazy as _

__all__ = ("BigSerialField", "SmallSerialField", "SerialField")


class SerialFieldMixin:
db_returning = True

def __init__(self, *args, **kwargs):
kwargs["blank"] = True
super().__init__(*args, **kwargs)

def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
kwargs.pop("blank")
return name, path, args, kwargs

def check(self, **kwargs):
return [
*super().check(**kwargs),
*self._check_not_null(),
*self._check_default(),
]

def _check_not_null(self):
if self.null:
return [
checks.Error(
"SerialFields do not accept null values.",
obj=self,
id="fields.E910",
),
]
else:
return []

def _check_default(self):
if self.default is NOT_PROVIDED:
return []
return [
checks.Error(
"SerialFields do not accept default values.",
obj=self,
id="fields.E911",
),
]

def get_prep_value(self, value):
value = super().get_prep_value(value)
if value is None:
value = DatabaseDefault()
return value


class BigSerialField(SerialFieldMixin, models.BigIntegerField):
description = _("Big serial")

def get_internal_type(self):
return "BigSerialField"

def db_type(self, connection):
return "bigserial"


class SmallSerialField(SerialFieldMixin, models.SmallIntegerField):
description = _("Small serial")

def get_internal_type(self):
return "SmallSerialField"

def db_type(self, connection):
return "smallserial"


class SerialField(SerialFieldMixin, models.IntegerField):
description = _("Serial")

def get_internal_type(self):
return "SerialField"

def db_type(self, connection):
return "serial"
13 changes: 8 additions & 5 deletions django/db/backends/postgresql/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,20 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):

def get_field_type(self, data_type, description):
field_type = super().get_field_type(data_type, description)
if description.is_autofield or (
# Required for pre-Django 4.1 serial columns.
description.default
and "nextval" in description.default
):
if description.is_autofield:
if field_type == "IntegerField":
return "AutoField"
elif field_type == "BigIntegerField":
return "BigAutoField"
elif field_type == "SmallIntegerField":
return "SmallAutoField"
elif description.default and "nextval" in description.default:
if field_type == "IntegerField":
return "SerialField"
elif field_type == "BigIntegerField":
return "BigSerialField"
elif field_type == "SmallIntegerField":
return "SmallSerialField"
return field_type

def get_table_list(self, cursor):
Expand Down
24 changes: 17 additions & 7 deletions django/db/backends/postgresql/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import lru_cache, partial

from django.conf import settings
from django.contrib.postgres.fields.serial import SerialFieldMixin
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.backends.postgresql.psycopg_any import (
Inet,
Expand All @@ -12,6 +13,7 @@
)
from django.db.backends.utils import split_tzname_delta
from django.db.models.constants import OnConflict
from django.db.models.fields import AutoFieldMixin
from django.db.models.functions import Cast
from django.utils.regex_helper import _lazy_re_compile

Expand Down Expand Up @@ -43,6 +45,15 @@ class DatabaseOperations(BaseDatabaseOperations):
"AutoField": "integer",
"BigAutoField": "bigint",
"SmallAutoField": "smallint",
"SerialField": "integer",
"BigSerialField": "bigint",
"SmallSerialField": "smallint",
}
integer_field_ranges = {
**BaseDatabaseOperations.integer_field_ranges,
"SmallSerialField": (-32768, 32767),
"SerialField": (-2147483648, 2147483647),
"BigSerialField": (-9223372036854775808, 9223372036854775807),
}

if is_psycopg3:
Expand All @@ -55,6 +66,9 @@ class DatabaseOperations(BaseDatabaseOperations):
"PositiveSmallIntegerField": numeric.Int2,
"PositiveIntegerField": numeric.Int4,
"PositiveBigIntegerField": numeric.Int8,
"SmallSerialField": numeric.Int2,
"SerialField": numeric.Int4,
"BigSerialField": numeric.Int8,
}

def unification_cast_sql(self, output_field):
Expand Down Expand Up @@ -237,20 +251,18 @@ def tablespace_sql(self, tablespace, inline=False):
return "TABLESPACE %s" % self.quote_name(tablespace)

def sequence_reset_sql(self, style, model_list):
from django.db import models

output = []
qn = self.quote_name

for model in model_list:
# Use `coalesce` to set the sequence for each model to the max pk
# value if there are records, or 1 if there are none. Set the
# `is_called` property (the third argument to `setval`) to true if
# there are records (as the max pk value is already in use),
# otherwise set it to false. Use pg_get_serial_sequence to get the
# underlying sequence name from the table name and column name.

for f in model._meta.local_fields:
if isinstance(f, models.AutoField):
if isinstance(f, (AutoFieldMixin, SerialFieldMixin)):
output.append(
"%s setval(pg_get_serial_sequence('%s','%s'), "
"coalesce(max(%s), 1), max(%s) %s null) %s %s;"
Expand All @@ -265,9 +277,7 @@ def sequence_reset_sql(self, style, model_list):
style.SQL_TABLE(qn(model._meta.db_table)),
)
)
# Only one AutoField is allowed per model, so don't bother
# continuing.
break

return output

def prep_for_iexact_query(self, x):
Expand Down

0 comments on commit 4c086eb

Please sign in to comment.