diff --git a/README.rst b/README.rst index 828f64c..d68f300 100644 --- a/README.rst +++ b/README.rst @@ -165,8 +165,9 @@ Changes UNRELEASED ========== +* Allowed default pickle protocol to be overriden using the + `PICKLEFIELD_DEFAULT_PROTOCOL` setting. * Dropped support for Python 2. -* Updated default pickle protocol to version 3. * Added testing against Django 3.0. * Dropped support for Django 1.11. diff --git a/picklefield/constants.py b/picklefield/constants.py index e8232d2..dc6d5c1 100644 --- a/picklefield/constants.py +++ b/picklefield/constants.py @@ -1,3 +1 @@ -from __future__ import unicode_literals - -DEFAULT_PROTOCOL = 3 +DEFAULT_PROTOCOL = 2 diff --git a/picklefield/fields.py b/picklefield/fields.py index 1e87690..10f0bf1 100644 --- a/picklefield/fields.py +++ b/picklefield/fields.py @@ -1,21 +1,15 @@ -from __future__ import unicode_literals - from base64 import b64decode, b64encode from copy import deepcopy +from pickle import dumps, loads from zlib import compress, decompress -from django import VERSION as DJANGO_VERSION +from django.conf import settings from django.core import checks from django.db import models -from django.utils.encoding import force_text +from django.utils.encoding import force_str from .constants import DEFAULT_PROTOCOL -try: - from cPickle import loads, dumps # pragma: no cover -except ImportError: - from pickle import loads, dumps # pragma: no cover - class PickledObject(str): """ @@ -52,13 +46,19 @@ def wrap_conflictual_object(obj): return obj -def dbsafe_encode(value, compress_object=False, pickle_protocol=DEFAULT_PROTOCOL, copy=True): +def get_default_protocol(): + return getattr(settings, 'PICKLEFIELD_DEFAULT_PROTOCOL', DEFAULT_PROTOCOL) + + +def dbsafe_encode(value, compress_object=False, pickle_protocol=None, copy=True): # We use deepcopy() here to avoid a problem with cPickle, where dumps # can generate different character streams for same lookup value if # they are referenced differently. # The reason this is important is because we do all of our lookups as # simple string matches, thus the character streams must be the same # for the lookups to work properly. See tests.py for more information. + if pickle_protocol is None: + pickle_protocol = get_default_protocol() if copy: # Copy can be very expensive if users aren't going to perform lookups # on the value anyway. @@ -92,7 +92,10 @@ class PickledObjectField(models.Field): def __init__(self, *args, **kwargs): self.compress = kwargs.pop('compress', False) - self.protocol = kwargs.pop('protocol', DEFAULT_PROTOCOL) + protocol = kwargs.pop('protocol', None) + if protocol is None: + protocol = get_default_protocol() + self.protocol = protocol self.copy = kwargs.pop('copy', True) kwargs.setdefault('editable', False) super().__init__(*args, **kwargs) @@ -143,6 +146,14 @@ def check(self, **kwargs): errors.extend(self._check_default()) return errors + def deconstruct(self): + name, path, args, kwargs = super().deconstruct() + if self.compress: + kwargs['compress'] = True + if self.protocol != get_default_protocol(): + kwargs['protocol'] = self.protocol + return name, path, args, kwargs + def to_python(self, value): """ B64decode and unpickle the object, optionally decompressing it. @@ -170,12 +181,8 @@ def pre_save(self, model_instance, add): value = super().pre_save(model_instance, add) return wrap_conflictual_object(value) - if DJANGO_VERSION < (2, 0): - def from_db_value(self, value, expression, connection, context): # pragma: no cover - return self.to_python(value) # pragma: no cover - else: - def from_db_value(self, value, expression, connection): # pragma: no cover - return self.to_python(value) # pragma: no cover + def from_db_value(self, value, expression, connection): + return self.to_python(value) def get_db_prep_value(self, value, connection=None, prepared=False): """ @@ -189,13 +196,13 @@ def get_db_prep_value(self, value, connection=None, prepared=False): """ if value is not None and not isinstance(value, PickledObject): - # We call force_text here explicitly, so that the encoded string + # We call force_str here explicitly, so that the encoded string # isn't rejected by the postgresql_psycopg2 backend. Alternatively, # we could have just registered PickledObject with the psycopg # marshaller (telling it to store it like it would a string), but # since both of these methods result in the same value being stored, # doing things this way is much easier. - value = force_text(dbsafe_encode(value, self.compress, self.protocol, self.copy)) + value = force_str(dbsafe_encode(value, self.compress, self.protocol, self.copy)) return value def value_to_string(self, obj): diff --git a/tests/models.py b/tests/models.py index 3c80b6e..eca1dbf 100644 --- a/tests/models.py +++ b/tests/models.py @@ -1,5 +1,3 @@ -from __future__ import unicode_literals - from datetime import date from django.db import models diff --git a/tests/settings.py b/tests/settings.py index e2a67fa..6911ad4 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -1,5 +1,3 @@ -from __future__ import unicode_literals - SECRET_KEY = 'not-anymore' DATABASES = { diff --git a/tests/tests.py b/tests/tests.py index 1bb42f8..88be048 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -1,5 +1,6 @@ import json from datetime import date +from unittest.mock import patch from django.core import checks, serializers from django.db import IntegrityError, models @@ -14,18 +15,13 @@ TestCustomDataType, TestingModel, ) -try: - from unittest.mock import patch # pragma: no cover -except ImportError: - from mock import patch # pragma: no cover - class PickledObjectFieldTests(TestCase): def setUp(self): self.testing_data = (D2, S1, T1, L1, TestCustomDataType(S1), MinimalTestingModel) - return super(PickledObjectFieldTests, self).setUp() + return super().setUp() def test_data_integrity(self): """ @@ -203,6 +199,17 @@ def mock_decode_error(*args, **kwargs): self.assertEqual(encoded_value, MinimalTestingModel.objects.get(pk=model.pk).pickle_field) +class PickledObjectFieldDeconstructTests(SimpleTestCase): + def test_protocol(self): + field = PickledObjectField() + self.assertNotIn('protocol', field.deconstruct()[3]) + with self.settings(PICKLEFIELD_DEFAULT_PROTOCOL=3): + field = PickledObjectField(protocol=4) + self.assertEqual(field.deconstruct()[3].get('protocol'), 4) + field = PickledObjectField(protocol=3) + self.assertNotIn('protocol', field.deconstruct()[3]) + + @isolate_apps('tests') class PickledObjectFieldCheckTests(SimpleTestCase): def test_mutable_default_check(self):