Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support save snapshot by iteration #1204

Merged
merged 7 commits into from
Sep 25, 2019
Merged

support save snapshot by iteration #1204

merged 7 commits into from
Sep 25, 2019

Conversation

fanlu
Copy link
Contributor

@fanlu fanlu commented Sep 19, 2019

the epoch num may not be very large on large dataset training because of time limit. And average_checkpoints is very important to achieve lower CER. So I suspect that it's necessary to support save snapshot when iterations(10000,20000 etc.).
Any idea will be appreciate, thanks

Copy link
Member

@kan-bayashi kan-bayashi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for late review.

@@ -428,11 +429,11 @@ def train(args):
# we used an empty collate function instead which returns list
train_iter = {'main': ChainerDataLoader(
dataset=TransformDataset(train, lambda data: converter([load_tr(data)])),
batch_size=1, num_workers=args.n_iter_processes,
batch_size=1, num_workers=args.n_iter_processes, pin_memory=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be better not to include this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. I will delete it

shuffle=not use_sortagrad, collate_fn=lambda x: x[0])}
valid_iter = {'main': ChainerDataLoader(
dataset=TransformDataset(valid, lambda data: converter([load_cv(data)])),
batch_size=1, shuffle=False, collate_fn=lambda x: x[0],
batch_size=1, pin_memory=True, shuffle=False, collate_fn=lambda x: x[0],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@@ -284,6 +284,21 @@ def torch_snapshot(trainer):
return torch_snapshot


def torch_snapshot_iter(savefun=torch.save,
filename='snapshot.iter.{.updater.iteration}'):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that it is not necessary to make a new function.
Just reuse torch_snapshot is fine.

@@ -490,6 +494,8 @@ def train(args):

# save snapshot which contains model and optimizer states
trainer.extend(torch_snapshot(), trigger=(1, 'epoch'))
if args.save_interval_iters > 0:
trainer.extend(torch_snapshot_iter(), trigger=(args.save_interval_iters, 'iteration'))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't you reuse torch_snapshot instead of torch_snapshot_iter?

Suggested change
trainer.extend(torch_snapshot_iter(), trigger=(args.save_interval_iters, 'iteration'))
trainer.extend(torch_snapshot(filename='snapshot.iter.{.updater.iteration}'),
trigger=(args.save_interval_iters, 'iteration'))

@codecov
Copy link

codecov bot commented Sep 23, 2019

Codecov Report

Merging #1204 into v.0.6.0 will increase coverage by <.01%.
The diff coverage is 100%.

Impacted file tree graph

@@             Coverage Diff             @@
##           v.0.6.0    #1204      +/-   ##
===========================================
+ Coverage    78.32%   78.32%   +<.01%     
===========================================
  Files          100      100              
  Lines         9295     9296       +1     
===========================================
+ Hits          7280     7281       +1     
  Misses        2015     2015
Impacted Files Coverage Δ
espnet/bin/asr_train.py 64.53% <100%> (+0.2%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 1d62a17...0a4058b. Read the comment docs.

@sw005320
Copy link
Contributor

Thanks a lot!

@sw005320 sw005320 merged commit 2626d0c into espnet:v.0.6.0 Sep 25, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants