In [None]:
# default_exp model

# Models

> API details.

In [None]:
%load_ext autoreload
%autoreload 2

import matplotlib as mpl
%matplotlib inline

In [None]:
#export
import warnings
import re

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

import pytorch_lightning as pl
from pytorch_lightning.core import LightningModule
from pytorch_lightning.metrics import functional as FM



In [None]:
#export
from isic.dataset import SkinDataModule
from isic.layers import LabelSmoothingCrossEntropy
from isic.callback.hyperlogger import HyperparamsLogger
from isic.callback.logtable import LogTableMetricsCallback
from isic.callback.mixup import MixupDict
from isic.callback.cutmix import CutmixDict
from isic.callback.freeze import FreezeCallback, UnfreezeCallback
from isic.utils.core import reduce_loss, generate_val_steps
from isic.utils.model import apply_init, get_bias_batchnorm_params, apply_leaf, check_attrib_module, create_body, create_head, lr_find, freeze, unfreeze

In [None]:
#export
class Model(LightningModule):
    def __init__(self, lr=1e-2, wd=0., n_out=7, concat_pool=True, arch='resnet50'):
        super().__init__()
        self.save_hyperparameters()
        # create body
        body, self.split, num_ftrs = create_body(arch)
        
        # create head
        head = create_head(num_ftrs, n_out)
        
        #model
        self.model = nn.Sequential(body, head)
        apply_init(self.model[1])
        
        # Setup so that batchnorm will not be freeze.
        for p in get_bias_batchnorm_params(self.model, False):
            p.force_train = True
        # Setup so that biases and batchnorm will skip weight decay.
        for p in get_bias_batchnorm_params(self.model, True):
            p.skip_wd = True
        
        self.loss_func = LabelSmoothingCrossEntropy()

    def exclude_params_with_attrib(self, splits, skip_list=['skip_wd']):
        includes = []
        excludes = []
        for param_group in splits:
            ins, exs = [], []
            for param in param_group:
                if not param.requires_grad:
                    continue
                elif any(getattr(param, attrib, False) for attrib in skip_list):
                    exs.append(param)
                else:
                    ins.append(param)
            includes.append(ins)
            excludes.append(exs)
        return includes + excludes

    def get_params(self, split_bn=True):
        if split_bn:
            splits = self.split(self.model)
            return self.exclude_params_with_attrib(splits)
        else:
            return self.split(self.model)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch['img'], batch['label']
        y_hat = self(x)
        loss = self.loss_func(y_hat, y)
        acc = FM.accuracy(y_hat, y, num_classes=7)
        result = pl.TrainResult(minimize=loss)
        result.log('train_loss', loss)
        result.log('train_acc', acc, prog_bar=True)
        return result

    def validation_step(self, batch, batch_idx):
        x, y = batch['img'], batch['label']
        y_hat = self(x)
        loss = self.loss_func(y_hat, y)
        acc = FM.accuracy(y_hat, y, num_classes=7)
        result = pl.EvalResult(checkpoint_on=loss)
        result.log('val_loss', loss, prog_bar=True) 
        result.log('val_acc', acc, prog_bar=True)
        return result
    
    def create_opt(self, lr=None, skip_bn_wd=True, epochs=None, steps_epoch=None):
        if lr is None:
            lr = self.hparams.lr
        if (epochs is None and self.epochs is None) or (steps_epoch is None and self.steps_epoch is None):
            raise Exception("You should pass epochs/steps_epoch at least one time in create_opt.")
        if epochs is None:
            epochs = self.epochs
        else:
            self.epochs = epochs
        if step_epoch is None:
            step_epoch = self.step_epoch
        else:
            self.step_epoch = step_epoch

        param_groups = self.get_params(skip_bn_wd)
        
        n_groups = real_n_groups = len(param_groups)
        if skip_bn_wd:
            # There are duplicates since we split the batchnorms out of it.
            n_groups //= 2

        def _inner():
            print('override_called')
            
            lrs = generate_val_steps(lr, n_groups)
            if skip_bn_wd:
                lrs += lrs
            assert len(lrs) == real_n_groups, f"Trying to set {len(lrs)} values for LR but there are {n_groups} parameter groups."

            grps = []
            for i, (pg, l) in enumerate(zip(param_groups, lrs)):
                grps.append({
                    "params": pg,
                    "lr": l,
                    "weight_decay": self.hparams.wd if i < n_groups else 0.
                })
            
            print(lrs)
            opt = torch.optim.Adam(grps, 
                        lr=self.hparams.lr
            )
            scheduler = torch.optim.lr_scheduler.OneCycleLR(opt, max_lr=lrs, steps_per_epoch=steps_epoch, epochs=epochs)
            sched = {
                'scheduler': scheduler, # The LR schduler
                'interval': 'step', # The unit of the scheduler's step size
                'frequency': 1, # The frequency of the scheduler
                'reduce_on_plateau': False, # For ReduceLROnPlateau scheduler
            }
            return [opt], [sched]
        self.configure_optimizers = _inner
        return n_groups

In [None]:
message_formater = "You have set {0} number of classes if different from predicted {0} and target {0} number of classes"
warnings.filterwarnings("ignore", message_formater.format("(.*)"), category=UserWarning)

In [None]:
dm = SkinDataModule()
dm.prepare_data()
dm.setup('fit')

In [None]:
EPOCHS = 10
STEPS_EPOCH = 1
lr = 1e-2

In [None]:
# init model
model = Model(lr, arch='resnet18')

# # Freeze model
# n_groups = model.create_opt(steps_epoch=STEPS_EPOCH, epochs=EPOCHS, lr=lr, skip_bn_wd=True)
# freeze(model, n_groups)

torch.Size([1, 512, 2, 2])


In [None]:
lr_find(model, dm,lr_find=False,verbose=True)

Running in fast_dev_run mode: will run a full train, val and test loop using a single batch
GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name      | Type                       | Params
---------------------------------------------------------
0 | model     | Sequential                 | 11 M  
1 | loss_func | LabelSmoothingCrossEntropy | 0     


override_called
[0.01, 0.01, 0.01, 0.01, 0.01, 0.01]


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Saving latest checkpoint..



Adam (
Parameter Group 0
    amsgrad: False
    base_momentum: 0.85
    betas: (0.8999999999999999, 0.999)
    eps: 1e-08
    initial_lr: 0.0004
    lr: 0.0052
    max_lr: 0.01
    max_momentum: 0.95
    min_lr: 4e-08
    weight_decay: 0.0

Parameter Group 1
    amsgrad: False
    base_momentum: 0.85
    betas: (0.8999999999999999, 0.999)
    eps: 1e-08
    initial_lr: 0.0004
    lr: 0.0052
    max_lr: 0.01
    max_momentum: 0.95
    min_lr: 4e-08
    weight_decay: 0.0

Parameter Group 2
    amsgrad: False
    base_momentum: 0.85
    betas: (0.8999999999999999, 0.999)
    eps: 1e-08
    initial_lr: 0.0004
    lr: 0.0052
    max_lr: 0.01
    max_momentum: 0.95
    min_lr: 4e-08
    weight_decay: 0.0

Parameter Group 3
    amsgrad: False
    base_momentum: 0.85
    betas: (0.8999999999999999, 0.999)
    eps: 1e-08
    initial_lr: 0.0004
    lr: 0.0052
    max_lr: 0.01
    max_momentum: 0.95
    min_lr: 4e-08
    weight_decay: 0.0

Parameter Group 4
    amsgrad: False
    base_momentum: 

In [None]:
trainer = pl.Trainer(max_epochs=EPOCHS, callbacks=[LogTableMetricsCallback(), HyperparamsLogger()], fast_dev_run=False, limit_val_batches=0, limit_train_batches=0.01)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores


In [None]:
trainer.fit(model, dm)


  | Name      | Type                       | Params
---------------------------------------------------------
0 | model     | Sequential                 | 25 M  
1 | loss_func | LabelSmoothingCrossEntropy | 0     


override_called
[0.0001, 0.0001, 0.001]


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

Saving latest checkpoint..





1

In [None]:
unfreeze(model, 3)

In [None]:
# Unfreeze training
trainer = pl.Trainer(max_epochs=EPOCHS, callbacks=[LogTableMetricsCallback(), HyperparamsLogger()], fast_dev_run=True, limit_val_batches=0, limit_train_batches=0.01)
model.create_opt(slice(5e-7, 3e-4))

Running in fast_dev_run mode: will run a full train, val and test loop using a single batch
GPU available: False, used: False
TPU available: False, using: 0 TPU cores


In [None]:
trainer.fit(model, dm)


  | Name      | Type             | Params
-----------------------------------------------
0 | model     | Sequential       | 25 M  
1 | loss_func | CrossEntropyLoss | 0     


override_called
wtf


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

epoch,train_loss,train_acc,val_loss,val_acc
1,3.256,0.15625,4.279556,0.0


Saving latest checkpoint..





1

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir=lightning_logs/

Reusing TensorBoard on port 6006 (pid 4636), started 4 days, 18:33:26 ago. (Use '!kill 4636' to kill it.)

In [None]:
from nbdev.export import *
notebook2script('model.ipynb')

Converted cb_mixup.ipynb.
