Skip to content

Commit

Permalink
Tests for dumping and loading to file by name
Browse files Browse the repository at this point in the history
  • Loading branch information
hgrecco committed Jan 28, 2016
1 parent ac8b895 commit e518775
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 13 deletions.
19 changes: 9 additions & 10 deletions serialize/all.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ def traverse_and_decode(obj, decode_func=None, trav_dict=None):
def register_format(fmt, dumpser=None, loadser=None, dumper=None, loader=None, extension=MISSING):
"""Register an available serialization format.
`fmt` is a unique string identifying the format, such as `json`.
`fmt` is a unique string identifying the format, such as `json`. Use a colon (`:`) to
separate between subformats.
`dumpser` and `dumper` should be callables with the same purpose and arguments
that `json.dumps` and `json.dump`. If one of those is missing, it will be
Expand All @@ -217,9 +218,8 @@ def register_format(fmt, dumpser=None, loadser=None, dumper=None, loader=None, e
generated automatically from the other.
`extension` is the file extension used to guess the desired serialization format when loading
from or dumping to a file. If not given, `fmt` will be used.
If `None`, the format will not be associated with any extension (use for serialization formats
with multiple subformats in which you want to select a default).
from or dumping to a file. If not given, the part before the colon of `fmt` will be used.
If `None`, the format will not be associated with any extension.
"""

# For simplicity. We do not allow to overwrite format.
Expand Down Expand Up @@ -259,11 +259,11 @@ def raiser(*args, **kwargs):
loader = loadser = raiser

if extension is MISSING:
extension = fmt
extension = fmt.split(':', 1)[0]

FORMATS[fmt] = Format(extension, dumper, dumpser, loader, loadser)

if extension:
if extension and extension not in FORMAT_BY_EXTENSION:
FORMAT_BY_EXTENSION[extension.lower()] = fmt


Expand All @@ -277,13 +277,12 @@ def register_unavailable(fmt, msg='', pkg='', extension=MISSING):
if pkg:
msg = 'This serialization format requires the %s package.' % pkg


if extension is MISSING:
extension = fmt
extension = fmt.split(':', 1)[0]

UNAVAILABLE_FORMATS[fmt] = UnavailableFormat(extension, msg)

if extension:
if extension and extension not in FORMAT_BY_EXTENSION:
FORMAT_BY_EXTENSION[extension.lower()] = fmt


Expand Down Expand Up @@ -329,7 +328,7 @@ def load(filename_or_file, fmt=None):
if fmt is None:
_, ext = splitext(filename_or_file)
fmt = _get_format_from_ext(ext.strip('.'))
with open(filename_or_file, 'wb') as fp:
with open(filename_or_file, 'rb') as fp:
return load(fp, fmt)

return _get_format(fmt).load(filename_or_file)
Expand Down
2 changes: 1 addition & 1 deletion serialize/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@ def loads(content):
# The first (default) is compact, the second is pretty.

all.register_format('json', dumps, loads)
all.register_format('json:pretty', dumps_pretty, loads, extension=None)
all.register_format('json:pretty', dumps_pretty, loads)
23 changes: 21 additions & 2 deletions serialize/testsuite/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from unittest import TestCase, skipIf

from serialize import register_class, loads, dumps, load, dump
from serialize.all import (FORMATS, UNAVAILABLE_FORMATS, _get_format_from_ext,
from serialize.all import (FORMATS, UNAVAILABLE_FORMATS,
_get_format_from_ext, _get_format,
register_format)


Expand Down Expand Up @@ -102,7 +103,25 @@ def _test_round_trip(self, obj):
# dump / load
self.assertEqual(obj, load(buf, self.FMT))


def test_file_by_name(self):
fh = _get_format(self.FMT)
obj = dict(answer=42)

filename1 = 'tmp.' + fh.extension
dump(obj, filename1)
try:
obj1 = load(filename1)
self.assertEqual(obj, obj1)
finally:
os.remove(filename1)

filename2 = 'tmp.' + fh.extension + '.bla'
dump(obj, filename2, fmt=self.FMT)
try:
obj2 = load(filename2, fmt=self.FMT)
self.assertEqual(obj, obj2)
finally:
os.remove(filename2)

def test_format_from_ext(self):
if ':' in self.FMT:
Expand Down

0 comments on commit e518775

Please sign in to comment.