Skip to content

Commit

Permalink
fixed serialization bug
Browse files Browse the repository at this point in the history
  • Loading branch information
jmmshn committed Jan 29, 2024
1 parent cb1ac52 commit a42f0d6
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 9 deletions.
10 changes: 7 additions & 3 deletions monty/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,17 +412,21 @@ def default(self, o) -> dict: # pylint: disable=E0202
try:
if pydantic is not None and isinstance(o, pydantic.BaseModel):
d = o.dict()
elif isinstance(o, Enum):
d = {"value": o.value}
elif (
dataclasses is not None
and (not issubclass(o.__class__, MSONable))
and dataclasses.is_dataclass(o)
):
# This handles dataclasses that are not subclasses of MSONAble.
d = dataclasses.asdict(o)
else:
elif hasattr(o, "as_dict"):
d = o.as_dict()
elif isinstance(o, Enum):
d = {"value": o.value}
else:
raise TypeError(
f"Object of type {o.__class__.__name__} is not JSON serializable"
)

if "@module" not in d:
d["@module"] = str(o.__class__.__module__)
Expand Down
30 changes: 24 additions & 6 deletions tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@
test_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_files")


class A(Enum):
name_a = "value_a"
name_b = "value_b"


class GoodMSONClass(MSONable):
def __init__(self, a, b, c, d=1, *values, **kwargs):
self.a = a
Expand Down Expand Up @@ -105,6 +100,23 @@ def my_callable(a, b):
return a + b


class EnumNoAsDict(Enum):
name_a = "value_a"
name_b = "value_b"


class EnumAsDict(Enum):
name_a = "value_a"
name_b = "value_b"

def as_dict(self):
return {"v": self.value}

@classmethod
def from_dict(cls, d):
return cls(d["v"])


class EnumTest(MSONable, Enum):
a = 1
b = 2
Expand Down Expand Up @@ -815,7 +827,13 @@ def test_dataclass(self):
assert isinstance(ndc2, NestedDataClass)

def test_enum(self):
s = MontyEncoder().encode(A.name_a)
s = MontyEncoder().encode(EnumNoAsDict.name_a)
p = MontyDecoder().decode(s)
assert p.name == "name_a"
assert p.value == "value_a"

na1 = EnumAsDict.name_a
d_ = na1.as_dict()
assert d_ == {"v": "value_a"}
na2 = EnumAsDict.from_dict(d_)
assert na2 == na1

0 comments on commit a42f0d6

Please sign in to comment.