Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed #33308 -- Added support for psycopg version 3 #15687

Merged
merged 1 commit into from Dec 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions django/contrib/gis/db/backends/postgis/adapter.py
@@ -1,8 +1,6 @@
"""
This object provides quoting for GEOS geometries into PostgreSQL/PostGIS.
"""
from psycopg2.extensions import ISQLQuote

from django.contrib.gis.db.backends.postgis.pgraster import to_pgraster
from django.contrib.gis.geos import GEOSGeometry
from django.db.backends.postgresql.psycopg_any import sql
Expand All @@ -27,6 +25,8 @@ def __init__(self, obj, geography=False):

def __conform__(self, proto):
"""Does the given protocol conform to what Psycopg2 expects?"""
from psycopg2.extensions import ISQLQuote

if proto == ISQLQuote:
return self
else:
Expand Down
126 changes: 122 additions & 4 deletions django/contrib/gis/db/backends/postgis/base.py
@@ -1,17 +1,93 @@
from functools import lru_cache
apollo13 marked this conversation as resolved.
Show resolved Hide resolved

from django.db.backends.base.base import NO_DB_ALIAS
from django.db.backends.postgresql.base import (
DatabaseWrapper as Psycopg2DatabaseWrapper,
)
from django.db.backends.postgresql.base import DatabaseWrapper as PsycopgDatabaseWrapper
from django.db.backends.postgresql.psycopg_any import is_psycopg3

from .adapter import PostGISAdapter
from .features import DatabaseFeatures
from .introspection import PostGISIntrospection
from .operations import PostGISOperations
from .schema import PostGISSchemaEditor

if is_psycopg3:
from psycopg.adapt import Dumper
from psycopg.pq import Format
from psycopg.types import TypeInfo
from psycopg.types.string import TextBinaryLoader, TextLoader

class GeometryType:
pass

class GeographyType:
pass

class RasterType:
pass

class BaseTextDumper(Dumper):
def dump(self, obj):
# Return bytes as hex for text formatting
return obj.ewkb.hex().encode()

class BaseBinaryDumper(Dumper):
format = Format.BINARY

def dump(self, obj):
return obj.ewkb

@lru_cache
def postgis_adapters(geo_oid, geog_oid, raster_oid):
class BaseDumper(Dumper):
def __init_subclass__(cls, base_dumper):
super().__init_subclass__()

cls.GeometryDumper = type(
"GeometryDumper", (base_dumper,), {"oid": geo_oid}
)
cls.GeographyDumper = type(
"GeographyDumper", (base_dumper,), {"oid": geog_oid}
)
cls.RasterDumper = type(
"RasterDumper", (BaseTextDumper,), {"oid": raster_oid}
)

def get_key(self, obj, format):
if obj.is_geometry:
return GeographyType if obj.geography else GeometryType
else:
return RasterType

def upgrade(self, obj, format):
if obj.is_geometry:
if obj.geography:
return self.GeographyDumper(GeographyType)
else:
return self.GeometryDumper(GeometryType)
else:
return self.RasterDumper(RasterType)

def dump(self, obj):
raise NotImplementedError

class PostGISTextDumper(BaseDumper, base_dumper=BaseTextDumper):
pass

class DatabaseWrapper(Psycopg2DatabaseWrapper):
class PostGISBinaryDumper(BaseDumper, base_dumper=BaseBinaryDumper):
format = Format.BINARY

return PostGISTextDumper, PostGISBinaryDumper


class DatabaseWrapper(PsycopgDatabaseWrapper):
SchemaEditorClass = PostGISSchemaEditor

_type_infos = {
"geometry": {},
"geography": {},
"raster": {},
}

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if kwargs.get("alias", "") != NO_DB_ALIAS:
Expand All @@ -27,3 +103,45 @@ def prepare_database(self):
if bool(cursor.fetchone()):
return
cursor.execute("CREATE EXTENSION IF NOT EXISTS postgis")
if is_psycopg3:
# Ensure adapters are registers if PostGIS is used within this
# connection.
self.register_geometry_adapters(self.connection, True)

def get_new_connection(self, conn_params):
connection = super().get_new_connection(conn_params)
if is_psycopg3:
self.register_geometry_adapters(connection)
return connection

if is_psycopg3:

def _register_type(self, pg_connection, typename):
registry = self._type_infos[typename]
try:
info = registry[self.alias]
except KeyError:
info = TypeInfo.fetch(pg_connection, typename)
registry[self.alias] = info

if info: # Can be None if the type does not exist (yet).
info.register(pg_connection)
pg_connection.adapters.register_loader(info.oid, TextLoader)
pg_connection.adapters.register_loader(info.oid, TextBinaryLoader)

return info.oid if info else None

def register_geometry_adapters(self, pg_connection, clear_caches=False):
if clear_caches:
for typename in self._type_infos:
self._type_infos[typename].pop(self.alias, None)

geo_oid = self._register_type(pg_connection, "geometry")
geog_oid = self._register_type(pg_connection, "geography")
raster_oid = self._register_type(pg_connection, "raster")

PostGISTextDumper, PostGISBinaryDumper = postgis_adapters(
geo_oid, geog_oid, raster_oid
)
pg_connection.adapters.register_dumper(PostGISAdapter, PostGISTextDumper)
pg_connection.adapters.register_dumper(PostGISAdapter, PostGISBinaryDumper)
4 changes: 2 additions & 2 deletions django/contrib/gis/db/backends/postgis/features.py
@@ -1,10 +1,10 @@
from django.contrib.gis.db.backends.base.features import BaseSpatialFeatures
from django.db.backends.postgresql.features import (
DatabaseFeatures as Psycopg2DatabaseFeatures,
DatabaseFeatures as PsycopgDatabaseFeatures,
)


class DatabaseFeatures(BaseSpatialFeatures, Psycopg2DatabaseFeatures):
class DatabaseFeatures(BaseSpatialFeatures, PsycopgDatabaseFeatures):
supports_geography = True
supports_3d_storage = True
supports_3d_functions = True
Expand Down
6 changes: 5 additions & 1 deletion django/contrib/gis/db/backends/postgis/operations.py
Expand Up @@ -11,6 +11,7 @@
from django.core.exceptions import ImproperlyConfigured
from django.db import NotSupportedError, ProgrammingError
from django.db.backends.postgresql.operations import DatabaseOperations
from django.db.backends.postgresql.psycopg_any import is_psycopg3
from django.db.models import Func, Value
from django.utils.functional import cached_property
from django.utils.version import get_version_tuple
Expand Down Expand Up @@ -161,7 +162,8 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):

unsupported_functions = set()

select = "%s::bytea"
select = "%s" if is_psycopg3 else "%s::bytea"

select_extent = None

@cached_property
Expand Down Expand Up @@ -407,6 +409,8 @@ def get_geometry_converter(self, expression):
geom_class = expression.output_field.geom_class

def converter(value, expression, connection):
if isinstance(value, str): # Coming from hex strings.
value = value.encode("ascii")
return None if value is None else GEOSGeometryBase(read(value), geom_class)

return converter
Expand Down
2 changes: 1 addition & 1 deletion django/contrib/postgres/fields/array.py
Expand Up @@ -237,7 +237,7 @@ def formfield(self, **kwargs):

class ArrayRHSMixin:
def __init__(self, lhs, rhs):
# Don't wrap arrays that contains only None values, psycopg2 doesn't
# Don't wrap arrays that contains only None values, psycopg doesn't
# allow this.
if isinstance(rhs, (tuple, list)) and any(self._rhs_not_none_values(rhs)):
expressions = []
Expand Down
10 changes: 9 additions & 1 deletion django/contrib/postgres/fields/ranges.py
Expand Up @@ -9,6 +9,7 @@
NumericRange,
Range,
)
from django.db.models.functions import Cast
from django.db.models.lookups import PostgresOperatorLookup

from .utils import AttributeSetter
Expand Down Expand Up @@ -208,7 +209,14 @@ def db_type(self, connection):
return "daterange"


RangeField.register_lookup(lookups.DataContains)
class RangeContains(lookups.DataContains):
def get_prep_lookup(self):
if not isinstance(self.rhs, (list, tuple, Range)):
return Cast(self.rhs, self.lhs.field.base_field)
return super().get_prep_lookup()


RangeField.register_lookup(RangeContains)
RangeField.register_lookup(lookups.ContainedBy)
RangeField.register_lookup(lookups.Overlap)

Expand Down
4 changes: 4 additions & 0 deletions django/contrib/postgres/operations.py
Expand Up @@ -35,6 +35,10 @@ def database_forwards(self, app_label, schema_editor, from_state, to_state):
# installed, otherwise a subsequent data migration would use the same
# connection.
register_type_handlers(schema_editor.connection)
if hasattr(schema_editor.connection, "register_geometry_adapters"):
schema_editor.connection.register_geometry_adapters(
schema_editor.connection.connection, True
)

def database_backwards(self, app_label, schema_editor, from_state, to_state):
if not router.allow_migrate(schema_editor.connection.alias, app_label):
Expand Down
14 changes: 13 additions & 1 deletion django/contrib/postgres/search.py
Expand Up @@ -39,6 +39,11 @@ def db_type(self, connection):
return "tsquery"


class _Float4Field(Field):
def db_type(self, connection):
return "float4"


class SearchConfig(Expression):
def __init__(self, config):
super().__init__()
Expand Down Expand Up @@ -138,7 +143,11 @@ def as_sql(self, compiler, connection, function=None, template=None):
if clone.weight:
weight_sql, extra_params = compiler.compile(clone.weight)
sql = "setweight({}, {})".format(sql, weight_sql)
return sql, config_params + params + extra_params

# These parameters must be bound on the client side because we may
# want to create an index on this expression.
sql = connection.ops.compose_sql(sql, config_params + params + extra_params)
return sql, []


class CombinedSearchVector(SearchVectorCombinable, CombinedExpression):
Expand Down Expand Up @@ -244,6 +253,8 @@ def __init__(
normalization=None,
cover_density=False,
):
from .fields.array import ArrayField

if not hasattr(vector, "resolve_expression"):
vector = SearchVector(vector)
if not hasattr(query, "resolve_expression"):
Expand All @@ -252,6 +263,7 @@ def __init__(
if weights is not None:
if not hasattr(weights, "resolve_expression"):
weights = Value(weights)
weights = Cast(weights, ArrayField(_Float4Field()))
expressions = (weights,) + expressions
if normalization is not None:
if not hasattr(normalization, "resolve_expression"):
Expand Down
77 changes: 48 additions & 29 deletions django/contrib/postgres/signals.py
@@ -1,10 +1,8 @@
import functools

import psycopg2
from psycopg2.extras import register_hstore

from django.db import connections
from django.db.backends.base.base import NO_DB_ALIAS
from django.db.backends.postgresql.psycopg_any import is_psycopg3


def get_type_oids(connection_alias, type_name):
Expand Down Expand Up @@ -32,30 +30,51 @@ def get_citext_oids(connection_alias):
return get_type_oids(connection_alias, "citext")


def register_type_handlers(connection, **kwargs):
if connection.vendor != "postgresql" or connection.alias == NO_DB_ALIAS:
return

oids, array_oids = get_hstore_oids(connection.alias)
# Don't register handlers when hstore is not available on the database.
#
# If someone tries to create an hstore field it will error there. This is
# necessary as someone may be using PSQL without extensions installed but
# be using other features of contrib.postgres.
#
# This is also needed in order to create the connection in order to install
# the hstore extension.
if oids:
register_hstore(
connection.connection, globally=True, oid=oids, array_oid=array_oids
)
if is_psycopg3:
from psycopg.types import TypeInfo, hstore

oids, citext_oids = get_citext_oids(connection.alias)
# Don't register handlers when citext is not available on the database.
#
# The same comments in the above call to register_hstore() also apply here.
if oids:
array_type = psycopg2.extensions.new_array_type(
citext_oids, "citext[]", psycopg2.STRING
)
psycopg2.extensions.register_type(array_type, None)
def register_type_handlers(connection, **kwargs):
if connection.vendor != "postgresql" or connection.alias == NO_DB_ALIAS:
return

oids, array_oids = get_hstore_oids(connection.alias)
for oid, array_oid in zip(oids, array_oids):
ti = TypeInfo("hstore", oid, array_oid)
hstore.register_hstore(ti, connection.connection)

_, citext_oids = get_citext_oids(connection.alias)
for array_oid in citext_oids:
ti = TypeInfo("citext", 0, array_oid)
felixxm marked this conversation as resolved.
Show resolved Hide resolved
ti.register(connection.connection)

else:
import psycopg2
from psycopg2.extras import register_hstore

def register_type_handlers(connection, **kwargs):
if connection.vendor != "postgresql" or connection.alias == NO_DB_ALIAS:
return

oids, array_oids = get_hstore_oids(connection.alias)
# Don't register handlers when hstore is not available on the database.
#
# If someone tries to create an hstore field it will error there. This is
# necessary as someone may be using PSQL without extensions installed but
# be using other features of contrib.postgres.
#
# This is also needed in order to create the connection in order to install
# the hstore extension.
if oids:
register_hstore(
connection.connection, globally=True, oid=oids, array_oid=array_oids
)

oids, citext_oids = get_citext_oids(connection.alias)
# Don't register handlers when citext is not available on the database.
#
# The same comments in the above call to register_hstore() also apply here.
if oids:
felixxm marked this conversation as resolved.
Show resolved Hide resolved
array_type = psycopg2.extensions.new_array_type(
citext_oids, "citext[]", psycopg2.STRING
)
psycopg2.extensions.register_type(array_type, None)
2 changes: 1 addition & 1 deletion django/core/management/commands/loaddata.py
Expand Up @@ -207,7 +207,7 @@ def save_obj(self, obj):
self.models.add(obj.object.__class__)
try:
obj.save(using=self.using)
# psycopg2 raises ValueError if data contains NUL chars.
# psycopg raises ValueError if data contains NUL chars.
except (DatabaseError, IntegrityError, ValueError) as e:
e.args = (
"Could not load %(object_label)s(pk=%(pk)s): %(error_msg)s"
Expand Down