In [1]:
#default_exp lightning.callbacks

In [2]:
#export
import sys
import time
import datetime
from tqdm.auto import tqdm
import wandb
import logging

import torch
import pytorch_lightning as pl
from pytorch_lightning import _logger as log

from src.lightning.core import *
from src.core import *

In [3]:
# set up python logging
logging.basicConfig(format='[%(asctime)s][%(levelname)s]: %(message)s', datefmt="%m/%d/%Y %H:%M:%S")

In [4]:
#export
class WandbImageClassificationCallback(pl.Callback):
    """ Custom callback to add some extra functionalites to the wandb logger """
    def __init__(self,
                 num_batches:int = 16, 
                 log_train_batch: bool = False,
                 log_preds: bool = False,
                 log_conf_mat: bool = True,):
        
        # class names for the confusion matrix
        self.class_names = list(conf_mat_idx2lbl.values())
        
        # counter to log training batch images
        self.num_bs = num_batches
        self.curr_epoch = 0
        
        self.log_train_batch = log_train_batch
        self.log_preds = log_preds
        self.log_conf_mat = log_conf_mat
        
        self.val_imgs, self.val_labels = None, None
        
    def on_train_start(self, trainer, pl_module, *args, **kwargs):
        try:
            # log model to the wandb experiment
            wandb.watch(models=pl_module.model, criterion=pl_module.loss_func)
        except:
            log.info("Skipping wandb.watch --->")
        
    def on_train_epoch_end(self, trainer, pl_module, *args, **kwargs):
        if self.log_train_batch:
            if pl_module.one_batch is None:
                log.info(f"{self.config_defaults['mixmethod']} samples not available . Skipping --->")
                pass

            else:
                one_batch = pl_module.one_batch[:self.num_bs]
                train_ims = one_batch.data.to('cpu')
                trainer.logger.experiment.log({"train_batch":[wandb.Image(x) for x in train_ims]}, commit=False)
        
    def on_validation_epoch_end(self, trainer, pl_module, *args, **kwargs):
        if self.log_preds:
            if self.val_imgs is None and self.val_labels is None:
                self.val_imgs, self.val_labels = next(iter(pl_module.val_dataloader()))
                self.val_imgs, self.val_labels = self.val_imgs[:self.num_bs], self.val_labels[:self.num_bs]
                self.val_imgs = self.val_imgs.to(device=pl_module.device)

            logits = pl_module(self.val_imgs)
            preds  = torch.argmax(logits, 1)
            preds  = preds.data.cpu()
            
            ims = [wandb.Image(x, caption=f"Pred:{pred}, Label:{y}") for x,pred,y in zip(self.val_imgs,preds,self.val_labels)]
            log_dict = {"predictions": ims}
            wandb.log(ims,commit=False)
            
    def on_epoch_start(self, trainer, pl_module, *args, **kwargs):
        pl_module.val_labels_list = []
        pl_module.val_preds_list  = []
    
    def on_epoch_end(self, trainer, pl_module, *args, **kwargs):
        if self.log_conf_mat:
            val_preds  = torch.tensor(pl_module.val_preds_list).data.cpu().numpy()
            val_labels = torch.tensor(pl_module.val_labels_list).data.cpu().numpy()
            log_dict = {'conf_mat': wandb.plot.confusion_matrix(val_preds,val_labels,self.class_names)}
            wandb.log(log_dict,commit=False)

In [12]:
#export
class LitProgressBar(pl.callbacks.ProgressBar):
    "Custom Progressbar callback for Lightning Training"
    
    def init_sanity_tqdm(self) -> tqdm:
        """ Override this to customize the tqdm bar for the validation sanity run. """
        bar = tqdm(
            desc='Validation sanity check',
            #position=(2 * self.process_position),
            disable=self.is_disabled,
            dynamic_ncols=True,)
        
        return bar
    
    def init_train_tqdm(self) -> tqdm:
        """ Override this to customize the tqdm bar for training. """
        bar = tqdm(
            desc='Training',
            #initial=self.train_batch_idx,
            #position=(2 * self.process_position),
            disable=self.is_disabled,
            dynamic_ncols=True,)
        
        return bar
    
    def init_validation_tqdm(self) -> tqdm:
        """ Override this to customize the tqdm bar for validation. """
        bar = tqdm(
            desc='Validating',
            #position=(2 * self.process_position + 1),
            disable=True,
            dynamic_ncols=False,)
        
        return bar
    
    def init_test_tqdm(self) -> tqdm:
        """ Override this to customize the tqdm bar for testing. """
        bar = tqdm(
            desc='Testing',
            #position=(2 * self.process_position),
            disable=self.is_disabled,
            dynamic_ncols=True,)
        
        return bar

In [13]:
#export
class PrintLogsCallback(pl.Callback):
    "Logs Training logs to console after every epoch"
    def __init__(self, print_str: str = None):
        self.print_str = 'eta: {} loss: {:.4f} acc: {:.4f} valid_loss: {:.4f} valid_acc: {:.4f}'
        log = logging.getLogger(__name__)
        log.setLevel(logging.INFO)
        self.logger = log
    
    def on_epoch_start(self, *args, **kwargs):
        self.eta_start = time.time()
    
    def on_epoch_end(self, trainer, pl_module):
        metrics = trainer.callback_metrics
        train_loss = metrics['train/loss_epoch']
        train_acc  = metrics['train/acc_epoch']
        valid_loss = metrics['valid/loss']
        valid_acc  = metrics['valid/acc']
        
        end_time = time.time()
        self.eta_string = str(datetime.timedelta(seconds=int(end_time-self.eta_start)))
        self.curr_epoch = int(trainer.current_epoch)
        print_str = self.print_str.format(self.eta_string, train_loss, train_acc, valid_loss, valid_acc)
        
        self.logger.info(f"Epoch {self.curr_epoch} ")
        self.logger.info(print_str)
    
    def on_test_epoch_end(self, trainer, pl_module, *args, **kwargs):
        metrics = trainer.callback_metrics
        
        train_loss = metrics['train/loss_epoch']
        train_acc  = metrics['train/acc_epoch']
        
        test_loss  = metrics['test/loss']
        test_acc   = metrics['test/acc']
        
        
        fmt_str1 = "[Train] loss: {:.4f} acc: {:.4f}"
        fmt_str2 = "[Test ] loss: {:.4f} acc: {:.4f}"
        
        str1 = fmt_str1.format(train_loss, train_acc)
        str2 = fmt_str2.format(test_loss, test_acc)
        
        self.logger.info("Finished !")
        self.logger.info(str1)
        self.logger.info(str2)

In [14]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
import timm
from torch import nn
from src.networks import *
from omegaconf import OmegaConf

In [15]:
train_augs = A.Compose([
    A.RandomResizedCrop(224, 224, p=1.0),
    A.RandomBrightness(limit=0.1),
    A.HueSaturationValue(20, 20, 20),
    A.HorizontalFlip(),
    A.Normalize(p=1.0),
    ToTensorV2(p=1.0)])

valid_augs = A.Compose([
    A.Resize(224, 224, p=1.0),
    A.Normalize(p=1.0),
    ToTensorV2(p=1.0)])

csv = "../../leaf-disease-classification-kaggle/data/stratified-data-5folds.csv"
ims = "../../Datasets/cassava/train_images/"
dm = CassavaLightningDataModule(csv, ims, curr_fold=0, train_augs=train_augs, valid_augs=valid_augs, bs=8, num_workers=0)


model_hparams = dict(
    mixmethod        = None,
    loss             = dict(_target_='src.losses.LabelSmoothingCrossEntropy', eps=0.1),
    learning_rate    = 1e-03,
    lr_mult          = 100,
    optimizer        = dict(_target_='torch.optim.Adam'),
    scheduler        = dict(function=dict(_target_='src.opts.FlatCos', num_epochs=10, pct_start=0.7), 
                            metric_to_track=None, scheduler_interval='step'),
)

OmegaConf.create(model_hparams)



encoder = timm.create_model('resnet18', pretrained=False)
model   = TransferLearningModel(encoder, cut=-2, c=5, act=nn.ReLU(inplace=True))
model   = LightningCassava(model=model, conf=model_hparams)

In [16]:
trainer = pl.Trainer(callbacks=[LitProgressBar(), PrintLogsCallback()], 
                     num_sanity_val_steps=0, max_epochs=2, limit_train_batches=1, 
                     limit_val_batches=1, limit_test_batches=1, weights_summary=None)

GPU available: False, used: False
[01/17/2021 13:04:54][INFO]: GPU available: False, used: False
TPU available: False, using: 0 TPU cores
[01/17/2021 13:04:54][INFO]: TPU available: False, using: 0 TPU cores


In [17]:
trainer.fit(model, datamodule=dm)

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

[01/17/2021 13:04:57][INFO]: Epoch 0 
[01/17/2021 13:04:57][INFO]: eta: 0:00:02 loss: 1.7270 acc: 0.1250 valid_loss: 1.6077 valid_acc: 0.3750
[01/17/2021 13:05:00][INFO]: Epoch 1 
[01/17/2021 13:05:00][INFO]: eta: 0:00:02 loss: 1.4613 acc: 0.1250 valid_loss: 1.5937 valid_acc: 0.6250





1

In [18]:
_ = trainer.test(model, datamodule=dm, verbose=True)

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

[01/17/2021 13:05:03][INFO]: Finished !
[01/17/2021 13:05:03][INFO]: [Train] loss: 1.4613 acc: 0.1250
[01/17/2021 13:05:03][INFO]: [Test ] loss: 1.5937 acc: 0.6250


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc': tensor(0.6250),
 'test/loss': tensor(1.5937),
 'train/acc': tensor(0.1250),
 'train/acc_epoch': tensor(0.1250),
 'train/acc_step': tensor(0.1250),
 'train/loss': tensor(1.4613),
 'train/loss_epoch': tensor(1.4613),
 'train/loss_step': tensor(1.4613),
 'valid/acc': tensor(0.6250),
 'valid/loss': tensor(1.5937)}
--------------------------------------------------------------------------------



In [19]:
#hide
from nbdev.export import *
notebook2script()

Converted 00_core.ipynb.
Converted 01_mixmethods.ipynb.
Converted 02_losses.ipynb.
Converted 03_layers.ipynb.
Converted 03a_networks.ipynb.
Converted 04_optimizers_schedules.ipynb.
Converted 05_lightning.core.ipynb.
Converted 05a_lightning.callbacks.ipynb.
Converted 06_fastai.core.ipynb.
Converted index.ipynb.
