Skip to content

Commit

Permalink
Make naming consistent with numpy.
Browse files Browse the repository at this point in the history
  • Loading branch information
Shyue Ping Ong committed May 8, 2023
1 parent f1c7dc5 commit 247d90e
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions monty/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def default(self, o) -> dict: # pylint: disable=E0202
d = {
"@module": "torch",
"@class": "Tensor",
"type": o.type(),
"dtype": o.type(),
}
if "Complex" in o.type():
d["data"] = [o.real.tolist(), o.imag.tolist()]
Expand Down Expand Up @@ -451,11 +451,11 @@ def process_decoded(self, d):
d = {k: self.process_decoded(v) for k, v in data.items()}
return cls_(**d)
elif torch is not None and modname == "torch" and classname == "Tensor":
if "Complex" in d["type"]:
if "Complex" in d["dtype"]:
return torch.tensor(
[np.array(r) + np.array(i) * 1j for r, i in zip(*d["data"])],
).type(d["type"])
return torch.tensor(d["data"]).type(d["type"]) # pylint: disable=E1101
).type(d["dtype"])
return torch.tensor(d["data"]).type(d["dtype"]) # pylint: disable=E1101
elif np is not None and modname == "numpy" and classname == "array":
if d["dtype"].startswith("complex"):
return np.array(
Expand Down

0 comments on commit 247d90e

Please sign in to comment.