Skip to content

Commit

Permalink
Bugfix json.py, make tests more robust
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewcarbone committed Apr 5, 2024
1 parent 51b4d13 commit d619737
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 22 deletions.
11 changes: 11 additions & 0 deletions monty/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,11 @@ def decoded_from_dict(d, name_object_map):
}
return decoded

@classmethod
def _from_dict(cls, d, name_object_map):
decoded = MSONable.decoded_from_dict(d, name_object_map=name_object_map)
return cls(**decoded)

@classmethod
def from_dict(cls, d):
"""
Expand Down Expand Up @@ -626,6 +631,7 @@ def process_decoded(self, d):
Recursive method to support decoding dicts and lists containing
pymatgen objects.
"""

if isinstance(d, dict):
if "@object_reference" in d and self._name_object_map is not None:
name = d["@object_reference"]
Expand Down Expand Up @@ -693,6 +699,11 @@ def process_decoded(self, d):
if hasattr(mod, classname):
cls_ = getattr(mod, classname)
data = {k: v for k, v in d.items() if not k.startswith("@")}
if hasattr(cls_, "_from_dict"):
# New functionality with save/load requires this
return cls_._from_dict(
data, name_object_map=self._name_object_map
)
if hasattr(cls_, "from_dict"):
return cls_.from_dict(data)
if issubclass(cls_, Enum):
Expand Down
31 changes: 9 additions & 22 deletions tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,44 +399,30 @@ def test_save_load(self, tmp_path):
"Python",
**{
"cant_serialize_me": GoodNOTMSONClass(
"Hello2",
"World2",
"Python2",
"Hello2", "World2", "Python2", **{"values": []}
),
"cant_serialize_me2": [
GoodNOTMSONClass(
"Hello4",
"World4",
"Python4",
),
GoodNOTMSONClass(
"Hello4",
"World4",
"Python4",
),
GoodNOTMSONClass("Hello4", "World4", "Python4", **{"values": []}),
GoodNOTMSONClass("Hello4", "World4", "Python4", **{"values": []}),
],
"cant_serialize_me3": [
{
"tmp": GoodNOTMSONClass(
"Hello5",
"World5",
"Python5",
"tmp": GoodMSONClass(
"Hello5", "World5", "Python5", **{"values": []}
),
"tmp2": 2,
"tmp3": [1, 2, 3],
},
{
"tmp5": GoodNOTMSONClass(
"aHello5",
"aWorld5",
"aPython5",
"aHello5", "aWorld5", "aPython5", **{"values": []}
),
"tmp2": 5,
"tmp3": {"test": "test123"},
},
# Gotta check that if I hide an MSONable class somewhere
# it still gets correctly serialized.
{"actually_good": GoodMSONClass("1", "2", "3")},
{"actually_good": GoodMSONClass("1", "2", "3", **{"values": []})},
],
"values": [],
},
Expand All @@ -451,14 +437,15 @@ def test_save_load(self, tmp_path):

# This should also pass though
target = tmp_path / "test_dir123"
test_good_class.save(target)
test_good_class.save(target, json_kwargs={"indent": 4, "sort_keys": True})

# This will fail
with pytest.raises(FileExistsError):
test_good_class.save(target, strict=True)

# Now check that reloading this, the classes are equal!
test_good_class2 = GoodMSONClass.load(target)

assert test_good_class == test_good_class2


Expand Down

0 comments on commit d619737

Please sign in to comment.