Skip to content

Commit

Permalink
Merge pull request #654 from matthewcarbone/mc-dev-653
Browse files Browse the repository at this point in the history
Add save and load functionality to MSONable
  • Loading branch information
shyuep committed Apr 11, 2024
2 parents 2a2391b + d619737 commit dbaba61
Show file tree
Hide file tree
Showing 2 changed files with 246 additions and 14 deletions.
179 changes: 165 additions & 14 deletions monty/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
import os
import pathlib
import pickle
import traceback
import types
from collections import OrderedDict, defaultdict
Expand All @@ -14,7 +15,8 @@
from importlib import import_module
from inspect import getfullargspec
from pathlib import Path
from uuid import UUID
from typing import Any, Dict
from uuid import UUID, uuid4

try:
import numpy as np
Expand Down Expand Up @@ -139,12 +141,12 @@ class MSONable:
class MSONClass(MSONable):
def __init__(self, a, b, c, d=1, **kwargs):
self.a = a
self.b = b
self._c = c
self._d = d
self.kwargs = kwargs
def __init__(self, a, b, c, d=1, **kwargs):
self.a = a
self.b = b
self._c = c
self._d = d
self.kwargs = kwargs
For such classes, you merely need to inherit from MSONable and you do not
need to implement your own as_dict or from_dict protocol.
Expand Down Expand Up @@ -222,17 +224,27 @@ def recursive_as_dict(obj):
d.update({"value": self.value}) # pylint: disable=E1101
return d

@staticmethod
def decoded_from_dict(d, name_object_map):
decoder = MontyDecoder()
decoder._name_object_map = name_object_map
decoded = {
k: decoder.process_decoded(v) for k, v in d.items() if not k.startswith("@")
}
return decoded

@classmethod
def _from_dict(cls, d, name_object_map):
decoded = MSONable.decoded_from_dict(d, name_object_map=name_object_map)
return cls(**decoded)

@classmethod
def from_dict(cls, d):
"""
:param d: Dict representation.
:return: MSONable class.
"""
decoded = {
k: MontyDecoder().process_decoded(v)
for k, v in d.items()
if not k.startswith("@")
}
decoded = MSONable.decoded_from_dict(d, name_object_map=None)
return cls(**decoded)

def to_json(self) -> str:
Expand All @@ -241,6 +253,111 @@ def to_json(self) -> str:
"""
return json.dumps(self, cls=MontyEncoder)

def save(
self,
save_dir=None,
mkdir=True,
pickle_kwargs=None,
json_kwargs=None,
return_results=False,
strict=True,
):
"""Utility that uses the standard tools of MSONable to convert the
class to json format, but also save it to disk. In addition, this
method intelligently uses pickle to individually pickle class objects
that are not serializable, saving them separately. This maximizes the
readability of the saved class information while allowing _any_
class to be at least partially serializable to disk.
For a fully MSONable class, only a class.json file will be saved to
the location {save_dir}/class.json. For a partially MSONable class,
additional information will be saved to the save directory at
{save_dir}. This includes a pickled object for each attribute that
e serialized.
Parameters
----------
save_dir : os.PathLike
The directory to which to save the class information.
mkdir : bool
If True, makes the provided directory, including all parent
directories.
pickle_kwargs : dict
Keyword arguments to pass to pickle.dump.
json_kwargs : dict
Keyword arguments to pass to the serializer.
return_results : bool
If True, also returns the dictionary to save to disk, as well
as the mapping between the object_references and the objects
themselves.
strict : bool
If True, will not allow you to overwrite existing files.
Returns
-------
None or tuple
"""

if save_dir is None and not return_results:
raise ValueError("save_dir must be set and/or return_results must be True")

if pickle_kwargs is None:
pickle_kwargs = {}
if json_kwargs is None:
json_kwargs = {}
encoder = MontyEncoder(allow_unserializable_objects=True, **json_kwargs)
encoded = encoder.encode(self)

if save_dir is not None:
save_dir = Path(save_dir)
if mkdir:
save_dir.mkdir(exist_ok=True, parents=True)
json_path = save_dir / "class.json"
pickle_path = save_dir / "class.pkl"
if strict and json_path.exists():
raise FileExistsError(f"strict is true and file {json_path} exists")
if strict and pickle_path.exists():
raise FileExistsError(f"strict is true and file {pickle_path} exists")

with open(json_path, "w") as outfile:
outfile.write(encoded)
pickle.dump(
encoder._name_object_map,
open(pickle_path, "wb"),
**pickle_kwargs,
)

if return_results:
return encoded, encoder._name_object_map

@classmethod
def load(cls, load_dir):
"""Loads a class from a provided {load_dir}/class.json and
{load_dir}/class.pkl file (if necessary).
Parameters
----------
load_dir : os.PathLike
The directory from which to reload the class from.
Returns
-------
MSONable
An instance of the class being reloaded.
"""

load_dir = Path(load_dir)

json_path = load_dir / "class.json"
pickle_path = load_dir / "class.pkl"

with open(json_path, "r") as infile:
d = json.loads(infile.read())
name_object_map = pickle.load(open(pickle_path, "rb"))
decoded = MSONable.decoded_from_dict(d, name_object_map)
klass = cls(**decoded)
return klass

def unsafe_hash(self):
"""
Returns an hash of the current object. This uses a generic but low
Expand Down Expand Up @@ -365,6 +482,18 @@ class MontyEncoder(json.JSONEncoder):
json.dumps(object, cls=MontyEncoder)
"""

def __init__(self, *args, allow_unserializable_objects=False, **kwargs):
super().__init__(*args, **kwargs)
self._track_unserializable_objects = allow_unserializable_objects
self._name_object_map: Dict[str, Any] = {}
self._index = 0

def _update_name_object_map(self, o):
name = f"{self._index:012}-{str(uuid4())}"
self._index += 1
self._name_object_map[name] = o
return {"@object_reference": name}

def default(self, o) -> dict: # pylint: disable=E0202
"""
Overriding default method for JSON encoding. This method does two
Expand Down Expand Up @@ -432,7 +561,13 @@ def default(self, o) -> dict: # pylint: disable=E0202
return {"@module": "bson.objectid", "@class": "ObjectId", "oid": str(o)}

if callable(o) and not isinstance(o, MSONable):
return _serialize_callable(o)
try:
return _serialize_callable(o)
except AttributeError as e:
# Some callables may not have instance __name__
if self._track_unserializable_objects:
return self._update_name_object_map(o)
raise AttributeError(e)

try:
if pydantic is not None and isinstance(o, pydantic.BaseModel):
Expand All @@ -448,6 +583,11 @@ def default(self, o) -> dict: # pylint: disable=E0202
d = o.as_dict()
elif isinstance(o, Enum):
d = {"value": o.value}
elif self._track_unserializable_objects:
# Last resort logic. We keep track of some name of the object
# as a reference, and instead of the object, store that
# name, which of course is json-serializable
d = self._update_name_object_map(o)
else:
raise TypeError(
f"Object of type {o.__class__.__name__} is not JSON serializable"
Expand Down Expand Up @@ -484,13 +624,19 @@ class MontyDecoder(json.JSONDecoder):
json.loads(json_string, cls=MontyDecoder)
"""

_name_object_map = None

def process_decoded(self, d):
"""
Recursive method to support decoding dicts and lists containing
pymatgen objects.
"""

if isinstance(d, dict):
if "@module" in d and "@class" in d:
if "@object_reference" in d and self._name_object_map is not None:
name = d["@object_reference"]
return self._name_object_map.pop(name)
elif "@module" in d and "@class" in d:
modname = d["@module"]
classname = d["@class"]
if cls_redirect := MSONable.REDIRECT.get(modname, {}).get(classname):
Expand Down Expand Up @@ -553,6 +699,11 @@ def process_decoded(self, d):
if hasattr(mod, classname):
cls_ = getattr(mod, classname)
data = {k: v for k, v in d.items() if not k.startswith("@")}
if hasattr(cls_, "_from_dict"):
# New functionality with save/load requires this
return cls_._from_dict(
data, name_object_map=self._name_object_map
)
if hasattr(cls_, "from_dict"):
return cls_.from_dict(data)
if issubclass(cls_, Enum):
Expand Down
81 changes: 81 additions & 0 deletions tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,29 @@ def __eq__(self, other):
)


class GoodNOTMSONClass:
"""Literally the same as the GoodMSONClass, except it does not have
the MSONable inheritance!"""

def __init__(self, a, b, c, d=1, *values, **kwargs):
self.a = a
self.b = b
self._c = c
self._d = d
self.values = values
self.kwargs = kwargs

def __eq__(self, other):
return (
self.a == other.a
and self.b == other.b
and self._c == other._c
and self._d == other._d
and self.kwargs == other.kwargs
and self.values == other.values
)


class LimitedMSONClass(MSONable):
"""An MSONable class that only accepts a limited number of options"""

Expand Down Expand Up @@ -367,6 +390,64 @@ def test_enum_serialization_no_msonable(self):
f = jsanitize(d, enum_values=True)
assert f["123"] == "value_a"

def test_save_load(self, tmp_path):
"""Tests the save and load serialization methods."""

test_good_class = GoodMSONClass(
"Hello",
"World",
"Python",
**{
"cant_serialize_me": GoodNOTMSONClass(
"Hello2", "World2", "Python2", **{"values": []}
),
"cant_serialize_me2": [
GoodNOTMSONClass("Hello4", "World4", "Python4", **{"values": []}),
GoodNOTMSONClass("Hello4", "World4", "Python4", **{"values": []}),
],
"cant_serialize_me3": [
{
"tmp": GoodMSONClass(
"Hello5", "World5", "Python5", **{"values": []}
),
"tmp2": 2,
"tmp3": [1, 2, 3],
},
{
"tmp5": GoodNOTMSONClass(
"aHello5", "aWorld5", "aPython5", **{"values": []}
),
"tmp2": 5,
"tmp3": {"test": "test123"},
},
# Gotta check that if I hide an MSONable class somewhere
# it still gets correctly serialized.
{"actually_good": GoodMSONClass("1", "2", "3", **{"values": []})},
],
"values": [],
},
)

# This will pass
test_good_class.as_dict()

# This will fail
with pytest.raises(TypeError):
test_good_class.to_json()

# This should also pass though
target = tmp_path / "test_dir123"
test_good_class.save(target, json_kwargs={"indent": 4, "sort_keys": True})

# This will fail
with pytest.raises(FileExistsError):
test_good_class.save(target, strict=True)

# Now check that reloading this, the classes are equal!
test_good_class2 = GoodMSONClass.load(target)

assert test_good_class == test_good_class2


class TestJson:
def test_as_from_dict(self):
Expand Down

0 comments on commit dbaba61

Please sign in to comment.