In [1]:
#default_exp lightning.callbacks

In [2]:
#export
import sys
import time
import datetime
import tqdm
import wandb
import logging

import torch
import pytorch_lightning as pl
from pytorch_lightning import _logger as log

from src.lightning.core import *
from src.core import *

In [3]:
#export
class WandbImageClassificationCallback(pl.Callback):
    """ Custom callback to add some extra functionalites to the wandb logger """
    def __init__(self,
                 num_batches:int = 16, 
                 log_train_batch: bool = False,
                 log_preds: bool = False,
                 log_conf_mat: bool = True,):
        
        # class names for the confusion matrix
        self.class_names = list(conf_mat_idx2lbl.values())
        
        # counter to log training batch images
        self.num_bs = num_batches
        self.curr_epoch = 0
        
        self.log_train_batch = log_train_batch
        self.log_preds = log_preds
        self.log_conf_mat = log_conf_mat
        
        self.val_imgs, self.val_labels = None, None
        
    def on_train_start(self, trainer, pl_module, *args, **kwargs):
        try:
            # log model to the wandb experiment
            wandb.watch(models=pl_module.model, criterion=pl_module.loss_func)
        except:
            log.info("Skipping wandb.watch --->")
        
    def on_train_epoch_end(self, trainer, pl_module, *args, **kwargs):
        if self.log_train_batch:
            if pl_module.one_batch is None:
                log.info(f"{self.config_defaults['mixmethod']} samples not available . Skipping --->")
                pass

            else:
                one_batch = pl_module.one_batch[:self.num_bs]
                train_ims = one_batch.data.to('cpu')
                trainer.logger.experiment.log({"train_batch":[wandb.Image(x) for x in train_ims]}, commit=False)
        
    def on_validation_epoch_end(self, trainer, pl_module, *args, **kwargs):
        if self.log_preds:
            if self.val_imgs is None and self.val_labels is None:
                self.val_imgs, self.val_labels = next(iter(pl_module.val_dataloader()))
                self.val_imgs, self.val_labels = self.val_imgs[:self.num_bs], self.val_labels[:self.num_bs]
                self.val_imgs = self.val_imgs.to(device=pl_module.device)

            logits = pl_module(self.val_imgs)
            preds  = torch.argmax(logits, 1)
            preds  = preds.data.cpu()
            
            ims = [wandb.Image(x, caption=f"Pred:{pred}, Label:{y}") for x,pred,y in zip(self.val_imgs,preds,self.val_labels)]
            log_dict = {"predictions": ims}
            wandb.log(ims,commit=False)
            
    def on_epoch_start(self, trainer, pl_module, *args, **kwargs):
        pl_module.val_labels_list = []
        pl_module.val_preds_list  = []
    
    def on_epoch_end(self, trainer, pl_module, *args, **kwargs):
        if self.log_conf_mat:
            val_preds  = torch.tensor(pl_module.val_preds_list).data.cpu().numpy()
            val_labels = torch.tensor(pl_module.val_labels_list).data.cpu().numpy()
            log_dict = {'conf_mat': wandb.plot.confusion_matrix(val_preds,val_labels,self.class_names)}
            wandb.log(log_dict,commit=False)

In [4]:
#export
class TqdmLoggingHandler(logging.Handler):
    def __init__(self, level=logging.NOTSET):
        super().__init__(level)

    def emit(self, record):
        try:
            msg = self.format(record)
            tqdm.tqdm.write(msg)
            self.flush()
        except (KeyboardInterrupt, SystemExit):
            raise
        except:
            self.handleError(record)  

In [5]:
logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
)
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)
log.addHandler(TqdmLoggingHandler())

In [6]:
#export
class LitProgressBar(pl.callbacks.ProgressBar):
    "Custom Progressbar callback for Lightning Training"
    
    def init_sanity_tqdm(self) -> tqdm:
        """ Override this to customize the tqdm bar for the validation sanity run. """
        bar = tqdm.tqdm(
            desc='Validation sanity check',
            position=(2 * self.process_position),
            disable=self.is_disabled,
            leave=False,
            dynamic_ncols=True,
            file=sys.stdout,
        )
        return bar
    
    def init_train_tqdm(self) -> tqdm:
        """ Override this to customize the tqdm bar for training. """
        bar = tqdm.tqdm(
            desc='Training',
            initial=self.train_batch_idx,
            position=(2 * self.process_position),
            disable=self.is_disabled,
            leave=True,
            dynamic_ncols=True,
            file=sys.stdout,
            smoothing=0,
        )
        return bar
    
    def init_validation_tqdm(self) -> tqdm:
        """ Override this to customize the tqdm bar for validation. """
        bar = tqdm.tqdm(
            desc='Validating',
            position=(2 * self.process_position + 1),
            disable=True,
            leave=False,
            dynamic_ncols=False,
            file=sys.stdout
        )
        return bar
    
    def init_test_tqdm(self) -> tqdm:
        """ Override this to customize the tqdm bar for testing. """
        bar = tqdm.tqdm(
            desc='Testing',
            position=(2 * self.process_position),
            disable=self.is_disabled,
            leave=True,
            dynamic_ncols=True,
            file=sys.stdout
        )
        return bar

In [7]:
#export
class PrintLogsCallback(pl.Callback):
    "Logs Training logs to console after every epoch"
    def __init__(self, print_str: str = None):
        self.print_str = 'Epoch: [{}] eta: {} loss: {:.4f} acc: {:.4f} valid_loss: {:.4f} valid_acc: {:.4f}'
        self.logger = logging.getLogger(__name__)
    
    def on_epoch_start(self, *args, **kwargs):
        self.eta_start = time.time()
    
    def on_epoch_end(self, trainer, pl_module):
        metrics = trainer.callback_metrics
        train_loss = metrics['train/loss']
        train_acc  = metrics['train/acc']
        valid_loss = metrics['valid/loss']
        valid_acc  = metrics['valid/acc']
        
        end_time = time.time()
        self.eta_string = str(datetime.timedelta(seconds=int(end_time-self.eta_start)))
        self.curr_epoch = int(trainer.current_epoch)
        print_str = self.print_str.format(self.curr_epoch, self.eta_string, 
                                        train_loss, train_acc, 
                                        valid_loss, valid_acc)
        self.logger.info(print_str)
    
    def on_test_epoch_end(self, trainer, pl_module, *args, **kwargs):
        metrics = trainer.callback_metrics
        train_loss = metrics['train/loss']
        train_acc  = metrics['train/acc']
        valid_loss = metrics['valid/loss']
        valid_acc  = metrics['valid/acc']
        test_loss  = metrics['test/loss']
        test_acc   = metrics['test/acc']
        
        
        fmt_str1 = "Summary: [Train] loss: {:.4f} acc: {:.4f}"
        fmt_str2 = "Summary: [Valid] loss: {:.4f} acc: {:.4f}"
        fmt_str3 = "Summary: [Test]  loss: {:.4f} acc: {:.4f}"
        
        str1 = fmt_str1.format(train_loss, train_acc)
        str2 = fmt_str2.format(valid_loss, valid_acc)
        str3 = fmt_str3.format(test_loss, test_acc)
        
        self.logger.info(str1)
        self.logger.info(str2)
        self.logger.info(str3)

In [8]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
import timm
from torch import nn
from src.networks import *
from omegaconf import OmegaConf

In [9]:
train_augs = A.Compose([
    A.RandomResizedCrop(224, 224, p=1.0),
    A.RandomBrightness(limit=0.1),
    A.HueSaturationValue(20, 20, 20),
    A.HorizontalFlip(),
    A.Normalize(p=1.0),
    ToTensorV2(p=1.0)])

valid_augs = A.Compose([
    A.Resize(224, 224, p=1.0),
    A.Normalize(p=1.0),
    ToTensorV2(p=1.0)])

csv = "../../leaf-disease-classification-kaggle/data/stratified-data-5folds.csv"
ims = "../../Datasets/cassava/train_images/"
dm = CassavaLightningDataModule(csv, ims, curr_fold=0, train_augs=train_augs, valid_augs=valid_augs, bs=8, num_workers=0)


model_hparams = dict(
    mixmethod        = None,
    loss             = dict(_target_='src.losses.LabelSmoothingCrossEntropy', eps=0.1),
    learning_rate    = 1e-03,
    lr_mult          = 100,
    optimizer        = dict(_target_='src.opts.Ranger', weight_decay=1e-02, betas=[0.95, 0.999], eps=1e-05),
    scheduler        = dict(_target_='src.opts.FlatCos', num_epochs=10, pct_start=0.7),
    metric_to_track  = None,
    scheduler_interval= "epoch",
)

OmegaConf.create(model_hparams)



encoder = timm.create_model('resnet18', pretrained=False)
model   = TransferLearningModel(encoder, cut=-2, c=5, act=nn.ReLU(inplace=True))
model   = LightningCassava(model=model, conf=model_hparams)

01/06/2021 09:07:21 - INFO - src.lightning.core - Loss Function : LabelSmoothingCrossEntropy()


In [10]:
trainer = pl.Trainer(callbacks=[LitProgressBar(), PrintLogsCallback()], 
                     num_sanity_val_steps=0, max_epochs=2, limit_train_batches=1, 
                     limit_val_batches=1, limit_test_batches=1, weights_summary=None)

GPU available: False, used: False
01/06/2021 09:07:22 - INFO - lightning - GPU available: False, used: False
TPU available: False, using: 0 TPU cores
01/06/2021 09:07:22 - INFO - lightning - TPU available: False, using: 0 TPU cores


In [11]:
trainer.fit(model, datamodule=dm)

01/06/2021 09:07:23 - INFO - src.lightning.core - DATA: ../../Datasets/cassava/train_images/
01/06/2021 09:07:23 - INFO - src.lightning.core - FOLD: 0  BATCH_SIZE: 8
01/06/2021 09:07:23 - INFO - src.lightning.core - Optimizer: Ranger  LR's: (1e-05, 0.001)
01/06/2021 09:07:23 - INFO - src.lightning.core - LR Scheculer: FlatCos


Ranger optimizer loaded. 
Gradient Centralization usage = True
GC applied to both conv and fc layers
Epoch: [0] eta: 0:00:02 loss: 1.7546 acc: 0.1250 valid_loss: 1.6144 valid_acc: 0.0000
Epoch 0: 100%|██████████| 2/2 [00:02<00:00,  1.36s/it, loss=1.755, v_num=2]

01/06/2021 09:07:26 - INFO - __main__ - Epoch: [0] eta: 0:00:02 loss: 1.7546 acc: 0.1250 valid_loss: 1.6144 valid_acc: 0.0000


Epoch: [1] eta: 0:00:02 loss: 1.6246 acc: 0.2500 valid_loss: 1.6092 valid_acc: 0.1250
Epoch 1: 100%|██████████| 2/2 [00:02<00:00,  1.26s/it, loss=1.690, v_num=2]

01/06/2021 09:07:28 - INFO - __main__ - Epoch: [1] eta: 0:00:02 loss: 1.6246 acc: 0.2500 valid_loss: 1.6092 valid_acc: 0.1250


Epoch 1: 100%|██████████| 2/2 [00:02<00:00,  1.27s/it, loss=1.690, v_num=2]


1

In [12]:
_ = trainer.test(model, datamodule=dm, verbose=True)

Testing: 100%|██████████| 1/1 [00:00<00:00,  1.97it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc': tensor(0.1250),
 'test/loss': tensor(1.6092),
 'train/acc': tensor(0.2500),
 'train/acc_epoch': tensor(0.2500),
 'train/acc_step': tensor(0.2500),
 'train/loss': tensor(1.6246),
 'train/loss_epoch': tensor(1.6246),
 'train/loss_step': tensor(1.6246),
 'valid/acc': tensor(0.1250),
 'valid/loss': tensor(1.6092)}
--------------------------------------------------------------------------------
Summary: [Train] loss: 1.6246 acc: 0.2500             
Testing: 100%|██████████| 1/1 [00:00<00:00,  1.97it/s]

01/06/2021 09:07:37 - INFO - __main__ - Summary: [Train] loss: 1.6246 acc: 0.2500


Summary: [Valid] loss: 1.6092 acc: 0.1250             
Testing: 100%|██████████| 1/1 [00:00<00:00,  1.97it/s]

01/06/2021 09:07:37 - INFO - __main__ - Summary: [Valid] loss: 1.6092 acc: 0.1250


Summary: [Test]  loss: 1.6092 acc: 0.1250             
Testing: 100%|██████████| 1/1 [00:00<00:00,  1.97it/s]

01/06/2021 09:07:37 - INFO - __main__ - Summary: [Test]  loss: 1.6092 acc: 0.1250


Testing: 100%|██████████| 1/1 [00:00<00:00,  1.69it/s]


In [13]:
#hide
from nbdev.export import *
notebook2script()

Converted 00_core.ipynb.
Converted 01_mixmethods.ipynb.
Converted 02_losses.ipynb.
Converted 03_layers.ipynb.
Converted 03a_networks.ipynb.
Converted 04_optimizers_schedules.ipynb.
Converted 05_lightning.core.ipynb.
Converted 05a_lightning.callbacks.ipynb.
Converted 06_fastai.core.ipynb.
Converted index.ipynb.
