In [None]:
# default_exp lightning.callbacks

In [None]:
%load_ext nb_black

<IPython.core.display.Javascript object>

In [None]:
# export
import os
import sys
import time
import datetime
import logging
from collections import namedtuple
from tqdm.auto import tqdm

import wandb

import torch
import pytorch_lightning as pl
from pytorch_lightning import _logger as log
from pytorch_lightning.core.memory import ModelSummary

from src.all import *

<IPython.core.display.Javascript object>

In [None]:
import colorlog

handler = colorlog.StreamHandler()

fmt = "[%(cyan)s%(asctime)s%(reset)s][%(blue)s%(name)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s"
colors = dict(
    DEBUG="purple", INFO="green", WARNING="yellow", ERROR="red", CRITICAL="red"
)
formatter = colorlog.ColoredFormatter(fmt=fmt, log_colors=colors)
handler.setFormatter(formatter)

logging.basicConfig(format=fmt, level=logging.INFO, handlers=[handler])

<IPython.core.display.Javascript object>

In [None]:
# export
class WandbImageClassificationCallback(pl.Callback):
    """ Custom callback to add some extra functionalites to the wandb logger """

    def __init__(
        self,
        num_batches: int = 16,
        log_train_batch: bool = False,
        log_preds: bool = False,
        log_conf_mat: bool = True,
    ):

        # class names for the confusion matrix
        self.class_names = list(conf_mat_idx2lbl.values())

        # counter to log training batch images
        self.num_bs = num_batches
        self.curr_epoch = 0

        self.log_train_batch = log_train_batch
        self.log_preds = log_preds
        self.log_conf_mat = log_conf_mat

        self.val_imgs, self.val_labels = None, None

    def on_train_start(self, trainer, pl_module, *args, **kwargs):
        try:
            # log model to the wandb experiment
            wandb.watch(models=pl_module.model, criterion=pl_module.loss_func)
        except:
            pass

    def on_train_epoch_end(self, trainer, pl_module, *args, **kwargs):
        if self.log_train_batch:
            if pl_module.one_batch is None:
                pass

            else:
                one_batch = pl_module.one_batch[: self.num_bs]
                train_ims = one_batch.data.to("cpu")
                trainer.logger.experiment.log(
                    {"train_batch": [wandb.Image(x) for x in train_ims]}, commit=False
                )

    def on_validation_epoch_end(self, trainer, pl_module, *args, **kwargs):
        if self.log_preds:
            if self.val_imgs is None and self.val_labels is None:
                self.val_imgs, self.val_labels = next(iter(pl_module.val_dataloader()))
                self.val_imgs, self.val_labels = (
                    self.val_imgs[: self.num_bs],
                    self.val_labels[: self.num_bs],
                )
                self.val_imgs = self.val_imgs.to(device=pl_module.device)

            logits = pl_module(self.val_imgs)
            preds = torch.argmax(logits, 1)
            preds = preds.data.cpu()

            ims = [
                wandb.Image(x, caption=f"Pred:{pred}, Label:{y}")
                for x, pred, y in zip(self.val_imgs, preds, self.val_labels)
            ]
            log_dict = {"predictions": ims}
            wandb.log(ims, commit=False)

    def on_epoch_start(self, trainer, pl_module, *args, **kwargs):
        pl_module.val_labels_list = []
        pl_module.val_preds_list = []

    def on_epoch_end(self, trainer, pl_module, *args, **kwargs):
        if self.log_conf_mat:
            val_preds = torch.tensor(pl_module.val_preds_list).data.cpu().numpy()
            val_labels = torch.tensor(pl_module.val_labels_list).data.cpu().numpy()
            log_dict = {
                "conf_mat": wandb.plot.confusion_matrix(
                    val_preds, val_labels, self.class_names
                )
            }
            wandb.log(log_dict, commit=False)

<IPython.core.display.Javascript object>

In [None]:
# export
class DisableValidationBar(pl.callbacks.ProgressBar):
    "Custom Progressbar callback for Lightning Training which disables the validation bar"

    def init_sanity_tqdm(self) -> tqdm:
        """ Override this to customize the tqdm bar for the validation sanity run. """
        bar = tqdm(
            desc="Validation sanity check",
            position=(2 * self.process_position),
            disable=self.is_disabled,
            dynamic_ncols=True,
        )

        return bar

    def init_train_tqdm(self) -> tqdm:
        """ Override this to customize the tqdm bar for training. """
        bar = tqdm(
            desc="Training",
            initial=self.train_batch_idx,
            position=(2 * self.process_position),
            disable=self.is_disabled,
            dynamic_ncols=True,
        )

        return bar

    def init_validation_tqdm(self) -> tqdm:
        """ Override this to customize the tqdm bar for validation. """
        bar = tqdm(
            desc="Validating",
            position=(2 * self.process_position + 1),
            disable=True,
            dynamic_ncols=False,
        )

        return bar

    def init_test_tqdm(self) -> tqdm:
        """ Override this to customize the tqdm bar for testing. """
        bar = tqdm(
            desc="Testing",
            position=(2 * self.process_position),
            disable=self.is_disabled,
            dynamic_ncols=True,
        )

        return bar

<IPython.core.display.Javascript object>

In [None]:
# export
class PrintLogsCallback(pl.Callback):
    "Logs Training logs to console after every epoch"
    TrainResult = namedtuple("TrainResult", ["loss", "acc", "valid_loss", "valid_acc"])
    TestResult = namedtuple("TestResult", ["test_loss", "test_acc"])

    logger = logging.getLogger("src.logs")

    def on_epoch_end(self, trainer, pl_module):
        metrics = trainer.callback_metrics
        train_loss = metrics["train/loss_epoch"]
        train_acc = metrics["train/acc_epoch"]
        valid_loss = metrics["valid/loss"]
        valid_acc = metrics["valid/acc"]
        trn_res = self.TrainResult(
            round(train_loss.data.cpu().numpy().item(), 3),
            round(train_acc.data.cpu().numpy().item(), 3),
            round(valid_loss.data.cpu().numpy().item(), 3),
            round(valid_acc.data.cpu().numpy().item(), 3),
        )

        curr_epoch = int(trainer.current_epoch)
        self.logger.info(f"[{curr_epoch}]: (100.00% done), {trn_res}")

    def on_test_epoch_end(self, trainer, pl_module, *args, **kwargs):
        metrics = trainer.callback_metrics
        test_loss = metrics["test/loss"]
        test_acc = metrics["test/acc"]
        self.logger.info(
            f"{self.TestResult(round(test_loss.data.cpu().numpy().item(), 2), round(test_acc.data.cpu().numpy().item(), 2))}"
        )

<IPython.core.display.Javascript object>

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
import timm
from torch import nn
from src.networks import *
from omegaconf import OmegaConf

<IPython.core.display.Javascript object>

In [None]:
train_augs = A.Compose(
    [
        A.RandomResizedCrop(224, 224, p=1.0),
        A.RandomBrightness(limit=0.1),
        A.HueSaturationValue(20, 20, 20),
        A.HorizontalFlip(),
        A.Normalize(p=1.0),
        ToTensorV2(p=1.0),
    ]
)

valid_augs = A.Compose(
    [A.Resize(224, 224, p=1.0), A.Normalize(p=1.0), ToTensorV2(p=1.0)]
)

csv = "../../leaf-disease-classification-kaggle/data/stratified-data-5folds.csv"
ims = "../../Datasets/cassava/train_images/"
dm = CassavaLightningDataModule(
    csv,
    ims,
    curr_fold=0,
    train_augs=train_augs,
    valid_augs=valid_augs,
    bs=8,
    num_workers=0,
)


model_hparams = dict(
    mixmethod=None,
    loss=dict(_target_="src.losses.LabelSmoothingCrossEntropy", eps=0.1),
    learning_rate=1e-03,
    lr_mult=100,
    optimizer=dict(_target_="torch.optim.Adam"),
    scheduler=dict(
        function=dict(_target_="src.opts.FlatCos", num_epochs=10, pct_start=0.7),
        metric_to_track=None,
        scheduler_interval="step",
    ),
)

cfg = OmegaConf.create(model_hparams)


encoder = timm.create_model("resnet18", pretrained=False)
model = TransferLearningModel(encoder, cut=-2, c=5, act=nn.ReLU(inplace=True))
model = LightningCassava(model=model, conf=cfg)

[[36m2021-01-22 17:25:07,244[0m][[34mLitModel[0m][[32mINFO[0m] - Loss Function : LabelSmoothingCrossEntropy()[0m


<IPython.core.display.Javascript object>

In [None]:
trainer = pl.Trainer(
    callbacks=[DisableValidationBar(), PrintLogsCallback()],
    num_sanity_val_steps=0,
    max_epochs=2,
    limit_train_batches=1,
    limit_val_batches=1,
    limit_test_batches=1,
    weights_summary=None,
    progress_bar_refresh_rate=0,
)

GPU available: False, used: False
[[36m2021-01-22 17:25:21,511[0m][[34mlightning[0m][[32mINFO[0m] - GPU available: False, used: False[0m
TPU available: False, using: 0 TPU cores
[[36m2021-01-22 17:25:21,512[0m][[34mlightning[0m][[32mINFO[0m] - TPU available: False, using: 0 TPU cores[0m


<IPython.core.display.Javascript object>

In [None]:
trainer.fit(model, datamodule=dm)

[[36m2021-01-22 17:25:24,338[0m][[34mdatamodule[0m][[32mINFO[0m] - Data(fold=0, batch_size=8, im_path='../../Datasets/cassava/train_images')[0m
[[36m2021-01-22 17:25:24,365[0m][[34mLitModel[0m][[32mINFO[0m] - Optimization Parameters: 
Optimizer(optimizer='Adam', scheduler='FlatCos', lrs=Lrs(encoder_lr=1e-05, fc_lr=0.001), wd=None)[0m


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

[[36m2021-01-22 17:25:27,020[0m][[34msrc.logs[0m][[32mINFO[0m] - [0]: (100.00% done), TrainResult(loss=1.83, acc=0.0, valid_loss=1.607, valid_acc=0.125)[0m
[[36m2021-01-22 17:25:29,448[0m][[34msrc.logs[0m][[32mINFO[0m] - [1]: (100.00% done), TrainResult(loss=1.584, acc=0.25, valid_loss=1.581, valid_acc=0.5)[0m





1

<IPython.core.display.Javascript object>

In [None]:
# export
class DisableProgressBar(pl.callbacks.ProgressBar):
    "Custom Progressbar callback for Lightning Training which disables the validation bar"

    def init_sanity_tqdm(self) -> tqdm:
        """ Override this to customize the tqdm bar for the validation sanity run. """
        bar = tqdm(
            desc="Validation sanity check",
            position=(2 * self.process_position),
            disable=True,
            dynamic_ncols=True,
        )

        return bar

    def init_train_tqdm(self) -> tqdm:
        """ Override this to customize the tqdm bar for training. """
        bar = tqdm(
            desc="Training",
            initial=self.train_batch_idx,
            position=(2 * self.process_position),
            disable=True,
            dynamic_ncols=True,
        )

        return bar

    def init_validation_tqdm(self) -> tqdm:
        """ Override this to customize the tqdm bar for validation. """
        bar = tqdm(
            desc="Validating",
            position=(2 * self.process_position + 1),
            disable=True,
            dynamic_ncols=False,
        )

        return bar

    def init_test_tqdm(self) -> tqdm:
        """ Override this to customize the tqdm bar for testing. """
        bar = tqdm(
            desc="Testing",
            position=(2 * self.process_position),
            disable=True,
            dynamic_ncols=True,
        )

        return bar

<IPython.core.display.Javascript object>

In [None]:
# export
class ConsoleLogger(pl.Callback):
    "Fancy logger for console-logging"

    tst_res = namedtuple("TestResult", ["test_loss", "test_acc"])
    trn_res = namedtuple("TrainResult", ["loss", "acc", "val_loss", "val_acc"])
    curr_step = 0
    has_init = False
    logger = logging.getLogger("src.log")

    def __init__(self, print_every: int = 50):
        self.print_every = print_every

    def on_fit_start(self, trainer, pl_module, *args, **kwargs):
        if not self.has_init:
            self.log_line()
            self.log_msg(f"Model: \n{ModelSummary(trainer.model)}")
            self.log_line()
            self.log_msg(f"Parameters:")
            self.log_msg(f" - max_epochs: {trainer.max_epochs}")
            self.log_msg(f" - accumulate_batches: {trainer.accumulate_grad_batches}")
            self.log_msg(f" - gradient_clipping: {trainer.gradient_clip_val}")
            self.log_line()
            self.log_msg(
                f"DATASET: 'data: {len(pl_module.train_dataloader())} train + {len(pl_module.val_dataloader())} valid + {len(pl_module.test_dataloader())} test batches'"
            )

            has_init = True

    def on_epoch_start(self, *args, **kwargs):
        self.log_line()

    def on_train_start(self, trainer, pl_module, *args, **kwargs):
        self.log_line()
        self.log_msg("STAGE: TRAIN / VALIDATION")
        self.log_line()
        self.log_msg(
            f"Model training base path: {os.path.relpath(trainer.checkpoint_callback.dirpath)}"
        )
        self.log_line()
        self.log_msg(f"Device: {pl_module.device}")

    def on_train_epoch_start(self, *args, **kwargs):
        # resets the current step
        self.curr_step = 0

    def on_train_batch_end(self, trainer, pl_module, *args, **kwargs):
        if self.curr_step % self.print_every == 0:
            ep = trainer.current_epoch
            tots = len(pl_module.train_dataloader())
            _stp_metrics = trainer.callback_metrics
            _stp_loss = _stp_metrics["train/loss_step"]
            _stp_acc = _stp_metrics["train/acc_step"]

            self.log_msg(
                f"epoch - {ep} - iteration {self.curr_step + 1}/{tots+1} - loss {_stp_loss:.3f} - acc {_stp_loss:.3f}"
            )

        self.curr_step += 1

    def on_epoch_end(self, trainer, pl_module, *args, **kwargs):
        metrics = trainer.callback_metrics

        train_loss = metrics["train/loss_epoch"]
        train_acc = metrics["train/acc_epoch"]

        valid_loss = metrics["valid/loss"]
        valid_acc = metrics["valid/acc"]

        _res = self.trn_res(
            round(train_loss.data.cpu().numpy().item(), 3),
            round(train_acc.data.cpu().numpy().item(), 3),
            round(valid_loss.data.cpu().numpy().item(), 3),
            round(valid_acc.data.cpu().numpy().item(), 3),
        )

        curr_epoch = int(trainer.current_epoch)
        self.log_line()
        self.log_msg(f"EPOCH {curr_epoch}: {_res}")

    def on_fit_end(self, *args, **kwargs):
        self.log_line()

    def on_test_start(self, trainer, pl_module, *args, **kwargs):
        self.log_line()
        self.log_msg("STAGE: TEST")
        self.log_line()
        self.log_msg(
            f"Model testing base path: {os.path.relpath(trainer.checkpoint_callback.dirpath)}"
        )
        self.log_line()
        self.log_msg(f"Device: {pl_module.device}")
        self.log_line()

    def on_test_epoch_end(self, trainer, pl_module, *args, **kwargs):
        metrics = trainer.callback_metrics
        test_loss = metrics["test/loss"]
        test_acc = metrics["test/acc"]
        self.log_msg(
            f"{self.tst_res(round(test_loss.data.cpu().numpy().item(), 2), round(test_acc.data.cpu().numpy().item(), 2))}"
        )

    def log_line(self):
        self.logger.info("-" * 70)

    def log_msg(self, msg: str):
        self.logger.info(msg)

<IPython.core.display.Javascript object>

In [None]:
dm.prepare_data()
dm.setup()

encoder = timm.create_model("resnet18", pretrained=False)
model = TransferLearningModel(encoder, cut=-2, c=5, act=nn.ReLU(inplace=True))
model = LightningCassava(model=model, conf=cfg)
info_logger = ConsoleLogger(print_every=2)

trainer = pl.Trainer(
    callbacks=[DisableProgressBar(), info_logger],
    num_sanity_val_steps=0,
    max_epochs=2,
    limit_train_batches=4,
    limit_val_batches=1,
    limit_test_batches=1,
    weights_summary=None,
    progress_bar_refresh_rate=0,
)

[[36m2021-01-22 17:26:09,063[0m][[34mdatamodule[0m][[32mINFO[0m] - Data(fold=0, batch_size=8, im_path='../../Datasets/cassava/train_images')[0m
[[36m2021-01-22 17:26:09,782[0m][[34mLitModel[0m][[32mINFO[0m] - Loss Function : LabelSmoothingCrossEntropy()[0m
GPU available: False, used: False
[[36m2021-01-22 17:26:09,784[0m][[34mlightning[0m][[32mINFO[0m] - GPU available: False, used: False[0m
TPU available: False, using: 0 TPU cores
[[36m2021-01-22 17:26:09,785[0m][[34mlightning[0m][[32mINFO[0m] - TPU available: False, using: 0 TPU cores[0m


<IPython.core.display.Javascript object>

In [None]:
trainer.fit(model, datamodule=dm)

[[36m2021-01-22 17:26:13,216[0m][[34mLitModel[0m][[32mINFO[0m] - Optimization Parameters: 
Optimizer(optimizer='Adam', scheduler='FlatCos', lrs=Lrs(encoder_lr=1e-05, fc_lr=0.001), wd=None)[0m
[[36m2021-01-22 17:26:13,217[0m][[34msrc.log[0m][[32mINFO[0m] - ----------------------------------------------------------------------[0m
[[36m2021-01-22 17:26:13,220[0m][[34msrc.log[0m][[32mINFO[0m] - Model: 
  | Name      | Type                       | Params
---------------------------------------------------------
0 | model     | TransferLearningModel      | 11.7 M
1 | accuracy  | Accuracy                   | 0     
2 | loss_func | LabelSmoothingCrossEntropy | 0     [0m
[[36m2021-01-22 17:26:13,221[0m][[34msrc.log[0m][[32mINFO[0m] - ----------------------------------------------------------------------[0m
[[36m2021-01-22 17:26:13,222[0m][[34msrc.log[0m][[32mINFO[0m] - Parameters:[0m
[[36m2021-01-22 17:26:13,223[0m][[34msrc.log[0m][[32mINFO[0m] -  - max_

1

<IPython.core.display.Javascript object>

In [None]:
_ = trainer.test(datamodule=dm, verbose=False)

[[36m2021-01-22 17:26:40,863[0m][[34mLitModel[0m][[32mINFO[0m] - Optimization Parameters: 
Optimizer(optimizer='Adam', scheduler='FlatCos', lrs=Lrs(encoder_lr=1e-05, fc_lr=0.001), wd=None)[0m
[[36m2021-01-22 17:26:40,872[0m][[34msrc.log[0m][[32mINFO[0m] - ----------------------------------------------------------------------[0m
[[36m2021-01-22 17:26:40,875[0m][[34msrc.log[0m][[32mINFO[0m] - Model: 
  | Name      | Type                       | Params
---------------------------------------------------------
0 | model     | TransferLearningModel      | 11.7 M
1 | accuracy  | Accuracy                   | 0     
2 | loss_func | LabelSmoothingCrossEntropy | 0     [0m
[[36m2021-01-22 17:26:40,876[0m][[34msrc.log[0m][[32mINFO[0m] - ----------------------------------------------------------------------[0m
[[36m2021-01-22 17:26:40,876[0m][[34msrc.log[0m][[32mINFO[0m] - Parameters:[0m
[[36m2021-01-22 17:26:40,877[0m][[34msrc.log[0m][[32mINFO[0m] -  - max_

<IPython.core.display.Javascript object>

In [None]:
# hide
from nbdev.export import *

notebook2script()

Converted 00_core.ipynb.
Converted 01_mixmethods.ipynb.
Converted 02_losses.ipynb.
Converted 03_layers.ipynb.
Converted 03a_networks.ipynb.
Converted 04_optimizers_schedules.ipynb.
Converted 05_lightning.core.ipynb.
Converted 05a_lightning.callbacks.ipynb.
Converted 06_fastai.core.ipynb.
Converted index.ipynb.


<IPython.core.display.Javascript object>