Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce breakpoint API #1940

Merged
merged 13 commits into from
Sep 13, 2023
17 changes: 17 additions & 0 deletions docs/source/concept_guides/deferring_execution.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,20 @@ with accelerator.main_process_first():
remove_columns=["idx", "sentence1", "sentence2"],
)
```

## Applying checks such as Early Stopping

To have a check that works with a flag set by a particular process, the `check` and `set` breakpoint API should be used. Useful examples
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
for doing so can include situations such as using early stopping and monitoring the loss (as each loss slightly differs on each process).

Call [`Accelerator.set_trigger`] when your condition has been met, and [`Accelerator.check_trigger`] when checking if that condition has been met in any process:

```python
# Assume `should_do_early_stopping` is a custom defined function that returns a conditional
if should_do_early_stopping(loss):
accelerator.set_trigger()

# Later in the training script when we need to check for the breakpoint
if accelerator.check_trigger():
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
break
```
246 changes: 246 additions & 0 deletions examples/by_feature/early_stopping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
# coding=utf-8
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse

import evaluate
import torch
from datasets import load_dataset
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed

from accelerate import Accelerator, DistributedType


########################################################################
# This is a fully working simple example to use Accelerate
# specifically showcasing how to perform early stopping,
# and builds off the `nlp_example.py` script
#
# This example trains a Bert base model on GLUE MRPC
# in any of the following settings (with the same script):
# - single CPU or single GPU
# - multi GPUS (using PyTorch distributed mode)
# - (multi) TPUs
# - fp16 (mixed-precision) or fp32 (normal precision)
#
# To run it in each of these various modes, follow the instructions
# in the readme for examples:
# https://github.com/huggingface/accelerate/tree/main/examples
#
########################################################################


MAX_GPU_BATCH_SIZE = 16
EVAL_BATCH_SIZE = 32


def get_dataloaders(accelerator: Accelerator, batch_size: int = 16):
"""
Creates a set of `DataLoader`s for the `glue` dataset,
using "bert-base-cased" as the tokenizer.

Args:
accelerator (`Accelerator`):
An `Accelerator` object
batch_size (`int`, *optional*):
The batch size for the train and validation DataLoaders.
"""
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
datasets = load_dataset("glue", "mrpc")

def tokenize_function(examples):
# max_length=None => use the model max length (it's actually the default)
outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None)
return outputs

# Apply the method we just defined to all the examples in all the splits of the dataset
# starting with the main process first:
with accelerator.main_process_first():
tokenized_datasets = datasets.map(
tokenize_function,
batched=True,
remove_columns=["idx", "sentence1", "sentence2"],
)

# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the
# transformers library
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")

def collate_fn(examples):
# On TPU it's best to pad everything to the same length or training will be very slow.
max_length = 128 if accelerator.distributed_type == DistributedType.TPU else None
# When using mixed precision we want round multiples of 8/16
if accelerator.mixed_precision == "fp8":
pad_to_multiple_of = 16
elif accelerator.mixed_precision != "no":
pad_to_multiple_of = 8
else:
pad_to_multiple_of = None

return tokenizer.pad(
examples,
padding="longest",
max_length=max_length,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors="pt",
)

# Instantiate dataloaders.
train_dataloader = DataLoader(
tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size, drop_last=True
)
eval_dataloader = DataLoader(
tokenized_datasets["validation"],
shuffle=False,
collate_fn=collate_fn,
batch_size=EVAL_BATCH_SIZE,
drop_last=(accelerator.mixed_precision == "fp8"),
)

return train_dataloader, eval_dataloader


# New code
class EarlyStoppingCallback:
"A callback class that helps with early stopping"

def __init__(self, min_delta=0, patience=5):
self.min_delta = min_delta
self.patience = patience
self.counter = 0
self.lowest_loss = float("inf")

def check_early_stopping(self, eval_loss):
delta = self.lowest_loss - eval_loss
if delta >= self.min_delta:
self.lowest_loss = eval_loss
self.counter = 0
else:
self.counter += 1
if self.counter >= self.patience:
return True
return False


callback = EarlyStoppingCallback()


def training_function(config, args):
# Initialize accelerator
accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)
# Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
lr = config["lr"]
num_epochs = int(config["num_epochs"])
seed = int(config["seed"])
batch_size = int(config["batch_size"])

metric = evaluate.load("glue", "mrpc")

# If the batch size is too big we use gradient accumulation
gradient_accumulation_steps = 1
if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.TPU:
gradient_accumulation_steps = batch_size // MAX_GPU_BATCH_SIZE
batch_size = MAX_GPU_BATCH_SIZE

set_seed(seed)
train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size)
# Instantiate the model (we build the model here so that the seed also control new weights initialization)
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", return_dict=True)

# We could avoid this line since the accelerator is set with `device_placement=True` (default value).
# Note that if you are placing tensors on devices manually, this line absolutely needs to be before the optimizer
# creation otherwise training will not work on TPU (`accelerate` will kindly throw an error to make us aware of that).
model = model.to(accelerator.device)
# Instantiate optimizer
optimizer = AdamW(params=model.parameters(), lr=lr)

# Instantiate scheduler
lr_scheduler = get_linear_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=100,
num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,
)

# Prepare everything
# There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the
# prepare method.

model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)

# Now we train the model
for epoch in range(num_epochs):
model.train()
for step, batch in enumerate(train_dataloader):
# We could avoid this line since we set the accelerator with `device_placement=True`.
batch.to(accelerator.device)
outputs = model(**batch)
loss = outputs.loss
loss = loss / gradient_accumulation_steps
accelerator.backward(loss)
if step % gradient_accumulation_steps == 0:
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()

# New code
# Check if we should stop the training on any processes
if callback.check_early_stopping(loss.item()):
accelerator.set_trigger()

# If so, we break the loop
if accelerator.check_trigger():
break

model.eval()
for step, batch in enumerate(eval_dataloader):
# We could avoid this line since we set the accelerator with `device_placement=True`.
batch.to(accelerator.device)
with torch.no_grad():
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"]))
metric.add_batch(
predictions=predictions,
references=references,
)

eval_metric = metric.compute()

# Use accelerator.print to print only on the main process.
accelerator.print(f"epoch {epoch}:", eval_metric)


def main():
parser = argparse.ArgumentParser(description="Simple example of training script.")
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16", "fp8"],
help="Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU.",
)
parser.add_argument("--cpu", action="store_true", help="If passed, will train on the CPU.")
args = parser.parse_args()
config = {"lr": 2e-5, "num_epochs": 3, "seed": 42, "batch_size": 16}
training_function(config, args)


if __name__ == "__main__":
main()
62 changes: 62 additions & 0 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,9 @@ def __init__(
if self.rng_types is None:
self.rng_types = ["generator"]

# Set a flag tensor for early stopping and other breakpoints
self.flag_tensor = None

@property
def use_distributed(self):
"""
Expand Down Expand Up @@ -1963,6 +1966,65 @@ def backward(self, loss, **kwargs):
else:
loss.backward(**kwargs)

def set_trigger(self):
"""
Sets the internal trigger tensor to 1 on the current process. A latter check should follow using this which
will check across all processes.

Note:
Does not require `wait_for_everyone()`

Example:

```python
>>> from accelerate import Accelerator

>>> accelerator = Accelerator()
>>> # Assume later in the training script
>>> # `should_do_breakpoint` is a custom function to monitor when to break,
>>> # e.g. when the loss is NaN
>>> if should_do_breakpoint(loss):
... accelerator.set_trigger()
>>> # Assume later in the training script
>>> if accelerator.check_breakpoint():
... break
```
"""
self.flag_tensor = torch.tensor(1, device=self.device)

def check_trigger(self):
"""
Checks if the internal trigger tensor has been set to 1 in any of the processes. If so, will return `True` and
reset the trigger tensor to 0.
SunMarc marked this conversation as resolved.
Show resolved Hide resolved

Note:
Does not require `wait_for_everyone()`

Example:

```python
>>> from accelerate import Accelerator

>>> accelerator = Accelerator()
>>> # Assume later in the training script
>>> # `should_do_breakpoint` is a custom function to monitor when to break,
>>> # e.g. when the loss is NaN
>>> if should_do_breakpoint(loss):
... accelerator.set_trigger()
>>> # Assume later in the training script
>>> if accelerator.check_trigger():
... break
```
"""
# Now that we are outside `__init__`, we can initialize it if it is `None` on device
if self.flag_tensor is None:
self.flag_tensor = torch.tensor(0, device=self.device)
flag_tensor = self.reduce(self.flag_tensor)
if flag_tensor.item() >= 1:
self.flag_tensor = None
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
return True
return False

def unscale_gradients(self, optimizer=None):
"""
Unscale the gradients in mixed precision training with AMP. This is a noop in all other settings.
Expand Down
18 changes: 18 additions & 0 deletions src/accelerate/test_utils/scripts/test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,20 @@ def test_split_between_processes_tensor():
state.wait_for_everyone()


def test_breakpoint():
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
accelerator = Accelerator()
# should start with being false
assert accelerator.check_trigger() is False

# set a breakpoint on the main process
if accelerator.is_main_process:
accelerator.set_trigger()

# check it's been activated across all processes
# calls `all_reduce` and triggers a sync
assert accelerator.check_trigger() is True
muellerzr marked this conversation as resolved.
Show resolved Hide resolved


def main():
accelerator = Accelerator()
state = accelerator.state
Expand Down Expand Up @@ -590,6 +604,10 @@ def main():
print("\n**Training integration test**")
training_check()

if state.local_process_index == 0:
print("\n**Breakpoint test**")
test_breakpoint()


if __name__ == "__main__":
main()
5 changes: 5 additions & 0 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"fsdp_with_peak_mem_tracking.py",
"deepspeed_with_config_support.py",
"megatron_lm_gpt_pretraining.py",
"early_stopping.py",
]


Expand Down Expand Up @@ -222,3 +223,7 @@ def test_gradient_accumulation(self):
def test_local_sgd(self):
testargs = ["examples/by_feature/local_sgd.py"]
run_command(self._launch_args + testargs)

def test_early_stopping(self):
testargs = ["examples/by_feature/early_stopping.py"]
run_command(self._launch_args + testargs)
Loading