Skip to content

Commit

Permalink
Improve torch support.
Browse files Browse the repository at this point in the history
  • Loading branch information
Shyue Ping Ong committed May 8, 2023
1 parent 2b1a524 commit bb1df75
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
4 changes: 2 additions & 2 deletions monty/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def default(self, o) -> dict: # pylint: disable=E0202
return {
"@module": "torch",
"@class": "Tensor",
"dtype": str(o.dtype),
"type": o.type(),
"data": o.cpu().detach().numpy().tolist(),
}

Expand Down Expand Up @@ -447,7 +447,7 @@ def process_decoded(self, d):
d = {k: self.process_decoded(v) for k, v in data.items()}
return cls_(**d)
elif torch is not None and modname == "torch" and classname == "Tensor":
return torch.tensor(d["data"]) # pylint: disable=E1101
return torch.tensor(d["data"]).type(d["type"]) # pylint: disable=E1101
elif np is not None and modname == "numpy" and classname == "array":
if d["dtype"].startswith("complex"):
return np.array(
Expand Down
1 change: 1 addition & 0 deletions tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ def test_torch(self):
jsonstr = json.dumps(t, cls=MontyEncoder)
t2 = json.loads(jsonstr, cls=MontyDecoder)
self.assertEqual(type(t2), torch.Tensor)
self.assertEqual(t2.type(), t.type())
self.assertTrue(np.array_equal(t2, t))

def test_datetime(self):
Expand Down

0 comments on commit bb1df75

Please sign in to comment.