Skip to content

Commit

Permalink
Fixed #33308 -- Added support for psycopg version 3.
Browse files Browse the repository at this point in the history
Thanks Simon Charette, Tim Graham, and Adam Johnson for reviews.
Co-authored-by: Florian Apolloner <florian@apolloner.eu>
Co-authored-by: Mariusz Felisiak <felisiak.mariusz@gmail.com>
  • Loading branch information
dvarrazzo authored and felixxm committed Dec 14, 2022
1 parent 3226845 commit 86ec886
Show file tree
Hide file tree
Showing 25 changed files with 579 additions and 166 deletions.
4 changes: 2 additions & 2 deletions django/contrib/gis/db/backends/postgis/adapter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
"""
This object provides quoting for GEOS geometries into PostgreSQL/PostGIS.
"""
from psycopg2.extensions import ISQLQuote

from django.contrib.gis.db.backends.postgis.pgraster import to_pgraster
from django.contrib.gis.geos import GEOSGeometry
from django.db.backends.postgresql.psycopg_any import sql
Expand All @@ -27,6 +25,8 @@ def __init__(self, obj, geography=False):

def __conform__(self, proto):
"""Does the given protocol conform to what Psycopg2 expects?"""
from psycopg2.extensions import ISQLQuote

if proto == ISQLQuote:
return self
else:
Expand Down
126 changes: 122 additions & 4 deletions django/contrib/gis/db/backends/postgis/base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,93 @@
from functools import lru_cache

from django.db.backends.base.base import NO_DB_ALIAS
from django.db.backends.postgresql.base import (
DatabaseWrapper as Psycopg2DatabaseWrapper,
)
from django.db.backends.postgresql.base import DatabaseWrapper as PsycopgDatabaseWrapper
from django.db.backends.postgresql.psycopg_any import is_psycopg3

from .adapter import PostGISAdapter
from .features import DatabaseFeatures
from .introspection import PostGISIntrospection
from .operations import PostGISOperations
from .schema import PostGISSchemaEditor

if is_psycopg3:
from psycopg.adapt import Dumper
from psycopg.pq import Format
from psycopg.types import TypeInfo
from psycopg.types.string import TextBinaryLoader, TextLoader

class GeometryType:
pass

class GeographyType:
pass

class RasterType:
pass

class BaseTextDumper(Dumper):
def dump(self, obj):
# Return bytes as hex for text formatting
return obj.ewkb.hex().encode()

class BaseBinaryDumper(Dumper):
format = Format.BINARY

def dump(self, obj):
return obj.ewkb

@lru_cache
def postgis_adapters(geo_oid, geog_oid, raster_oid):
class BaseDumper(Dumper):
def __init_subclass__(cls, base_dumper):
super().__init_subclass__()

cls.GeometryDumper = type(
"GeometryDumper", (base_dumper,), {"oid": geo_oid}
)
cls.GeographyDumper = type(
"GeographyDumper", (base_dumper,), {"oid": geog_oid}
)
cls.RasterDumper = type(
"RasterDumper", (BaseTextDumper,), {"oid": raster_oid}
)

def get_key(self, obj, format):
if obj.is_geometry:
return GeographyType if obj.geography else GeometryType
else:
return RasterType

def upgrade(self, obj, format):
if obj.is_geometry:
if obj.geography:
return self.GeographyDumper(GeographyType)
else:
return self.GeometryDumper(GeometryType)
else:
return self.RasterDumper(RasterType)

def dump(self, obj):
raise NotImplementedError

class PostGISTextDumper(BaseDumper, base_dumper=BaseTextDumper):
pass

class DatabaseWrapper(Psycopg2DatabaseWrapper):
class PostGISBinaryDumper(BaseDumper, base_dumper=BaseBinaryDumper):
format = Format.BINARY

return PostGISTextDumper, PostGISBinaryDumper


class DatabaseWrapper(PsycopgDatabaseWrapper):
SchemaEditorClass = PostGISSchemaEditor

_type_infos = {
"geometry": {},
"geography": {},
"raster": {},
}

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if kwargs.get("alias", "") != NO_DB_ALIAS:
Expand All @@ -27,3 +103,45 @@ def prepare_database(self):
if bool(cursor.fetchone()):
return
cursor.execute("CREATE EXTENSION IF NOT EXISTS postgis")
if is_psycopg3:
# Ensure adapters are registers if PostGIS is used within this
# connection.
self.register_geometry_adapters(self.connection, True)

def get_new_connection(self, conn_params):
connection = super().get_new_connection(conn_params)
if is_psycopg3:
self.register_geometry_adapters(connection)
return connection

if is_psycopg3:

def _register_type(self, pg_connection, typename):
registry = self._type_infos[typename]
try:
info = registry[self.alias]
except KeyError:
info = TypeInfo.fetch(pg_connection, typename)
registry[self.alias] = info

if info: # Can be None if the type does not exist (yet).
info.register(pg_connection)
pg_connection.adapters.register_loader(info.oid, TextLoader)
pg_connection.adapters.register_loader(info.oid, TextBinaryLoader)

return info.oid if info else None

def register_geometry_adapters(self, pg_connection, clear_caches=False):
if clear_caches:
for typename in self._type_infos:
self._type_infos[typename].pop(self.alias, None)

geo_oid = self._register_type(pg_connection, "geometry")
geog_oid = self._register_type(pg_connection, "geography")
raster_oid = self._register_type(pg_connection, "raster")

PostGISTextDumper, PostGISBinaryDumper = postgis_adapters(
geo_oid, geog_oid, raster_oid
)
pg_connection.adapters.register_dumper(PostGISAdapter, PostGISTextDumper)
pg_connection.adapters.register_dumper(PostGISAdapter, PostGISBinaryDumper)
6 changes: 5 additions & 1 deletion django/contrib/gis/db/backends/postgis/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from django.core.exceptions import ImproperlyConfigured
from django.db import NotSupportedError, ProgrammingError
from django.db.backends.postgresql.operations import DatabaseOperations
from django.db.backends.postgresql.psycopg_any import is_psycopg3
from django.db.models import Func, Value
from django.utils.functional import cached_property
from django.utils.version import get_version_tuple
Expand Down Expand Up @@ -161,7 +162,8 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):

unsupported_functions = set()

select = "%s::bytea"
select = "%s" if is_psycopg3 else "%s::bytea"

select_extent = None

@cached_property
Expand Down Expand Up @@ -407,6 +409,8 @@ def get_geometry_converter(self, expression):
geom_class = expression.output_field.geom_class

def converter(value, expression, connection):
if isinstance(value, str): # Coming from hex strings.
value = value.encode("ascii")
return None if value is None else GEOSGeometryBase(read(value), geom_class)

return converter
Expand Down
10 changes: 9 additions & 1 deletion django/contrib/postgres/fields/ranges.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
NumericRange,
Range,
)
from django.db.models.functions import Cast
from django.db.models.lookups import PostgresOperatorLookup

from .utils import AttributeSetter
Expand Down Expand Up @@ -208,7 +209,14 @@ def db_type(self, connection):
return "daterange"


RangeField.register_lookup(lookups.DataContains)
class RangeContains(lookups.DataContains):
def get_prep_lookup(self):
if not isinstance(self.rhs, (list, tuple, Range)):
return Cast(self.rhs, self.lhs.field.base_field)
return super().get_prep_lookup()


RangeField.register_lookup(RangeContains)
RangeField.register_lookup(lookups.ContainedBy)
RangeField.register_lookup(lookups.Overlap)

Expand Down
4 changes: 4 additions & 0 deletions django/contrib/postgres/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def database_forwards(self, app_label, schema_editor, from_state, to_state):
# installed, otherwise a subsequent data migration would use the same
# connection.
register_type_handlers(schema_editor.connection)
if hasattr(schema_editor.connection, "register_geometry_adapters"):
schema_editor.connection.register_geometry_adapters(
schema_editor.connection.connection, True
)

def database_backwards(self, app_label, schema_editor, from_state, to_state):
if not router.allow_migrate(schema_editor.connection.alias, app_label):
Expand Down
14 changes: 13 additions & 1 deletion django/contrib/postgres/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ def db_type(self, connection):
return "tsquery"


class _Float4Field(Field):
def db_type(self, connection):
return "float4"


class SearchConfig(Expression):
def __init__(self, config):
super().__init__()
Expand Down Expand Up @@ -138,7 +143,11 @@ def as_sql(self, compiler, connection, function=None, template=None):
if clone.weight:
weight_sql, extra_params = compiler.compile(clone.weight)
sql = "setweight({}, {})".format(sql, weight_sql)
return sql, config_params + params + extra_params

# These parameters must be bound on the client side because we may
# want to create an index on this expression.
sql = connection.ops.compose_sql(sql, config_params + params + extra_params)
return sql, []


class CombinedSearchVector(SearchVectorCombinable, CombinedExpression):
Expand Down Expand Up @@ -244,6 +253,8 @@ def __init__(
normalization=None,
cover_density=False,
):
from .fields.array import ArrayField

if not hasattr(vector, "resolve_expression"):
vector = SearchVector(vector)
if not hasattr(query, "resolve_expression"):
Expand All @@ -252,6 +263,7 @@ def __init__(
if weights is not None:
if not hasattr(weights, "resolve_expression"):
weights = Value(weights)
weights = Cast(weights, ArrayField(_Float4Field()))
expressions = (weights,) + expressions
if normalization is not None:
if not hasattr(normalization, "resolve_expression"):
Expand Down
77 changes: 48 additions & 29 deletions django/contrib/postgres/signals.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import functools

import psycopg2
from psycopg2.extras import register_hstore

from django.db import connections
from django.db.backends.base.base import NO_DB_ALIAS
from django.db.backends.postgresql.psycopg_any import is_psycopg3


def get_type_oids(connection_alias, type_name):
Expand Down Expand Up @@ -32,30 +30,51 @@ def get_citext_oids(connection_alias):
return get_type_oids(connection_alias, "citext")


def register_type_handlers(connection, **kwargs):
if connection.vendor != "postgresql" or connection.alias == NO_DB_ALIAS:
return

oids, array_oids = get_hstore_oids(connection.alias)
# Don't register handlers when hstore is not available on the database.
#
# If someone tries to create an hstore field it will error there. This is
# necessary as someone may be using PSQL without extensions installed but
# be using other features of contrib.postgres.
#
# This is also needed in order to create the connection in order to install
# the hstore extension.
if oids:
register_hstore(
connection.connection, globally=True, oid=oids, array_oid=array_oids
)
if is_psycopg3:
from psycopg.types import TypeInfo, hstore

oids, citext_oids = get_citext_oids(connection.alias)
# Don't register handlers when citext is not available on the database.
#
# The same comments in the above call to register_hstore() also apply here.
if oids:
array_type = psycopg2.extensions.new_array_type(
citext_oids, "citext[]", psycopg2.STRING
)
psycopg2.extensions.register_type(array_type, None)
def register_type_handlers(connection, **kwargs):
if connection.vendor != "postgresql" or connection.alias == NO_DB_ALIAS:
return

oids, array_oids = get_hstore_oids(connection.alias)
for oid, array_oid in zip(oids, array_oids):
ti = TypeInfo("hstore", oid, array_oid)
hstore.register_hstore(ti, connection.connection)

_, citext_oids = get_citext_oids(connection.alias)
for array_oid in citext_oids:
ti = TypeInfo("citext", 0, array_oid)
ti.register(connection.connection)

else:
import psycopg2
from psycopg2.extras import register_hstore

def register_type_handlers(connection, **kwargs):
if connection.vendor != "postgresql" or connection.alias == NO_DB_ALIAS:
return

oids, array_oids = get_hstore_oids(connection.alias)
# Don't register handlers when hstore is not available on the database.
#
# If someone tries to create an hstore field it will error there. This is
# necessary as someone may be using PSQL without extensions installed but
# be using other features of contrib.postgres.
#
# This is also needed in order to create the connection in order to install
# the hstore extension.
if oids:
register_hstore(
connection.connection, globally=True, oid=oids, array_oid=array_oids
)

oids, citext_oids = get_citext_oids(connection.alias)
# Don't register handlers when citext is not available on the database.
#
# The same comments in the above call to register_hstore() also apply here.
if oids:
array_type = psycopg2.extensions.new_array_type(
citext_oids, "citext[]", psycopg2.STRING
)
psycopg2.extensions.register_type(array_type, None)
5 changes: 5 additions & 0 deletions django/db/backends/base/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ class BaseDatabaseFeatures:
# Can we roll back DDL in a transaction?
can_rollback_ddl = False

schema_editor_uses_clientside_param_binding = False

# Does it support operations requiring references rename in a transaction?
supports_atomic_references_rename = True

Expand Down Expand Up @@ -335,6 +337,9 @@ class BaseDatabaseFeatures:
# Does the backend support the logical XOR operator?
supports_logical_xor = False

# Set to (exception, message) if null characters in text are disallowed.
prohibits_null_characters_in_text_exception = None

# Collation names for use by the Django test suite.
test_collations = {
"ci": None, # Case-insensitive.
Expand Down
3 changes: 3 additions & 0 deletions django/db/backends/base/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,9 @@ def adapt_unknown_value(self, value):
else:
return value

def adapt_integerfield_value(self, value, internal_type):
return value

def adapt_datefield_value(self, value):
"""
Transform a date value to an object compatible with what is expected
Expand Down

0 comments on commit 86ec886

Please sign in to comment.