In [None]:
# default_exp callbacks

# Custom List of Lightning Callbacks

In [None]:
import warnings

warnings.filterwarnings('ignore')

In [None]:
# export
import time
from collections import namedtuple

import pytorch_lightning as pl
import torch
import wandb
from pytorch_lightning import Callback, Trainer
from timm.utils import AverageMeter
from tqdm.auto import tqdm
import copy

from src import _logger
from src.core import conf_mat_idx2lbl, idx2lbl
from src.models import Task

## Custom Wandb Callback -
> meant to be used in conjunction with `pl.loggers.WandbLogger`

In [None]:
# export
class WandbTask(Callback):
    """ Custom callback to add some extra functionalites to the wandb logger 
    Does the following:
        1. Logs the model graph to wandb.
        2. Logs confusion matrix of preds/labels for each validation epoch.
        3. Logs confusion matrix of preds/labels after testing.
    """
    class_names = list(conf_mat_idx2lbl.values())
    
    def on_train_start(self, trainer: Trainer, pl_module: Task, *args, **kwrags) -> None:
        try   : wandb.watch(models=pl_module.model, criterion=pl_module.criterion)
        except: pass
        
    def on_validation_epoch_start(self, trainer: Trainer, pl_module: Task, *args, **kwrags) -> None:
        self.labels, self.predictions = [], []
    
    def on_validation_batch_end(self, trainer: Trainer, pl_module: Task, *args, **kwrags) -> None:
        self.labels = self.labels + pl_module.labels
        self.predictions = self.predictions + pl_module.preds

    def on_validation_epoch_end(self, trainer: Trainer, pl_module: Task, *args, **kwrags) -> None:
        preds   = torch.tensor(self.predictions).data.cpu().numpy()
        vlabels = torch.tensor(self.labels).data.cpu().numpy()
        
        matrix = wandb.plot.confusion_matrix(preds, labels, self.class_names)
        wandb.log(dict(valid_confusion_matrix=matrix), commit=False)
        
    def on_test_epoch_start(self, trainer: Trainer, pl_module: Task, *args, **kwrags) -> None:
        self.labels, self.predictions = [], []
        
    def on_test_batch_end(self, trainer: Trainer, pl_module: Task, *args, **kwrags) -> None:
        self.labels = self.labels + pl_module.labels
        self.predictions = self.predictions + pl_module.preds
        
    def on_test_epoch_end(self, trainer: Trainer, pl_module: Task, *args, **kwrags) -> None:
        preds   = torch.tensor(self.predictions).data.cpu().numpy()
        vlabels = torch.tensor(self.labels).data.cpu().numpy()
        
        matrix = wandb.plot.confusion_matrix(preds, labels, self.class_names)
        wandb.log(dict(test_confusion_matrix=matrix), commit=False)

## Custom Progress Bar Callback -

In [None]:
# export
class DisableValidationBar(pl.callbacks.ProgressBar):
    "Custom Progressbar callback for Lightning Training which disables the validation bar"

    def init_sanity_tqdm(self) -> tqdm:
        """ Override this to customize the tqdm bar for the validation sanity run. """
        bar = tqdm(desc="Validation sanity check", dynamic_ncols=True,)
        return bar

    def init_train_tqdm(self) -> tqdm:
        """ Override this to customize the tqdm bar for training. """
        bar = tqdm(desc="Training", disable=self.is_disabled, dynamic_ncols=True,)
        return bar

    def init_validation_tqdm(self) -> tqdm:
        """ Override this to customize the tqdm bar for validation. """
        bar = tqdm(desc="Validating", disable=True, dynamic_ncols=False,)
        return bar

    def init_test_tqdm(self) -> tqdm:
        """ Override this to customize the tqdm bar for testing. """
        bar = tqdm(desc="Testing", disable=self.is_disabled, dynamic_ncols=True,)
        return bar

## Information Logger Callback -

In [None]:
# export
class LogInformationCallback(pl.Callback):
    "Logs Training loss/metric to console after every epoch"
    TrainResult = namedtuple("TrainOutput", ["loss", "acc", "val_loss", "val_acc"])
    TestResult  = namedtuple("TestOutput",  ["test_loss", "test_acc"])
    
    def on_train_epoch_start(self, trainer: Trainer, pl_module: Task, *args, **kwrags) -> None:
        self.batch_time=  AverageMeter()
        self.end = time.time()
        
    def on_train_batch_end(self, trainer: Trainer, pl_module: Task, *args, **kwrags) -> None:
        self.batch_time.update(time.time() - self.end)

    def on_epoch_end(self, trainer: Trainer, pl_module: Task, *args, **kwrags) -> None:
        metrics = copy.copy(trainer.callback_metrics)
        
        train_loss = metrics["train/loss_epoch"]
        train_acc  = metrics["train/acc_epoch"]
        valid_loss = metrics["valid/loss"]
        valid_acc  = metrics["valid/acc"]
        
        res = self.TrainResult(
            round(train_loss.data.cpu().numpy().item(), 4),
            round(train_acc.data.cpu().numpy().item(),  4),
            round(valid_loss.data.cpu().numpy().item(), 4),
            round(valid_acc.data.cpu().numpy().item(),  4),
        )

        curr_epoch = int(pl_module.current_epoch)
        total_epoch = int(trainer.max_epochs)
        _logger.info(f"Train: [ {curr_epoch}/{total_epoch}] Time: {self.batch_time.val:.3f} ({self.batch_time.avg:.3f}) {res}")
        
    def on_test_epoch_start(self, trainer: Trainer, pl_module: Task, *args, **kwrags) -> None:
        self.batch_time = AverageMeter()
        self.end = time.time()
        
    def on_test_batch_end(self, trainer: Trainer, pl_module: Task, *args, **kwrags) -> None:
        self.batch_time.update(time.time() - self.end)

    def on_test_epoch_end(self, trainer: Trainer, pl_module: Task, *args, **kwrags) -> None:
        metrics = trainer.callback_metrics
        
        test_loss = metrics["test/loss"]
        test_acc  = metrics["test/acc"]
        
        res = self.TestResult(
            round(test_loss.data.cpu().numpy().item(), 4), 
            round(test_acc.data.cpu().numpy().item(),  4))
        
        _logger.info(f"Test: Time: {self.batch_time.val:.3f} ({self.batch_time.avg:.3f}) {res}")

In [None]:
import os
from hydra.experimental import compose, initialize
from hydra.utils import instantiate
from omegaconf import OmegaConf
from pl_bolts.callbacks import BatchGradientVerificationCallback, TrainingDataMonitor

In [None]:
overrides = [
    "input.input_size=120",
    "training.batch_size=5",
    "augmentations=tfms-v0",
    "data.dataloader.num_workers=0",
    "general=default",
    "trainer=fast-dev-cpu",
    "mixmethod=snapmix",
    "training.mix_epochs=3",
    "training.batch_size=64",
    "model=v0",
    "model.base_model.activation=mish",
    "model.head.params.act_layer=mish",
    "training.accumulate_grad_batches=1",
    "loss=crossentropy",
]

with initialize(config_path=os.path.relpath("../conf/")):
    cfg = compose(config_name="effnet-base", overrides=overrides)

In [None]:
model = Task(cfg)

cbs = [LogInformationCallback(), 
       DisableValidationBar(), 
       pl.callbacks.LearningRateMonitor("step"),
       TrainingDataMonitor()]

trainer = instantiate(cfg.trainer, callbacks=cbs, max_epochs=8, 
                      limit_train_batches=10, limit_val_batches=5, 
                      limit_test_batches=5, weights_summary="top", 
                      profiler=False, terminate_on_nan=True)
trainer.fit(model)

[[32m01/31 20:56:16[0m [35msrc.models.builder[0m]: Configuration for the current model :
[[32m01/31 20:56:16[0m [35msrc.models.builder[0m]:  feature_extractor: tf_efficientnet_b3_ns
[[32m01/31 20:56:16[0m [35msrc.models.builder[0m]:  activation: mish
[[32m01/31 20:56:16[0m [35msrc.models.builder[0m]:  params: {'drop_path_rate': 0.25}
[[32m01/31 20:56:16[0m [35msrc.models.builder[0m]:  head: CnnHeadV0
[[32m01/31 20:56:16[0m [35msrc.models.builder[0m]:  params: {'n_out': 5, 'pool_type': 'avg', 'use_conv': False, 'act_layer': 'mish'}
[[32m01/31 20:56:17[0m [35msrc.models.task[0m]: LossFunction: CrossEntropyLoss()
[[32m01/31 20:56:17[0m [35msrc.models.task[0m]: Training with Snapmix(alpha=5.0, conf_prob=0.5, num_iters=3).
GPU available: False, used: False
TPU available: None, using: 0 TPU cores
[[32m01/31 20:56:17[0m [35msrc.data.dataset_factory[0m]: Generating Datasets for FOLD :0
[[32m01/31 20:56:18[0m [35msrc.data.dataset_factory[0m]: Train Datase

Training: 0it [00:00, ?it/s]

[[32m01/31 20:58:56[0m [35m__main__[0m]: Train: [ 0/8] Time: 138.052 (75.587) TrainOutput(loss=1.7081, acc=0.4422, val_loss=1.6205, val_acc=0.0938)


1

In [None]:
_ = trainer.test(verbose=False)

[[32m01/31 20:59:58[0m [35msrc.data.dataset_factory[0m]: Generating Datasets for FOLD :0
[[32m01/31 20:59:59[0m [35msrc.data.dataset_factory[0m]: Train Dataset has 17117, Validation Dataset has 4280 instances.


Testing: 0it [00:00, ?it/s]

[[32m01/31 21:00:19[0m [35m__main__[0m]: Test: Time: 18.367 (11.300) TestOutput(test_loss=1.6166, test_acc=0.1094)


In [None]:
# hide
from nbdev.export import *

notebook2script()

Converted 00_core.ipynb.
Converted 01a_data.datasets.ipynb.
Converted 01b_data.datasests_factory.ipynb.
Converted 01c_data.mixmethods.ipynb.
Converted 02_losses.ipynb.
Converted 03a_optimizers.ipynb.
Converted 03b_schedulers.ipynb.
Converted 04a_models.utils.ipynb.
Converted 04b_models.layers.ipynb.
Converted 04c_models.classifiers.ipynb.
Converted 04d_models.builder.ipynb.
Converted 04e_models.task.ipynb.
Converted 05_callbacks.ipynb.
Converted index.ipynb.
