Skip to content

Commit

Permalink
Allow pickle protocol to be specified in serializer.
Browse files Browse the repository at this point in the history
Also adds test to deserialization when compression is enabled, to first
check whether the data actually is compressed before decompressing. A
corollary check is **not** added for when compression is disabled (the
default).
  • Loading branch information
coleifer committed Jan 6, 2020
1 parent a40351a commit 49dd3ad
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 3 deletions.
23 changes: 20 additions & 3 deletions huey/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@
zlib = None
import hashlib
import hmac
import logging
import pickle
import sys

from huey.exceptions import ConfigurationError
from huey.utils import encode


logger = logging.getLogger('huey.serializer')


if gzip is not None:
if sys.version_info[0] > 2:
gzip_compress = gzip.compress
Expand All @@ -39,11 +43,21 @@ def gzip_decompress(data):
fh.close()


if sys.version_info[0] == 2:
def is_compressed(data):
return data and (data[0] == b'\x1f' or data[0] == b'\x78')
else:
def is_compressed(data):
return data and data[0] == 0x1for data[0] == 0x78

This comment has been minimized.

Copy link
@adamchainz

adamchainz Jan 6, 2020

Contributor

I think you missed the space between 0x1f and or ?



class Serializer(object):
def __init__(self, compression=False, compression_level=6, use_zlib=False):
def __init__(self, compression=False, compression_level=6, use_zlib=False,
pickle_protocol=pickle.HIGHEST_PROTOCOL):
self.comp = compression
self.comp_level = compression_level
self.use_zlib = use_zlib
self.pickle_protocol = pickle_protocol or pickle.HIGHEST_PROTOCOL
if self.comp:
if self.use_zlib and zlib is None:
raise ConfigurationError('use_zlib specified, but zlib module '
Expand All @@ -53,7 +67,7 @@ def __init__(self, compression=False, compression_level=6, use_zlib=False):
'compression.')

def _serialize(self, data):
return pickle.dumps(data, pickle.HIGHEST_PROTOCOL)
return pickle.dumps(data, self.pickle_protocol)

def _deserialize(self, data):
return pickle.loads(data)
Expand All @@ -69,7 +83,10 @@ def serialize(self, data):

def deserialize(self, data):
if self.comp:
if self.use_zlib:
if not is_compressed(data):
logger.warning('compression enabled but message data does not '
'appear to be compressed.')
elif self.use_zlib:
data = zlib.decompress(data)
else:
data = gzip_decompress(data)
Expand Down
9 changes: 9 additions & 0 deletions huey/tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,12 @@ def test_serializer_gzip(self):
@unittest.skipIf(zlib is None, 'zlib module not installed')
def test_serializer_zlib(self):
self._test_serializer(Serializer(compression=True, use_zlib=True))

@unittest.skipIf(zlib is None, 'zlib module not installed')
@unittest.skipIf(gzip is None, 'gzip module not installed')
def test_mismatched_compression(self):
for use_zlib in (False, True):
s = Serializer()
scomp = Serializer(compression=True, use_zlib=use_zlib)
for item in self.data:
self.assertEqual(scomp.deserialize(s.serialize(item)), item)

0 comments on commit 49dd3ad

Please sign in to comment.