diff --git a/serialize/all.py b/serialize/all.py index b10c87b..9798946 100644 --- a/serialize/all.py +++ b/serialize/all.py @@ -12,7 +12,7 @@ from collections import namedtuple from io import BytesIO -from os.path import splitext +from pathlib import Path #: Stores the functions to convert custom classes to and from builtin types. ClassHelper = namedtuple('ClassHelper', 'to_builtin from_builtin') @@ -293,21 +293,23 @@ def dumps(obj, fmt): return _get_format(fmt).dumps(obj) -def dump(obj, filename_or_file, fmt=None): +def dump(obj, file, fmt=None): """Serialize `obj` to a file using the format specified by `fmt` The file can be specified by a file-like object or filename. In the latter case the fmt is not need if it can be guessed from the extension. """ + if isinstance(file, str): + file = Path(file) - if isinstance(filename_or_file, str): + if isinstance(file, Path): if fmt is None: - _, ext = splitext(filename_or_file) - fmt = _get_format_from_ext(ext.strip('.')) - with open(filename_or_file, 'wb') as fp: + fmt = _get_format_from_ext(file.suffix.lstrip('.')) + + with open(file, 'wb') as fp: dump(obj, fp, fmt) else: - _get_format(fmt).dump(obj, filename_or_file) + _get_format(fmt).dump(obj, file) def loads(serialized, fmt): @@ -317,21 +319,23 @@ def loads(serialized, fmt): return _get_format(fmt).loads(serialized) -def load(filename_or_file, fmt=None): +def load(file, fmt=None): """Deserialize from a file using the format specified by `fmt` The file can be specified by a file-like object or filename. In the latter case the fmt is not need if it can be guessed from the extension. """ + if isinstance(file, str): + file = Path(file) - if isinstance(filename_or_file, str): + if isinstance(file, Path): if fmt is None: - _, ext = splitext(filename_or_file) - fmt = _get_format_from_ext(ext.strip('.')) - with open(filename_or_file, 'rb') as fp: - return load(fp, fmt) + fmt = _get_format_from_ext(file.suffix.lstrip('.')) - return _get_format(fmt).load(filename_or_file) + with open(file, 'rb') as fp: + return load(fp, fmt) + else: + return _get_format(fmt).load(file) def register_class(klass, to_builtin, from_builtin): diff --git a/serialize/testsuite/test_basic.py b/serialize/testsuite/test_basic.py index 0b65420..9411768 100644 --- a/serialize/testsuite/test_basic.py +++ b/serialize/testsuite/test_basic.py @@ -1,6 +1,7 @@ import io import os +from pathlib import Path from unittest import TestCase, skipIf from serialize import register_class, loads, dumps, load, dump @@ -123,6 +124,14 @@ def test_file_by_name(self): finally: os.remove(filename2) + filename3 = Path("tmp." + fh.extension) + dump(obj, filename3) + try: + obj1 = load(filename3) + self.assertEqual(obj, obj1) + finally: + os.remove(filename3) + def test_format_from_ext(self): if ':' in self.FMT: return