# Callbacks


The classes and functions in this notebook are highly specific and probably not usefull for the most tasks.


In [1]:
# export
# default_exp callback

from fastai.basics import *
from fastai.vision.all import *
from fastai.callback.all import *

import torch.nn.functional as F

In [2]:
# export

from faimed3d.basics import *
from faimed3d.augment import *
from faimed3d.data import *
from faimed3d.models import *

## Basic training for testing

In [3]:
train = Path('../../dl-prostate-mapping/data/train') # ../data/train would work to, but with this lon string, I can also test the nb in faimed3d
valid = Path('../../dl-prostate-mapping/data/valid')

files = list(train.rglob('DICOM')) + list(valid.rglob('DICOM'))

# take only T2 images for now
t2_files, adc_files, dwi_files = [], [], []
for f in files: 
    m = re.search(r'T2|ADC|DWI', str(f)) 
    if hasattr(m, 'string'):
        if 'T2' in m.string: t2_files.append(Path(m.string)/'cropped_volume.nii.gz') 
        if 'ADC' in m.string: adc_files.append(Path(m.string)/'cropped_volume.nii.gz') 
        if 'DWI' in m.string: dwi_files.append(Path(m.string)/'cropped_volume.nii.gz')     
            
t2_segmentation = [(p.parent.parent/'Annotation/cropped_mask_adapted.nii.gz') for p in t2_files]
adc_segmentation = [(p.parent.parent/'Annotation/cropped_mask_adapted.nii.gz') for p in adc_files]

mris = DataBlock(
    blocks = (ImageBlock3D(cls=TensorDicom3D), 
              ImageBlock3D(cls=TensorDicom3D), 
              ImageBlock3D(cls=TensorDicom3D), 
              MaskBlock3D(codes = ['void', "peripheral", 'transitional', 'cancer']), 
              MaskBlock3D(codes = ['void', "peripheral", 'transitional', 'cancer']), 
             ),
    get_x = [lambda x: x[0], lambda x: x[1], lambda x: x[2]],
    get_y = [lambda x: x[3], lambda x: x[4]],
    item_tfms = ResizeCrop3D(crop_by = (0, 0, 0), resize_to = (16, 80, 80)),
    batch_tfms = [
        *aug_transforms_3d(p_all=0.15, do_rotate_by = False), 
        PseudoColor],
    splitter = ColSplitter(),
    n_inp = 3)

d = pd.DataFrame({'T2' : t2_files,
                  'ADC': adc_files,
                  'DWI': dwi_files,
                  'mask_t2' : t2_segmentation, 
                  'mask_adc': adc_segmentation,
                  'is_valid': [1 if 'valid' in str(o) else 0 for o in t2_files]})

dls = mris.dataloaders(d, 
                       batch_size = 10, 
                       num_workers = 0,
                      )
dls.valid.bs = 20 # defaults to 64 and will cause Cuda out of Memory errors

In [4]:
class SplitVolumes(Callback):
    """
        Separates a 3D tensor into smaller equal sized subvolumes. 

        Args: 
            n_subvol = number of subvolumes
            split_along_depth = whether volumes should also be split along the D dimension fpr a [B, C, D, H, W] tensor
    """
    def __init__(self, n_subvol = 2**3, split_along_depth = True):
        store_attr()
        
    def before_batch(self):
        xb = self.learn.xb
        if len(xb) > 1: raise ValueError('Got multiple items in x batch. You need to concatenate the batch first.')     
        self.learn.xb = self.split_volume(xb)
        self.learn.yb = self.split_volume(self.learn.yb)

    def after_pred(self):
        self.learn.xb = self.patch_volume(self.learn.xb)
        self.learn.pred = detuplify(self.patch_volume(self.learn.pred))
        self.learn.yb = self.patch_volume(self.learn.yb)
    
    def split_volume(self, xb:(Tensor, TensorDicom3D, TensorMask3D)):
        "splits a large tensor into multiple smaller tensors"

        xb = detuplify(xb) # xb is always a tuple
        # calculate number of splits per dimension
        self.n = self.n_subvol**(1/3) if self.split_along_depth else self.n_subvol**0.5
        self.n = int(self.n)
        
        # check if shape of dims is divisible by n, if not resize the Tensor acordingly
        shape = [s if s % self.n == 0 else s - s % self.n for s in xb.shape[-3:]] 
        if not self.split_along_depth: shape[0]=xb.shape[0]    
        xb = F.interpolate(xb, size = shape, mode = "trilinear", align_corners=True)
        
        # split each dim into smaller dimensions
        d, h, w = shape
        if self.split_along_depth: xb = xb.reshape(xb.size(0), xb.size(1), self.n, int(d/self.n), self.n, int(h/self.n), self.n, int(w/self.n))
        else: xb = xb.reshape(xb.size(0), xb.size(1),1, d, self.n, int(h/self.n), self.n, int(w/self.n))    
            
        # swap the dimensions an flatten Batchdim and the newly created dims    
        # return a tuple as xb is always a tuple
        return (xb.permute(1, 3, 5, 7, 0, 2, 4, 6).flatten(-4).permute(4, 0, 1, 2, 3), )

    def patch_volume(self, p:(Tensor, TensorDicom3D, TensorMask3D)):
        "patches a prior split volume back together"
        p = detuplify(p)

        old_shape = p.shape[0]//self.n_subvol, p.shape[1], *[s * self.n for s in p.shape[2:]]
        if not self.split_along_depth: old_shape[2]=p.shape[2]
        p = p.reshape(old_shape[0], self.n, self.n, self.n, *p.shape[1:])
        return (p.permute(0, 4, 1, 5, 2, 6, 3, 7).reshape(old_shape), )

In [5]:
# export
class StackVolumesCallback(Callback):
    run_before = SplitVolumes
    def before_batch(self):
        self.learn.xb = (torch.cat(self.learn.xb, dim=1), )       
    
        y_one_hot = [self.to_one_hot(y, n_classes = 4) for y in self.learn.yb]
        y_cat = torch.stack(y_one_hot).max(dim = 0)[0]
        self.learn.yb = (y_cat, )
        
    def make_binary(self, target, set_to_one):
        return (target == set_to_one).float()

    def to_one_hot(self, target, n_classes):
        target = target.squeeze(1).long() # remove the solitary color channel (if there is one) and set type to int64
        one_hot = [self.make_binary(target, set_to_one=i) for i in range(0, n_classes)]
        return torch.stack(one_hot, 1)

In [6]:
learn = Learner(dls, 
                UResNet3D(n_channels = 3, n_classes = 4),
                opt_func = SGD, 
                loss_func = MCCLossMulti(4),
                cbs = [StackVolumesCallback, SplitVolumes],
               )
learn = learn.to_fp16()
#learn.unfreeze()

In [7]:
learn.fit_one_cycle(1)

epoch,train_loss,valid_loss,time
0,0.91404,1.000124,00:37


## 3D MixUp