diff --git a/serialize/dill.py b/serialize/dill.py index 5c2eae2..e1d6987 100644 --- a/serialize/dill.py +++ b/serialize/dill.py @@ -12,40 +12,17 @@ """ from . import all +from . import pickle try: import dill - import copyreg except ImportError: all.register_unavailable('dill', pkg='dill') raise - -class DispatchTable: - - def __getitem__(self, item): - if item in all.CLASSES: - return lambda obj: (all.CLASSES[item].from_builtin, - (all.CLASSES[item].to_builtin(obj),), - None, None, None) - - return copyreg.dispatch_table[item] - - def __setitem__(self, key, value): - copyreg.dispatch_table[key] = value - - def get(self, key, default=None): - if key in all.CLASSES: - return lambda obj: (all.CLASSES[key].from_builtin, - (all.CLASSES[key].to_builtin(obj),), - None, None, None) - - return copyreg.dispatch_table.get(key, default) - - class MyPickler(dill.Pickler): - dispatch_table = DispatchTable() + dispatch_table = pickle.DispatchTable() def dump(obj, fp): diff --git a/serialize/pickle.py b/serialize/pickle.py index 300a986..d5f0428 100644 --- a/serialize/pickle.py +++ b/serialize/pickle.py @@ -11,6 +11,8 @@ :license: BSD, see LICENSE for more details. """ +import collections + from . import all try: @@ -21,7 +23,7 @@ raise -class DispatchTable: +class DispatchTable(collections.MutableMapping): def __getitem__(self, item): if item in all.CLASSES: @@ -34,13 +36,14 @@ def __getitem__(self, item): def __setitem__(self, key, value): copyreg.dispatch_table[key] = value - def get(self, key, default=None): - if key in all.CLASSES: - return lambda obj: (all.CLASSES[key].from_builtin, - (all.CLASSES[key].to_builtin(obj),), - None, None, None) + def __delitem__(self, key): + del copyreg.dispatch_table[key] + + def __iter__(self): + return copyreg.dispatch_table.__iter__() - return copyreg.dispatch_table.get(key, default) + def __len__(self): + return copyreg.dispatch_table.__len__() class MyPickler(pickle.Pickler):