Skip to content

Commit

Permalink
more tests for copy_update
Browse files Browse the repository at this point in the history
  • Loading branch information
jerch committed Apr 6, 2022
1 parent 9fa6c79 commit 617cdce
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 44 deletions.
41 changes: 41 additions & 0 deletions example/postgres_tests/migrations/0002_fieldupdatenotnull.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Generated by Django 3.2.12 on 2022-04-06 13:59

import datetime
from decimal import Decimal
from django.db import migrations, models
import uuid


class Migration(migrations.Migration):

dependencies = [
('postgres_tests', '0001_initial'),
]

operations = [
migrations.CreateModel(
name='FieldUpdateNotNull',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('f_biginteger', models.BigIntegerField(default=0)),
('f_binary', models.BinaryField(default=b'')),
('f_boolean', models.BooleanField(default=False)),
('f_char', models.CharField(default='', max_length=32)),
('f_date', models.DateField(default=datetime.date(1000, 1, 1))),
('f_datetime', models.DateTimeField(default=datetime.datetime(1000, 1, 1, 0, 0))),
('f_decimal', models.DecimalField(decimal_places=2, default=Decimal('0'), max_digits=10)),
('f_duration', models.DurationField(default=datetime.timedelta(1))),
('f_email', models.EmailField(default='', max_length=254)),
('f_float', models.FloatField(default=0.0)),
('f_integer', models.IntegerField(default=0)),
('f_ip', models.GenericIPAddressField(default='')),
('f_json', models.JSONField(default=dict)),
('f_slug', models.SlugField(default='')),
('f_smallinteger', models.SmallIntegerField(default=0)),
('f_text', models.TextField(default='')),
('f_time', models.TimeField(default=datetime.time(1, 0))),
('f_url', models.URLField(default='')),
('f_uuid', models.UUIDField(default=uuid.UUID('5e0b7757-b58e-40c2-b305-a35912b4a43a'))),
],
),
]
31 changes: 31 additions & 0 deletions example/postgres_tests/migrations/0003_auto_20220406_1404.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Generated by Django 3.2.12 on 2022-04-06 14:04

import datetime
from django.db import migrations, models
from django.utils.timezone import utc
import uuid


class Migration(migrations.Migration):

dependencies = [
('postgres_tests', '0002_fieldupdatenotnull'),
]

operations = [
migrations.AlterField(
model_name='fieldupdatenotnull',
name='f_datetime',
field=models.DateTimeField(default=datetime.datetime(1000, 1, 1, 0, 0, tzinfo=utc)),
),
migrations.AlterField(
model_name='fieldupdatenotnull',
name='f_ip',
field=models.GenericIPAddressField(default='127.0.0.1'),
),
migrations.AlterField(
model_name='fieldupdatenotnull',
name='f_uuid',
field=models.UUIDField(default=uuid.UUID('cb36b1fa-976a-47ae-bf8b-42efa4fa59de')),
),
]
28 changes: 28 additions & 0 deletions example/postgres_tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
from django.contrib.postgres.fields import (ArrayField, HStoreField,
IntegerRangeField, DateTimeRangeField, DateRangeField)
from fast_update.query import FastUpdateManager
from datetime import date, datetime, timedelta, time
from decimal import Decimal
from uuid import uuid4
import pytz

fixed_uuid = uuid4()

class PostgresFields(models.Model):
objects = FastUpdateManager()
Expand All @@ -14,3 +19,26 @@ class PostgresFields(models.Model):
int_r = IntegerRangeField(null=True)
dt_r = DateTimeRangeField(null=True)
date_r = DateRangeField(null=True)


class FieldUpdateNotNull(models.Model):
objects = FastUpdateManager()
f_biginteger = models.BigIntegerField(default=0)
f_binary = models.BinaryField(default=b'')
f_boolean = models.BooleanField(default=False)
f_char = models.CharField(max_length=32, default='')
f_date = models.DateField(default=date(1000, 1, 1))
f_datetime = models.DateTimeField(default=datetime(1000,1,1,0,0,0, tzinfo=pytz.UTC))
f_decimal = models.DecimalField(max_digits=10, decimal_places=2, default=Decimal(0))
f_duration = models.DurationField(default=timedelta(days=1))
f_email = models.EmailField(default='')
f_float = models.FloatField(default=0.0)
f_integer = models.IntegerField(default=0)
f_ip = models.GenericIPAddressField(default='127.0.0.1')
f_json = models.JSONField(default=dict)
f_slug = models.SlugField(default='')
f_smallinteger = models.SmallIntegerField(default=0)
f_text = models.TextField(default='')
f_time = models.TimeField(default=time(1, 0, 0))
f_url = models.URLField(default='')
f_uuid = models.UUIDField(default=fixed_uuid)
53 changes: 51 additions & 2 deletions example/postgres_tests/tests.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from django.test import TestCase
from .models import PostgresFields
from .models import PostgresFields, FieldUpdateNotNull
from exampleapp.models import FieldUpdate
from psycopg2.extras import NumericRange, DateTimeTZRange, DateRange
import datetime
Expand Down Expand Up @@ -181,61 +181,89 @@ def _single(self, fieldname):
res_a, res_b = FieldUpdate.objects.all().values(fieldname)
self.assertEqual(res_b[fieldname], res_a[fieldname])

def _single_raise(self, fieldname, wrong_value, msg):
a = FieldUpdate.objects.create()
setattr(a, fieldname, wrong_value)
self.assertRaisesMessage(TypeError, msg, lambda : FieldUpdate.objects.copy_update([a], [fieldname]))

def test_biginteger(self):
self._single('f_biginteger')
self._single_raise('f_biginteger', 'wrong', 'expected int or NoneType')


def test_binary(self):
self._single('f_binary')
# test big binary data
data = b'1234567890' * 10000
obj = FieldUpdate.objects.all().first()
obj.f_binary = data
FieldUpdate.objects.copy_update([obj], ['f_binary'])
self.assertEqual(FieldUpdate.objects.get(pk=obj.pk).f_binary.tobytes(), data)
self._single_raise('f_binary', 'wrong', 'expected memoryview, bytes or NoneType')

def test_boolean(self):
self._single('f_boolean')
self._single_raise('f_boolean', 'wrong', 'expected bool or NoneType')

def test_char(self):
self._single('f_char')
self._single_raise('f_char', 123, 'expected str or NoneType')

def test_date(self):
self._single('f_date')
self._single_raise('f_date', 'wrong', 'expected datetime.date or NoneType')

def test_datetime(self):
self._single('f_datetime')
self._single_raise('f_datetime', 'wrong', 'expected datetime or NoneType')

def test_decimal(self):
self._single('f_decimal')
self._single_raise('f_decimal', 'wrong', 'expected Decimal or NoneType')

def test_duration(self):
self._single('f_duration')
self._single_raise('f_duration', 'wrong', 'expected timedelta or NoneType')

def test_email(self):
self._single('f_email')
self._single_raise('f_email', 123, 'expected str or NoneType')

def test_float(self):
self._single('f_float')
self._single_raise('f_float', 'wrong', 'expected float, int or NoneType')

def test_integer(self):
self._single('f_integer')
self._single_raise('f_integer', 'wrong', 'expected int or NoneType')

def test_ip(self):
self._single('f_ip')
self._single_raise('f_ip', 123, 'expected str or NoneType')

def test_json(self):
self._single('f_json')

def test_slug(self):
self._single('f_slug')
self._single_raise('f_slug', 123, 'expected str or NoneType')

def test_text(self):
self._single('f_text')
self._single_raise('f_text', 123, 'expected str or NoneType')

def test_time(self):
self._single('f_time')
self._single_raise('f_time', 'wrong', 'expected datetime.time or NoneType')

def test_uuid(self):
self._single('f_uuid')
self._single_raise('f_uuid', 'wrong', 'expected UUID or NoneType')

def test_updatefull_multiple(self):
a = []
b = []
for _ in range(100):
for _ in range(1000):
a.append(FieldUpdate.objects.create())
b.append(FieldUpdate.objects.create())
update_a = []
Expand All @@ -251,3 +279,24 @@ def test_updatefull_multiple(self):
for r in results[1:]:
for f in CU_FIELDS:
self.assertEqual(r[f], first[f])

class TestCopyUpdateNotNull(TestCase):
def test_updatefull_multiple(self):
a = []
b = []
for _ in range(1000):
a.append(FieldUpdateNotNull.objects.create())
b.append(FieldUpdateNotNull.objects.create())
update_a = []
for _a in a:
update_a.append(FieldUpdateNotNull(pk=_a.pk, **CU_EXAMPLE))
update_b = []
for _b in b:
update_b.append(FieldUpdateNotNull(pk=_b.pk, **CU_EXAMPLE))
FieldUpdateNotNull.objects.bulk_update(update_a, CU_FIELDS)
FieldUpdateNotNull.objects.copy_update(update_b, CU_FIELDS)
results = list(FieldUpdateNotNull.objects.all().values(*CU_FIELDS))
first = results[0]
for r in results[1:]:
for f in CU_FIELDS:
self.assertEqual(r[f], first[f])
59 changes: 17 additions & 42 deletions fast_update/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,34 +18,13 @@
# TODO: copy encoders and array impl from playground
# TODO: tons of tests...


def AsNone(v, lazy):
"""Treat field value as ``None`` converted to NULL."""
return NULL


def AsIs(v, lazy):
def textEscape(v):
"""
Field value passed along unchecked.
Can be used for a performance gain, if all provided field values
are known to correctly translate into Postgres' COPY TEXT format
by python's string formatting. That is sometimes the case for values,
where type checking/narrowing happened in an earlier step.
Python types known to work that way are: int, float, ...
Nullish field values may use ``AsIsOrNone`` instead.
When used for string values, make sure that the strings never contain
characters, that need explicit escaping in the TEXT format.
Escape str-like data for postgres' TEXT format.
"""
# FIXME: move description above to top level docs
return v


def AsIsOrNone(v, lazy):
"""Same as ``AsIs``, additionally handling ``None`` as NULL."""
if v is None:
return NULL
return v
return (v.replace('\\', '\\\\')
.replace('\b', '\\b').replace('\f', '\\f').replace('\n', '\\n')
.replace('\r', '\\r').replace('\t', '\\t').replace('\v', '\\v'))


def Int(v, lazy):
Expand Down Expand Up @@ -116,21 +95,21 @@ def BooleanOrNone(v, lazy):
return NULL
if isinstance(v, bool):
return v
raise TypeError('expected bool type or NoneType')
raise TypeError('expected bool or NoneType')


def Date(v, lazy):
if isinstance(v, date):
return v
raise TypeError('expected date type')
raise TypeError('expected datetime.date type')


def DateOrNone(v, lazy):
if v is None:
return NULL
if isinstance(v, date):
return v
raise TypeError('expected date type or NoneType')
raise TypeError('expected datetime.date or NoneType')


def Datetime(v, lazy):
Expand All @@ -144,7 +123,7 @@ def DatetimeOrNone(v, lazy):
return NULL
if isinstance(v, datetime):
return v
raise TypeError('expected datetime type or NoneType')
raise TypeError('expected datetime or NoneType')


def Numeric(v, lazy):
Expand All @@ -158,7 +137,7 @@ def NumericOrNone(v, lazy):
return NULL
if isinstance(v, Decimal):
return v
raise TypeError('expected Decimal type or NoneType')
raise TypeError('expected Decimal or NoneType')


def Duration(v, lazy):
Expand All @@ -172,7 +151,7 @@ def DurationOrNone(v, lazy):
return NULL
if isinstance(v, timedelta):
return v
raise TypeError('expected timedelta type or NoneType')
raise TypeError('expected timedelta or NoneType')


def Float(v, lazy):
Expand All @@ -190,13 +169,13 @@ def FloatOrNone(v, lazy):


def Json(v, lazy):
return Text(dumps(v), lazy)
return textEscape(dumps(v))


def JsonOrNone(v, lazy):
if v is None:
return NULL
return Text(dumps(v), lazy)
return textEscape(dumps(v))


def Text(v, lazy):
Expand All @@ -207,9 +186,7 @@ def Text(v, lazy):
for the TEXT format of COPY FROM.
"""
if isinstance(v, str):
return (v.replace('\\', '\\\\')
.replace('\b', '\\b').replace('\f', '\\f').replace('\n', '\\n')
.replace('\r', '\\r').replace('\t', '\\t').replace('\v', '\\v'))
return textEscape(v)
raise TypeError('expected str type')


Expand All @@ -218,9 +195,7 @@ def TextOrNone(v, lazy):
if v is None:
return NULL
if isinstance(v, str):
return (v.replace('\\', '\\\\')
.replace('\b', '\\b').replace('\f', '\\f').replace('\n', '\\n')
.replace('\r', '\\r').replace('\t', '\\t').replace('\v', '\\v'))
return textEscape(v)
raise TypeError('expected str or NoneType')


Expand All @@ -235,7 +210,7 @@ def TimeOrNone(v, lazy):
return NULL
if isinstance(v, dt_time):
return v
raise TypeError('expected datetime.time type or NoneType')
raise TypeError('expected datetime.time or NoneType')


def Uuid(v, lazy):
Expand All @@ -249,7 +224,7 @@ def UuidOrNone(v, lazy):
return NULL
if isinstance(v, UUID):
return v
raise TypeError('expected UUID type or NoneType')
raise TypeError('expected UUID or NoneType')


ENCODERS = {
Expand Down

0 comments on commit 617cdce

Please sign in to comment.