Skip to content

Commit

Permalink
Added test for dump and load for file like objects
Browse files Browse the repository at this point in the history
  • Loading branch information
hgrecco committed Jan 28, 2016
1 parent 7d80d3a commit 8d34107
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 7 deletions.
20 changes: 14 additions & 6 deletions serialize/all.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
ClassHelper = namedtuple('ClassHelper', 'to_builtin from_builtin')

#: Stores information and function about each format type.
FormatHelper = namedtuple('Format', 'extension dump dumps load loads')
Format = namedtuple('Format', 'extension dump dumps load loads')
UnavailableFormat = namedtuple('UnavailableFormat', 'extension msg')

#: Map unavailable formats to the corresponding error message.
# :type: str -> str
# :type: str -> UnavailableFormat
UNAVAILABLE_FORMATS = {}

#: Map available format names to the corresponding dumper and loader.
# :type: str -> FormatHelper
# :type: str -> Format
FORMATS = {}

#: Map extension to format name.
Expand Down Expand Up @@ -260,13 +261,13 @@ def raiser(*args, **kwargs):
if extension is MISSING:
extension = fmt

FORMATS[fmt] = FormatHelper(extension, dumper, dumpser, loader, loadser)
FORMATS[fmt] = Format(extension, dumper, dumpser, loader, loadser)

if extension:
FORMAT_BY_EXTENSION[extension.lower()] = fmt


def register_unavailable(fmt, msg='', pkg=''):
def register_unavailable(fmt, msg='', pkg='', extension=MISSING):
"""Register an unavailable serialization format.
Unavailable formats are those known by Serialize but that cannot be used
Expand All @@ -276,7 +277,14 @@ def register_unavailable(fmt, msg='', pkg=''):
if pkg:
msg = 'This serialization format requires the %s package.' % pkg

UNAVAILABLE_FORMATS[fmt] = msg

if extension is MISSING:
extension = fmt

FORMATS[fmt] = UnavailableFormat(extension, msg)

if extension:
FORMAT_BY_EXTENSION[extension.lower()] = fmt


def dumps(obj, fmt):
Expand Down
19 changes: 18 additions & 1 deletion serialize/testsuite/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

import io
import os
from unittest import TestCase, skipIf

Expand Down Expand Up @@ -69,7 +70,23 @@ class _TestEncoderDecoder:
FMT = None

def _test_round_trip(self, obj):
self.assertEqual(obj, loads(dumps(obj, self.FMT), self.FMT))

dumped = dumps(obj, self.FMT)

# dumps / loads
self.assertEqual(obj, loads(dumped, self.FMT))

buf = io.BytesIO()
dump(obj, buf, self.FMT)

# dump / dumps
self.assertEqual(dumped, buf.getvalue())

buf.seek(0)
# dump / load
self.assertEqual(obj, load(buf, self.FMT))



def test_format_from_ext(self):
if ':' in self.FMT:
Expand Down

0 comments on commit 8d34107

Please sign in to comment.