# Training a 3D CNN

In [None]:
from faimed3d.basics import *
from faimed3d.augment import *
from faimed3d.models import *
from faimed3d.data import *

import pathlib
import re
from fastai.basics import *
from fastai.vision.data import *

## Loading the data

In [None]:
train = pathlib.Path('../../dl-prostate-mapping/data/train')
valid = pathlib.Path('../../dl-prostate-mapping/data/valid')
test = pathlib.Path('../../dl-prostate-mapping/data/test')

train_files = list(train.rglob('DICOM')) + list(train.rglob('T2W'))
valid_files = list(valid.rglob('DICOM'))
test_files = list(test.rglob('DICOM'))

# take only T2 and T1 images for noe
subset_train =[]
for f in train_files: 
    m = re.search(r'T2', str(f)) 
    if hasattr(m, 'string'): subset_train.append(Path(m.string))
        
subset_valid =[]
for f in valid_files: 
    m = re.search(r'T2', str(f)) 
    if hasattr(m, 'string'): subset_valid.append(Path(m.string))
        
subset_test = []
for f in test_files: 
    m = re.search(r'T2', str(f)) 
    if hasattr(m, 'string'): subset_test.append(Path(m.string))
        
        
def label_func(fn):
    return re.findall(r'(Gesund|ProstataCa)', str(fn))[0]
labels = ['Gesund', 'ProstataCa']

In [None]:
oversampled_train = subset_train +  subset_valid 

## Setting up dataloaders

In [None]:
labels = ['Gesund', 'ProstataCa']
def label_func(fn): return re.findall(r'(Gesund|ProstataCa)', str(fn))[0]

In [None]:
mris = DataBlock(
            blocks = (ImageBlock3D(cls=TensorDicom3D), 
                      CategoryBlock),
            get_x = lambda x: x,
            get_y = label_func, 
            item_tfms = ResizeCrop3D(crop_by = (0., 0.15, 0.15), resize_to = (20, 200, 200), perc_crop = True),
            batch_tfms = [*aug_transforms_3d(p_all =0.15), RandomCrop3D((0, 25, 25), (0, 10, 10)), PseudoColor],
            splitter = RandomSplitter())

In [None]:
dls = mris.dataloaders(source = oversampled_train, 
                      num_workers = 0, 
                      batch_size = 8)
dls.valid.bs = 8 # fastai takes a larger bs for valid dset, however for 3D this is to large

In [None]:
class Sequential_(nn.Sequential):
    "Somehow, in 3D CNNs, the input is not transfered to cuda. I believe something in the transforms is wrong. Until this is fixed, subcalssing nn.Sequential is the workarround"
    def forward(self, input):
        for module in self:
            input = module(input.cuda())
        return input

def block(ni, nf, **kwargs): 
    return Sequential_(
        ResBlock(1, ni, nf, stride = (2,2,1), ndim = 3, **kwargs), 
        ResBlock(1, nf, nf, ndim = 3))

#def block(ni, nf, **kwargs): return ConvLayer(ni, nf, ndim = 3, **kwargs)

def get_model():
    return Sequential_(
        block(20, 128, ks = 7),
        block(128, 256, ks = 5),
        block(256, 512),
        block(512, 768),
        nn.AdaptiveAvgPool3d(1),
        Flatten(),
        nn.Linear(768, dls.c))

In [None]:
def loss_func(out, targ):
    return CrossEntropyLossFlat()(out, targ.long())

In [None]:
auc = RocAucBinary()

In [None]:
from fastai.callback.hook import *

In [None]:
learn = Learner(dls, get_model(), opt_func = SGD, loss_func = loss_func, metrics=[error_rate, auc], cbs=ActivationStats(with_hist=True))
learn = learn.to_fp16()

In [None]:
#learn.lr_find()

In [None]:
learn.fit_one_cycle(n_epoch = 100, lr_max = 0.001)

In [None]:
learn.recorder.plot_sched()

In [None]:
learn.activation_stats.plot_layer_stats(-4)

In [None]:
learn.activation_stats.color_dim(-1)