From 247d90ebd4f6b23682ad596c4dca89b1d0338a95 Mon Sep 17 00:00:00 2001 From: Shyue Ping Ong Date: Sun, 7 May 2023 19:57:29 -0700 Subject: [PATCH] Make naming consistent with numpy. --- monty/json.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/monty/json.py b/monty/json.py index 6fbb1e46..be7543a7 100644 --- a/monty/json.py +++ b/monty/json.py @@ -295,7 +295,7 @@ def default(self, o) -> dict: # pylint: disable=E0202 d = { "@module": "torch", "@class": "Tensor", - "type": o.type(), + "dtype": o.type(), } if "Complex" in o.type(): d["data"] = [o.real.tolist(), o.imag.tolist()] @@ -451,11 +451,11 @@ def process_decoded(self, d): d = {k: self.process_decoded(v) for k, v in data.items()} return cls_(**d) elif torch is not None and modname == "torch" and classname == "Tensor": - if "Complex" in d["type"]: + if "Complex" in d["dtype"]: return torch.tensor( [np.array(r) + np.array(i) * 1j for r, i in zip(*d["data"])], - ).type(d["type"]) - return torch.tensor(d["data"]).type(d["type"]) # pylint: disable=E1101 + ).type(d["dtype"]) + return torch.tensor(d["data"]).type(d["dtype"]) # pylint: disable=E1101 elif np is not None and modname == "numpy" and classname == "array": if d["dtype"].startswith("complex"): return np.array(