Skip to content

Commit

Permalink
Comments incorporated
Browse files Browse the repository at this point in the history
  • Loading branch information
ankiaga committed Apr 17, 2024
1 parent a92db3a commit 628b286
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 12 deletions.
4 changes: 2 additions & 2 deletions django_spanner/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .introspection import DatabaseIntrospection
from .operations import DatabaseOperations
from .schema import DatabaseSchemaEditor
from django_spanner import USING_DJANGO_3, USING_DJANGO_4
from django_spanner import USING_DJANGO_3


class DatabaseWrapper(BaseDatabaseWrapper):
Expand Down Expand Up @@ -130,7 +130,7 @@ def allow_transactions_in_auto_commit(self):
return self.settings_dict["ALLOW_TRANSACTIONS_IN_AUTO_COMMIT"]
if USING_DJANGO_3:
return False
if USING_DJANGO_4:
else:
return True

@property
Expand Down
9 changes: 9 additions & 0 deletions django_spanner/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def get_combinator_sql(self, combinator, all):
:returns: A tuple containing SQL statement(s) with some additional
parameters.
"""
# This method copies the complete code of this overridden method from
# Django core and modify it for Spanner by adding one line
if USING_DJANGO_3:
features = self.connection.features
compilers = [
Expand Down Expand Up @@ -101,6 +103,8 @@ def get_combinator_sql(self, combinator, all):
if not parts:
raise EmptyResultSet
combinator_sql = self.connection.ops.set_operators[combinator]
# This is the only line that is changed from the Django core
# implementation of this method
combinator_sql += " ALL" if all else " DISTINCT"
braces = (
"({})"
Expand All @@ -116,6 +120,9 @@ def get_combinator_sql(self, combinator, all):
params.extend(part)

return result, params
# As the code of this method has somewhat changed in Django 4.2 core
# version, so we are copying the complete code of this overridden method
# and modifying it for Spanner
else:
features = self.connection.features
compilers = [
Expand Down Expand Up @@ -191,6 +198,8 @@ def get_combinator_sql(self, combinator, all):
if not parts:
raise EmptyResultSet
combinator_sql = self.connection.ops.set_operators[combinator]
# This is the only line that is changed from the Django core
# implementation of this method
combinator_sql += " ALL" if all else " DISTINCT"
braces = "{}"
if (
Expand Down
20 changes: 10 additions & 10 deletions tests/unit/django_spanner/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from django.db.utils import DatabaseError
from google.cloud.spanner_dbapi.types import DateStr

from django_spanner import USING_DJANGO_3, USING_DJANGO_4
from django_spanner import USING_DJANGO_3
from tests.unit.django_spanner.simple_test import SpannerSimpleTestClass
import uuid

Expand Down Expand Up @@ -118,7 +118,7 @@ def test_date_extract_sql(self):
self.db_operations.date_extract_sql("week", "dummy_field"),
"EXTRACT(isoweek FROM dummy_field)",
)
elif USING_DJANGO_4:
else:
self.assertEqual(
self.db_operations.date_extract_sql("week", "dummy_field"),
("EXTRACT(isoweek FROM dummy_field)", None),
Expand All @@ -132,7 +132,7 @@ def test_date_extract_sql_lookup_type_dayofweek(self):
),
"EXTRACT(dayofweek FROM dummy_field)",
)
elif USING_DJANGO_4:
else:
self.assertEqual(
self.db_operations.date_extract_sql(
"dayofweek", "dummy_field"
Expand All @@ -149,7 +149,7 @@ def test_datetime_extract_sql(self):
),
'EXTRACT(dayofweek FROM dummy_field AT TIME ZONE "IST")',
)
elif USING_DJANGO_4:
else:
self.assertEqual(
self.db_operations.datetime_extract_sql(
"dayofweek", "dummy_field", None, "IST"
Expand All @@ -169,7 +169,7 @@ def test_datetime_extract_sql_use_tz_false(self):
),
'EXTRACT(dayofweek FROM dummy_field AT TIME ZONE "UTC")',
)
elif USING_DJANGO_4:
else:
self.assertEqual(
self.db_operations.datetime_extract_sql(
"dayofweek", "dummy_field", None, "IST"
Expand All @@ -189,7 +189,7 @@ def test_time_extract_sql(self):
),
'EXTRACT(dayofweek FROM dummy_field AT TIME ZONE "UTC")',
)
elif USING_DJANGO_4:
else:
self.assertEqual(
self.db_operations.time_extract_sql(
"dayofweek", "dummy_field"
Expand All @@ -206,7 +206,7 @@ def test_time_trunc_sql(self):
self.db_operations.time_trunc_sql("dayofweek", "dummy_field"),
'TIMESTAMP_TRUNC(dummy_field, dayofweek, "UTC")',
)
elif USING_DJANGO_4:
else:
self.assertEqual(
self.db_operations.time_trunc_sql(
"dayofweek", "dummy_field", None
Expand All @@ -222,7 +222,7 @@ def test_datetime_cast_date_sql(self):
),
'DATE(dummy_field, "IST")',
)
elif USING_DJANGO_4:
else:
self.assertEqual(
self.db_operations.datetime_cast_date_sql(
"dummy_field", None, "IST"
Expand All @@ -239,7 +239,7 @@ def test_datetime_cast_time_sql(self):
),
"TIMESTAMP(FORMAT_TIMESTAMP('%Y-%m-%d %R:%E9S %Z', dummy_field, 'IST'))",
)
elif USING_DJANGO_4:
else:
self.assertEqual(
self.db_operations.datetime_cast_time_sql(
"dummy_field", None, "IST"
Expand All @@ -259,7 +259,7 @@ def test_datetime_cast_time_sql_use_tz_false(self):
),
"TIMESTAMP(FORMAT_TIMESTAMP('%Y-%m-%d %R:%E9S %Z', dummy_field, 'UTC'))",
)
elif USING_DJANGO_4:
else:
self.assertEqual(
self.db_operations.datetime_cast_time_sql(
"dummy_field", None, "IST"
Expand Down

0 comments on commit 628b286

Please sign in to comment.