Skip to content

Commit

Permalink
fix: CRF save in torch_transformers_sequence_tagger (#1637)
Browse files Browse the repository at this point in the history
Co-authored-by: Fedor Ignatov <ignatov.fedor@gmail.com>
  • Loading branch information
dmitrijeuseew and IgnatovFedor committed Apr 7, 2023
1 parent 3626367 commit e9d0b7a
Showing 1 changed file with 3 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,10 @@ def load(self, fname=None):
log.warning(f"Init from scratch. Load path {weights_path_crf} does not exist.")

def save(self, fname: Optional[str] = None, *args, **kwargs) -> None:
super().save()
super().save(fname)
if self.use_crf:
if fname is None:
fname = self.save_path
weights_path_crf = Path(f"{fname}_crf").resolve()
weights_path_crf = weights_path_crf.with_suffix(".pth.tar")
torch.save({"model_state_dict": self.crf.cpu().state_dict()}, weights_path_crf)
Expand Down

0 comments on commit e9d0b7a

Please sign in to comment.