Skip to content

Commit

Permalink
fix: typing and fname usage in TorchModel.save/load, resolves #1420
Browse files Browse the repository at this point in the history
  • Loading branch information
yurakuratov committed Mar 24, 2021
1 parent 93b6ac8 commit f1efeec
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions deeppavlov/core/models/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from copy import deepcopy
from logging import getLogger
from pathlib import Path
from typing import Optional
from typing import Optional, Union

import torch
from overrides import overrides
Expand Down Expand Up @@ -123,7 +123,7 @@ def init_from_opt(self, model_func: str) -> None:
raise AttributeError("Model is not defined.")

@overrides
def load(self, fname: Optional[str] = None, *args, **kwargs) -> None:
def load(self, fname: Optional[Union[str, Path]] = None, *args, **kwargs) -> None:
"""Load model from `fname` (if `fname` is not given, use `self.load_path`) to `self.model` along with
the optimizer `self.optimizer`, optionally `self.lr_scheduler`.
If `fname` (if `fname` is not given, use `self.load_path`) does not exist, initialize model from scratch.
Expand All @@ -139,15 +139,18 @@ def load(self, fname: Optional[str] = None, *args, **kwargs) -> None:
if fname is not None:
self.load_path = fname

if isinstance(self.load_path, str):
self.load_path = Path(self.load_path)

model_func = getattr(self, self.opt.get("model_name"), None)

if self.load_path:
log.info(f"Load path {self.load_path} is given.")
if isinstance(self.load_path, Path) and not self.load_path.parent.is_dir():
if not self.load_path.parent.is_dir():
raise ConfigError("Provided load path is incorrect!")

weights_path = Path(self.load_path.resolve())
weights_path = weights_path.with_suffix(f".pth.tar")
weights_path = self.load_path.resolve()
weights_path = weights_path.with_suffix(".pth.tar")
if weights_path.exists():
log.info(f"Load path {weights_path} exists.")
log.info(f"Initializing `{self.__class__.__name__}` from saved.")
Expand All @@ -169,7 +172,7 @@ def load(self, fname: Optional[str] = None, *args, **kwargs) -> None:
self.init_from_opt(model_func)

@overrides
def save(self, fname: Optional[str] = None, *args, **kwargs) -> None:
def save(self, fname: Optional[Union[str, Path]] = None, *args, **kwargs) -> None:
"""Save torch model to `fname` (if `fname` is not given, use `self.save_path`). Checkpoint includes
`model_state_dict`, `optimizer_state_dict`, and `epochs_done` (number of training epochs).
Expand All @@ -184,10 +187,11 @@ def save(self, fname: Optional[str] = None, *args, **kwargs) -> None:
if fname is None:
fname = self.save_path

fname = Path(fname)
if not fname.parent.is_dir():
raise ConfigError("Provided save path is incorrect!")

weights_path = Path(fname).with_suffix(f".pth.tar")
weights_path = fname.with_suffix(".pth.tar")
log.info(f"Saving model to {weights_path}.")
# move the model to `cpu` before saving to provide consistency
torch.save({
Expand Down

0 comments on commit f1efeec

Please sign in to comment.