# Callbacks

##### Callbacks in PyTorch Lightning are a way to execute custom code at specific points during the training and validation process. They provide a flexible mechanism to perform actions such as logging metrics, saving model checkpoints, adjusting learning rates, or early stopping based on certain conditions.

##### Here’s a quick rundown of common callbacks in PyTorch Lightning:

##### ModelCheckpoint: Saves the model at specified intervals or when certain metrics improve. Useful for keeping the best version of the model.

##### EarlyStopping: Stops training early if a monitored metric does not improve for a certain number of epochs, preventing overfitting and saving computation.

##### LearningRateMonitor: Logs learning rate changes during training, which can help in visualizing and debugging the learning rate schedule.

##### ProgressBar: Displays a progress bar during training and validation, which helps in monitoring the training process visually.

##### TensorBoard: Integrates with TensorBoard to log metrics, model graphs, and other useful visualizations.

##### You can also create custom callbacks by subclassing pl.Callback and overriding methods like on_epoch_end, on_train_batch_end, or on_validation_end to insert your custom logic.

### 1. ModelCheckpoint

In [None]:
import pytorch_lightning as pl

class CustomModelCheckpoint(pl.Callback):
    def __init__(self, monitor='val_loss', mode='min', save_top_k=1):
        self.monitor = monitor
        self.mode = mode
        self.save_top_k = save_top_k
        self.best_score = float('inf') if mode == 'min' else -float('inf')
        self.saved_models = []

    def on_validation_end(self, trainer, pl_module):
        current_score = trainer.callback_metrics.get(self.monitor)
        if current_score is not None:
            if (self.mode == 'min' and current_score < self.best_score) or (self.mode == 'max' and current_score > self.best_score):
                self.best_score = current_score
                # Save model checkpoint
                checkpoint_path = f"model_checkpoint_epoch_{trainer.current_epoch}.ckpt"
                trainer.save_checkpoint(checkpoint_path)
                self.saved_models.append(checkpoint_path)
                # Keep only the top_k models
                if len(self.saved_models) > self.save_top_k:
                    oldest_model = self.saved_models.pop(0)
                    os.remove(oldest_model)


### 2. EarlyStopping


In [None]:
import pytorch_lightning as pl

class CustomEarlyStopping(pl.Callback):
    def __init__(self, monitor='val_loss', patience=3, mode='min'):
        self.monitor = monitor
        self.patience = patience
        self.mode = mode
        self.best_score = float('inf') if mode == 'min' else -float('inf')
        self.wait = 0
        self.stopped_epoch = 0

    def on_validation_end(self, trainer, pl_module):
        current_score = trainer.callback_metrics.get(self.monitor)
        if current_score is not None:
            if (self.mode == 'min' and current_score < self.best_score) or (self.mode == 'max' and current_score > self.best_score):
                self.best_score = current_score
                self.wait = 0
            else:
                self.wait += 1
                if self.wait >= self.patience:
                    trainer.should_stop = True
                    self.stopped_epoch = trainer.current_epoch
                    print(f"Early stopping at epoch {self.stopped_epoch}")


### 3. LearningRateMonitor


In [None]:
import pytorch_lightning as pl

class CustomLearningRateMonitor(pl.Callback):
    def on_batch_end(self, trainer, pl_module):
        for i, param_group in enumerate(trainer.optimizers[0].param_groups):
            lr = param_group['lr']
            print(f"Epoch: {trainer.current_epoch}, Batch: {trainer.global_step}, Learning Rate: {lr}")


### 4. ProgressBar


In [None]:
import pytorch_lightning as pl

class CustomProgressBar(pl.Callback):
    def on_train_start(self, trainer, pl_module):
        print(f"Training started for {trainer.max_epochs} epochs.")
    
    def on_epoch_end(self, trainer, pl_module):
        print(f"Epoch {trainer.current_epoch} ended.")
    
    def on_train_end(self, trainer, pl_module):
        print("Training finished.")


### 5. TensorBoard


In [None]:
import pytorch_lightning as pl
from torch.utils.tensorboard import SummaryWriter

class CustomTensorBoard(pl.Callback):
    def __init__(self, log_dir='logs'):
        self.writer = SummaryWriter(log_dir=log_dir)
    
    def on_epoch_end(self, trainer, pl_module):
        for name, param in pl_module.named_parameters():
            self.writer.add_histogram(name, param, trainer.current_epoch)
        self.writer.flush()
    
    def on_train_end(self, trainer, pl_module):
        self.writer.close()


In [None]:
trainer = pl.Trainer(callbacks=[CustomModelCheckpoint(), CustomEarlyStopping(), CustomLearningRateMonitor(), CustomProgressBar(), CustomTensorBoard()])