In [None]:
# default_exp utils

# Utilities

> The utility functions here can be used for training and evaluation of the model.

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#export
import logging
import numpy as np
import os
import time
import torch

log = logging.getLogger("Utilities for training classification models on the BreaKHis dataset.")

In [None]:
#export
def mixup_data(x, y, criterion, alpha=1.0):
    """Compute the mixup data for batch `x, y`. Return mixed inputs, pairs of targets, and lambda."""
    batch_size = x.size()[0]
    if alpha > 0:
        lam = np.random.beta(alpha, alpha, batch_size)
        lam = np.concatenate(
            [lam[:, None], 1 - lam[:, None]], 1
        ).max(1)[:, None, None, None]
        lam = torch.from_numpy(lam).float()
        if torch.cuda.is_available():
            lam = lam.cuda()
    else:
        lam = 1.
    index = torch.randperm(batch_size)
    if torch.cuda.is_available():
        index = index.cuda()
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    mixed_y = (lam * y_a) + ((1 - lam) * y_b)
    
    def mixup_criterion(pred):
        return (lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)).mean()
    def mixup_acc(pred):
        return (pred == mixed_y).sum().item()
    
    return mixed_x, y_a, y_b, lam, mixup_criterion, mixup_acc

This function is called in train if `mixup` is specified as true.
* `x`, `y` should be `torch.Tensor`
* `criterion` should be a `torch` loss function, e.g. `nn.CrossEntropyLoss`
* `alpha` is a float defining the distribution for sampling the mixing value (see the Mixup paper for details)

In [None]:
#export
def train(
    model, epoch, dataloader, criterion, optimizer, scheduler=None, mixup=False, alpha=0.4,
    logging_frequency=50
):
    """ Trains `model` on data in `dataloader` with loss `criterion` and optimization scheme
        defined by `optimizer`, with optional learning schedule defined by `scheduler`. This
        function calls performs 1 epoch - passing in `epoch` is purely for logging clarity.
        Logs every `logging_frequency` iterations."""
    
    model.train()
    total, total_loss, total_correct = 0, 0., 0.
    
    for i, (x, y) in enumerate(dataloader):
        if torch.cuda.is_available():
            x, y = x.cuda(), y.cuda()
        mixed_x, y_a, y_b, lam, mixup_criterion, mixup_acc = mixup_data(
            x, y, criterion, alpha=alpha if mixup else 0.0
        )
        optimizer.zero_grad()
        output = model(mixed_x)
        prediction = torch.argmax(output, -1)
        loss = mixup_criterion(output)
        total_loss += loss.item() * len(y)
        total_correct += mixup_acc(prediction)
        total += len(y)
        loss.backward()
        optimizer.step()
        if scheduler is not None:
            scheduler.step()
        if i % logging_frequency == 0 and i > 0:
            """ TODO:
            Add Tensorboard functionality here - mainly writer.add_scalar for
            overall loss, accuracy (i.e. over all epochs).
            """
            log.error(
                "[Epoch %d, Iteration %d / %d] Training Loss: %.5f, "
                "Training Accuracy: %.5f [Projected Accuracy: %.5f]"
                % (
                    epoch,
                    i,
                    len(dataloader),
                    total_loss / total,
                    total_correct / len(dataloader.dataset),
                    (total_correct / len(dataloader.dataset)) / (i / len(dataloader))
                )
            )
    final_loss, final_acc = total_loss / total, total_correct / total
    log.info(
        "Reporting %.5f training loss, %.5f training accuracy for epoch %d." % 
        (final_loss, final_acc, epoch)
    )
    return final_loss, final_acc

This function performs 1 epoch of training.
* `model` should be a `torch.nn.Module`
* `epoch` should indicate the current epoch of training, and is only really necessary for logging purposes.
* `dataloader should be a `torch.utils.data.DataLoader` wrapping a `BreaKHisDataset` object
* `criterion` should be a `torch` loss function
* `optimizer` should be a `torch.optim.Optimizer`, e.g. Adam
* `scheduler` is optional, but when included, should be a `torch.optim._LRScheduler`, e.g. CyclicLR
* `mixup` is a boolean indicating whether to use mixup augmentation for training (default is False)
* `alpha` is a float determining the distribution for sampling the mixing ratio
* `logging_frequency` determines the cycle of iterations before logging metrics

In [None]:
#export
def validate(model, epoch, dataloader, criterion, tta=False, tta_mixing=0.6, logging_frequency=50):
    """Validates `model` on data in `dataloader` for epoch `epoch` using objective `criterion`."""
    model.eval()
    total, total_loss, total_correct = 0, 0., 0.

    for i, (x, y) in enumerate(dataloader):
        if torch.cuda.is_available():
            x, y = x.cuda(), y.cuda()
        with torch.no_grad():
            if tta:
                bs, n_aug, c, h, w = x.size()
                output = model(x.view(-1, c, h, w)).view(bs, n_aug, -1)
                output = (
                    ((1 - tta_mixing) * output[:, -1, :]) + (tta_mixing * output[:, :-1, :].mean(1))
                )
            else:
                output = model(x)
            prediction = torch.argmax(output, -1)
            loss = criterion(output, y)
            total_loss += loss.item() * len(y)
            total_correct += (prediction == y).sum().item()
            total += len(y)

        if i % logging_frequency == 0 and i > 0:
            """ TODO:
            Add Tensorboard functionality here - mainly writer.add_scalar for
            overall loss, accuracy (i.e. over all epochs).
            """
            log.error(
                "[Epoch %d, Iteration %d / %d] Validation Loss: %.5f, "
                "Validation Accuracy: %.5f [Validation Accuracy: %.5f]"
                % (
                    epoch,
                    i,
                    len(dataloader),
                    total_loss / total,
                    total_correct / len(dataloader.dataset),
                    (total_correct / len(dataloader.dataset)) / (i / len(dataloader))
                )
            )
    final_loss, final_acc = total_loss / total, total_correct / total
    log.info(
        "Reporting %.5f validation loss, %.5f validation accuracy for epoch %d." % 
        (final_loss, final_acc, epoch)
    )
    return final_loss, final_acc

This function performs 1 epoch of validation.
* `model` should be a `torch.nn.Module`
* `epoch` should indicate the current epoch of training, and is only really necessary for logging purposes.
* `dataloader should be a `torch.utils.data.DataLoader` wrapping a `BreaKHisDataset` object
* `criterion` should be a `torch` loss function
* `optimizer` should be a `torch.optim.Optimizer`, e.g. Adam
* `tta` is a boolean indicating whether to use test-time augmentation (default is False)
* `tta_mixing` determines how much of the test-time augmented data to use in determining the final output (default is 0.6)
* `logging_frequency` determines the cycle of iterations before logging metrics

Here are some toy examples using the functions defined above. For brevity, we use a small subset of the dataset.

In [None]:
from breakhis_gradcam.data import initialize_datasets
from breakhis_gradcam.resnet import resnet18
from torch import nn
from torchvision import transforms

def get_tta_transforms(resize_shape, normalize_transform, n=5):
    tta = transforms.Compose([
        transforms.RandomRotation(15),
        transforms.RandomResizedCrop((resize_shape, resize_shape)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
        transforms.ToTensor()
    ])
    original_transform = transforms.Compose([
        transforms.Resize((resize_shape, resize_shape)),
        transforms.ToTensor()
    ])
    return transforms.Compose([
        transforms.Lambda(
            lambda image: torch.stack(
                [tta(image) for _ in range(n)] + [original_transform(image)]
            )
        ),
        transforms.Lambda(
            lambda images: torch.stack([
                normalize_transform(image) for image in images
            ])
        ),
    ])

def get_transforms(resize_shape, tta=False, tta_n=5):
    random_resized_crop = transforms.RandomResizedCrop((resize_shape, resize_shape))
    random_horizontal_flip = transforms.RandomHorizontalFlip()
    resize = transforms.Resize((resize_shape, resize_shape))
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )
    train_transforms = transforms.Compose([
        random_resized_crop, random_horizontal_flip, transforms.ToTensor(), normalize
    ])
    val_transforms = (
        get_tta_transforms(resize_shape, normalize, n=tta_n) if tta
        else transforms.Compose([resize, transforms.ToTensor(), normalize])
    )
    return train_transforms, val_transforms
    
train_transform, val_transform = get_transforms(224, tta=True)

In [None]:
ds_mapping = initialize_datasets(
    '/share/nikola/export/dt372/BreaKHis_v1/',
    label='tumor_class', criterion=['tumor_type', 'magnification'],
    split_transforms={'train': train_transform, 'val': val_transform}
)

In [None]:
tr_ds, val_ds = ds_mapping['train'], ds_mapping['val']

In [None]:
tr_dl = torch.utils.data.DataLoader(tr_ds, batch_size=32, shuffle=True)
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=32)

In [None]:
tr_loss, tr_acc = train(
    model, 0, tr_dl, criterion['train'], optimizer, mixup=mixup, alpha=0.4,
    logging_frequency=25
)
val_loss, val_acc = validate(
    model, 0, val_dl, criterion['val'], tta=True,
    logging_frequency=25
)

[Epoch 0, Iteration 25 / 198] Training Loss: 0.46333, Training Accuracy: 2.17543 [Projected Accuracy: 17.22939]
[Epoch 0, Iteration 50 / 198] Training Loss: 0.47674, Training Accuracy: 4.24937 [Projected Accuracy: 16.82749]
[Epoch 0, Iteration 75 / 198] Training Loss: 0.48411, Training Accuracy: 6.35830 [Projected Accuracy: 16.78590]
[Epoch 0, Iteration 100 / 198] Training Loss: 0.47731, Training Accuracy: 8.50142 [Projected Accuracy: 16.83282]
[Epoch 0, Iteration 125 / 198] Training Loss: 0.47459, Training Accuracy: 10.64962 [Projected Accuracy: 16.86900]
[Epoch 0, Iteration 150 / 198] Training Loss: 0.47411, Training Accuracy: 12.74208 [Projected Accuracy: 16.81955]
[Epoch 0, Iteration 175 / 198] Training Loss: 0.47475, Training Accuracy: 14.79908 [Projected Accuracy: 16.74410]
[Epoch 0, Iteration 25 / 50] Validation Loss: 0.42023, Validation Accuracy: 0.44005 [Validation Accuracy: 0.88010]


In [None]:
#export
def get_param_lr_maps(model, base_lr, finetune_body_factor):
    """ Output parameter LR mappings for setting up an optimizer for `model`."""
    body_parameters = [
        (param, _) for (param, _) in model.named_parameters() if param.split('.')[0] != 'out_fc'
    ]
    if type(finetune_body_factor) is float:
        log.error(
            "Setting up optimizer to fine-tune body with LR %.8f and head with LR %.5f" %
            (base_lr * finetune_body_factor, base_lr)
        )
        return [
            {'params': body_parameters, 'lr': base_lr * finetune_body_factor},
            {'params': model.out_fc.parameters(), 'lr': base_lr}
        ]
    else:
        lower_bound_factor, upper_bound_factor = finetune_body_factor
        log.error(
            "Setting up optimizer to fine-tune body with LR in range [%.8f, %.8f]"
            " and head with LR %.5f" %
            (base_lr * lower_bound_factor, base_lr * upper_bound_factor, base_lr)
        )
        lrs = np.geomspace(
            base_lr * lower_bound_factor, base_lr * upper_bound_factor,
            len(body_parameters)
        )
        param_lr_maps = [
            {'params': param, 'lr': lr} for ((_, param), lr) in
            zip(body_parameters, lrs)
        ]
        param_lr_maps.append({'params': model.out_fc.parameters(), 'lr': base_lr})
        return param_lr_maps

This function is useful for setting up parameter to LR mappings for fine-tuning the model. Specifically:
* `model` should be a `torch.nn.Module`
* `base_lr` should be a float, defining the LR for the linear head
* `finetune_body_factor` should be a list of two floats: a lower bound factor and upper bound factor. The learning rate for the body of the model will be equally (log) spaced between (`base_lr` * `lower_bound_factor`) and (`base_lr` * `upper_bound_factor`)

In [None]:
#export
def setup_optimizer_and_scheduler(param_lr_maps, base_lr, epochs, steps_per_epoch):
    """Create a PyTorch AdamW optimizer and OneCycleLR scheduler with `param_lr_maps` parameter mapping,
       with base LR `base_lr`, for training for `epochs` epochs, with `steps_per_epoch` iterations
       per epoch."""
    optimizer = torch.optim.AdamW(param_lr_maps, lr=base_lr)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, base_lr, epochs=epochs, steps_per_epoch=steps_per_epoch
    )
    return optimizer, scheduler

def checkpoint_state(
    model, epoch, optimizer, scheduler, train_loss, train_acc, val_loss, val_acc,
    model_dir='/share/nikola/export/dt372/breakhis_gradcam/models'
):
    """Checkpoint the state of the system, including `model` state, `optimizer` state, `scheduler`
       state, for `epoch`, saving the metrics as well."""
    if not os.path.exists(model_dir):
        os.mkdir(model_dir)
    try:
        getattr(model, 'save_dir')
    except BaseException:
        setattr(
            model, 'save_dir',
            os.path.join(model_dir, time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime()))
        )
        os.mkdir(model.save_dir)
    torch.save(
        {
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': None if scheduler is None else scheduler.state_dict(),
            'train_loss': train_loss,
            'train_acc': train_acc,
            'val_loss': val_loss,
            'val_acc': val_acc,
            'epoch': epoch
        },
        os.path.join(model.save_dir, 'epoch_%d.pth' % epoch)
    )

In the below example, you can see how to set up the optimizer and scheduler to fine-tune using the one-cycle LR scheme. The linear head is fine-tuned with a learning rate of $10^{-3}$, and the body is fine-tuned with a learning rate spaced between $10^{-8}$ and $10^{-5}$.

In [None]:
model = resnet18(pretrained=True, num_classes=2)
if torch.cuda.is_available():
    model = model.cuda()
mixup = True
num_epochs = 5
base_lr = 1e-3
finetune_body_factor = [1e-5, 1e-2]
param_lr_maps = get_param_lr_maps(model, base_lr, finetune_body_factor)
optimizer, scheduler = setup_optimizer_and_scheduler(param_lr_maps, base_lr, num_epochs, len(tr_dl))
criterion = {
    'train': nn.CrossEntropyLoss(reduction='none' if mixup else 'mean'),
    'val': nn.CrossEntropyLoss()
}

Setting up optimizer to fine-tune body with LR in range [0.00000001, 0.00001000] and head with LR 0.00100


A simple training loop would look like the following. Note that:
* The one-cycle LR scheduler is passed in, and the logic for updating that is handled in `train`
* Different criterion are used for training and validation. This is because the criterion for mixup is different for each batch, due to the mixing factor, so the criterion is modified in the loop for training, so the reduction is handled there, whereas reduction is standard when evaluating in validation (i.e. mean reduction)
* Test-time augmentation is done in validation. Note that this will require having a special augmentation scheme, so validation transforms will need to be set appropriately. You can see above for an example of how to do that.
* The model state is checkpointed each epoch. After checkpointing the state of the model and system, the directory where the state was saved can be accessed by inspecting `model.save_dir`.

In [None]:
for epoch in range(num_epochs):
    tr_loss, tr_acc = train(
        model, epoch + 1, tr_dl, criterion['train'], optimizer, scheduler=scheduler,
        mixup=mixup, alpha=0.4, logging_frequency=25
    )
    val_loss, val_acc = validate(
        model, epoch + 1, val_dl, criterion['val'], tta=True,
        logging_frequency=25
    )
    checkpoint_state(
        model, epoch + 1, optimizer, scheduler, tr_loss, tr_acc, val_loss, val_acc,
    )

[Epoch 1, Iteration 25 / 198] Training Loss: 0.61779, Training Accuracy: 1.66688 [Projected Accuracy: 13.20167]
[Epoch 1, Iteration 50 / 198] Training Loss: 0.53964, Training Accuracy: 3.76124 [Projected Accuracy: 14.89452]
[Epoch 1, Iteration 75 / 198] Training Loss: 0.50664, Training Accuracy: 5.83344 [Projected Accuracy: 15.40028]
[Epoch 1, Iteration 100 / 198] Training Loss: 0.50104, Training Accuracy: 7.81111 [Projected Accuracy: 15.46601]
[Epoch 1, Iteration 125 / 198] Training Loss: 0.49492, Training Accuracy: 10.04433 [Projected Accuracy: 15.91022]
[Epoch 1, Iteration 150 / 198] Training Loss: 0.48815, Training Accuracy: 12.14424 [Projected Accuracy: 16.03039]
[Epoch 1, Iteration 175 / 198] Training Loss: 0.48081, Training Accuracy: 14.39566 [Projected Accuracy: 16.28766]
[Epoch 1, Iteration 25 / 50] Validation Loss: 0.50810, Validation Accuracy: 0.38920 [Validation Accuracy: 0.77841]
[Epoch 2, Iteration 25 / 198] Training Loss: 0.44745, Training Accuracy: 2.18651 [Projected Ac

In [None]:
os.listdir(model.save_dir)

['epoch_1.pth', 'epoch_2.pth', 'epoch_3.pth', 'epoch_4.pth', 'epoch_5.pth']

We can just use the validate method with some slight alterations to get the standard training accuracy (not the mixup accracy, which might not be as representative).

In [None]:
_, tr_acc_no_mixup = validate(model, epoch + 1, tr_dl, criterion['val'], tta=False, logging_frequency=25)
print("Training accuracy after %d epochs is %.5f" % (epoch + 1, tr_acc_no_mixup))

[Epoch 5, Iteration 25 / 198] Validation Loss: 0.16696, Validation Accuracy: 0.12666 [Validation Accuracy: 1.00317]
[Epoch 5, Iteration 50 / 198] Validation Loss: 0.17132, Validation Accuracy: 0.24731 [Validation Accuracy: 0.97934]
[Epoch 5, Iteration 75 / 198] Validation Loss: 0.17119, Validation Accuracy: 0.36859 [Validation Accuracy: 0.97307]
[Epoch 5, Iteration 100 / 198] Validation Loss: 0.16980, Validation Accuracy: 0.49003 [Validation Accuracy: 0.97025]
[Epoch 5, Iteration 125 / 198] Validation Loss: 0.16848, Validation Accuracy: 0.61257 [Validation Accuracy: 0.97031]
[Epoch 5, Iteration 150 / 198] Validation Loss: 0.16889, Validation Accuracy: 0.73385 [Validation Accuracy: 0.96868]
[Epoch 5, Iteration 175 / 198] Validation Loss: 0.16793, Validation Accuracy: 0.85529 [Validation Accuracy: 0.96770]


Training accuracy after 5 epochs is 0.95836
