Skip to content

Commit

Permalink
add map_location to torch.load to make it work when cuda is unavailable
Browse files Browse the repository at this point in the history
  • Loading branch information
cning112 committed Dec 28, 2021
1 parent a0f49fe commit a3859bd
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion qlib/contrib/model/pytorch_gats.py
Expand Up @@ -260,7 +260,7 @@ def fit(

if self.model_path is not None:
self.logger.info("Loading pretrained model...")
pretrained_model.load_state_dict(torch.load(self.model_path))
pretrained_model.load_state_dict(torch.load(self.model_path, map_location=self.device))

model_dict = self.GAT_model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
Expand Down
2 changes: 1 addition & 1 deletion qlib/contrib/model/pytorch_gats_ts.py
Expand Up @@ -276,7 +276,7 @@ def fit(

if self.model_path is not None:
self.logger.info("Loading pretrained model...")
pretrained_model.load_state_dict(torch.load(self.model_path))
pretrained_model.load_state_dict(torch.load(self.model_path, map_location=self.device))

model_dict = self.GAT_model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
Expand Down
4 changes: 2 additions & 2 deletions qlib/contrib/model/pytorch_nn.py
Expand Up @@ -257,7 +257,7 @@ def fit(
self.scheduler.step(cur_loss_val)

# restore the optimal parameters after training
self.dnn_model.load_state_dict(torch.load(save_path))
self.dnn_model.load_state_dict(torch.load(save_path, map_location=self.device))
if self.use_gpu:
torch.cuda.empty_cache()

Expand Down Expand Up @@ -296,7 +296,7 @@ def load(self, buffer, **kwargs):
]
_model_path = os.path.join(model_dir, _model_name)
# Load model
self.dnn_model.load_state_dict(torch.load(_model_path))
self.dnn_model.load_state_dict(torch.load(_model_path, map_location=self.device))
self.fitted = True


Expand Down
2 changes: 1 addition & 1 deletion qlib/contrib/model/pytorch_tabnet.py
Expand Up @@ -160,7 +160,7 @@ def fit(
self.logger.info("Pretrain...")
self.pretrain_fn(dataset, self.pretrain_file)
self.logger.info("Load Pretrain model")
self.tabnet_model.load_state_dict(torch.load(self.pretrain_file))
self.tabnet_model.load_state_dict(torch.load(self.pretrain_file, map_location=self.device))

# adding one more linear layer to fit the final output dimension
self.tabnet_model = FinetuneModel(self.out_dim, self.final_out_dim, self.tabnet_model).to(self.device)
Expand Down
4 changes: 2 additions & 2 deletions qlib/contrib/model/pytorch_tcts.py
Expand Up @@ -350,9 +350,9 @@ def training(
break

print("best loss:", best_loss, "@", best_epoch)
best_param = torch.load(save_path + "_fore_model.bin")
best_param = torch.load(save_path + "_fore_model.bin", map_location=self.device)
self.fore_model.load_state_dict(best_param)
best_param = torch.load(save_path + "_weight_model.bin")
best_param = torch.load(save_path + "_weight_model.bin", map_location=self.device)
self.weight_model.load_state_dict(best_param)
self.fitted = True

Expand Down

0 comments on commit a3859bd

Please sign in to comment.