Skip to content

Commit

Permalink
Added support for pathlib.Path
Browse files Browse the repository at this point in the history
  • Loading branch information
maurosilber committed Aug 22, 2020
1 parent 5891627 commit f8afb94
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 14 deletions.
32 changes: 18 additions & 14 deletions serialize/all.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from collections import namedtuple
from io import BytesIO
from os.path import splitext
from pathlib import Path

#: Stores the functions to convert custom classes to and from builtin types.
ClassHelper = namedtuple('ClassHelper', 'to_builtin from_builtin')
Expand Down Expand Up @@ -293,21 +293,23 @@ def dumps(obj, fmt):
return _get_format(fmt).dumps(obj)


def dump(obj, filename_or_file, fmt=None):
def dump(obj, file, fmt=None):
"""Serialize `obj` to a file using the format specified by `fmt`
The file can be specified by a file-like object or filename.
In the latter case the fmt is not need if it can be guessed from the extension.
"""
if isinstance(file, str):
file = Path(file)

if isinstance(filename_or_file, str):
if isinstance(file, Path):
if fmt is None:
_, ext = splitext(filename_or_file)
fmt = _get_format_from_ext(ext.strip('.'))
with open(filename_or_file, 'wb') as fp:
fmt = _get_format_from_ext(file.suffix.lstrip('.'))

with open(file, 'wb') as fp:
dump(obj, fp, fmt)
else:
_get_format(fmt).dump(obj, filename_or_file)
_get_format(fmt).dump(obj, file)


def loads(serialized, fmt):
Expand All @@ -317,21 +319,23 @@ def loads(serialized, fmt):
return _get_format(fmt).loads(serialized)


def load(filename_or_file, fmt=None):
def load(file, fmt=None):
"""Deserialize from a file using the format specified by `fmt`
The file can be specified by a file-like object or filename.
In the latter case the fmt is not need if it can be guessed from the extension.
"""
if isinstance(file, str):
file = Path(file)

if isinstance(filename_or_file, str):
if isinstance(file, Path):
if fmt is None:
_, ext = splitext(filename_or_file)
fmt = _get_format_from_ext(ext.strip('.'))
with open(filename_or_file, 'rb') as fp:
return load(fp, fmt)
fmt = _get_format_from_ext(file.suffix.lstrip('.'))

return _get_format(fmt).load(filename_or_file)
with open(file, 'rb') as fp:
return load(fp, fmt)
else:
return _get_format(fmt).load(file)


def register_class(klass, to_builtin, from_builtin):
Expand Down
9 changes: 9 additions & 0 deletions serialize/testsuite/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

import io
import os
from pathlib import Path
from unittest import TestCase, skipIf

from serialize import register_class, loads, dumps, load, dump
Expand Down Expand Up @@ -123,6 +124,14 @@ def test_file_by_name(self):
finally:
os.remove(filename2)

filename3 = Path("tmp." + fh.extension)
dump(obj, filename3)
try:
obj1 = load(filename3)
self.assertEqual(obj, obj1)
finally:
os.remove(filename3)

def test_format_from_ext(self):
if ':' in self.FMT:
return
Expand Down

0 comments on commit f8afb94

Please sign in to comment.