# Creating custom datablocks for 3D images and fastai

In [None]:
import SimpleITK as sitk
import re
import pathlib
import torchvision

from fastai.basics import *
from fastai.vision.all import *
from fastai.callback.all import *

In [None]:
from faimed3d.basics import *
from faimed3d.augment import *
from faimed3d.data import *
from faimed3d.models import *

## Getting the data
Same approach as in Notebook 3. 

In [None]:
train = pathlib.Path('../../dl-prostate-mapping/data/train')
valid = pathlib.Path('../../dl-prostate-mapping/data/valid')

train_files = list(train.rglob('DICOM'))
valid_files = list(valid.rglob('DICOM'))

To reduce complexity of the data, only the T2 map will be used for the first runs. 

In [None]:
random.shuffle(valid_files)
# files need to be shuffeld, because otherwise the scores will not work on the valid ds, since with a bs of 10, the first valid batch will only be positive cases and the second only negative. 

In [None]:
files = train_files + valid_files

In [None]:
# take only T2 images for now
subset_files =[]
for f in files: 
    m = re.search(r'T2', str(f)) 
    if hasattr(m, 'string'): subset_files.append(Path(m.string))

In [None]:
oversampled = subset_files[0:34] + subset_files

If the patient has prostate cancer or not, can be extracted from the file path. 

In [None]:
labels = ['Gesund', 'ProstataCa']
def label_func(fn): return re.findall(r'(Gesund|ProstataCa)', str(fn))[0]

## Construct the dataloaders

In [None]:
def GreatGreatGrandparentSplitter(train_name='train', valid_name='valid'):
    "Split `items` from the great great grand parent folder names (`train_name` and `valid_name`)."
    def _inner(o):
        return _great_great_grandparent_idxs(o, train_name),_great_great_grandparent_idxs(o, valid_name)
    return _inner

In [None]:
def _great_great_grandparent_idxs(items, name):
    def _inner(items, name): return mask2idxs(Path(o).parent.parent.parent.parent.name == name for o in items)
    return [i for n in L(name) for i in _inner(items,n)]

In [None]:
mris = DataBlock(
    blocks = (ImageBlock3D(cls=TensorDicom3D), 
              CategoryBlock),
    get_x = lambda x: x,
    get_y = label_func, 
    item_tfms = ResizeCrop3D(crop_by = (0., 0.1, 0.1), resize_to = (14, 145, 145), perc_crop = True),
    batch_tfms = [
        *aug_transforms_3d(), 
        RandomCrop3D(((2, 2), (25,25), (25,25)), (0, 10, 10)), 
        PseudoColor],
    splitter = GreatGreatGrandparentSplitter())

In [None]:
dls = mris.dataloaders(oversampled, 
                       batch_size = 10, 
                       num_workers = 0
                      )
dls.valid.bs = 20 # defaults to 64 and will cause Cuda out of Memory errors

## Train a simple 3D ResNet

The data is unbalanced, so I will use the mcc as loss function and as metric, as it strongly pentalizes wrong predictions. 

In [None]:
def mcc_loss(out, targ):
    targ = torch.stack((1-targ, targ), 1) # two classes are predicted: 0 = No Cancer, 1 = Cancer. So target also needs 2 dims.
    return torch.mean(MCCLossBinary(smooth = 0.)(out, targ.long()))

def mcc_score(out, targ):
    targ = torch.stack((1-targ, targ), 1)
    return mcc_binary(out, targ)


In [None]:
learn = Learner(dls, 
                resnet_3d(n_input = 1, n_classes = 2),
                opt_func = SGD, 
                loss_func = mcc_loss,
                metrics = [error_rate, mcc_score],
                model_dir = '../models/'
               )
learn = learn.to_fp16()

In [None]:
#learn.lr_find()

All weights are random, we are doing no transfer learning here. So for the first 10 epochs learning rate is very high and all layers are unfrozen. The `fit_one_cycle` function will still make sure we start with a reasonable small lr as warmup and decrease after some epochs. 

In [None]:
learn.fit_one_cycle(10, 0.1, wd = 1e-3)

In [None]:
from sklearn.metrics import roc_curve, auc
x, y = learn.get_preds()

ns_fpr, ns_tpr, _ = roc_curve(y.numpy(), x[:, 0])
plt.plot(ns_fpr, ns_tpr, linestyle='--', label='ROC Curve')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.annotate('AUC: '+ str(auc(ns_fpr, ns_tpr)), (0.75, 0.05))
plt.legend()
plt.show()