Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce breakpoint API #1940

Merged
merged 13 commits into from
Sep 13, 2023
Merged

Introduce breakpoint API #1940

merged 13 commits into from
Sep 13, 2023

Conversation

muellerzr
Copy link
Collaborator

@muellerzr muellerzr commented Sep 7, 2023

Introduce breakpoint API

What does this add?

This PR adds two new functions to Accelerator: set_breakpoint and check_breakpoint based on https://discuss.pytorch.org/t/how-to-use-break-in-distributeddataparallel-training/88296/7

Who is it for?

https://discuss.huggingface.co/t/early-stopping-for-eval-loss-causes-timeout/51349

Why is it needed?

When doing early stopping in DDP, if each process has a specific conditional that it can check, where it may not be synchronized across all of them, a break can happen on process 0 but not on process 1. As a result this will cause the code to hang indefinitely until a timeout occurs.

To address that this PR introduces a new set and check breakpoint API, which should be used in-tandem with such conditionals to ensure that the breakpoint will be reached. This API can be used on any distributed type, not just DDP/multi-gpu

What parts of the API does this impact?

User-facing:

Accelerator.set_breakpoint and Accelerator.check_breakpoint

Internal structure:

The Accelerator object now keeps track of a tensor at self.flag_tensor

Basic Usage Example(s):

# Assume `should_do_breakpoint` is a custom defined function that returns a conditional, 
# and that conditional might be true only on process 1
if should_do_breakpoint(loss):
    accelerator.set_breakpoint()

# Later in the training script when we need to check for the breakpoint
if accelerator.check_breakpoint():
    break

When would I use it, and when wouldn't I?

A prime example would be triggering early stopping, where stopping on process 0 should trigger stopping the process across all of them.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 7, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for investigating that issue and coming up with a solution.

This is not an in depth review yet, but maybe these comments can already help a bit. Apart from those, I also wondered about the term "breakpoint". I mostly associate it to the context of debugging, which is not what is happening here, but maybe my view is too narrow. In reality, this feature is really a synchronized flag and could in theory be used as a basis for all kinds of features, right?

src/accelerate/accelerator.py Outdated Show resolved Hide resolved
docs/source/concept_guides/deferring_execution.md Outdated Show resolved Hide resolved
@muellerzr
Copy link
Collaborator Author

@BenjaminBossan I'd say it's a conditional trigger setup, more than synchronizing flags. Because also synchronizing flags aren't really true, it's checking if a particular condition should be met that affects all processes which occurred in a singular process. (Or more than one of these processes).

Agreed the naming could use some work however, if anyone has some ideas lmk otherwise I'll think on it :)

@BenjaminBossan
Copy link
Member

I'd say it's a conditional trigger setup, more than synchronizing flags

True, it's basically an any operator that works across processes.

Agreed the naming could use some work however, if anyone has some ideas lmk otherwise I'll think on it :)

Always the hard questions. Since you used the word yourself, I wonder if it could be something with "trigger".

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot, LGTM. I have a few suggestions, but no blockers for merging. Feel free to follow or ignore them.

docs/source/concept_guides/deferring_execution.md Outdated Show resolved Hide resolved
docs/source/concept_guides/deferring_execution.md Outdated Show resolved Hide resolved
examples/by_feature/early_stopping.py Show resolved Hide resolved
src/accelerate/accelerator.py Outdated Show resolved Hide resolved
src/accelerate/test_utils/scripts/test_script.py Outdated Show resolved Hide resolved
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ! Left a few suggestions ;)

@muellerzr muellerzr merged commit 40a73e0 into main Sep 13, 2023
26 checks passed
@muellerzr muellerzr deleted the breakpoint branch September 13, 2023 16:42
@beyondguo
Copy link

Hi @muellerzr , is set_breakpoint diabled now? I got AttributeError: 'Accelerator' object has no attribute 'set_breakpoint' with the latest acclerate (v0.31.0).

Actually I am very excited to find this PR, since this seems quite relevant to my problem (similar to the early stopping problem), I'm using 2-gpu to train a roberta model, and I want to calculate the valication loss after eval_steps, here's the core code:

...
model, optimizer, train_dataloader, val_dataloader = accelerator.prepare(model, optimizer, train_dataloader, val_dataloader)
...
best_eval_loss = float('inf')

# Training loop:
for epoch in range(num_epochs): 
    for batch in train_dataloader:
        model.train()
        outputs = model(**batch)
        loss = outputs.loss 
        accelerator.backward(loss)
        optimizer.step()
        optimizer.zero_grad()
        progress_bar.update(1)

        # validation when reach 'eval_steps'
        if progress_bar.n > 10 and progress_bar.n % args.eval_steps == 0:
            model.eval()
            losses = []
            for val_batch in val_dataloader:
                # ----------------> stuck here!
                with torch.no_grad():
                    outputs = model(**val_batch)
                e_loss = outputs.loss
                losses.append(accelerator.gather_for_metrics(e_loss.repeat(val_bs)))
            losses = torch.cat(losses)
            eval_loss = torch.mean(losses)
            print("eval_loss", eval_loss.item())
            
            # save the current best model
            if eval_loss.item() < best_eval_loss:
                print("Current best model!, Steps:", progress_bar.n, "Eval loss:", eval_loss.item())
                # save
                ...

The process just got stuck when computing the first validation batch. Two GPUs are both 100% utilization, so I guess the process is hanging.

After a long time, an error will occur:

[E ProcessGroupNCCL.cpp:828] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=2315, OpType=ALLREDUCE, Timeout(ms)=1800000) ran for 1808543 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:828] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=2315, OpType=ALLGATHER, Timeout(ms)=1800000) ran for 1808546 milliseconds before timing out.
...
Traceback (most recent call last):
  File "rolling_model_train.py", line 289, in <module>
    train()
  File "rolling_model_train.py", line 195, in train
    losses.append(accelerator.gather_for_metrics(loss.repeat(val_bs)))
  File "/home/guoby/app/Anaconda3-2021.05/envs/news/lib/python3.8/site-packages/accelerate/accelerator.py", line 2217, in gather_for_metrics
[E ProcessGroupNCCL.cpp:455] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:460] To avoid data inconsistency, we are taking the entire process down.
terminate called after throwing an instance of 'std::runtime_error'
  what():  [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=2315, OpType=ALLGATHER, Timeout(ms)=1800000) ran for 1808546 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:455] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:460] To avoid data inconsistency, we are taking the entire process down.
terminate called after throwing an instance of 'std::runtime_error'
  what():  [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=2315, OpType=ALLREDUCE, Timeout(ms)=1800000) ran for 1808543 milliseconds before timing out.
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: -6) local_rank: 0 (pid: 89866) of binary: /home/guoby/app/Anaconda3-2021.05/envs/news/bin/python
Traceback (most recent call last):
  File "/home/guoby/app/Anaconda3-2021.05/envs/news/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/home/guoby/app/Anaconda3-2021.05/envs/news/lib/python3.8/site-packages/accelerate/commands/accelerate_cli.py", line 47, in main
    args.func(args)
  File "/home/guoby/app/Anaconda3-2021.05/envs/news/lib/python3.8/site-packages/accelerate/commands/launch.py", line 977, in launch_command
    multi_gpu_launcher(args)
  File "/home/guoby/app/Anaconda3-2021.05/envs/news/lib/python3.8/site-packages/accelerate/commands/launch.py", line 646, in multi_gpu_launcher
    distrib_run.run(args)
  File "/home/guoby/app/Anaconda3-2021.05/envs/news/lib/python3.8/site-packages/torch/distributed/run.py", line 785, in run
    elastic_launch(
  File "/home/guoby/app/Anaconda3-2021.05/envs/news/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/guoby/app/Anaconda3-2021.05/envs/news/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 250, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
======================================================
rolling_model_train.py FAILED
------------------------------------------------------
Failures:
[1]:
  time      : 2024-06-20_15:02:22
  rank      : 1 (local_rank: 1)
  exitcode  : -6 (pid: 89867)
  error_file: <N/A>
  traceback : Signal 6 (SIGABRT) received by PID 89867
------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-06-20_15:02:22
  rank      : 0 (local_rank: 0)
  exitcode  : -6 (pid: 89866)
  error_file: <N/A>
  traceback : Signal 6 (SIGABRT) received by PID 89866
======================================================

Note that, if I move the validation code after each epoch, instead during one epoch, the program runs fine! Which is like this:

# Training loop:
for epoch in range(num_epochs): 
    for batch in train_dataloader:
        model.train()
        outputs = model(**batch)
        loss = outputs.loss 
        accelerator.backward(loss)
        optimizer.step()
        optimizer.zero_grad()
        progress_bar.update(1)

    # validation:
    model.eval()
    losses = []
    for val_batch in val_dataloader:
        with torch.no_grad():
            outputs = model(**val_batch)
        e_loss = outputs.loss
        losses.append(accelerator.gather_for_metrics(e_loss.repeat(val_bs)))
    losses = torch.cat(losses)
    eval_loss = torch.mean(losses)
    print("eval_loss", eval_loss.item())

Could you help me on this? Thanks a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants