In [None]:
#@title Define if we are on Colab and mount drive { display-mode: "form" }
run_params = {}
try:
  from google.colab import drive
  drive.mount('/content/gdrive')
  run_params['IN_COLAB'] = True
except:
  run_params['IN_COLAB'] = False

In [None]:
#@title (COLAB ONLY) Clone GitHub repo { display-mode: "form" }

if run_params['IN_COLAB']:
  !git clone https://github.com/lluissalord/radiology_ai.git

  %cd radiology_ai

In [None]:
#@title Setup environment and Colab general variables { display-mode: "form" }
%%capture
%run colab_setup.ipynb

In [None]:
#@title Move images from Drive to temporary folder here to be able to train models { display-mode: "form" }
%%capture
%run move_raw_preprocess.ipynb

In [None]:
import pandas as pd
import numpy as np
import sklearn
from sklearn.model_selection import train_test_split

import os
import gc

import matplotlib.pyplot as plt

from fastai.basics import *
from fastai.callback import *
from fastai.data.block import *
from fastai.data.transforms import *
from fastai.medical.imaging import *
from fastai.vision.data import *
from fastai.vision.learner import *
from fastai.vision.augment import *
from fastai.vision.all import *
from fastai.vision.widgets import *

In [None]:
from utils.organize import *
from utils.misc import *

from preprocessing.transforms import *
# from preprocessing.dicom import *

In [None]:
run_params['TEST_SIZE'] = 0.15
run_params['VALID_SIZE'] = 0.15

In [None]:
run_params['BATCH_SIZE'] = 32

run_params['RESIZE'] = 512


run_params['HIST_CLIPPING'] = True
run_params['HIST_CLIPPING_CUT_MIN'] = 5.
run_params['HIST_CLIPPING_CUT_MAX'] = 99.

run_params['KNEE_LOCALIZER'] = True
run_params['CLAHE_SCALED'] = True
run_params['HIST_SCALED'] = False
run_params['HIST_SCALED_SELF'] = True

run_params['BINARY_CLASSIFICATION'] = True

run_params['USE_SAVED_MODEL'] = True
run_params['SAVE_MODEL'] = False

run_params['MODEL'] = resnet18
run_params['MODEL_VERSION'] = 0
run_params['MODEL_DESCRIPTION'] = f'SUP_sz{run_params["RESIZE"]}'
run_params['MODEL_SAVE_NAME'] = f'{run_params["MODEL"].__name__}_{run_params["MODEL_DESCRIPTION"]}_v{run_params["MODEL_VERSION"]}.pkl'
run_params['MODEL_SAVE_PATH'] = os.path.join(run_params['MODELS_FOLDER'], run_params['MODEL_SAVE_NAME'])

run_params['PRETRAINED_MODEL_SAVE_NAME'] = 'resnet18_v0.pkl'
run_params['PRETRAINED_MODEL_SAVE_NAME'] = os.path.join(run_params['MODELS_FOLDER'], run_params['PRETRAINED_MODEL_SAVE_NAME'])

In [None]:
run_params['SEED'] = 42

seed_everything(run_params['SEED'])

In [None]:
# Transformations

item_tfms = []

if run_params['HIST_CLIPPING']:
    item_tfms.append(XRayPreprocess(PIL_cls=PILImageBW, cut_min=run_params['HIST_CLIPPING_CUT_MIN'], cut_max=run_params['HIST_CLIPPING_CUT_MAX'], np_input=len(item_tfms) > 0, np_output=True))

if run_params['KNEE_LOCALIZER']:
    item_tfms.append(KneeLocalizer(run_params['KNEE_SVM_MODEL_PATH'], PIL_cls=PILImageBW, resize=run_params['RESIZE'], np_input=len(item_tfms) > 0, np_output=True))
else:
    item_tfms.append(Resize(run_params['RESIZE'], method=ResizeMethod.Pad, pad_mode=PadMode.Zeros))

# item_tfms.append(RandomResizedCrop(RANDOM_RESIZE_CROP))

batch_tfms=[*aug_transforms(), Normalize()]

In [None]:
if run_params['IN_COLAB']:
  df = concat_templates(run_params['ORGANIZE_FOLDER'], excel=True)
  df.to_excel(
      os.path.join(run_params['PATH_PREFIX'], 'all.xlsx'),
      index=False
  )
else:
  df = pd.read_excel(os.path.join(run_params['PATH_PREFIX'], 'all.xlsx'), dtype={'ID':'string','Target':'string'})

In [None]:
# Data
relation_df = pd.read_csv(os.path.join(run_params['PATH_PREFIX'], 'relation.csv'))
relation_df = relation_df.set_index('Filename')

final_df = df.set_index('ID').merge(relation_df, left_index=True, right_index=True)
final_df['ID'] = final_df.index.values
final_df = final_df.reset_index(drop=True)
final_df['Raw_preprocess'] = final_df['Original_Filename'].apply(lambda filename: os.path.join(run_params['RAW_PREPROCESS_FOLDER'], filename + '.png'))

unlabel_df = final_df[df['Target'].isnull()].reset_index(drop=True)
label_df = final_df[df['Target'].notnull()].reset_index(drop=True)
if run_params['BINARY_CLASSIFICATION']:
  label_df['Target'] = (label_df['Target'] != '0').astype(int).astype('string')

try:
  train_df, test_df = train_test_split(label_df, test_size=run_params['TEST_SIZE'], shuffle=True, stratify=label_df['Target'], random_state=run_params['SEED'])
except ValueError:
  train_df, test_df = train_test_split(label_df, test_size=run_params['TEST_SIZE'], shuffle=True, random_state=run_params['SEED'])

try:
  train_df, val_df = train_test_split(train_df, test_size=run_params['VALID_SIZE']/(1-run_params['TEST_SIZE']), shuffle=True, stratify=train_df['Target'], random_state=run_params['SEED'])
except ValueError:
  train_df, val_df = train_test_split(train_df, test_size=run_params['VALID_SIZE']/(1-run_params['TEST_SIZE']), shuffle=True, random_state=run_params['SEED'])

label_df.loc[train_df.index, 'Dataset'] = 'train'
label_df.loc[val_df.index, 'Dataset'] = 'valid'
label_df.loc[test_df.index, 'Dataset'] = 'test'

In [None]:
# Histogram scaling DICOM on the fly

if run_params['CLAHE_SCALED']:
    item_tfms.append(CLAHE_Transform(PIL_cls=PILImageBW, grayscale=True, np_input=len(item_tfms) > 0, np_output=False))
elif run_params['HIST_SCALED']:
    if run_params['HIST_SCALED_SELF']:
        bins = None
    else:
        # bins = init_bins(fnames=L(list(final_df['Original'].values)), n_samples=100)
        bins = init_bins(fnames=L(list(final_df['Raw_preprocess'].values)), n_samples=100, isDCM=False)
    item_tfms.append(HistScaled(bins))

In [None]:
label_data = DataBlock(
    blocks=(ImageBlock(cls=PILImageBW), MultiCategoryBlock),
    get_x=ColReader('Original_Filename', pref=run_params['RAW_PREPROCESS_FOLDER']+'/', suff='.png'), 
    get_y=ColReader('Target'),
    splitter=TestColSplitter(col='Dataset'),
    item_tfms=item_tfms,
    batch_tfms=batch_tfms,
).dataloaders(label_df, bs=run_params['BATCH_SIZE'], num_workers=0)

label_data.show_batch()

In [None]:
import torch

def focal_loss(input, target, reduction='mean', beta=0.5, gamma=2., eps=1e-7, **kwargs):
    n = input.size(0)
    iflat = torch.sigmoid(input).view(n, -1).clamp(eps, 1-eps)
    tflat = target.view(n, -1)
    focal = -(beta*tflat*(1-iflat).pow(gamma)*iflat.log()+
             (1-beta)*(1-tflat)*iflat.pow(gamma)*(1-iflat).log()).mean(-1)
    if torch.isnan(focal.mean()) or torch.isinf(focal.mean()):
        pdb.set_trace()
    if reduction == 'mean':
        return focal.mean()
    elif reduction == 'sum':
        return focal.sum()
    else:
        return focal

class FocalLoss(nn.Module):
    def __init__(self, beta=0.5, gamma=2., reduction='mean'):
        super().__init__()
        self.beta = beta
        self.gamma = gamma
        self.reduction = reduction
        
    def forward(self, input, target, **kwargs):
        return focal_loss(input, target, beta=self.beta, gamma=self.gamma, reduction=self.reduction, **kwargs)

In [None]:
# Define the callbacks that will be used during training
callback_fns = [
        MixUp(),
        # partial(OverSamplingCallback),
        # ShowGraph
    ]
roc_auc = RocAuc()
f1_score = F1ScoreMulti(average='macro')
precision = PrecisionMulti(average='macro')
recall = RecallMulti(average='macro')
learn = cnn_learner(
    label_data,
    run_params['MODEL'],
    loss_func=FocalLoss(),
    metrics=[
        accuracy_multi,
        # roc_auc, # Not able if in some step/epoch there is only one class
        f1_score,
        precision,
        recall
    ],
    callback_fns=callback_fns
)

# Regularization by using float precision of 16 bits
# This helps to not overfit because is more difficult to "memorize" images, but enough to learn
learn = learn.to_fp16()

In [None]:
if run_params['USE_SAVED_MODEL']:
    classes = label_df['Target'].unique()
    n_out = len(classes)
    
    body = create_model(run_params['MODEL'], n_out, pretrained=True, n_in=1, bn_final=True)

    load_model(file=run_params['PRETRAINED_MODEL_SAVE_NAME'], model=body, opt=None, with_opt=False, device=torch.cuda.current_device(), strict=False)
    body = body[0]

    nf = num_features_model(nn.Sequential(*body.children())) * 2
    head = create_head(nf, n_out, concat_pool=True, bn_final=True)
        
    model = nn.Sequential(body, head)
    apply_init(model[1], nn.init.kaiming_normal_)

    learn.model = model

In [None]:
learn.lr_find()

In [None]:
learn.fine_tune(10, 0.05, freeze_epochs=3)

In [None]:
learn.show_results(max_n=25)

In [None]:
if run_params['SAVE_MODEL']:

    if not os.path.exists(run_params['MODELS_FOLDER']):
        os.makedirs(run_params['MODELS_FOLDER'])

    save_model(file=run_params['MODEL_SAVE_PATH'], model=learn.model, opt=learn.opt)

In [None]:
# Select only the top K images with largest loss

from fastai.interpret import ClassificationInterpretation
# from fastai2_extensions.interpret.all import *
# from fastai_amalgam.interpret.all import *

k = 9
largest = True
dls_idx = 1

preds, targs, decoded, all_losses = learn.get_preds(dls_idx, with_loss=True)
losses, idx = all_losses.topk(ifnone(k, len(all_losses)), largest=largest)

top_losses_dl = learn.dls.test_dl(learn.dls[dls_idx].items.iloc[idx])
top_losses_dl.bs = len(idx)

interp = ClassificationInterpretation(
    learn.dls[dls_idx],
    inputs=first(top_losses_dl),
    preds=preds[idx],
    targs=targs[idx],
    decoded=decoded[idx],
    losses=losses,
    # *tuple(map(lambda x: x[idx], learn.get_preds(dls_idx, with_input=True, with_loss=True, with_decoded=True)))
)
interp.plot_top_losses(k=k, cmap=plt.cm.bone)

In [None]:
# Plot GradCAM for the top K images with largest loss

from fastai_amalgam.interpret.gradcam import gradcam

for i in idx:
    gcam = gradcam(learn, learn.dls[dls_idx].items.iloc[i.numpy()]['Raw_preprocess'], labels=['0', '1'], show_original=True, cmap=plt.cm.bone)
    display(gcam)
    print()

In [None]:
# Plot GradCAM for the true positive images

from fastai_amalgam.interpret.gradcam import gradcam

dls_idx = 0
label_idxs = learn.dls[dls_idx].items[learn.dls[dls_idx].items['Target'] != '0'].index

for i in label_idxs:
    gcam = gradcam(learn, learn.dls[dls_idx].items.loc[i, 'Raw_preprocess'], labels=['0', '1'], show_original=True, cmap=plt.cm.bone)
    display(gcam)
    print()