From 617cdcee4ebd468aae0dbf45e14999130b0902da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B6rg=20Breitbart?= Date: Wed, 6 Apr 2022 18:29:29 +0200 Subject: [PATCH] more tests for copy_update --- .../migrations/0002_fieldupdatenotnull.py | 41 +++++++++++++ .../migrations/0003_auto_20220406_1404.py | 31 ++++++++++ example/postgres_tests/models.py | 28 +++++++++ example/postgres_tests/tests.py | 53 ++++++++++++++++- fast_update/copy.py | 59 ++++++------------- 5 files changed, 168 insertions(+), 44 deletions(-) create mode 100644 example/postgres_tests/migrations/0002_fieldupdatenotnull.py create mode 100644 example/postgres_tests/migrations/0003_auto_20220406_1404.py diff --git a/example/postgres_tests/migrations/0002_fieldupdatenotnull.py b/example/postgres_tests/migrations/0002_fieldupdatenotnull.py new file mode 100644 index 0000000..261c677 --- /dev/null +++ b/example/postgres_tests/migrations/0002_fieldupdatenotnull.py @@ -0,0 +1,41 @@ +# Generated by Django 3.2.12 on 2022-04-06 13:59 + +import datetime +from decimal import Decimal +from django.db import migrations, models +import uuid + + +class Migration(migrations.Migration): + + dependencies = [ + ('postgres_tests', '0001_initial'), + ] + + operations = [ + migrations.CreateModel( + name='FieldUpdateNotNull', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('f_biginteger', models.BigIntegerField(default=0)), + ('f_binary', models.BinaryField(default=b'')), + ('f_boolean', models.BooleanField(default=False)), + ('f_char', models.CharField(default='', max_length=32)), + ('f_date', models.DateField(default=datetime.date(1000, 1, 1))), + ('f_datetime', models.DateTimeField(default=datetime.datetime(1000, 1, 1, 0, 0))), + ('f_decimal', models.DecimalField(decimal_places=2, default=Decimal('0'), max_digits=10)), + ('f_duration', models.DurationField(default=datetime.timedelta(1))), + ('f_email', models.EmailField(default='', max_length=254)), + ('f_float', models.FloatField(default=0.0)), + ('f_integer', models.IntegerField(default=0)), + ('f_ip', models.GenericIPAddressField(default='')), + ('f_json', models.JSONField(default=dict)), + ('f_slug', models.SlugField(default='')), + ('f_smallinteger', models.SmallIntegerField(default=0)), + ('f_text', models.TextField(default='')), + ('f_time', models.TimeField(default=datetime.time(1, 0))), + ('f_url', models.URLField(default='')), + ('f_uuid', models.UUIDField(default=uuid.UUID('5e0b7757-b58e-40c2-b305-a35912b4a43a'))), + ], + ), + ] diff --git a/example/postgres_tests/migrations/0003_auto_20220406_1404.py b/example/postgres_tests/migrations/0003_auto_20220406_1404.py new file mode 100644 index 0000000..465ee95 --- /dev/null +++ b/example/postgres_tests/migrations/0003_auto_20220406_1404.py @@ -0,0 +1,31 @@ +# Generated by Django 3.2.12 on 2022-04-06 14:04 + +import datetime +from django.db import migrations, models +from django.utils.timezone import utc +import uuid + + +class Migration(migrations.Migration): + + dependencies = [ + ('postgres_tests', '0002_fieldupdatenotnull'), + ] + + operations = [ + migrations.AlterField( + model_name='fieldupdatenotnull', + name='f_datetime', + field=models.DateTimeField(default=datetime.datetime(1000, 1, 1, 0, 0, tzinfo=utc)), + ), + migrations.AlterField( + model_name='fieldupdatenotnull', + name='f_ip', + field=models.GenericIPAddressField(default='127.0.0.1'), + ), + migrations.AlterField( + model_name='fieldupdatenotnull', + name='f_uuid', + field=models.UUIDField(default=uuid.UUID('cb36b1fa-976a-47ae-bf8b-42efa4fa59de')), + ), + ] diff --git a/example/postgres_tests/models.py b/example/postgres_tests/models.py index 4bffb2f..c0fd929 100644 --- a/example/postgres_tests/models.py +++ b/example/postgres_tests/models.py @@ -2,7 +2,12 @@ from django.contrib.postgres.fields import (ArrayField, HStoreField, IntegerRangeField, DateTimeRangeField, DateRangeField) from fast_update.query import FastUpdateManager +from datetime import date, datetime, timedelta, time +from decimal import Decimal +from uuid import uuid4 +import pytz +fixed_uuid = uuid4() class PostgresFields(models.Model): objects = FastUpdateManager() @@ -14,3 +19,26 @@ class PostgresFields(models.Model): int_r = IntegerRangeField(null=True) dt_r = DateTimeRangeField(null=True) date_r = DateRangeField(null=True) + + +class FieldUpdateNotNull(models.Model): + objects = FastUpdateManager() + f_biginteger = models.BigIntegerField(default=0) + f_binary = models.BinaryField(default=b'') + f_boolean = models.BooleanField(default=False) + f_char = models.CharField(max_length=32, default='') + f_date = models.DateField(default=date(1000, 1, 1)) + f_datetime = models.DateTimeField(default=datetime(1000,1,1,0,0,0, tzinfo=pytz.UTC)) + f_decimal = models.DecimalField(max_digits=10, decimal_places=2, default=Decimal(0)) + f_duration = models.DurationField(default=timedelta(days=1)) + f_email = models.EmailField(default='') + f_float = models.FloatField(default=0.0) + f_integer = models.IntegerField(default=0) + f_ip = models.GenericIPAddressField(default='127.0.0.1') + f_json = models.JSONField(default=dict) + f_slug = models.SlugField(default='') + f_smallinteger = models.SmallIntegerField(default=0) + f_text = models.TextField(default='') + f_time = models.TimeField(default=time(1, 0, 0)) + f_url = models.URLField(default='') + f_uuid = models.UUIDField(default=fixed_uuid) diff --git a/example/postgres_tests/tests.py b/example/postgres_tests/tests.py index 6be97b2..9b56e10 100644 --- a/example/postgres_tests/tests.py +++ b/example/postgres_tests/tests.py @@ -1,5 +1,5 @@ from django.test import TestCase -from .models import PostgresFields +from .models import PostgresFields, FieldUpdateNotNull from exampleapp.models import FieldUpdate from psycopg2.extras import NumericRange, DateTimeTZRange, DateRange import datetime @@ -181,61 +181,89 @@ def _single(self, fieldname): res_a, res_b = FieldUpdate.objects.all().values(fieldname) self.assertEqual(res_b[fieldname], res_a[fieldname]) + def _single_raise(self, fieldname, wrong_value, msg): + a = FieldUpdate.objects.create() + setattr(a, fieldname, wrong_value) + self.assertRaisesMessage(TypeError, msg, lambda : FieldUpdate.objects.copy_update([a], [fieldname])) + def test_biginteger(self): self._single('f_biginteger') + self._single_raise('f_biginteger', 'wrong', 'expected int or NoneType') + def test_binary(self): self._single('f_binary') + # test big binary data + data = b'1234567890' * 10000 + obj = FieldUpdate.objects.all().first() + obj.f_binary = data + FieldUpdate.objects.copy_update([obj], ['f_binary']) + self.assertEqual(FieldUpdate.objects.get(pk=obj.pk).f_binary.tobytes(), data) + self._single_raise('f_binary', 'wrong', 'expected memoryview, bytes or NoneType') def test_boolean(self): self._single('f_boolean') + self._single_raise('f_boolean', 'wrong', 'expected bool or NoneType') def test_char(self): self._single('f_char') + self._single_raise('f_char', 123, 'expected str or NoneType') def test_date(self): self._single('f_date') + self._single_raise('f_date', 'wrong', 'expected datetime.date or NoneType') def test_datetime(self): self._single('f_datetime') + self._single_raise('f_datetime', 'wrong', 'expected datetime or NoneType') def test_decimal(self): self._single('f_decimal') + self._single_raise('f_decimal', 'wrong', 'expected Decimal or NoneType') def test_duration(self): self._single('f_duration') + self._single_raise('f_duration', 'wrong', 'expected timedelta or NoneType') def test_email(self): self._single('f_email') + self._single_raise('f_email', 123, 'expected str or NoneType') def test_float(self): self._single('f_float') + self._single_raise('f_float', 'wrong', 'expected float, int or NoneType') def test_integer(self): self._single('f_integer') + self._single_raise('f_integer', 'wrong', 'expected int or NoneType') def test_ip(self): self._single('f_ip') + self._single_raise('f_ip', 123, 'expected str or NoneType') def test_json(self): self._single('f_json') def test_slug(self): self._single('f_slug') + self._single_raise('f_slug', 123, 'expected str or NoneType') def test_text(self): self._single('f_text') + self._single_raise('f_text', 123, 'expected str or NoneType') def test_time(self): self._single('f_time') + self._single_raise('f_time', 'wrong', 'expected datetime.time or NoneType') def test_uuid(self): self._single('f_uuid') + self._single_raise('f_uuid', 'wrong', 'expected UUID or NoneType') def test_updatefull_multiple(self): a = [] b = [] - for _ in range(100): + for _ in range(1000): a.append(FieldUpdate.objects.create()) b.append(FieldUpdate.objects.create()) update_a = [] @@ -251,3 +279,24 @@ def test_updatefull_multiple(self): for r in results[1:]: for f in CU_FIELDS: self.assertEqual(r[f], first[f]) + +class TestCopyUpdateNotNull(TestCase): + def test_updatefull_multiple(self): + a = [] + b = [] + for _ in range(1000): + a.append(FieldUpdateNotNull.objects.create()) + b.append(FieldUpdateNotNull.objects.create()) + update_a = [] + for _a in a: + update_a.append(FieldUpdateNotNull(pk=_a.pk, **CU_EXAMPLE)) + update_b = [] + for _b in b: + update_b.append(FieldUpdateNotNull(pk=_b.pk, **CU_EXAMPLE)) + FieldUpdateNotNull.objects.bulk_update(update_a, CU_FIELDS) + FieldUpdateNotNull.objects.copy_update(update_b, CU_FIELDS) + results = list(FieldUpdateNotNull.objects.all().values(*CU_FIELDS)) + first = results[0] + for r in results[1:]: + for f in CU_FIELDS: + self.assertEqual(r[f], first[f]) diff --git a/fast_update/copy.py b/fast_update/copy.py index e5c791a..588a355 100644 --- a/fast_update/copy.py +++ b/fast_update/copy.py @@ -18,34 +18,13 @@ # TODO: copy encoders and array impl from playground # TODO: tons of tests... - -def AsNone(v, lazy): - """Treat field value as ``None`` converted to NULL.""" - return NULL - - -def AsIs(v, lazy): +def textEscape(v): """ - Field value passed along unchecked. - - Can be used for a performance gain, if all provided field values - are known to correctly translate into Postgres' COPY TEXT format - by python's string formatting. That is sometimes the case for values, - where type checking/narrowing happened in an earlier step. - Python types known to work that way are: int, float, ... - Nullish field values may use ``AsIsOrNone`` instead. - When used for string values, make sure that the strings never contain - characters, that need explicit escaping in the TEXT format. + Escape str-like data for postgres' TEXT format. """ - # FIXME: move description above to top level docs - return v - - -def AsIsOrNone(v, lazy): - """Same as ``AsIs``, additionally handling ``None`` as NULL.""" - if v is None: - return NULL - return v + return (v.replace('\\', '\\\\') + .replace('\b', '\\b').replace('\f', '\\f').replace('\n', '\\n') + .replace('\r', '\\r').replace('\t', '\\t').replace('\v', '\\v')) def Int(v, lazy): @@ -116,13 +95,13 @@ def BooleanOrNone(v, lazy): return NULL if isinstance(v, bool): return v - raise TypeError('expected bool type or NoneType') + raise TypeError('expected bool or NoneType') def Date(v, lazy): if isinstance(v, date): return v - raise TypeError('expected date type') + raise TypeError('expected datetime.date type') def DateOrNone(v, lazy): @@ -130,7 +109,7 @@ def DateOrNone(v, lazy): return NULL if isinstance(v, date): return v - raise TypeError('expected date type or NoneType') + raise TypeError('expected datetime.date or NoneType') def Datetime(v, lazy): @@ -144,7 +123,7 @@ def DatetimeOrNone(v, lazy): return NULL if isinstance(v, datetime): return v - raise TypeError('expected datetime type or NoneType') + raise TypeError('expected datetime or NoneType') def Numeric(v, lazy): @@ -158,7 +137,7 @@ def NumericOrNone(v, lazy): return NULL if isinstance(v, Decimal): return v - raise TypeError('expected Decimal type or NoneType') + raise TypeError('expected Decimal or NoneType') def Duration(v, lazy): @@ -172,7 +151,7 @@ def DurationOrNone(v, lazy): return NULL if isinstance(v, timedelta): return v - raise TypeError('expected timedelta type or NoneType') + raise TypeError('expected timedelta or NoneType') def Float(v, lazy): @@ -190,13 +169,13 @@ def FloatOrNone(v, lazy): def Json(v, lazy): - return Text(dumps(v), lazy) + return textEscape(dumps(v)) def JsonOrNone(v, lazy): if v is None: return NULL - return Text(dumps(v), lazy) + return textEscape(dumps(v)) def Text(v, lazy): @@ -207,9 +186,7 @@ def Text(v, lazy): for the TEXT format of COPY FROM. """ if isinstance(v, str): - return (v.replace('\\', '\\\\') - .replace('\b', '\\b').replace('\f', '\\f').replace('\n', '\\n') - .replace('\r', '\\r').replace('\t', '\\t').replace('\v', '\\v')) + return textEscape(v) raise TypeError('expected str type') @@ -218,9 +195,7 @@ def TextOrNone(v, lazy): if v is None: return NULL if isinstance(v, str): - return (v.replace('\\', '\\\\') - .replace('\b', '\\b').replace('\f', '\\f').replace('\n', '\\n') - .replace('\r', '\\r').replace('\t', '\\t').replace('\v', '\\v')) + return textEscape(v) raise TypeError('expected str or NoneType') @@ -235,7 +210,7 @@ def TimeOrNone(v, lazy): return NULL if isinstance(v, dt_time): return v - raise TypeError('expected datetime.time type or NoneType') + raise TypeError('expected datetime.time or NoneType') def Uuid(v, lazy): @@ -249,7 +224,7 @@ def UuidOrNone(v, lazy): return NULL if isinstance(v, UUID): return v - raise TypeError('expected UUID type or NoneType') + raise TypeError('expected UUID or NoneType') ENCODERS = {