diff --git a/django_spanner/base.py b/django_spanner/base.py index fcf61a22f8..bdbdbc4de1 100644 --- a/django_spanner/base.py +++ b/django_spanner/base.py @@ -17,7 +17,7 @@ from .introspection import DatabaseIntrospection from .operations import DatabaseOperations from .schema import DatabaseSchemaEditor -from django_spanner import USING_DJANGO_3, USING_DJANGO_4 +from django_spanner import USING_DJANGO_3 class DatabaseWrapper(BaseDatabaseWrapper): @@ -130,7 +130,7 @@ def allow_transactions_in_auto_commit(self): return self.settings_dict["ALLOW_TRANSACTIONS_IN_AUTO_COMMIT"] if USING_DJANGO_3: return False - if USING_DJANGO_4: + else: return True @property diff --git a/django_spanner/compiler.py b/django_spanner/compiler.py index ad809e337a..a2175113f7 100644 --- a/django_spanner/compiler.py +++ b/django_spanner/compiler.py @@ -39,6 +39,8 @@ def get_combinator_sql(self, combinator, all): :returns: A tuple containing SQL statement(s) with some additional parameters. """ + # This method copies the complete code of this overridden method from + # Django core and modify it for Spanner by adding one line if USING_DJANGO_3: features = self.connection.features compilers = [ @@ -101,6 +103,8 @@ def get_combinator_sql(self, combinator, all): if not parts: raise EmptyResultSet combinator_sql = self.connection.ops.set_operators[combinator] + # This is the only line that is changed from the Django core + # implementation of this method combinator_sql += " ALL" if all else " DISTINCT" braces = ( "({})" @@ -116,6 +120,9 @@ def get_combinator_sql(self, combinator, all): params.extend(part) return result, params + # As the code of this method has somewhat changed in Django 4.2 core + # version, so we are copying the complete code of this overridden method + # and modifying it for Spanner else: features = self.connection.features compilers = [ @@ -191,6 +198,8 @@ def get_combinator_sql(self, combinator, all): if not parts: raise EmptyResultSet combinator_sql = self.connection.ops.set_operators[combinator] + # This is the only line that is changed from the Django core + # implementation of this method combinator_sql += " ALL" if all else " DISTINCT" braces = "{}" if ( diff --git a/tests/unit/django_spanner/test_operations.py b/tests/unit/django_spanner/test_operations.py index 197bd56688..a1d87520a3 100644 --- a/tests/unit/django_spanner/test_operations.py +++ b/tests/unit/django_spanner/test_operations.py @@ -12,7 +12,7 @@ from django.db.utils import DatabaseError from google.cloud.spanner_dbapi.types import DateStr -from django_spanner import USING_DJANGO_3, USING_DJANGO_4 +from django_spanner import USING_DJANGO_3 from tests.unit.django_spanner.simple_test import SpannerSimpleTestClass import uuid @@ -118,7 +118,7 @@ def test_date_extract_sql(self): self.db_operations.date_extract_sql("week", "dummy_field"), "EXTRACT(isoweek FROM dummy_field)", ) - elif USING_DJANGO_4: + else: self.assertEqual( self.db_operations.date_extract_sql("week", "dummy_field"), ("EXTRACT(isoweek FROM dummy_field)", None), @@ -132,7 +132,7 @@ def test_date_extract_sql_lookup_type_dayofweek(self): ), "EXTRACT(dayofweek FROM dummy_field)", ) - elif USING_DJANGO_4: + else: self.assertEqual( self.db_operations.date_extract_sql( "dayofweek", "dummy_field" @@ -149,7 +149,7 @@ def test_datetime_extract_sql(self): ), 'EXTRACT(dayofweek FROM dummy_field AT TIME ZONE "IST")', ) - elif USING_DJANGO_4: + else: self.assertEqual( self.db_operations.datetime_extract_sql( "dayofweek", "dummy_field", None, "IST" @@ -169,7 +169,7 @@ def test_datetime_extract_sql_use_tz_false(self): ), 'EXTRACT(dayofweek FROM dummy_field AT TIME ZONE "UTC")', ) - elif USING_DJANGO_4: + else: self.assertEqual( self.db_operations.datetime_extract_sql( "dayofweek", "dummy_field", None, "IST" @@ -189,7 +189,7 @@ def test_time_extract_sql(self): ), 'EXTRACT(dayofweek FROM dummy_field AT TIME ZONE "UTC")', ) - elif USING_DJANGO_4: + else: self.assertEqual( self.db_operations.time_extract_sql( "dayofweek", "dummy_field" @@ -206,7 +206,7 @@ def test_time_trunc_sql(self): self.db_operations.time_trunc_sql("dayofweek", "dummy_field"), 'TIMESTAMP_TRUNC(dummy_field, dayofweek, "UTC")', ) - elif USING_DJANGO_4: + else: self.assertEqual( self.db_operations.time_trunc_sql( "dayofweek", "dummy_field", None @@ -222,7 +222,7 @@ def test_datetime_cast_date_sql(self): ), 'DATE(dummy_field, "IST")', ) - elif USING_DJANGO_4: + else: self.assertEqual( self.db_operations.datetime_cast_date_sql( "dummy_field", None, "IST" @@ -239,7 +239,7 @@ def test_datetime_cast_time_sql(self): ), "TIMESTAMP(FORMAT_TIMESTAMP('%Y-%m-%d %R:%E9S %Z', dummy_field, 'IST'))", ) - elif USING_DJANGO_4: + else: self.assertEqual( self.db_operations.datetime_cast_time_sql( "dummy_field", None, "IST" @@ -259,7 +259,7 @@ def test_datetime_cast_time_sql_use_tz_false(self): ), "TIMESTAMP(FORMAT_TIMESTAMP('%Y-%m-%d %R:%E9S %Z', dummy_field, 'UTC'))", ) - elif USING_DJANGO_4: + else: self.assertEqual( self.db_operations.datetime_cast_time_sql( "dummy_field", None, "IST"