Skip to content

Commit

Permalink
Add custom loads and dumps settings
Browse files Browse the repository at this point in the history
  • Loading branch information
massover committed May 28, 2021
1 parent c6e018e commit a075744
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ django_picklefield.egg-info/*
build/
.coverage
.tox
.tool-versions
15 changes: 15 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,21 @@ and assign whatever you like (as long as it's picklable) to the field:
>>> obj.args = ['fancy', {'objects': 'inside'}]
>>> obj.save()
--------
Settings
--------

`PICKLEFIELD_DEFAULT_PROTOCOL`

Set the default pickle protocol on dumps encoding

`PICKLEFIELD_DUMPS`

Path to a dumps function, defaults to `pickle.dumps`

`PICKLEFIELD_LOADS`

Path to a loads function, defaults to `pickle.loads`

-----
Notes
Expand Down
22 changes: 21 additions & 1 deletion picklefield/fields.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import functools
import pickle
from base64 import b64decode, b64encode
from copy import deepcopy
from pickle import dumps, loads
from zlib import compress, decompress

from django.conf import settings
from django.core import checks
from django.db import models
from django.utils.encoding import force_str
from django.utils.module_loading import import_string

from .constants import DEFAULT_PROTOCOL

Expand Down Expand Up @@ -50,6 +52,22 @@ def get_default_protocol():
return getattr(settings, 'PICKLEFIELD_DEFAULT_PROTOCOL', DEFAULT_PROTOCOL)


@functools.lru_cache(maxsize=1)
def get_dumps():
attr = getattr(settings, 'PICKLEFIELD_DUMPS', None)
if attr is None:
return pickle.dumps
return import_string(attr)


@functools.lru_cache(maxsize=1)
def get_loads():
attr = getattr(settings, 'PICKLEFIELD_LOADS', None)
if attr is None:
return pickle.loads
return import_string(attr)


def dbsafe_encode(value, compress_object=False, pickle_protocol=None, copy=True):
# We use deepcopy() here to avoid a problem with cPickle, where dumps
# can generate different character streams for same lookup value if
Expand All @@ -63,6 +81,7 @@ def dbsafe_encode(value, compress_object=False, pickle_protocol=None, copy=True)
# Copy can be very expensive if users aren't going to perform lookups
# on the value anyway.
value = deepcopy(value)
dumps = get_dumps()
value = dumps(value, protocol=pickle_protocol)
if compress_object:
value = compress(value)
Expand All @@ -75,6 +94,7 @@ def dbsafe_decode(value, compress_object=False):
value = b64decode(value)
if compress_object:
value = decompress(value)
loads = get_loads()
return loads(value)


Expand Down
1 change: 1 addition & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import date

from django.db import models

from picklefield import PickledObjectField

S1 = 'Hello World'
Expand Down
33 changes: 32 additions & 1 deletion tests/tests.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import json
import pickle
from datetime import date
from unittest.mock import patch

from django.core import checks, serializers
from django.db import IntegrityError, models
from django.test import SimpleTestCase, TestCase
from django.test.utils import isolate_apps

from picklefield.fields import (
PickledObjectField, dbsafe_encode, wrap_conflictual_object,
PickledObjectField, dbsafe_encode, get_dumps, get_loads,
wrap_conflictual_object,
)

from .models import (
Expand Down Expand Up @@ -251,3 +254,31 @@ class Model(models.Model):
set_field = PickledObjectField(default=set)

self.assertEqual(Model().check(), [])


def custom_fn():
pass


class TestSettings(SimpleTestCase):
def setUp(self):
get_dumps.cache_clear()
get_loads.cache_clear()

def test_get_dumps_defaults_to_pickle(self):
dumps = get_dumps()
self.assertEqual(dumps, pickle.dumps)

def test_get_loads_defaults_to_pickle(self):
loads = get_loads()
self.assertEqual(loads, pickle.loads)

def test_get_dumps_from_setting(self):
with self.settings(PICKLEFIELD_DUMPS='tests.tests.custom_fn'):
dumps = get_dumps()
self.assertEqual(dumps, custom_fn)

def test_get_loads_from_setting(self):
with self.settings(PICKLEFIELD_LOADS='tests.tests.custom_fn'):
loads = get_loads()
self.assertEqual(loads, custom_fn)

0 comments on commit a075744

Please sign in to comment.