-
Notifications
You must be signed in to change notification settings - Fork 901
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce breakpoint API #1940
Introduce breakpoint API #1940
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for investigating that issue and coming up with a solution.
This is not an in depth review yet, but maybe these comments can already help a bit. Apart from those, I also wondered about the term "breakpoint". I mostly associate it to the context of debugging, which is not what is happening here, but maybe my view is too narrow. In reality, this feature is really a synchronized flag and could in theory be used as a basis for all kinds of features, right?
@BenjaminBossan I'd say it's a conditional trigger setup, more than synchronizing flags. Because also synchronizing flags aren't really true, it's checking if a particular condition should be met that affects all processes which occurred in a singular process. (Or more than one of these processes). Agreed the naming could use some work however, if anyone has some ideas lmk otherwise I'll think on it :) |
True, it's basically an
Always the hard questions. Since you used the word yourself, I wonder if it could be something with "trigger". |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot, LGTM. I have a few suggestions, but no blockers for merging. Feel free to follow or ignore them.
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM ! Left a few suggestions ;)
Hi @muellerzr , is Actually I am very excited to find this PR, since this seems quite relevant to my problem (similar to the early stopping problem), I'm using 2-gpu to train a roberta model, and I want to calculate the valication loss after ...
model, optimizer, train_dataloader, val_dataloader = accelerator.prepare(model, optimizer, train_dataloader, val_dataloader)
...
best_eval_loss = float('inf')
# Training loop:
for epoch in range(num_epochs):
for batch in train_dataloader:
model.train()
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
progress_bar.update(1)
# validation when reach 'eval_steps'
if progress_bar.n > 10 and progress_bar.n % args.eval_steps == 0:
model.eval()
losses = []
for val_batch in val_dataloader:
# ----------------> stuck here!
with torch.no_grad():
outputs = model(**val_batch)
e_loss = outputs.loss
losses.append(accelerator.gather_for_metrics(e_loss.repeat(val_bs)))
losses = torch.cat(losses)
eval_loss = torch.mean(losses)
print("eval_loss", eval_loss.item())
# save the current best model
if eval_loss.item() < best_eval_loss:
print("Current best model!, Steps:", progress_bar.n, "Eval loss:", eval_loss.item())
# save
... The process just got stuck when computing the first validation batch. Two GPUs are both 100% utilization, so I guess the process is hanging. After a long time, an error will occur:
Note that, if I move the validation code after each epoch, instead during one epoch, the program runs fine! Which is like this: # Training loop:
for epoch in range(num_epochs):
for batch in train_dataloader:
model.train()
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
progress_bar.update(1)
# validation:
model.eval()
losses = []
for val_batch in val_dataloader:
with torch.no_grad():
outputs = model(**val_batch)
e_loss = outputs.loss
losses.append(accelerator.gather_for_metrics(e_loss.repeat(val_bs)))
losses = torch.cat(losses)
eval_loss = torch.mean(losses)
print("eval_loss", eval_loss.item()) Could you help me on this? Thanks a lot! |
Introduce breakpoint API
What does this add?
This PR adds two new functions to
Accelerator
:set_breakpoint
andcheck_breakpoint
based on https://discuss.pytorch.org/t/how-to-use-break-in-distributeddataparallel-training/88296/7Who is it for?
https://discuss.huggingface.co/t/early-stopping-for-eval-loss-causes-timeout/51349
Why is it needed?
When doing early stopping in DDP, if each process has a specific conditional that it can check, where it may not be synchronized across all of them, a
break
can happen on process 0 but not on process 1. As a result this will cause the code to hang indefinitely until a timeout occurs.To address that this PR introduces a new
set
andcheck
breakpoint API, which should be used in-tandem with such conditionals to ensure that the breakpoint will be reached. This API can be used on any distributed type, not just DDP/multi-gpuWhat parts of the API does this impact?
User-facing:
Accelerator.set_breakpoint
andAccelerator.check_breakpoint
Internal structure:
The
Accelerator
object now keeps track of a tensor atself.flag_tensor
Basic Usage Example(s):
When would I use it, and when wouldn't I?
A prime example would be triggering early stopping, where stopping on process 0 should trigger stopping the process across all of them.