# Early Stopping and Threshold Stopping

In composer, Callbacks modify trainer behavior and are called at the relevent Events in the training loop. This tutorial focuses on two callbacks, the EarlyStopper and ThresholdStopper, both of which halt training early depending on different criteria.


## Setup



In this tutorial, we'll train a ComposerModel and halt training for criteria that we'll set. We'll use the same model and setup in the "Getting Up and Running with Composer" tutorial.

First, install composer if you haven't already:

In [None]:
!pip install mosaicml 

import torch
import composer
from torchvision import datasets, transforms

torch.manual_seed(42)

In [None]:
data_directory = "../data"

# Normalization constants
mean = (0.507, 0.487, 0.441)
std = (0.267, 0.256, 0.276)

batch_size = 1024

cifar10_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])

train_dataset = datasets.CIFAR10(data_directory, train=True, download=True, transform=cifar10_transforms)
test_dataset = datasets.CIFAR10(data_directory, train=False, download=True, transform=cifar10_transforms)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

## Model, Optimizer, Scheduler Setup

In [None]:
from composer import models

model = models.ComposerResNetCIFAR(model_name='resnet_56', num_classes=10)

optimizer = composer.optim.DecoupledSGDW(
    model.parameters(), # Model parameters to update
    lr=0.05, # Peak learning rate
    momentum=0.9,
    weight_decay=2.0e-3 # If this looks large, it's because its not scaled by the LR as in non-decoupled weight decay
)

lr_scheduler = composer.optim.LinearWithWarmupScheduler(
    t_warmup="1ep", # Warm up over 1 epoch
    alpha_i=1.0, # Flat LR schedule achieved by having alpha_i == alpha_f
    alpha_f=1.0
)

## EarlyStopper

The EarlyStopper callback tracks a particular training or evaluation metric and stops training if the metric does not improve within a given time interval. 


Here, we'll use it track the Accuracy metric and set it set the patience val

The EarlyStopper takes several parameters. Here are the one's we'll use in this tutorial:
-  monitor: The string name of the metric to track

- dataloader_label: The dataloader_label identifies which specific metric to use. For example, for the trainer to track the Accuracy metric associated with the test dataset, the dataloader_label distinguishes the metric to track. In our example, our dataloader_label has to be the same as the label in our Evaluator, "test_eval_label". When not using 
Evaluators, the dataloader_labels, are usually set to "train" and "eval".

- patience: The interval of the time that the EarlyStopper will wait before stopping training if the metric is not improving. You can use integers to specify the number of epochs or use the units of time specified in the Composer Time library ex, "50ba", "2ep" for 50 batches and 2 epochs.

In [None]:
early_stopper = EarlyStopper(monitor="Accuracy", dataloader_label="test_eval_label", patience=1)

There are several other parameters you can specify:
- min_delta: If the min_delta is a non zero value, the EarlyStopper will still halt training if the change in the metric is smaller than min_delta value. 

- comp: A comparison operator can be provided toe measure change in the monitored metric. The comparison operator will be called comp(current_value, previous_best)

More details can be found in the documentation for the EarlyStopper callback

In [None]:
from composer.callbacks.early_stopper import EarlyStopper
from torchmetrics.classification.accuracy import Accuracy
from composer.core import Evaluator


early_stopper = EarlyStopper("Accuracy", "test_eval_label", patience=1)
evaluator = Evaluator(
    dataloader = test_dataloader,
    label = "test_evaluator",
    metrics = Accuracy()
)


Now that we have our EarlyStopper callback object tracking our Evauator metric, we can instantiate our Trainer.

In [None]:
train_epochs = "3ep" # Train for 3 epochs because we're assuming Colab environment and hardware
device = "gpu" if torch.cuda.is_available() else "cpu" # select the device

trainer = composer.trainer.Trainer(
    model=model,
    train_dataloader=train_dataloader,
    eval_dataloader=evaluator,
    max_duration=train_epochs,
    optimizers=optimizer,
    schedulers=lr_scheduler,
    device=device,
    callbacks=[early_stopper]
)

## ThresholdStopper

The ThresholdStopper is similar to the EarlyStopper, but it halts training when the metric crosses a threshold set in the ThresholdStopper callback.

The parameters for the ThresholdStopper are similar to the EarlyStopper. See the above descriptions of the `monitor`, and `dataloader_label`, and `comp`

The other parameters are:
- threshold: The float threshold that dictates when the halt training.

- stop_on_batch: If stop_on_batch is specified, training will halt in the middle of training if the training metrics satisfy the training metrics.

We will reuse the same setup for the ThresholdStopper example.

In [None]:
from composer.callbacks.threshold_stopper import ThresholdStopper
from torchmetrics.classification.accuracy import Accuracy


threshold_stopper = ThresholdStopper("Accuracy", "test_eval_label", threshold=0.65)
evaluator = Evaluator(
    dataloader = test_dataloader,
    label = "test_evaluator",
    metrics = Accuracy()
)

train_epochs = "3ep" # Train for 3 epochs because we're assuming Colab environment and hardware
device = "gpu" if torch.cuda.is_available() else "cpu" # select the device

trainer = composer.trainer.Trainer(
    model=model,
    train_dataloader=train_dataloader,
    eval_dataloader=evaluator,
    max_duration=train_epochs,
    optimizers=optimizer,
    schedulers=lr_scheduler,
    device=device,
    callbacks=[threshold_stopper]
)