Skip to content

Commit

Permalink
fix: add map_location to torch load
Browse files Browse the repository at this point in the history
  • Loading branch information
eduardocarvp authored and Optimox committed Oct 12, 2020
1 parent 5a01359 commit c2b560e
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion pytorch_tabnet/abstract_model.py
Expand Up @@ -313,7 +313,7 @@ def load_model(self, filepath):
loaded_params = json.load(f)
with z.open("network.pt") as f:
try:
saved_state_dict = torch.load(f)
saved_state_dict = torch.load(f, map_location=self.device)
except io.UnsupportedOperation:
# In Python <3.7, the returned file object is not seekable (which at least
# some versions of PyTorch require) - so we'll try buffering it in to a
Expand Down

0 comments on commit c2b560e

Please sign in to comment.