Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove compatiblity shims and allow default pickle override through a setting. #52

Merged
merged 4 commits into from
Jun 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,9 @@ Changes
UNRELEASED
==========

* Allowed default pickle protocol to be overriden using the
`PICKLEFIELD_DEFAULT_PROTOCOL` setting.
* Dropped support for Python 2.
* Updated default pickle protocol to version 3.
* Added testing against Django 3.0.
* Dropped support for Django 1.11.

Expand Down
4 changes: 1 addition & 3 deletions picklefield/constants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1 @@
from __future__ import unicode_literals

DEFAULT_PROTOCOL = 3
DEFAULT_PROTOCOL = 2
45 changes: 26 additions & 19 deletions picklefield/fields.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,15 @@
from __future__ import unicode_literals

from base64 import b64decode, b64encode
from copy import deepcopy
from pickle import dumps, loads
from zlib import compress, decompress

from django import VERSION as DJANGO_VERSION
from django.conf import settings
from django.core import checks
from django.db import models
from django.utils.encoding import force_text
from django.utils.encoding import force_str

from .constants import DEFAULT_PROTOCOL

try:
from cPickle import loads, dumps # pragma: no cover
except ImportError:
from pickle import loads, dumps # pragma: no cover


class PickledObject(str):
"""
Expand Down Expand Up @@ -52,13 +46,19 @@ def wrap_conflictual_object(obj):
return obj


def dbsafe_encode(value, compress_object=False, pickle_protocol=DEFAULT_PROTOCOL, copy=True):
def get_default_protocol():
return getattr(settings, 'PICKLEFIELD_DEFAULT_PROTOCOL', DEFAULT_PROTOCOL)


def dbsafe_encode(value, compress_object=False, pickle_protocol=None, copy=True):
# We use deepcopy() here to avoid a problem with cPickle, where dumps
# can generate different character streams for same lookup value if
# they are referenced differently.
# The reason this is important is because we do all of our lookups as
# simple string matches, thus the character streams must be the same
# for the lookups to work properly. See tests.py for more information.
if pickle_protocol is None:
pickle_protocol = get_default_protocol()
if copy:
# Copy can be very expensive if users aren't going to perform lookups
# on the value anyway.
Expand Down Expand Up @@ -92,7 +92,10 @@ class PickledObjectField(models.Field):

def __init__(self, *args, **kwargs):
self.compress = kwargs.pop('compress', False)
self.protocol = kwargs.pop('protocol', DEFAULT_PROTOCOL)
protocol = kwargs.pop('protocol', None)
if protocol is None:
protocol = get_default_protocol()
self.protocol = protocol
self.copy = kwargs.pop('copy', True)
kwargs.setdefault('editable', False)
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -143,6 +146,14 @@ def check(self, **kwargs):
errors.extend(self._check_default())
return errors

def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
if self.compress:
kwargs['compress'] = True
if self.protocol != get_default_protocol():
kwargs['protocol'] = self.protocol
return name, path, args, kwargs

def to_python(self, value):
"""
B64decode and unpickle the object, optionally decompressing it.
Expand Down Expand Up @@ -170,12 +181,8 @@ def pre_save(self, model_instance, add):
value = super().pre_save(model_instance, add)
return wrap_conflictual_object(value)

if DJANGO_VERSION < (2, 0):
def from_db_value(self, value, expression, connection, context): # pragma: no cover
return self.to_python(value) # pragma: no cover
else:
def from_db_value(self, value, expression, connection): # pragma: no cover
return self.to_python(value) # pragma: no cover
def from_db_value(self, value, expression, connection):
return self.to_python(value)

def get_db_prep_value(self, value, connection=None, prepared=False):
"""
Expand All @@ -189,13 +196,13 @@ def get_db_prep_value(self, value, connection=None, prepared=False):

"""
if value is not None and not isinstance(value, PickledObject):
# We call force_text here explicitly, so that the encoded string
# We call force_str here explicitly, so that the encoded string
# isn't rejected by the postgresql_psycopg2 backend. Alternatively,
# we could have just registered PickledObject with the psycopg
# marshaller (telling it to store it like it would a string), but
# since both of these methods result in the same value being stored,
# doing things this way is much easier.
value = force_text(dbsafe_encode(value, self.compress, self.protocol, self.copy))
value = force_str(dbsafe_encode(value, self.compress, self.protocol, self.copy))
return value

def value_to_string(self, obj):
Expand Down
2 changes: 0 additions & 2 deletions tests/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from __future__ import unicode_literals

from datetime import date

from django.db import models
Expand Down
2 changes: 0 additions & 2 deletions tests/settings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from __future__ import unicode_literals

SECRET_KEY = 'not-anymore'

DATABASES = {
Expand Down
19 changes: 13 additions & 6 deletions tests/tests.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from datetime import date
from unittest.mock import patch

from django.core import checks, serializers
from django.db import IntegrityError, models
Expand All @@ -14,18 +15,13 @@
TestCustomDataType, TestingModel,
)

try:
from unittest.mock import patch # pragma: no cover
except ImportError:
from mock import patch # pragma: no cover


class PickledObjectFieldTests(TestCase):
def setUp(self):
self.testing_data = (D2, S1, T1, L1,
TestCustomDataType(S1),
MinimalTestingModel)
return super(PickledObjectFieldTests, self).setUp()
return super().setUp()

def test_data_integrity(self):
"""
Expand Down Expand Up @@ -203,6 +199,17 @@ def mock_decode_error(*args, **kwargs):
self.assertEqual(encoded_value, MinimalTestingModel.objects.get(pk=model.pk).pickle_field)


class PickledObjectFieldDeconstructTests(SimpleTestCase):
def test_protocol(self):
field = PickledObjectField()
self.assertNotIn('protocol', field.deconstruct()[3])
with self.settings(PICKLEFIELD_DEFAULT_PROTOCOL=3):
field = PickledObjectField(protocol=4)
self.assertEqual(field.deconstruct()[3].get('protocol'), 4)
field = PickledObjectField(protocol=3)
self.assertNotIn('protocol', field.deconstruct()[3])


@isolate_apps('tests')
class PickledObjectFieldCheckTests(SimpleTestCase):
def test_mutable_default_check(self):
Expand Down