Skip to content

Commit

Permalink
Updated implementation to accept serializers from parameters to match…
Browse files Browse the repository at this point in the history
… Djangos improved interface.
  • Loading branch information
mitsuhiko committed Jun 29, 2011
1 parent 6648bc0 commit f5a2158
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 10 deletions.
36 changes: 31 additions & 5 deletions itsdangerous.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,21 +201,27 @@ class Serializer(object):
:meth:`dump_payload` functions.
This implementation uses simplejson for dumping and loading.
.. versionchanged:: 0.
"""
default_serializer = simplejson

def __init__(self, secret_key, salt='itsdangerous'):
def __init__(self, secret_key, salt='itsdangerous', serializer=None):
self.secret_key = secret_key
self.salt = salt
if serializer is None:
serializer = self.default_serializer
self.serializer = serializer

def load_payload(self, payload):
"""Loads the encoded object. This implementation uses simplejson."""
return simplejson.loads(payload)
return self.serializer.loads(payload)

def dump_payload(self, obj):
"""Dumps the encoded object into a bytestring. This implementation
uses simplejson.
"""
return simplejson.dumps(obj)
return self.serializer.dumps(obj)

def make_signer(self):
"""A method that creates a new instance of the signer to be used.
Expand Down Expand Up @@ -268,6 +274,10 @@ def loads(self, s, max_age=None, return_timestamp=False):


class URLSafeSerializerMixin(object):
"""Mixed in with a regular serializer it will attempt to zlib compress
the string to make it shorter if necessary. It will also base64 encode
the string so that it can safely be placed in a URL.
"""

def load_payload(self, payload):
decompress = False
Expand All @@ -277,10 +287,10 @@ def load_payload(self, payload):
json = base64_decode(payload)
if decompress:
json = zlib.decompress(json)
return simplejson.loads(json)
return super(URLSafeSerializerMixin, self).load_payload(json)

def dump_payload(self, obj):
json = simplejson.dumps(obj, separators=(',', ':'))
json = super(URLSafeSerializerMixin, self).dump_payload(obj)
is_compressed = False
compressed = zlib.compress(json)
if len(compressed) < (len(json) - 1):
Expand All @@ -292,15 +302,31 @@ def dump_payload(self, obj):
return base64d


class _CompactJSON(object):
"""Wrapper around simplejson that strips whitespace.
"""

def loads(self, payload):
return simplejson.loads(payload)

def dumps(self, obj):
return simplejson.dumps(obj, separators=(',', ':'))


compact_json = _CompactJSON()


class URLSafeSerializer(URLSafeSerializerMixin, Serializer):
"""Works like :class:`Serializer` but dumps and loads into a URL
safe string consisting of the upper and lowercase character of the
alphabet as well as ``'_'``, ``'-'`` and ``'.'``.
"""
default_serializer = compact_json


class URLSafeTimedSerializer(URLSafeSerializerMixin, TimedSerializer):
"""Works like :class:`TimedSerializer` but dumps and loads into a URL
safe string consisting of the upper and lowercase character of the
alphabet as well as ``'_'``, ``'-'`` and ``'.'``.
"""
default_serializer = compact_json
37 changes: 32 additions & 5 deletions tests.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import time
import pickle
import unittest
import itsdangerous as idmod


class SerializerTestCase(unittest.TestCase):
serializer_class = idmod.Serializer

def make_serializer(self, *args, **kwargs):
return self.serializer_class(*args, **kwargs)

def test_dumps_loads(self):
objects = (['a', 'list'], 'a string', u'a unicode string \u2019',
{'a': 'dictionary'}, 42, 42.5)
s = self.serializer_class('Test')
s = self.make_serializer('Test')
for o in objects:
value = s.dumps(o)
self.assertNotEqual(o, value)
Expand All @@ -26,7 +30,7 @@ def test_decode_detects_tampering(self):
'foo': 'bar',
'baz': 1,
}
s = self.serializer_class('Test')
s = self.make_serializer('Test')
encoded = s.dumps(value)
self.assertEqual(value, s.loads(encoded))
for transform in transforms:
Expand All @@ -36,7 +40,7 @@ def test_decode_detects_tampering(self):
def test_accepts_unicode(self):
objects = (['a', 'list'], 'a string', u'a unicode string \u2019',
{'a': 'dictionary'}, 42, 42.5)
s = self.serializer_class('Test')
s = self.make_serializer('Test')
for o in objects:
value = s.dumps(o)
self.assertNotEqual(o, value)
Expand All @@ -52,7 +56,7 @@ def test_decode_with_timeout(self):
_time = time.time
time.time = lambda: idmod.EPOCH
try:
s = self.serializer_class(secret_key)
s = self.make_serializer(secret_key)
ts = s.dumps(value)
self.assertNotEqual(ts, idmod.Serializer(secret_key).dumps(value))

Expand All @@ -74,14 +78,21 @@ def test_is_base62(self):
'ABCDEFGHIJKLMNOPQRSTUVWXYZ_-.')
objects = (['a', 'list'], 'a string', u'a unicode string \u2019',
{'a': 'dictionary'}, 42, 42.5)
s = self.serializer_class('Test')
s = self.make_serializer('Test')
for o in objects:
value = s.dumps(o)
self.assert_(set(value).issubset(set(allowed)))
self.assertNotEqual(o, value)
self.assertEqual(o, s.loads(value))


class PickleSerializerMixin(object):

def make_serializer(self, *args, **kwargs):
kwargs.setdefault('serializer', pickle)
return super(PickleSerializerMixin, self).make_serializer(*args, **kwargs)


class URLSafeSerializerTestCase(URLSafeSerializerMixin, SerializerTestCase):
serializer_class = idmod.URLSafeSerializer

Expand All @@ -90,5 +101,21 @@ class URLSafeTimedSerializerTestCase(URLSafeSerializerMixin, TimedSerializerTest
serializer_class = idmod.URLSafeTimedSerializer


class PickleSerializerTestCase(PickleSerializerMixin, SerializerTestCase):
pass


class PickleTimedSerializerTestCase(PickleSerializerMixin, TimedSerializerTestCase):
pass


class PickleURLSafeSerializerTestCase(PickleSerializerMixin, URLSafeSerializerTestCase):
pass


class PickleTimedSerializerTestCase(PickleSerializerMixin, URLSafeTimedSerializerTestCase):
pass


if __name__ == '__main__':
unittest.main()

0 comments on commit f5a2158

Please sign in to comment.