Skip to content

Commit

Permalink
Add CheckpointManager to keep avg checkpoint weights in memory to red…
Browse files Browse the repository at this point in the history
…uce disk read when averaging + various checkpoint refactoring

Summary: Pull Request resolved: pytorch/translate#315

Reviewed By: akinh

Differential Revision: D13510446

fbshipit-source-id: 22a6594af9253130a93e638285a47183a974e0de
  • Loading branch information
theweiho authored and facebook-github-bot committed Feb 6, 2019
1 parent 829bd8c commit c49c292
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion fairseq/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def save_checkpoint(self, filename, extra_state):
if distributed_utils.is_master(self.args): # only save one checkpoint
extra_state['train_meters'] = self.meters
utils.save_state(
filename, self.args, self.get_model(), self.criterion, self.optimizer,
filename, self.args, self.get_model().state_dict(), self.criterion, self.optimizer,
self.lr_scheduler, self._num_updates, self._optim_history, extra_state,
)

Expand Down
4 changes: 2 additions & 2 deletions fairseq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ def convert_state_dict_type(state_dict, ttype=torch.FloatTensor):
return state_dict


def save_state(filename, args, model, criterion, optimizer, lr_scheduler,
def save_state(filename, args, model_state_dict, criterion, optimizer, lr_scheduler,
num_updates, optim_history=None, extra_state=None):
if optim_history is None:
optim_history = []
if extra_state is None:
extra_state = {}
state_dict = {
'args': args,
'model': model.state_dict() if model else {},
'model': model_state_dict if model_state_dict else {},
'optimizer_history': optim_history + [
{
'criterion_name': criterion.__class__.__name__,
Expand Down

0 comments on commit c49c292

Please sign in to comment.