In [None]:
#@title Define if we are on Colab and mount drive { display-mode: "form" }
try:
  from google.colab import drive
  drive.mount('/content/gdrive')
  IN_COLAB = True
except:
  IN_COLAB = False

In [None]:
#@title (COLAB ONLY) Clone GitHub repo { display-mode: "form" }

if IN_COLAB:
  !git clone https://github.com/lluissalord/radiology_ai.git

  %cd radiology_ai

In [None]:
#@title Setup environment and Colab general variables { display-mode: "form" }
%%capture
%run colab_setup.ipynb

In [None]:
#@title Move images from Drive to temporary folder here to be able to train models { display-mode: "form" }
%%capture
%run move_raw_preprocess.ipynb

In [None]:
import pandas as pd
import numpy as np
import sklearn
from sklearn.model_selection import train_test_split

import os
import gc

import matplotlib.pyplot as plt

from fastai.basics import *
from fastai.callback import *
from fastai.data.block import *
from fastai.data.transforms import *
from fastai.medical.imaging import *
from fastai.vision.data import *
from fastai.vision.augment import *
from fastai.vision.all import *
from fastai.vision.widgets import *

In [None]:
from utils import concat_templates, TestColSplitter

# Required to load DICOM on the fly
from preprocessing import *
from utils import seed_everything

In [None]:
TEST_SIZE = 0.15
VALID_SIZE = 0.15

In [None]:
BATCH_SIZE = 32

RESIZE = 512

HIST_CLIPPING = True
KNEE_LOCALIZER = True
CLAHE_SCALED = True
HIST_SCALED = False
HIST_SCALED_SELF = True

BINARY_CLASSIFICATION = True

In [None]:
SEED = 42

seed_everything(SEED)

In [None]:
# Transformations

item_tfms = []

if HIST_CLIPPING:
    item_tfms.append(XRayPreprocess())
    
if KNEE_LOCALIZER:
    item_tfms.append(KneeLocalizer(KNEE_SVM_MODEL_PATH))

item_tfms.append(Resize(RESIZE, method=ResizeMethod.Pad, pad_mode=PadMode.Zeros))
# item_tfms.append(RandomResizedCrop(RANDOM_RESIZE_CROP))

batch_tfms=[*aug_transforms(), Normalize()]

In [None]:
if IN_COLAB:
  df = concat_templates(organize_folder, excel=True)
  df.to_excel(
      os.path.join(PATH_PREFIX, 'all.xlsx'),
      index=False
  )
else:
  df = pd.read_excel(os.path.join(PATH_PREFIX, 'all.xlsx'), dtype={'ID':'string','Target':'string'})

In [None]:
# Data
relation_df = pd.read_csv(os.path.join(PATH_PREFIX, 'relation.csv'))
relation_df = relation_df.set_index('Filename')

final_df = df.set_index('ID').merge(relation_df, left_index=True, right_index=True)
final_df['ID'] = final_df.index.values
final_df = final_df.reset_index(drop=True)
final_df['Raw_preprocess'] = final_df['Original_Filename'].apply(lambda filename: os.path.join(raw_preprocess_folder, filename + '.png'))

unlabel_df = final_df[df['Target'].isnull()].reset_index(drop=True)
label_df = final_df[df['Target'].notnull()].reset_index(drop=True)
if BINARY_CLASSIFICATION:
  label_df['Target'] = (label_df['Target'] != '0').astype(int).astype('string')

try:
  train_df, test_df = train_test_split(label_df, test_size=TEST_SIZE, shuffle=True, stratify=label_df['Target'], random_state=SEED)
except ValueError:
  train_df, test_df = train_test_split(label_df, test_size=TEST_SIZE, shuffle=True, random_state=SEED)

try:
  train_df, val_df = train_test_split(train_df, test_size=VALID_SIZE/(1-TEST_SIZE), shuffle=True, stratify=train_df['Target'], random_state=SEED)
except ValueError:
  train_df, val_df = train_test_split(train_df, test_size=VALID_SIZE/(1-TEST_SIZE), shuffle=True, random_state=SEED)

label_df.loc[train_df.index, 'Dataset'] = 'train'
label_df.loc[val_df.index, 'Dataset'] = 'valid'
label_df.loc[test_df.index, 'Dataset'] = 'test'

In [None]:
# Histogram scaling DICOM on the fly

if CLAHE_SCALED:
    item_tfms.append(CLAHE_Transform(grayscale=not SELF_SUPERVISED))
elif HIST_SCALED:
    if HIST_SCALED_SELF:
        bins = None
    else:
        # bins = init_bins(fnames=L(list(final_df['Original'].values)), n_samples=100)
        bins = init_bins(fnames=L(list(final_df['Raw_preprocess'].values)), n_samples=100, isDCM=False)
    # item_tfms.append(HistScaled(bins))
    item_tfms.append(HistScaled_all(bins))

In [None]:
label_data = DataBlock(
    blocks=(ImageBlock(cls=PILImageBW), MultiCategoryBlock),
    get_x=ColReader('Original_Filename', pref=raw_preprocess_folder+'/', suff='.png'), 
    get_y=ColReader('Target'),
    splitter=TestColSplitter(col='Dataset'),
    item_tfms=item_tfms,
    batch_tfms=batch_tfms,
).dataloaders(label_df, bs=BATCH_SIZE, num_workers=0)

label_data.show_batch()

In [None]:
import torch

def focal_loss(input, target, reduction='mean', beta=0.5, gamma=2., eps=1e-7, **kwargs):
    n = input.size(0)
    iflat = torch.sigmoid(input).view(n, -1).clamp(eps, 1-eps)
    tflat = target.view(n, -1)
    focal = -(beta*tflat*(1-iflat).pow(gamma)*iflat.log()+
             (1-beta)*(1-tflat)*iflat.pow(gamma)*(1-iflat).log()).mean(-1)
    if torch.isnan(focal.mean()) or torch.isinf(focal.mean()):
        pdb.set_trace()
    if reduction == 'mean':
        return focal.mean()
    elif reduction == 'sum':
        return focal.sum()
    else:
        return focal

class FocalLoss(nn.Module):
    def __init__(self, beta=0.5, gamma=2., reduction='mean'):
        super().__init__()
        self.beta = beta
        self.gamma = gamma
        self.reduction = reduction
        
    def forward(self, input, target, **kwargs):
        return focal_loss(input, target, beta=self.beta, gamma=self.gamma, reduction=self.reduction, **kwargs)

In [None]:
# Define the callbacks that will be used during training
callback_fns = [
        MixUp(),
        # partial(OverSamplingCallback),
        # ShowGraph
    ]
roc_auc = RocAuc()
f1_score = F1ScoreMulti(average='macro')
precision = PrecisionMulti(average='macro')
recall = RecallMulti(average='macro')
learn = cnn_learner(
    label_data,
    resnet18,
    loss_func=FocalLoss(),
    metrics=[
        accuracy_multi,
        # roc_auc, # Not able if in some step/epoch there is only one class
        f1_score,
        precision,
        recall
    ],
    callback_fns=callback_fns
)

# Regularization by using float precision of 16 bits
# This helps to not overfit because is more difficult to "memorize" images, but enough to learn
learn = learn.to_fp16()

In [None]:
learn.lr_find()

In [None]:
learn.fine_tune(10, 0.05, freeze_epochs=3)

In [None]:
learn.show_results(max_n=25)

In [None]:
interp = Interpretation.from_learner(learn, 1)
losses, idx = interp.top_losses()
interp.plot_top_losses(25, figsize=(15,10))