In [None]:
# default_exp model

# Models

> API details.

In [None]:
%load_ext autoreload
%autoreload 2

import matplotlib as mpl
%matplotlib inline

In [None]:
#export
import warnings
import re
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

import pytorch_lightning as pl
from pytorch_lightning.core import LightningModule
from pytorch_lightning.metrics import functional as FM

In [None]:
#export
from isic.dataset import SkinDataModule, from_label_idx_to_key
from isic.layers import LabelSmoothingCrossEntropy
from isic.callback.hyperlogger import HyperparamsLogger
from isic.callback.logtable import LogTableMetricsCallback
from isic.callback.mixup import MixupDict
from isic.callback.cutmix import CutmixDict
from isic.callback.freeze import FreezeCallback, UnfreezeCallback
from isic.utils.core import reduce_loss, generate_val_steps
from isic.utils.model import apply_init, get_bias_batchnorm_params, apply_leaf, check_attrib_module, create_body, create_head, lr_find, freeze, unfreeze, log_metrics_per_key

In [None]:
#export
class BaselineModel(LightningModule):
    def __init__(self, arch='resnet50', lr=1e-2, loss_func=None):
        super().__init__()
        self.save_hyperparameters()
        self.model = getattr(models, arch)(pretrained=True)
        num_ftrs = self.model.fc.in_features
        self.model.fc = nn.Linear(num_ftrs, 7)
        self.loss_func = loss_func
        if self.loss_func is None:
            self.loss_func = F.cross_entropy

    def forward(self, x):
        return self.model(x)

    def shared_step(self, batch, batch_id):
        x, y = batch['img'], batch['label']
        y_hat = self(x)
        return self.loss_func(y_hat, y), (y_hat, y)

    def training_step(self, batch, batch_idx):
        loss, _ = self.shared_step(batch, batch_idx)
        result = pl.TrainResult(minimize=loss)
        result.log('train_loss', loss)
        return result

    def validation_step(self, batch, batch_idx):
        loss, (y_hat, y) = self.shared_step(batch, batch_idx)
        result = pl.EvalResult()
        result.y = y
        result.y_hat = y_hat
        result.loss = loss
        return result

    def calc_and_log_metrics(self, y_hat, y):
        acc = FM.accuracy(y_hat, y, num_classes=7)
        preds = y_hat.argmax(1)
        precision, recall = FM.precision_recall(y_hat, y, num_classes=7)
        f1 = FM.f1_score(y_hat, y, num_classes=7)
        prec_arr, recall_arr = FM.precision_recall(y_hat, y, num_classes=7, reduction='none')

        result = pl.EvalResult()
        result.log('val_acc', acc, prog_bar=True)
        result.log('val_precision', precision, prog_bar=True)
        result.log('val_recall', recall, prog_bar=True)
        result.log('F1', f1, prog_bar=True)
        metrics = {
            "precision": prec_arr,
            "recall": recall_arr,
        }
        log_metrics_per_key(result, metrics)
        return result

    def validation_epoch_end(self, out):
        avg_val_loss = out.loss.mean()

        result = self.calc_and_log_metrics(out.y_hat, out.y)
        result.log('val_loss', avg_val_loss, prog_bar=True)

        return result


    def test_step(self, batch, batch_idx):
        _, (y_hat, y) = self.shared_step(batch, batch_idx)
        result = pl.EvalResult()
        result.y = y
        result.y_hat = y_hat
        return result

    def test_epoch_end(self, out):        
        result = self.calc_and_log_metrics(out.y_hat, out.y)
        torch.save(out.y_hat.cpu(), 'preds.pt')
        torch.save(out.y.cpu(), 'labels.pt')

        return result
    
    def configure_optimizers(self):
        opt = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
        return opt

In [None]:
#export
class Model(LightningModule):
    def __init__(self, lr=1e-2, wd=0., n_out=7, concat_pool=True, arch='resnet50', loss_func=None, verbose=True):
        super().__init__()
        self.save_hyperparameters()
        # create body
        body, self.split, num_ftrs = create_body(arch)
        
        # create head
        head = create_head(num_ftrs, n_out)
        
        #model
        self.model = nn.Sequential(body, head)
        apply_init(self.model[1])
        
        # Setup so that batchnorm will not be freeze.
        for p in get_bias_batchnorm_params(self.model, False):
            p.force_train = True
        # Setup so that biases and batchnorm will skip weight decay.
        for p in get_bias_batchnorm_params(self.model, True):
            p.skip_wd = True

        n_groups = self.create_opt(torch.optim.Adam, None)
        freeze(self, n_groups)
        
        self.loss_func = loss_func
        if self.loss_func is None:
            self.loss_func = F.cross_entropy

    def forward(self, x):
        return self.model(x)
    
    def exclude_params_with_attrib(self, splits, skip_list=['skip_wd']):
        includes = []
        excludes = []
        for param_group in splits:
            ins, exs = [], []
            for param in param_group:
                if not param.requires_grad:
                    continue
                elif any(getattr(param, attrib, False) for attrib in skip_list):
                    exs.append(param)
                else:
                    ins.append(param)
            includes.append(ins)
            excludes.append(exs)
        
        if self.hparams.verbose:
            print('Total splits = ', len(excludes))
            for i in range(len(excludes)):
                print(f'Split {i+1}: {len(excludes[i])} layers are excluded.')

        return includes + excludes

    def get_params(self, split_bn=True):
        if split_bn:
            splits = self.split(self.model)
            return self.exclude_params_with_attrib(splits)
        else:
            return self.split(self.model)

    def shared_step(self, batch, batch_id):
        x, y = batch['img'], batch['label']
        y_hat = self(x)
        return self.loss_func(y_hat, y), (y_hat, y)

    def training_step(self, batch, batch_idx):
        loss, _ = self.shared_step(batch, batch_idx)
        result = pl.TrainResult(minimize=loss)
        result.log('train_loss', loss)
        return result

    def validation_step(self, batch, batch_idx):
        loss, (y_hat, y) = self.shared_step(batch, batch_idx)
        result = pl.EvalResult()
        result.y = y
        result.y_hat = y_hat
        result.loss = loss
        return result

    def calc_and_log_metrics(self, y_hat, y):
        acc = FM.accuracy(y_hat, y, num_classes=7)
        preds = y_hat.argmax(1)
        precision, recall = FM.precision_recall(y_hat, y, num_classes=7)
        f1 = FM.f1_score(y_hat, y, num_classes=7)
        prec_arr, recall_arr = FM.precision_recall(y_hat, y, num_classes=7, reduction='none')

        result = pl.EvalResult()
        result.log('val_acc', acc, prog_bar=True)
        result.log('val_precision', precision, prog_bar=True)
        result.log('val_recall', recall, prog_bar=True)
        result.log('F1', f1, prog_bar=True)
        metrics = {
            "precision": prec_arr,
            "recall": recall_arr,
        }
        log_metrics_per_key(result, metrics)
        return result

    def validation_epoch_end(self, out):
        avg_val_loss = out.loss.mean()

        result = self.calc_and_log_metrics(out.y_hat, out.y)
        result.log('val_loss', avg_val_loss, prog_bar=True)

        return result

    def test_step(self, batch, batch_idx):
        _, (y_hat, y) = self.shared_step(batch, batch_idx)
        result = pl.EvalResult()
        result.y = y
        result.y_hat = y_hat
        return result

    def test_epoch_end(self, out):        
        result = self.calc_and_log_metrics(out.y_hat, out.y)
        torch.save(out.y_hat.cpu(), 'preds.pt')
        torch.save(out.y.cpu(), 'labels.pt')

        return result

    def create_opt(self, opt_func, sched_func, lr=None, wd=None, skip_bn_wd=True):
        if lr is None:
            lr = self.hparams.lr
        if wd is None:
            wd = self.hparams.wd

        param_groups = self.get_params(skip_bn_wd)        
        n_groups = real_n_groups = len(param_groups)
        if skip_bn_wd:
            # There are duplicates since we split the batchnorms out of it.
            n_groups //= 2

        def _inner():
            if self.hparams.verbose:
                print('Overriding_configure_optimizer...')
            
            lrs = generate_val_steps(lr, n_groups)
            if skip_bn_wd:
                lrs += lrs
            assert len(lrs) == real_n_groups, f"Trying to set {len(lrs)} values for LR but there are {n_groups} parameter groups."

            grps = []
            for i, (pg, pg_lr) in enumerate(zip(param_groups, lrs)):
                grps.append({
                    "params": pg,
                    "lr": pg_lr,
                    "weight_decay": wd if i < n_groups else 0.
                })
            
            if self.hparams.verbose:
                print('LRs for each layer:', lrs)
            
            # Create a dummy optimizer, lr will be corrected by the scheduler.
            opt = opt_func(grps, lr=self.hparams.lr if isinstance(lr, slice) else lr)
            if sched_func is not None:
                scheduler = sched_func(opt, max_lr=lrs)
                sched = {
                    'scheduler': scheduler, # The LR schduler
                    'interval': 'step', # The unit of the scheduler's step size
                    'frequency': 1, # The frequency of the scheduler
                    'reduce_on_plateau': False, # For ReduceLROnPlateau scheduler
                }
                return [opt], [sched]
            # Not use sched_func
            return [opt]

        self.configure_optimizers = _inner
        return n_groups

In [None]:
#export
def fit_one_cycle(epochs, model, datamodule, opt='Adam', max_lr=None, pct_start=0.25, 
                  div_factor=25., final_div_factor=1e5, wd=None, 
                  skip_bn_wd=True, max_momentum=0.95, base_momentum=0.85, **kwargs):
    if isinstance(opt, str):
        opt_func = getattr(torch.optim, opt, False)
        if not opt_func:
            raise Exception("Invalid optimizer, please pass correct name string as in pytorch.optim.")
    else:
        opt_func = opt
    sched_func = torch.optim.lr_scheduler.OneCycleLR
    steps_epoch = len(datamodule.train_dataloader())
    sched = partial(sched_func, epochs=epochs, steps_per_epoch=steps_epoch, pct_start=pct_start, 
                    div_factor=div_factor, final_div_factor=final_div_factor,
                    base_momentum=base_momentum, max_momentum=max_momentum)
    model.create_opt(opt_func, sched, lr=max_lr, wd=wd)
    trainer = pl.Trainer(max_epochs=epochs, **kwargs)
    trainer.fit(model, datamodule)
    return trainer

In [None]:
message_formater = "You have set {0} number of classes if different from predicted {0} and target {0} number of classes"
warnings.filterwarnings("ignore", message_formater.format("(.*)"), category=UserWarning)

In [None]:
dm = SkinDataModule()
dm.prepare_data()
dm.setup('fit')

In [None]:
F_EPOCHS = 1
U_EPOCHS = 1
LR = 1e-2

In [None]:
# init model
model = Model(LR, arch='resnet18')

Total splits =  3
Split 1: 20 layers are excluded.
Split 2: 20 layers are excluded.
Split 3: 4 layers are excluded.


In [None]:
check_attrib_module(model)

In [None]:
lr_find(model, dm,fast_dev_run=True,verbose=True)

In [None]:
cbs = [LogTableMetricsCallback(), HyperparamsLogger()]
trainer = fit_one_cycle(F_EPOCHS, model, dm, max_lr=LR, callbacks=cbs, fast_dev_run=False, limit_val_batches=0, limit_train_batches=0.01)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores


In [None]:
unfreeze(model, 3)

In [None]:
# Unfreeze training
trainer = fit_one_cycle(callbacks=cbs, fast_dev_run=False, limit_val_batches=0, limit_train_batches=0.01)

Running in fast_dev_run mode: will run a full train, val and test loop using a single batch
GPU available: False, used: False
TPU available: False, using: 0 TPU cores


In [None]:
trainer.fit(model, dm)


  | Name      | Type             | Params
-----------------------------------------------
0 | model     | Sequential       | 25 M  
1 | loss_func | CrossEntropyLoss | 0     


override_called
wtf


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

epoch,train_loss,train_acc,val_loss,val_acc
1,3.256,0.15625,4.279556,0.0


Saving latest checkpoint..





1

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir=lightning_logs/

Reusing TensorBoard on port 6006 (pid 4636), started 4 days, 18:33:26 ago. (Use '!kill 4636' to kill it.)

In [None]:
from nbdev.export import *
notebook2script('model.ipynb')

Converted model.ipynb.
