Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[Torch classifier agent][bug fix]Fix optimizer loading in classifier agent #4406

Merged
merged 1 commit into from Mar 9, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 9 additions & 3 deletions parlai/core/torch_classifier_agent.py
Expand Up @@ -557,9 +557,15 @@ def __init__(self, opt: Opt, shared=None):
if 'optimizer' in shared:
self.optimizer = shared['optimizer']
elif self._should_initialize_optimizer():
optim_params = [p for p in self.model.parameters() if p.requires_grad]
self.init_optim(optim_params)
self.build_lr_scheduler(states, hard_reset=self.is_finetune)
was_reset = self.init_optim(
[p for p in self.model.parameters() if p.requires_grad],
optim_states=states.get('optimizer'),
saved_optim_type=states.get('optimizer_type'),
is_finetune=self.is_finetune,
)
if was_reset:
logging.warning("Optimizer was reset. Also resetting LR scheduler.")
self.build_lr_scheduler(states, hard_reset=self.is_finetune or was_reset)

def build_criterion(self):
weight_tensor = torch.FloatTensor(self.class_weights)
Expand Down