From a42f0d683f3fe63d41a3f0cfffadb90fef82d85b Mon Sep 17 00:00:00 2001 From: "@jmmshn" Date: Sun, 28 Jan 2024 23:34:45 -0800 Subject: [PATCH] fixed serialization bug --- monty/json.py | 10 +++++++--- tests/test_json.py | 30 ++++++++++++++++++++++++------ 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/monty/json.py b/monty/json.py index 68d46ca2..a704754c 100644 --- a/monty/json.py +++ b/monty/json.py @@ -412,8 +412,6 @@ def default(self, o) -> dict: # pylint: disable=E0202 try: if pydantic is not None and isinstance(o, pydantic.BaseModel): d = o.dict() - elif isinstance(o, Enum): - d = {"value": o.value} elif ( dataclasses is not None and (not issubclass(o.__class__, MSONable)) @@ -421,8 +419,14 @@ def default(self, o) -> dict: # pylint: disable=E0202 ): # This handles dataclasses that are not subclasses of MSONAble. d = dataclasses.asdict(o) - else: + elif hasattr(o, "as_dict"): d = o.as_dict() + elif isinstance(o, Enum): + d = {"value": o.value} + else: + raise TypeError( + f"Object of type {o.__class__.__name__} is not JSON serializable" + ) if "@module" not in d: d["@module"] = str(o.__class__.__module__) diff --git a/tests/test_json.py b/tests/test_json.py index e3cba850..16ba2d6c 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -20,11 +20,6 @@ test_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_files") -class A(Enum): - name_a = "value_a" - name_b = "value_b" - - class GoodMSONClass(MSONable): def __init__(self, a, b, c, d=1, *values, **kwargs): self.a = a @@ -105,6 +100,23 @@ def my_callable(a, b): return a + b +class EnumNoAsDict(Enum): + name_a = "value_a" + name_b = "value_b" + + +class EnumAsDict(Enum): + name_a = "value_a" + name_b = "value_b" + + def as_dict(self): + return {"v": self.value} + + @classmethod + def from_dict(cls, d): + return cls(d["v"]) + + class EnumTest(MSONable, Enum): a = 1 b = 2 @@ -815,7 +827,13 @@ def test_dataclass(self): assert isinstance(ndc2, NestedDataClass) def test_enum(self): - s = MontyEncoder().encode(A.name_a) + s = MontyEncoder().encode(EnumNoAsDict.name_a) p = MontyDecoder().decode(s) assert p.name == "name_a" assert p.value == "value_a" + + na1 = EnumAsDict.name_a + d_ = na1.as_dict() + assert d_ == {"v": "value_a"} + na2 = EnumAsDict.from_dict(d_) + assert na2 == na1