diff --git a/monty/json.py b/monty/json.py index 6fbb1e46..be7543a7 100644 --- a/monty/json.py +++ b/monty/json.py @@ -295,7 +295,7 @@ def default(self, o) -> dict: # pylint: disable=E0202 d = { "@module": "torch", "@class": "Tensor", - "type": o.type(), + "dtype": o.type(), } if "Complex" in o.type(): d["data"] = [o.real.tolist(), o.imag.tolist()] @@ -451,11 +451,11 @@ def process_decoded(self, d): d = {k: self.process_decoded(v) for k, v in data.items()} return cls_(**d) elif torch is not None and modname == "torch" and classname == "Tensor": - if "Complex" in d["type"]: + if "Complex" in d["dtype"]: return torch.tensor( [np.array(r) + np.array(i) * 1j for r, i in zip(*d["data"])], - ).type(d["type"]) - return torch.tensor(d["data"]).type(d["type"]) # pylint: disable=E1101 + ).type(d["dtype"]) + return torch.tensor(d["data"]).type(d["dtype"]) # pylint: disable=E1101 elif np is not None and modname == "numpy" and classname == "array": if d["dtype"].startswith("complex"): return np.array(