In [None]:
# default_exp lightning.callbacks

In [None]:
# export
import os
import sys
import time
import datetime
import logging
import time
from collections import namedtuple
from tqdm.auto import tqdm

import wandb

import torch
import pytorch_lightning as pl
from pytorch_lightning.core.memory import ModelSummary, get_human_readable_count

from src.all import *

In [None]:
import colorlog

handler = colorlog.StreamHandler()


fmt = "%(purple)s%(name)s%(reset)s:%(log_color)s%(levelname)s%(reset)s: %(message)s"
colors = dict(
    DEBUG="purple", INFO="green", WARNING="yellow", ERROR="red", CRITICAL="red"
)
formatter = colorlog.ColoredFormatter(fmt=fmt, log_colors=colors)
handler.setFormatter(formatter)

logging.basicConfig(format=fmt, level=logging.INFO, handlers=[handler])

In [None]:
# export
class WandbImageClassificationCallback(pl.Callback):
    """ Custom callback to add some extra functionalites to the wandb logger """

    def __init__(
        self,
        num_batches: int = 16,
        log_train_batch: bool = False,
        log_preds: bool = False,
        log_conf_mat: bool = True,
    ):

        # class names for the confusion matrix
        self.class_names = list(conf_mat_idx2lbl.values())

        # counter to log training batch images
        self.num_bs = num_batches
        self.curr_epoch = 0

        self.log_train_batch = log_train_batch
        self.log_preds = log_preds
        self.log_conf_mat = log_conf_mat

        self.val_imgs, self.val_labels = None, None

    def on_train_start(self, trainer, pl_module, *args, **kwargs):
        try:
            # log model to the wandb experiment
            wandb.watch(models=pl_module.model, criterion=pl_module.loss_func)
        except:
            pass

    def on_train_epoch_end(self, trainer, pl_module, *args, **kwargs):
        if self.log_train_batch:
            if pl_module.one_batch is None:
                pass

            else:
                one_batch = pl_module.one_batch[: self.num_bs]
                train_ims = one_batch.data.to("cpu")
                trainer.logger.experiment.log(
                    {"train_batch": [wandb.Image(x) for x in train_ims]}, commit=False
                )

    def on_validation_epoch_end(self, trainer, pl_module, *args, **kwargs):
        if self.log_preds:
            if self.val_imgs is None and self.val_labels is None:
                self.val_imgs, self.val_labels = next(iter(pl_module.val_dataloader()))
                self.val_imgs, self.val_labels = (
                    self.val_imgs[: self.num_bs],
                    self.val_labels[: self.num_bs],
                )
                self.val_imgs = self.val_imgs.to(device=pl_module.device)

            logits = pl_module(self.val_imgs)
            preds = torch.argmax(logits, 1)
            preds = preds.data.cpu()

            ims = [
                wandb.Image(x, caption=f"Pred:{pred}, Label:{y}")
                for x, pred, y in zip(self.val_imgs, preds, self.val_labels)
            ]
            log_dict = {"predictions": ims}
            wandb.log(ims, commit=False)

    def on_epoch_start(self, trainer, pl_module, *args, **kwargs):
        pl_module.val_labels_list = []
        pl_module.val_preds_list = []

    def on_epoch_end(self, trainer, pl_module, *args, **kwargs):
        if self.log_conf_mat:
            val_preds = torch.tensor(pl_module.val_preds_list).data.cpu().numpy()
            val_labels = torch.tensor(pl_module.val_labels_list).data.cpu().numpy()
            log_dict = {
                "conf_mat": wandb.plot.confusion_matrix(
                    val_preds, val_labels, self.class_names
                )
            }
            wandb.log(log_dict, commit=False)

In [None]:
# export
class DisableValidationBar(pl.callbacks.ProgressBar):
    "Custom Progressbar callback for Lightning Training which disables the validation bar"

    def init_sanity_tqdm(self) -> tqdm:
        """ Override this to customize the tqdm bar for the validation sanity run. """
        bar = tqdm(
            desc="Validation sanity check",
            position=(2 * self.process_position),
            disable=self.is_disabled,
            dynamic_ncols=True,
        )

        return bar

    def init_train_tqdm(self) -> tqdm:
        """ Override this to customize the tqdm bar for training. """
        bar = tqdm(
            desc="Training",
            initial=self.train_batch_idx,
            position=(2 * self.process_position),
            disable=self.is_disabled,
            dynamic_ncols=True,
        )

        return bar

    def init_validation_tqdm(self) -> tqdm:
        """ Override this to customize the tqdm bar for validation. """
        bar = tqdm(
            desc="Validating",
            position=(2 * self.process_position + 1),
            disable=True,
            dynamic_ncols=False,
        )

        return bar

    def init_test_tqdm(self) -> tqdm:
        """ Override this to customize the tqdm bar for testing. """
        bar = tqdm(
            desc="Testing",
            position=(2 * self.process_position),
            disable=self.is_disabled,
            dynamic_ncols=True,
        )

        return bar

In [None]:
# export
class PrintLogsCallback(pl.Callback):
    "Logs Training logs to console after every epoch"
    TrainResult = namedtuple("TrainOutput", ["loss", "acc", "val_loss", "val_acc"])
    TestResult = namedtuple("TestOutput", ["test_loss", "test_acc"])

    logger = logging.getLogger("train")

    def on_epoch_end(self, trainer, pl_module):
        metrics = trainer.callback_metrics
        train_loss = metrics["train/loss_epoch"]
        train_acc = metrics["train/acc_epoch"]
        valid_loss = metrics["valid/loss"]
        valid_acc = metrics["valid/acc"]
        trn_res = self.TrainResult(
            round(train_loss.data.cpu().numpy().item(), 3),
            round(train_acc.data.cpu().numpy().item(), 3),
            round(valid_loss.data.cpu().numpy().item(), 3),
            round(valid_acc.data.cpu().numpy().item(), 3),
        )

        curr_epoch = int(trainer.current_epoch)
        self.logger.info(f"EPOCH {curr_epoch}: {trn_res}")

    def on_test_epoch_end(self, trainer, pl_module, *args, **kwargs):
        metrics = trainer.callback_metrics
        test_loss = metrics["test/loss"]
        test_acc = metrics["test/acc"]
        self.logger.info(
            f"{self.TestResult(round(test_loss.data.cpu().numpy().item(), 2), round(test_acc.data.cpu().numpy().item(), 2))}"
        )

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
import timm
from torch import nn
from src.networks import *

from hydra.experimental import compose, initialize
from hydra.utils import instantiate
from omegaconf import OmegaConf

In [None]:
overrides = [
    "image_dims=120",
    "datamodule.bs=5",
    "datamodule.num_workers=0",
    "general=default",
    "trainer=fast-dev-cpu",
    "mixmethod=mixup",
    "network=transferlearning",
    "augmentations=custom-augs",
    "mixmethod.steps=2",
    "network.activation=mish",
]

with initialize(config_path=os.path.relpath("../conf/")):
    cfg = compose(config_name="effnet-base", overrides=overrides)

In [None]:
trn_augs = None
val_augs = None

dm = instantiate(
    cfg.datamodule, train_augs=trn_augs, valid_augs=val_augs, default_config=cfg
)

activation_func = activation_map[cfg.network.activation]
logging.info(f"Using {activation_func()} activation function.")
encoder = timm.create_model(cfg.encoder, pretrained=False, act_layer=activation_func)
model = TransferLearningModel(encoder, cut=-2, c=5, act=activation_func(inplace=True))
model = LightningCassava(model=model, conf=cfg)

[35mroot[0m:[32mINFO[0m: Using Mish() activation function.[0m


In [None]:
trainer = pl.Trainer(
    callbacks=[DisableValidationBar(), PrintLogsCallback()],
    num_sanity_val_steps=0,
    max_epochs=3,
    limit_train_batches=1,
    limit_val_batches=1,
    limit_test_batches=1,
    weights_summary=None,
    progress_bar_refresh_rate=0,
)

GPU available: False, used: False
[35mlightning[0m:[32mINFO[0m: GPU available: False, used: False[0m
TPU available: False, using: 0 TPU cores
[35mlightning[0m:[32mINFO[0m: TPU available: False, using: 0 TPU cores[0m


In [None]:
trainer.fit(model, datamodule=dm)

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

[35mtrain[0m:[32mINFO[0m: EPOCH 0: TrainOutput(loss=0.793, acc=0.0, val_loss=0.658, val_acc=0.0)[0m
[35mtrain[0m:[32mINFO[0m: EPOCH 1: TrainOutput(loss=0.84, acc=0.8, val_loss=0.663, val_acc=0.2)[0m
[35msrc.mixmethods[0m:[32mINFO[0m: Threshold steps reached, stopping Mixup(probability=1.0, alpha=0.4, iters=2)[0m
[35mtrain[0m:[32mINFO[0m: EPOCH 2: TrainOutput(loss=0.61, acc=0.2, val_loss=0.672, val_acc=0.0)[0m





1

In [None]:
test_outputs = trainer.test(datamodule=dm, verbose=False)

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

[35mtrain[0m:[32mINFO[0m: TestOutput(test_loss=0.67, test_acc=0.0)[0m





In [None]:
# export
class DisableProgressBar(pl.callbacks.ProgressBar):
    "Custom Progressbar callback for Lightning Training which disables the validation bar"

    def init_sanity_tqdm(self) -> tqdm:
        """ Override this to customize the tqdm bar for the validation sanity run. """
        bar = tqdm(
            desc="Validation sanity check",
            position=(2 * self.process_position),
            disable=True,
            dynamic_ncols=True,
        )

        return bar

    def init_train_tqdm(self) -> tqdm:
        """ Override this to customize the tqdm bar for training. """
        bar = tqdm(
            desc="Training",
            initial=self.train_batch_idx,
            position=(2 * self.process_position),
            disable=True,
            dynamic_ncols=True,
        )

        return bar

    def init_validation_tqdm(self) -> tqdm:
        """ Override this to customize the tqdm bar for validation. """
        bar = tqdm(
            desc="Validating",
            position=(2 * self.process_position + 1),
            disable=True,
            dynamic_ncols=False,
        )

        return bar

    def init_test_tqdm(self) -> tqdm:
        """ Override this to customize the tqdm bar for testing. """
        bar = tqdm(
            desc="Testing",
            position=(2 * self.process_position),
            disable=True,
            dynamic_ncols=True,
        )

        return bar

In [None]:
# export
class ConsoleLogger(pl.Callback):
    "Fancy logger for console-logging"
    trn_res = namedtuple("TrainOutput", ["loss", "acc", "val_loss", "val_acc"])
    tst_res = namedtuple("TestOutput", ["test_loss", "test_acc"])
    curr_step = 0
    has_init = False
    logger = logging.getLogger("train")

    def __init__(self, print_every: int = 50):
        self.print_every = print_every

    def on_train_start(self, trainer, pl_module, *args, **kwargs):
        # start timer
        self.timer = datetime.datetime.now()
        self.current_iteration = 0

        if not self.has_init:
            cfg = pl_module.hparams

            model_class = str(cfg.network.transfer_learning_model._target_).split(".")[-1]
            summary = ModelSummary(pl_module)
            param_count = get_human_readable_count(summary.param_nums[0])
            # log information about the model
            self.log_msg(f"[Model] model_class: {model_class}")
            self.log_msg(f"[Model] base_model: {cfg.encoder}")
            self.log_msg(f"[Model] total_parameters: {param_count}")
            self.log_line()

            path = os.path.relpath(cfg.datamodule.im_dir)
            oof_fold = str(cfg.datamodule.curr_fold)
            trn_batches = len(pl_module.train_dataloader())
            val_batches = len(pl_module.val_dataloader())
            tst_batches = len(pl_module.test_dataloader())
            # log information about the dataset
            self.log_msg(f"[Dataset] path: {path}")
            self.log_msg(f"[Dataset] validation_fold: {oof_fold}")
            self.log_msg(
                f"[Dataset] batches: {trn_batches} train + {val_batches} valid + {tst_batches} test"
            )
            self.log_line()

            lr_scheduler = str(cfg.scheduler.function._target_).split(".")[-1]
            optimizer = str(cfg.optimizer._target_).split(".")[-1]
            # log information about the training parameters
            self.log_msg(
                f"[Parameters] input_dimensions: {(cfg.image_dims, cfg.image_dims, 3)}"
            )
            self.log_msg(f"[Parameters] max_epochs: {trainer.max_epochs}")
            self.log_msg(f"[Parameters] mini_batch_size: {str(cfg.datamodule.bs)}")
            self.log_msg(
                f"[Parameters] accumulate_batches: {trainer.accumulate_grad_batches}"
            )
            self.log_msg(f"[Parameters] optimizer: {optimizer}")
            self.log_msg(f"[Parameters] learning_rates: {str(pl_module.lr_list)}")
            self.log_msg(
                f"[Parameters] weight_decay: {str(cfg.optimizer.weight_decay)}"
            )
            self.log_msg(f"[Parameters] scheduler: {lr_scheduler}")
            self.log_msg(f"[Parameters] gradient_clipping: {trainer.gradient_clip_val}")
            self.log_msg(f"[Parameters] loss_function: {pl_module.loss_func}")
            self.log_msg(f"[Parameters] mix_method: {pl_module.mix_fn}")
            self.log_line()

            self.has_init = True
            
        # start the training job
        self.log_msg(f"Start TRAIN / VALIDATION from epoch {trainer.current_epoch}")
        self.log_msg(f"Model training base path: {os.path.relpath(trainer.checkpoint_callback.dirpath)}")
        self.log_msg(f"Device: {pl_module.device}")
        self.log_line()

    def on_train_epoch_start(self, *args, **kwargs):
        # resets the current step
        self.curr_step = 0
        self.seen_batches = 0
        self._stp_loss = 0
        self._stp_acc = 0
        self.running_time = datetime.datetime.now()

    def on_train_batch_start(self, *args, **kwargs):
        self.start_time = datetime.datetime.now()

    def on_train_batch_end(self, trainer, pl_module, *args, **kwargs):
        msg = "eta: {} iteration: {} loss: {:.3f} accuracy: {:.3f} lrs: {}"

        mini_batch_size = pl_module.hparams.datamodule.bs

        _stp_metrics = trainer.callback_metrics

        self.seen_batches += 1

        self._stp_loss += _stp_metrics["train/loss_step"]
        self._stp_acc += _stp_metrics["train/acc_step"]

        if self.curr_step % self.print_every == 0:
            # compute average loss/accuracy
            avg_loss = self._stp_loss / self.seen_batches
            avg_acc = self._stp_acc / self.seen_batches
            # compute time elasped since last log
            batch_time = str(datetime.datetime.now() - self.running_time).split(".")[0]
            # the learning-rates of the model parameters
            lrs = tuple(trainer.lr_schedulers[0]["scheduler"].get_lr())
            lrs = [float("{0:.2e}".format(v)) for v in lrs]

            self.log_msg(msg.format(batch_time, self.current_iteration, avg_loss, avg_acc, lrs))
            
        
        # time taken for completion of one-batch    
        time_delta = datetime.datetime.now() - self.start_time
        self.running_time += time_delta

        # increment iterations
        self.curr_step += 1
        self.current_iteration += 1

    def on_epoch_end(self, trainer, pl_module, *args, **kwargs):
        metrics = trainer.callback_metrics

        train_loss = metrics["train/loss_epoch"]
        train_acc = metrics["train/acc_epoch"]

        valid_loss = metrics["valid/loss"]
        valid_acc = metrics["valid/acc"]

        _res = self.trn_res(
            round(train_loss.data.cpu().numpy().item(), 3),
            round(train_acc.data.cpu().numpy().item(), 3),
            round(valid_loss.data.cpu().numpy().item(), 3),
            round(valid_acc.data.cpu().numpy().item(), 3),
        )

        curr_epoch = int(trainer.current_epoch)
        self.log_line()
        self.log_msg(f"EPOCH {curr_epoch}: {_res}")
        self.log_line()

    def on_train_end(self, *args, **kwargs):
        time_elasped = datetime.datetime.now() - self.timer
        self.log_msg(f"Total training time: {str(time_elasped).split('.')[0]}")
        self.log_line()

    def on_test_start(self, trainer, pl_module, *args, **kwargs):
        self.has_init = False
        self.test_time_start = datetime.datetime.now()

        path = os.path.relpath(trainer.checkpoint_callback.dirpath)
        self.log_msg("Start TEST")
        self.log_msg(f"Model testing base path: {path}")
        self.log_msg(f"Device: {pl_module.device}")

    def on_test_epoch_end(self, trainer, pl_module, *args, **kwargs):
        metrics = trainer.callback_metrics
        test_loss = metrics["test/loss"]
        test_acc = metrics["test/acc"]

        loss = round(test_loss.data.cpu().numpy().item(), 2)
        accuracy = round(test_acc.data.cpu().numpy().item(), 2)
        self.log_msg(f"{self.tst_res(loss, accuracy)}")

    def on_test_end(self, *args, **kwargs):
        test_time = datetime.datetime.now() - self.test_time_start
        self.log_msg(f"Total testing time: {str(test_time).split('.')[0]}")

    def log_msg(self, msg: str):
        self.logger.info(msg)

    def log_line(self):
        self.logger.info("-" * 94)

In [None]:
cfg.image_dims = 255

trn_augs = None
val_augs = None

dm = instantiate(
    cfg.datamodule, train_augs=trn_augs, valid_augs=val_augs, default_config=cfg
)

activation_func = activation_map[cfg.network.activation]
logging.info(f"Using {activation_func()} activation function.")
encoder = timm.create_model(cfg.encoder, pretrained=False, act_layer=activation_func)
model = TransferLearningModel(encoder, cut=-2, c=5, act=activation_func(inplace=True))
model = LightningCassava(model=model, conf=cfg)

dm.prepare_data()
dm.setup()

trainer = instantiate(
    cfg.trainer,
    callbacks=[DisableProgressBar(), ConsoleLogger(print_every=5)],
    max_epochs=2,
    limit_train_batches=20,
    limit_val_batches=6,
    limit_test_batches=6,
    weights_summary=None,
)

[35mroot[0m:[32mINFO[0m: Using Mish() activation function.[0m
GPU available: False, used: False
[35mlightning[0m:[32mINFO[0m: GPU available: False, used: False[0m
TPU available: False, using: 0 TPU cores
[35mlightning[0m:[32mINFO[0m: TPU available: False, using: 0 TPU cores[0m


In [None]:
trainer.fit(model, datamodule=dm)

[35mtrain[0m:[32mINFO[0m: [Model] model_class: TransferLearningModel[0m
[35mtrain[0m:[32mINFO[0m: [Model] base_model: tf_efficientnet_b0_ns[0m
[35mtrain[0m:[32mINFO[0m: [Model] total_parameters: 5.3 M[0m
[35mtrain[0m:[32mINFO[0m: ----------------------------------------------------------------------------------------------[0m
[35mtrain[0m:[32mINFO[0m: [Dataset] path: ../../Datasets/cassava/train_images[0m
[35mtrain[0m:[32mINFO[0m: [Dataset] validation_fold: 0[0m
[35mtrain[0m:[32mINFO[0m: [Dataset] batches: 3424 train + 856 valid + 856 test[0m
[35mtrain[0m:[32mINFO[0m: ----------------------------------------------------------------------------------------------[0m
[35mtrain[0m:[32mINFO[0m: [Parameters] input_dimensions: (255, 255, 3)[0m
[35mtrain[0m:[32mINFO[0m: [Parameters] max_epochs: 2[0m
[35mtrain[0m:[32mINFO[0m: [Parameters] mini_batch_size: 5[0m
[35mtrain[0m:[32mINFO[0m: [Parameters] accumulate_batches: 1[0m
[35mtrain[

1

In [None]:
_ = trainer.test(datamodule=dm, verbose=False)

[35mtrain[0m:[32mINFO[0m: Start TEST[0m
[35mtrain[0m:[32mINFO[0m: Model testing base path: lightning_logs/version_8/checkpoints[0m
[35mtrain[0m:[32mINFO[0m: Device: cpu[0m
[35mtrain[0m:[32mINFO[0m: TestOutput(test_loss=0.7, test_acc=0.17)[0m
[35mtrain[0m:[32mINFO[0m: Total testing time: 0:00:03[0m


In [None]:
# hide
from nbdev.export import *

notebook2script()

Converted 00_core.ipynb.
Converted 01_mixmethods.ipynb.
Converted 02_losses.ipynb.
Converted 03_layers.ipynb.
Converted 03a_networks.ipynb.
Converted 04_optimizers_schedules.ipynb.
Converted 05_lightning.data.ipynb.
Converted 05a_lightning.core.ipynb.
Converted 05b_lightning.callbacks.ipynb.
Converted 06_fastai.core.ipynb.
Converted index.ipynb.
