# Connect to the drive to get the data

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Sometimes it's needed to hard reset the machine using command below.

In [None]:
!kill -9 -1

Check how much RAM on GPU is available 

In [None]:
!ln -sf /opt/bin/nvidia-smi /usr/bin/nvidia-smi
!pip install gputil
!pip install psutil
!pip install humanize
import psutil
import humanize
import os
import GPUtil as GPU
GPUs = GPU.getGPUs()

# XXX: only one GPU on Colab and isn’t guaranteed
gpu = GPUs[0]

def printm():
    process = psutil.Process(os.getpid())
    print("Gen RAM Free: " + humanize.naturalsize( psutil.virtual_memory().available ), " | Proc size: " + humanize.naturalsize( process.memory_info().rss))
    print("GPU RAM Free: {0:.0f}MB | Used: {1:.0f}MB | Util {2:3.0f}% | Total {3:.0f}MB".format(gpu.memoryFree, gpu.memoryUsed, gpu.memoryUtil*100, gpu.memoryTotal))

    printm()

# Install deps

In [None]:
!curl https://course-v3.fast.ai/setup/colab | bash

In [None]:
!pip install pydicom

In [None]:
import os
import re

import numpy as np
import pydicom
import matplotlib.pyplot as plt

from fastai import *
from fastai.vision import *
import cv2

from fastai.layers import FlattenedLoss

Copy data from drive to the filesystem - this might make things faster.
Note: from time to time you need to update compressed file.

In [None]:
drive_data_path = './drive/My Drive/Data/'
data_path = './data'

In [None]:
!cp ./drive/My\ Drive/data.tar.gz .

In [None]:
!tar xzf data.tar.gz

In [None]:
scans = []
for root, dirs, files in os.walk(data_path):
    if 'CT' in root:
        # remove wrongly labeled data
        if not 'P1B1' in root:
            for _file in files:
                scans.append(root + '/' + _file)

In [None]:
np.random.seed(42)
np.random.shuffle(scans)

Overwrite fastai loading images methods

In [None]:
def open_dcm_image(fn, *args, **kwargs)->Image:
#     window_min = -100
#     window_max = 400
    window_min = -100
    window_max = 100

    array = pydicom.dcmread(fn).pixel_array
    array = np.clip(array, a_min=window_min, a_max=window_max)
            
    array = (((array - array.min()) / (array.max() - array.min())) * (255 - 0) + 0).astype(np.uint8)
    array = cv2.equalizeHist(array.astype(np.uint8))

    array = np.repeat(array[:, :, None], 3, axis=2)
    
    # we can store images in this format :top: to make stuff faster...
    return Image(pil2tensor(array, np.float32).div_(255))

def open_dcm_mask(fn, *args, **kwargs)->Image:
    x = pydicom.dcmread(fn).pixel_array
    x = pil2tensor(x, np.float32)
    return ImageSegment(x)


def annotate_metadata(fn, ax):
    subdirs = fn.split('/')
    patient_id = subdirs[-3]
    slice_number = re.findall(r'\d+', subdirs[-1])[0]
    ax.annotate(
        '{} [{}]'.format(patient_id, slice_number),
        xy=(.25, .25),
        xycoords='data', 
        xytext=(30, 10),
        fontsize=20,
        textcoords='offset points',
    )

                                  
# monkey patch
fastai.vision.image.open_image = open_dcm_image
fastai.vision.image.open_mask = open_dcm_mask
fastai.vision.data.open_image = open_dcm_image
fastai.vision.data.open_mask = open_dcm_mask
open_image = open_dcm_image
open_mask = open_dcm_mask

# Look at the data

In [None]:
open_image(scans[1003])

In [None]:
get_y_fn = lambda path: str('.' / Path(path).parent / '../label' / Path(path).name)

open_mask(get_y_fn(scans[1003]))

In [None]:
codes = ['void', 'water']

In [None]:
src = (
    SegmentationItemList.from_df(pd.DataFrame(scans, columns=['files']), '.')
    .split_by_valid_func(lambda img_src: 'P7' in str(img_src) or 'P6' in str(img_src))
    .label_from_func(get_y_fn, classes=codes)
)
src

In [None]:
img = open_image(scans[600]).data
src_size = np.array(img.shape[1:])
size = src_size // 4
size

In [None]:
bs = 80

In [None]:
data = (
    # note wrap might deform images. For now I've set up 0, maybe we can use it.
    src.transform(get_transforms(max_rotate=5., max_lighting=0, p_lighting=0, max_warp=0), size=size, tfm_y=True)
    .databunch(bs=bs)
    .normalize(imagenet_stats)
)

In [None]:
data.show_batch(2, figsize=(10,7))

In [None]:
data.show_batch(2, figsize=(10,7), ds_type=DatasetType.Valid)

# Choose metrics to evaluate 

In [None]:
from fastai.metrics import accuracy, dice

In [None]:
def acc(input, target):
    target = target.squeeze(1)
    return (input.argmax(dim=1)==target).float().mean()

In [None]:
metrics=[acc, dice]

# Implement new loss functions

In [None]:
from torch.nn.modules.loss import _Loss

In [None]:
class DiceLoss(_Loss):
    def __init__(self, **kwargs):
        super(DiceLoss, self).__init__(**kwargs)
        self.softmax = nn.Softmax(1)

    def forward(self, input, target):
        input = self.softmax(input)[:, 1]
        target = target.float()
        smooth = 1.
        intersection = (input * target).sum()

        return 1 - ((2. * intersection + smooth) /
                  (input.sum() + target.sum() + smooth))


class GeneralizedDiceLoss(_Loss):
    # reference: https://niftynet.readthedocs.io/en/dev/_modules/niftynet/layer/loss_segmentation.html#generalised_dice_loss
    def __init__(self, **kwargs):
        super(GeneralizedDiceLoss, self).__init__(**kwargs)
        self.softmax = nn.Softmax(1)

    def forward(self, input, target):
        prediction = self.softmax(input)
        one_hot = (
            torch.sparse.torch.eye(2).cuda()
            .index_select(0, target.long())
        )
          
        ref_vol = torch.sum(one_hot, 0)
        
        seg_vol = torch.sum(prediction, 0)
        intersect = torch.sum(one_hot * prediction, 0)
        
        weights = torch.reciprocal(ref_vol ** 2)
        weights[weights == float("Inf")] = 0

        generalised_dice_numerator = 2 * torch.sum(weights * intersect)
        generalised_dice_denominator = torch.sum(
            weights * torch.max(seg_vol + ref_vol, torch.ones_like(weights))
        )
        generalised_dice_score = \
            generalised_dice_numerator / generalised_dice_denominator
        
        generalised_dice_score[torch.isnan(generalised_dice_score)] =  1.
        return 1 - generalised_dice_score
            
dice_loss = FlattenedLoss(DiceLoss, axis=1)
generalized_dice_loss = FlattenedLoss(GeneralizedDiceLoss, axis=1)

In [None]:
dice_loss(torch.Tensor([[10, 1], [10, 0]]), torch.Tensor([[1], [1]]))

In [None]:
generalized_dice_loss(torch.Tensor([[10, 1], [10, 0]]).cuda(), torch.Tensor([[1], [1]]).cuda())

# Train model

In [None]:
learn = unet_learner(
    data, models.resnet34, metrics=metrics, 
    self_attention=False,
    loss_func=generalized_dice_loss,
)

In [None]:
lr_find(learn)
learn.recorder.plot()

In [None]:
lr=3e-5

In [None]:
learn.fit_one_cycle(10, slice(lr), pct_start=0.9)

In [None]:
learn.save('3_1')

In [None]:
learn.load('3_1');

In [None]:
!cp ./models/3_1.pth ./drive/My\ Drive/

In [None]:
learn.show_results(rows=20)

In [None]:
learn.unfreeze()

In [None]:
lr_find(learn)
learn.recorder.plot()

In [None]:
lrs = slice(1e-6, 8e-5)

In [None]:
learn.fit_one_cycle(12, lrs, pct_start=0.8)

In [None]:
learn.recorder.plot_losses()

In [None]:
learn.recorder.plot_lr()

In [None]:
learn.save('3_2');

In [None]:
!mkdir -p ./drive/My\ Drive/Code/
!cp ./models/3_2.pth ./drive/My\ Drive/

In [None]:
learn = learn.load('3_2')

In [None]:
learn.show_results(rows=24)

# Go big - full size of an image

In [None]:
!mkdir -p models
!cp ./drive/My\ Drive/Code/Mateusz/stage-1.pth ./models/stage-1.pth 

In [None]:
size = src_size
bs = 5

In [None]:
data = (
    src.transform(get_transforms(max_rotate=5., max_lighting=0, p_lighting=0), size=size, tfm_y=True)
    .databunch(bs=bs)
    .normalize(imagenet_stats)
)

In [None]:
learn = unet_learner(
    data, models.resnet34, metrics=metrics, self_attention=True,
)

In [None]:
learn.load('stage-1');

In [None]:
lr_find(learn)
learn.recorder.plot()

In [None]:
lr=1e-5

In [None]:
learn.fit_one_cycle(3, slice(lr))

In [None]:
learn.save('stage-1-big')

In [None]:
learn.show_results()

In [None]:
!cp ./models/stage-1-big.pth ./drive/My\ Drive/

In [None]:
learn.load('stage-1-big');

In [None]:
learn.unfreeze()

In [None]:
lrs = slice(1e-6,1e-4)

In [None]:
learn.fit_one_cycle(10, lrs, wd=1e-3)

In [None]:
learn.save('stage-2-big')

In [None]:
learn.load('stage-2-big')

In [None]:
learn.show_results()

In [None]:
!cp ./models/stage-2-big.pth ./drive/My\ Drive/