Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
Fixing issue loading checkpoint #48
Browse files Browse the repository at this point in the history
  • Loading branch information
adefossez committed Feb 8, 2021
1 parent dcef2cb commit 01cc7f8
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion denoiser/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,12 @@ def get_model(args):
if args.model_path:
logger.info("Loading model from %s", args.model_path)
pkg = torch.load(args.model_path)
model = deserialize_model(pkg)
if 'model' in pkg:
if 'best_state' in pkg:
pkg['model']['state'] = pkg['best_state']
model = deserialize_model(pkg['model'])
else:
model = deserialize_model(pkg)
elif args.dns64:
logger.info("Loading pre-trained real time H=64 model trained on DNS.")
model = dns64()
Expand Down

0 comments on commit 01cc7f8

Please sign in to comment.