-
Notifications
You must be signed in to change notification settings - Fork 2.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support save snapshot by iteration #1204
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for late review.
espnet/asr/pytorch_backend/asr.py
Outdated
@@ -428,11 +429,11 @@ def train(args): | |||
# we used an empty collate function instead which returns list | |||
train_iter = {'main': ChainerDataLoader( | |||
dataset=TransformDataset(train, lambda data: converter([load_tr(data)])), | |||
batch_size=1, num_workers=args.n_iter_processes, | |||
batch_size=1, num_workers=args.n_iter_processes, pin_memory=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be better not to include this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. I will delete it
espnet/asr/pytorch_backend/asr.py
Outdated
shuffle=not use_sortagrad, collate_fn=lambda x: x[0])} | ||
valid_iter = {'main': ChainerDataLoader( | ||
dataset=TransformDataset(valid, lambda data: converter([load_cv(data)])), | ||
batch_size=1, shuffle=False, collate_fn=lambda x: x[0], | ||
batch_size=1, pin_memory=True, shuffle=False, collate_fn=lambda x: x[0], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
espnet/asr/asr_utils.py
Outdated
@@ -284,6 +284,21 @@ def torch_snapshot(trainer): | |||
return torch_snapshot | |||
|
|||
|
|||
def torch_snapshot_iter(savefun=torch.save, | |||
filename='snapshot.iter.{.updater.iteration}'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that it is not necessary to make a new function.
Just reuse torch_snapshot
is fine.
espnet/asr/pytorch_backend/asr.py
Outdated
@@ -490,6 +494,8 @@ def train(args): | |||
|
|||
# save snapshot which contains model and optimizer states | |||
trainer.extend(torch_snapshot(), trigger=(1, 'epoch')) | |||
if args.save_interval_iters > 0: | |||
trainer.extend(torch_snapshot_iter(), trigger=(args.save_interval_iters, 'iteration')) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't you reuse torch_snapshot
instead of torch_snapshot_iter
?
trainer.extend(torch_snapshot_iter(), trigger=(args.save_interval_iters, 'iteration')) | |
trainer.extend(torch_snapshot(filename='snapshot.iter.{.updater.iteration}'), | |
trigger=(args.save_interval_iters, 'iteration')) |
Codecov Report
@@ Coverage Diff @@
## v.0.6.0 #1204 +/- ##
===========================================
+ Coverage 78.32% 78.32% +<.01%
===========================================
Files 100 100
Lines 9295 9296 +1
===========================================
+ Hits 7280 7281 +1
Misses 2015 2015
Continue to review full report at Codecov.
|
Thanks a lot! |
the epoch num may not be very large on large dataset training because of time limit. And average_checkpoints is very important to achieve lower CER. So I suspect that it's necessary to support save snapshot when iterations(10000,20000 etc.).
Any idea will be appreciate, thanks