Skip to content

Commit

Permalink
fix: override AutoField default value only for Spanner (#780)
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyaFaer committed Jul 20, 2022
1 parent 8cad6f6 commit 3bf2c77
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 3 deletions.
8 changes: 7 additions & 1 deletion django_spanner/__init__.py
Expand Up @@ -15,6 +15,7 @@
import pkg_resources
from google.cloud.spanner_v1 import JsonObject
from django.db.models.fields import (
NOT_PROVIDED,
AutoField,
Field,
)
Expand Down Expand Up @@ -61,7 +62,12 @@ def gen_rand_int64():
def autofield_init(self, *args, **kwargs):
kwargs["blank"] = True
Field.__init__(self, *args, **kwargs)
self.default = gen_rand_int64

if (
django.db.connection.settings_dict["ENGINE"] == "django_spanner"
and self.default == NOT_PROVIDED
):
self.default = gen_rand_int64


AutoField.__init__ = autofield_init
Expand Down
33 changes: 31 additions & 2 deletions tests/unit/django_spanner/test_schema.py
Expand Up @@ -6,9 +6,10 @@


from .models import Author
from django.db import NotSupportedError
from django.db import NotSupportedError, connection
from django.db.models import Index
from django.db.models.fields import IntegerField
from django.db.models.fields import AutoField, IntegerField
from django_spanner import gen_rand_int64
from django_spanner.schema import DatabaseSchemaEditor
from tests._helpers import HAS_OPENTELEMETRY_INSTALLED
from tests.unit.django_spanner.simple_test import SpannerSimpleTestClass
Expand Down Expand Up @@ -404,3 +405,31 @@ def constraint_names(*args, **kwargs):
new_field.set_attributes_from_name("author_num")
with self.assertRaises(NotSupportedError):
schema_editor.alter_field(Author, old_field, new_field)

def test_autofield_no_default(self):
"""Spanner, default is not provided."""
field = AutoField(name="field_name")
assert gen_rand_int64 == field.default

def test_autofield_default(self):
"""Spanner, default provided."""
mock_func = mock.Mock()
field = AutoField(name="field_name", default=mock_func)
assert gen_rand_int64 != field.default
assert mock_func == field.default

def test_autofield_not_spanner(self):
"""Not Spanner, default not provided."""
connection.settings_dict["ENGINE"] = "another_db"
field = AutoField(name="field_name")
assert gen_rand_int64 != field.default
connection.settings_dict["ENGINE"] = "django_spanner"

def test_autofield_not_spanner_w_default(self):
"""Not Spanner, default provided."""
connection.settings_dict["ENGINE"] = "another_db"
mock_func = mock.Mock()
field = AutoField(name="field_name", default=mock_func)
assert gen_rand_int64 != field.default
assert mock_func == field.default
connection.settings_dict["ENGINE"] = "django_spanner"

0 comments on commit 3bf2c77

Please sign in to comment.