Skip to content

Commit

Permalink
Make compatible with psycopg 3.
Browse files Browse the repository at this point in the history
  • Loading branch information
hwalinga committed Apr 19, 2023
1 parent 64f06b3 commit 772cfa8
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 29 deletions.
57 changes: 40 additions & 17 deletions example/postgres_tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,41 @@
from django.db import connection

import unittest
if connection.vendor != 'postgresql' or connection.Database.__version__ > '3':
raise unittest.SkipTest('postgres with pscopg2 only tests')
if connection.vendor != 'postgresql':
raise unittest.SkipTest('postgres only tests')

from django.test import TestCase
from .models import PostgresFields, FieldUpdateNotNull, CustomField, FieldUpdateArray, TestCoverage
from exampleapp.models import FieldUpdate, MultiSub, Child, Parent
from psycopg2.extras import NumericRange, DateTimeTZRange, DateRange
import datetime
import pytz
import uuid
from decimal import Decimal
from fast_update.copy import get_encoder, register_fieldclass, Int, IntOrNone, array_factory
import json

if connection.Database.__version__ > '3':
from psycopg.types.range import Range

# For psycopg 3, Range is adapted automatically, see documentation:
# https://www.psycopg.org/psycopg3/docs/basic/pgtypes.html#range-adaptation
# The built-in range objects are adapted automatically:
# if a Range objects contains date bounds,
# it is dumped using the daterange OID,
# and of course daterange values are loaded back as Range[date].

DateRange, DateTimeTZRange, NumericRange = Range, Range, Range
else:
from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange


def tobytes(val):
"""The val is a bytes or a memoryview. For a memoryview, tobytes needs to be called."""
try:
return val.tobytes()
except AttributeError:
return val


dt = datetime.datetime.now()
dt_utc = datetime.datetime.now(tz=pytz.UTC)
Expand Down Expand Up @@ -203,25 +224,25 @@ def test_biginteger(self):
def test_binary(self):
self._single('f_binary')
self._single_raise('f_binary', 'wrong', "expected types <class 'memoryview'>, <class 'bytes'> or None")

def test_binary_big(self):
# >64k
data = b'1234567890' * 10000
obj = FieldUpdate.objects.create()
obj.f_binary = data
FieldUpdate.objects.copy_update([obj], ['f_binary'])
self.assertEqual(FieldUpdate.objects.get(pk=obj.pk).f_binary.tobytes(), data)
self.assertEqual(tobytes(FieldUpdate.objects.get(pk=obj.pk).f_binary), data)
# <64k
data = b'1234567890' * 1000
obj = FieldUpdate.objects.create()
obj.f_binary = data
FieldUpdate.objects.copy_update([obj], ['f_binary'])
self.assertEqual(FieldUpdate.objects.get(pk=obj.pk).f_binary.tobytes(), data)
self.assertEqual(tobytes(FieldUpdate.objects.get(pk=obj.pk).f_binary), data)

def test_boolean(self):
self._single('f_boolean')
self._single_raise('f_boolean', 'wrong', "expected type <class 'bool'> or None")

def test_char(self):
self._single('f_char')
self._single_raise('f_char', 123, "expected type <class 'str'> or None")
Expand Down Expand Up @@ -302,17 +323,17 @@ def test_big_lazy(self):
obj.f_binary = b'0' * 100000
obj.f_text = 'x' * 100000
FieldUpdate.objects.copy_update([obj], ['f_binary', 'f_text'])
self.assertEqual(FieldUpdate.objects.get(pk=obj.pk).f_binary.tobytes(), b'0' * 100000)
self.assertEqual(tobytes(FieldUpdate.objects.get(pk=obj.pk).f_binary), b'0' * 100000)
self.assertEqual(FieldUpdate.objects.get(pk=obj.pk).f_text, 'x' * 100000)

def test_lazy_after_big(self):
obj1 = FieldUpdate.objects.create()
obj1.f_text = 'x' * 70000
obj2 = FieldUpdate.objects.create()
obj2.f_binary = b'0' * 100000
FieldUpdate.objects.copy_update([obj1, obj2], ['f_binary', 'f_text'])
self.assertEqual(FieldUpdate.objects.get(pk=obj1.pk).f_text, 'x' * 70000)
self.assertEqual(FieldUpdate.objects.get(pk=obj2.pk).f_binary.tobytes(), b'0' * 100000)
self.assertEqual(tobytes(FieldUpdate.objects.get(pk=obj2.pk).f_binary), b'0' * 100000)


class TestCopyUpdateNotNull(TestCase):
Expand Down Expand Up @@ -342,25 +363,25 @@ def test_biginteger(self):
def test_binary(self):
self._single('f_binary')
self._single_raise('f_binary', 'wrong', "expected types <class 'memoryview'> or <class 'bytes'>")

def test_binary_big(self):
# >64k
data = b'1234567890' * 10000
obj = FieldUpdateNotNull.objects.create()
obj.f_binary = data
FieldUpdateNotNull.objects.copy_update([obj], ['f_binary'])
self.assertEqual(FieldUpdateNotNull.objects.get(pk=obj.pk).f_binary.tobytes(), data)
self.assertEqual(tobytes(FieldUpdateNotNull.objects.get(pk=obj.pk).f_binary), data)
# <64k
data = b'1234567890' * 1000
obj = FieldUpdateNotNull.objects.create()
obj.f_binary = data
FieldUpdateNotNull.objects.copy_update([obj], ['f_binary'])
self.assertEqual(FieldUpdateNotNull.objects.get(pk=obj.pk).f_binary.tobytes(), data)
self.assertEqual(tobytes(FieldUpdateNotNull.objects.get(pk=obj.pk).f_binary), data)

def test_boolean(self):
self._single('f_boolean')
self._single_raise('f_boolean', 'wrong', "expected type <class 'bool'>")

def test_char(self):
self._single('f_char')
self._single_raise('f_char', 123, "expected type <class 'str'>")
Expand Down Expand Up @@ -490,7 +511,7 @@ def test_local_nonlocal_mixed(self):
list(MultiSub.objects.all().values_list('b1', 'b2', 's1', 's2').order_by('pk')),
[(i, i*10, i*100, i*1000) for i in range(10)]
)

def test_nonlocal_only(self):
objs = [MultiSub.objects.create() for _ in range(10)]
for i, obj in enumerate(objs):
Expand Down Expand Up @@ -805,6 +826,7 @@ def test_singles(self):
res_a, res_b = FieldUpdateArray.objects.all().values(fieldname)
self.assertEqual(res_b[fieldname], res_a[fieldname])

@unittest.skipIf(connection.Database.__version__ > '3', "psycopg 3 does not perform array reduction")
def test_2d(self):
for fieldname, values in ARRAY_SINGLES.items():
fieldname += '2'
Expand Down Expand Up @@ -953,6 +975,7 @@ def test_override_array(self):


class TestArrayEvaluation(TestCase):
@unittest.skipIf(connection.Database.__version__ > '3', "psycopg 3 does not perform array reduction")
def test_empty_reduction(self):
a = FieldUpdateArray.objects.create()
b = FieldUpdateArray.objects.create()
Expand Down Expand Up @@ -1016,7 +1039,7 @@ def test_hstore_array(self):
hstore={}, int_r=NumericRange(1,8), int_2d=[], hstore_2d=[], int_r_2d=[])
b = TestCoverage.objects.create(
hstore={}, int_r=NumericRange(1,8), int_2d=[], hstore_2d=[], int_r_2d=[])

# 1d
a.hstore_2d = value
b.hstore_2d = value
Expand Down Expand Up @@ -1044,7 +1067,7 @@ def test_range_array(self):
hstore={}, int_r=NumericRange(1,8), int_2d=[], hstore_2d=[], int_r_2d=[])
b = TestCoverage.objects.create(
hstore={}, int_r=NumericRange(1,8), int_2d=[], hstore_2d=[], int_r_2d=[])

# 1d
a.int_r_2d = value
b.int_r_2d = value
Expand Down
34 changes: 26 additions & 8 deletions fast_update/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
from django.db.models.fields.related import RelatedField
from django.contrib.postgres.fields import (HStoreField, ArrayField, IntegerRangeField,
BigIntegerRangeField, DecimalRangeField, DateTimeRangeField, DateRangeField)
from psycopg2.extras import Range

from django.db import connection
if connection.Database.__version__ > '3':
from psycopg.types.range import Range
else:
from psycopg2.extras import Range

# typings imports
from django.db.backends.utils import CursorWrapper
Expand Down Expand Up @@ -83,6 +88,13 @@ class FieldEncoder(Generic[EncoderProto]):
}


def get_encoding(conn) -> str:
if connection.Database.__version__ > '3':
return conn.info.encoding
else:
return CONNECTION_ENCODINGS[conn.encoding] # psycopg2


# NULL placeholder for COPY FROM
NULL = '\\N'
# SQL NULL
Expand Down Expand Up @@ -161,7 +173,7 @@ def some_encoder(value: Any, fname: str, lazy: List[Any]) -> ReturnType: ...
the field ``fname``. ``lazy`` is a context helper object for lazy encoders
(see Lazy Encoder below).
The return type should be a string in postgres' TEXT format. For proper escaping
The return type should be a string in postgres' TEXT format. For proper escaping
the helper ``text_escape`` can be used. For None/nullish values ``NULL`` is predefined.
It is also possible to return a python type directly, if it is known to translate
correctly into the TEXT format from ``__str__`` output (some default encoders
Expand Down Expand Up @@ -426,7 +438,7 @@ def JsonOrNone(v: Any, fname: str, lazy: List[Any]):
def Text(v: Any, fname: str, lazy: List[Any]):
"""
Test and encode ``str``, raise for any other.
The encoder escapes characters as denoted in the postgres documentation
for the TEXT format of COPY FROM.
"""
Expand Down Expand Up @@ -762,13 +774,19 @@ def write_lazy(f: BinaryIO, data: bytearray, stack: List[Any]) -> None:
f.write(m[idx:])


def threaded_copy(
def compat_copy_from(
c: CursorWrapper,
fr: BinaryIO,
tname: str,
columns: Tuple[str]
) -> None:
c.copy_from(fr, tname, size=65536, columns=columns)
"""A copy_from operation compatible between psycopg2 and psycopg 3."""
if not connection.Database.__version__ > '3':
c.copy_from(fr, tname, size=65536, columns=columns)
else:
with c.copy(f"COPY {tname} ({','.join(columns)}) FROM STDIN") as copy:
while data := fr.read(4096):
copy.write(data)


def copy_from(
Expand Down Expand Up @@ -800,7 +818,7 @@ def copy_from(
fr = os.fdopen(r, 'rb')
fw = os.fdopen(w, 'wb')
t = Thread(
target=threaded_copy,
target=compat_copy_from,
args=[c.connection.cursor(), fr, tname, columns]
)
t.start()
Expand Down Expand Up @@ -839,7 +857,7 @@ def copy_from(
f.seek(0)
else:
f = BytesIO(payload)
c.copy_from(f, tname, size=65536, columns=columns)
compat_copy_from(c, f, tname, columns)
f.close()


Expand Down Expand Up @@ -905,7 +923,7 @@ def copy_update(
c.execute(f'DROP TABLE IF EXISTS "{temp}"')
c.execute(f'CREATE TEMPORARY TABLE "{temp}" ({create_columns(column_def)})')
copy_from(c, temp, objs, attnames, colnames, get, encs,
encoding or CONNECTION_ENCODINGS[c.connection.encoding])
encoding or get_encoding(c.connection))
# optimization (~6x speedup in ./manage.py perf for 10 instances):
# for small changesets ANALYZE is much more expensive than
# a sequential scan of the temp table
Expand Down
7 changes: 3 additions & 4 deletions fast_update/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,17 +108,16 @@ def copy_update(
"""
self._for_write = True
connection = connections[self.db]
if connection.vendor != 'postgresql' or connection.Database.__version__ > '3':
raise NotSupportedError(
f'copy_update() only supported on "postgres" backend with psycopg2')
if connection.vendor != 'postgresql':
raise NotSupportedError(f'copy_update() is not supported on "{connection.vendor}" backend')
from .copy import copy_update # TODO: better in conditional import?
if not objs:
return 0
objs = tuple(objs)
fields_ = set(fields or [])
sanity_check(self.model, objs, fields_, 'copy_update()')
return copy_update(self, objs, fields_, field_encoders, encoding)

copy_update.alters_data = True


Expand Down

0 comments on commit 772cfa8

Please sign in to comment.