From c2b560e72bc01e34e8dba7578f239e37bbd6782c Mon Sep 17 00:00:00 2001 From: Eduardo Carvalho Date: Tue, 13 Oct 2020 00:02:29 +0200 Subject: [PATCH] fix: add map_location to torch load --- pytorch_tabnet/abstract_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_tabnet/abstract_model.py b/pytorch_tabnet/abstract_model.py index 279f6edf..4ef3f4d4 100644 --- a/pytorch_tabnet/abstract_model.py +++ b/pytorch_tabnet/abstract_model.py @@ -313,7 +313,7 @@ def load_model(self, filepath): loaded_params = json.load(f) with z.open("network.pt") as f: try: - saved_state_dict = torch.load(f) + saved_state_dict = torch.load(f, map_location=self.device) except io.UnsupportedOperation: # In Python <3.7, the returned file object is not seekable (which at least # some versions of PyTorch require) - so we'll try buffering it in to a