# Callbacks

Special callbacks for 3D data or training.

In [None]:
# hide
import sys
sys.path.append("..")

In [None]:
# export
# default_exp callback

from fastai.basics import *
from fastai.vision.all import *
from fastai.callback.all import *
import torch.nn.functional as F

In [None]:
# export
from faimed3d.basics import *
from faimed3d.preprocess import *
from faimed3d.augment import *
from faimed3d.data import *

## Channel manipulation

For processing multiple 3D volumes, the volumes can be stacked to the color dimension. This need to be implemented as callback before the batch is presented to the model. 

In [None]:
# export
class StackVolumes(Callback):
    """
    Takes multiple 3D volumes and stacks them in the channel dim.
    This is useful when using multi-sequence medical data.

    Example:
        Having the Tensors of size (10, 1, 5, 25, 25) would lead to a single Tensor of
        size (10, 3, 5, 25, 25).
    """

    def before_batch(self):
        self.learn.xb = (torch.cat(self.learn.xb, dim=1), )


## Callbacks for volume manipulation

In medical imaging, small regions in the image are often decisive for the diagnosis. This means, given a smaller subregion of the image, the model could still correctly detect the pathology. Through splitting the volumes the data might thus be augmented. 

In [None]:
# export
class SplitVolumes(Callback):
    """
        Separates a 3D tensor into smaller equal-sized sub-volumes.

         o---o---o       o---o---o
         | A | A |       | B | B |        o---o  o---o  o---o  o---o  o---o  o---o  o---o  o---o
         o---o---o   +   o---o---o  ==>   | A | +| A | +| B | +| B | +| A | +| A | +| B | +| B |
         | A | A |       | B | B |        o---o  o---o  o---o  o---o  o---o  o---o  o---o  o---o
         o---o---o       o---o---o


        Args:
            n_subvol = number of subvolumes
            split_along_depth = whether volumes should also be split along the D dimension fpr a [B, C, D, H, W] tensor
    """
    run_after = StackVolumes
    def __init__(self, n_subvol = 2**3, split_along_depth = True):
        store_attr()

    def before_batch(self):
        xb = self.learn.xb
        if len(xb) > 1: raise ValueError('Got multiple items in x batch. You need to concatenate the batch first.')
        self.learn.xb = self.split_volume(xb)
        self.learn.yb = self.split_volume(self.learn.yb)

    def after_pred(self):
        self.learn.xb = self.patch_volume(self.learn.xb)
        self.learn.pred = detuplify(self.patch_volume(self.learn.pred))
        self.learn.yb = self.patch_volume(self.learn.yb)

    def split_volume(self, xb:(Tensor, TensorDicom3D, TensorMask3D)):
        "splits a large tensor into multiple smaller tensors"

        xb = detuplify(xb) # xb is always a tuple
        # calculate number of splits per dimension
        self.n = self.n_subvol**(1/3) if self.split_along_depth else self.n_subvol**0.5
        self.n = int(self.n)

        # check if shape of dims is divisible by n, if not resize the Tensor acordingly
        shape = [s if s % self.n == 0 else s - s % self.n for s in xb.shape[-3:]]
        if not self.split_along_depth: shape[0]=xb.shape[0]
        xb = F.interpolate(xb, size = shape, mode = "trilinear", align_corners=True)

        # split each dim into smaller dimensions
        d, h, w = shape
        if self.split_along_depth: xb = xb.reshape(xb.size(0), xb.size(1), self.n, int(d/self.n), self.n, int(h/self.n), self.n, int(w/self.n))
        else: xb = xb.reshape(xb.size(0), xb.size(1),1, d, self.n, int(h/self.n), self.n, int(w/self.n))

        # swap the dimensions an flatten Batchdim and the newly created dims
        # return a tuple as xb is always a tuple
        return (xb.permute(1, 3, 5, 7, 0, 2, 4, 6).flatten(-4).permute(4, 0, 1, 2, 3), )

    def patch_volume(self, p:(Tensor, TensorDicom3D, TensorMask3D)):
        "patches a prior split volume back together"
        p = detuplify(p)

        old_shape = p.shape[0]//self.n_subvol, p.shape[1], *[s * self.n for s in p.shape[2:]]
        if not self.split_along_depth: old_shape[2]=p.shape[2]
        p = p.reshape(old_shape[0], self.n, self.n, self.n, *p.shape[1:])
        return (p.permute(0, 4, 1, 5, 2, 6, 3, 7).reshape(old_shape), )

Subsampling the subvolumes allows for more variability in the image and also training with a batch size < 1. 

In [None]:
# export
class SubsampleShuffle(SplitVolumes):
    """
        After splitting the volume into multiple subvolumes, draws a randon amount of subvolumes for training.
        Would allow to train on an effective batch size < 1.

        o---o---o        o---o---o
        | A | A |        | B | B |        o---o  o---o  o---o  o---o  o---o  o---o
        o---o---o    +   o---o---o  ==>   | B | +| A | +| A | +| A | +| B | +| A |
        | A | A |        | B | B |        o---o  o---o  o---o  o---o  o---o  o---o
        o---o---o        o---o---o

        Args:
            p: percentage of subvolumes to train on
    """
    run_after = [StackVolumes]

    def __init__(self, p = 0.5, n_subvol=2**3, split_along_depth = True):
        store_attr()

    def before_batch(self):

        xb = self.learn.xb
        if len(xb) > 1: raise ValueError('Got multiple items in x batch. You need to concatenate the batch first.')
        self.learn.xb = self.split_volume(xb)
        self.learn.yb = self.split_volume(self.learn.yb)

        if self.training:
            xb = detuplify(self.learn.xb)
            yb = detuplify(self.learn.yb)
            draw = tuple(random.sample(range(0, xb.size(0)), int(xb.size(0)*self.p)))
            self.learn.xb = (xb[draw, :], )
            self.learn.yb = (yb[draw, :], )

    def after_pred(self):
        if not self.training:
            self.learn.xb = self.patch_volume(self.learn.xb)
            self.learn.pred = detuplify(self.patch_volume(self.learn.pred))
            self.learn.yb = self.patch_volume(self.learn.yb)

Assuming, a small finding is predominantly located in the, e.g. upper left image region, the model might wrongly learn the location as an important factor for the finding. Mixing subvolumes might help. 

In [None]:
# export
class MixSubvol(SplitVolumes):
    """
        After splitting the volume into multiple subvolumes, shuffels the subvolumes and sticks the images back together.

        o---o---o        o---o---o        o---o---o        o---o---o
        | A | A |        | B | B |        | B | B |        | A | B |
        o---o---o    +   o---o---o  ==>   o---o---o    +   o---o---o
        | A | A |        | B | B |        | A | A |        | B | A |
        o---o---o        o---o---o        o---o---o        o---o---o


        Args:
            p: probability that the callback will be applied
            n_subvol: number of subvolumina to create
            split_along_depth: whether the depth dimension should be included

    """
    run_after = [StackVolumes]

    def __init__(self, p = 0.25, n_subvol=2**3, split_along_depth = True):
        store_attr()

    def before_batch(self):
        if self.training and random.random() < self.p:
            xb = self.learn.xb
            if len(xb) > 1: raise ValueError('Got multiple items in x batch. You need to concatenate the batch first.')
            xb = detuplify(self.split_volume(xb))
            yb = detuplify(self.split_volume(self.learn.yb))
            shuffle = tuple(random.sample(range(0, xb.size(0)), xb.size(0)))
            self.learn.xb = self.patch_volume((xb[shuffle, :], ))
            self.learn.yb = self.patch_volume((yb[shuffle, :], ))

    def after_pred(self):
        pass

Implementation for MixUp on 3D data

## Tracker callbacks

In [None]:
# export
class ReloadBestFit(TrackerCallback):
    "A `TrackerCallback` that reloads the previous best model if not improvement happend for n epochs"
    def __init__(self, fname,  monitor='valid_loss', comp=None, min_delta=0., patience=1):
        super().__init__(monitor=monitor, comp=comp, min_delta=min_delta)
        self.patience = patience
        self.fname = fname

    def before_fit(self): self.wait = 0; super().before_fit()
    def after_epoch(self):
        "Compare the value monitored to its best score and maybe stop training."
        super().after_epoch()
        if self.new_best: self.wait = 0
        else:
            self.wait += 1
            if self.wait >= self.patience:
                print(f'No improvement since epoch {self.epoch-self.wait}: reloading previous best model.')
                self.learn = self.learn.load(self.fname)
                self.wait=0

In [None]:
# hide
from nbdev.export import *
notebook2script()

Converted 01_basics.ipynb.
Converted 02_preprocessing.ipynb.
Converted 03_transforms.ipynb.
Converted 04_dataloaders.ipynb.
Converted 05_layers.ipynb.
Converted 06_learner.ipynb.
Converted 06a_models.alexnet.ipynb.
Converted 06b_models.resnet.ipynb.
Converted 06d_models.unet.ipynb.
Converted 06f_models.losses.ipynb.
Converted 07_callback.ipynb.
Converted index.ipynb.
