In [None]:
#@title Define if we are on Colab and mount drive { display-mode: "form" }
try:
  from google.colab import drive
  drive.mount('/content/gdrive')
  IN_COLAB = True
except:
  IN_COLAB = False

In [None]:
#@title (COLAB ONLY) Clone GitHub repo { display-mode: "form" }

if IN_COLAB:
  !git clone https://github.com/lluissalord/radiology_ai.git

  %cd radiology_ai

In [None]:
#@title Setup environment and Colab general variables { display-mode: "form" }
# %%capture
%run colab_setup.ipynb

In [None]:
#@title Move images from Drive to temporary folder here to be able to train models { display-mode: "form" }
# %%capture
%run move_raw_preprocess.ipynb

In [None]:
import os

import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight

from fastai.basics import *
from fastai.callback.all import *
from fastai.data.block import *
from fastai.data.transforms import *
from fastai.vision import models
from fastai.vision.augment import *
from fastai.vision.core import PILImageBW
from fastai.vision.data import *
from fastai.vision.learner import create_cnn_model

from efficientnet_pytorch import EfficientNet

In [None]:
from semisupervised.fixmatch.losses import FixMatchLoss
from semisupervised.fixmatch.callback import FixMatchCallback
from semisupervised.ema import EMAModel

# Required to load DICOM on the fly
from preprocessing import PILDicom_scaled, init_bins, HistScaled

from utils import seed_everything, concat_templates, create_model, TestColSplitter

In [None]:
TEST_SIZE = 0.15
VALID_SIZE = 0.15

In [None]:
# Hyperparameters

HIST_SCALED = False
HIST_SCALED_SELF = True

LR = 0.002
MOMENTUM = 0.9

BATCH_SIZE = 8
RESIZE = 512
RANDOM_RESIZE_CROP = 256

MU = 7
LABEL_THRESHOLD = 0.95

LAMBDA_U = 1
EMA_DECAY = 0.999

MODEL = models.resnet18
# MODEL = 'efficientnet-b0'

In [None]:
SEED = 42

seed_everything(SEED)

In [None]:
# Transformations

item_tfms = [
    Resize(RESIZE, method=ResizeMethod.Pad, pad_mode=PadMode.Zeros),
    # RandomResizedCrop(RANDOM_RESIZE_CROP),
]

label_transform = [
    RandomResizedCropGPU(RANDOM_RESIZE_CROP),
    Flip(),
    # Normalize()
]

class Multiply_255(Transform):
    def encodes(self, o): return o * 255

weak_transform = [
    RandomResizedCropGPU(RANDOM_RESIZE_CROP),
    Flip(),
    # Multiply_255(),
    # Normalize()
]

strong_transform = [
    RandomResizedCropGPU(RANDOM_RESIZE_CROP),
    Flip(),
    Rotate(90),
    Brightness(),
    Contrast(),
    RandomErasing(),
    # Multiply_255(),
    # Normalize()
]

In [None]:
# Callbacks
from fastai.callback.tensorboard import TensorBoardCallback

cbs = None
cbs = [
    TensorBoardCallback(),
    MixUp,
]

In [None]:
# df = concat_templates(organize_folder, excel=True)
# df.to_excel(
#     os.path.join(PATH_PREFIX, 'all.xlsx'),
#     index=False
# )
df = pd.read_excel(os.path.join(PATH_PREFIX, 'all.xlsx'), dtype={'ID':'string','Target':'string'})

In [None]:
# Data

relation_df = pd.read_csv(os.path.join(PATH_PREFIX, 'relation.csv'))
relation_df = relation_df.set_index('Filename')

final_df = df.set_index('ID').merge(relation_df, left_index=True, right_index=True)
final_df['ID'] = final_df.index.values
final_df = final_df.reset_index(drop=True)
final_df['Raw_preprocess'] = final_df['Original_Filename'].apply(lambda filename: os.path.join(raw_preprocess_folder, filename + '.png'))

unlabel_df = final_df[df['Target'].isnull()].reset_index(drop=True)
label_df = final_df[df['Target'].notnull()].reset_index(drop=True)

try:
  train_df, test_df = train_test_split(label_df, test_size=TEST_SIZE, shuffle=True, stratify=label_df['Target'], random_state=SEED)
except ValueError:
  train_df, test_df = train_test_split(label_df, test_size=TEST_SIZE, shuffle=True, random_state=SEED)

try:
  train_df, val_df = train_test_split(train_df, test_size=VALID_SIZE/(1-TEST_SIZE), shuffle=True, stratify=train_df['Target'], random_state=SEED)
except ValueError:
  train_df, val_df = train_test_split(train_df, test_size=VALID_SIZE/(1-TEST_SIZE), shuffle=True, random_state=SEED)

label_df.loc[train_df.index, 'Dataset'] = 'train'
label_df.loc[val_df.index, 'Dataset'] = 'valid'
label_df.loc[test_df.index, 'Dataset'] = 'test'

sort_dataset = {'train': 0, 'valid': 1, 'test': 2}
label_df = label_df.sort_values('Dataset', key=lambda x: x.map(sort_dataset)).reset_index(drop=True)

In [None]:
# Histogram scaling DICOM on the fly

if HIST_SCALED:
    if HIST_SCALED_SELF:
        bins = None
    else:
        # bins = init_bins(fnames=L(list(final_df['Original'].values)), n_samples=100)
        bins = init_bins(fnames=L(list(final_df['Raw_preprocess'].values)), n_samples=100, isDCM=False)
    # item_tfms.append(HistScaled(bins))
    item_tfms.append(HistScaled_all(bins))

In [None]:
base_ds_params = {
    'get_x': ColReader('Original_Filename', pref=raw_preprocess_folder+'/', suff='.png'),
    # 'get_x': ColReader('Original'),
    'item_tfms': item_tfms
}

label_ds_params = base_ds_params.copy()
label_ds_params['blocks'] = (ImageBlock(cls=PILImageBW), MultiCategoryBlock)
# label_ds_params['blocks'] = (ImageBlock(cls=PILDicom_scaled), MultiCategoryBlock)
label_ds_params['get_y'] = ColReader('Target')
label_ds_params['splitter'] = TestColSplitter(col='Dataset')
label_ds_params['batch_tfms'] = label_transform

unlabel_ds_params = base_ds_params.copy()
unlabel_ds_params['blocks'] = (ImageBlock(cls=PILImageBW))
# unlabel_ds_params['blocks'] = (ImageBlock(cls=PILDicom_scaled))
unlabel_ds_params['splitter'] = RandomSplitter(0)

dls_params = {
    'bs': BATCH_SIZE,
    'num_workers': 0,
    'shuffle_train': True,
    'drop_last': True
}

unlabel_dls_params = dls_params.copy()
unlabel_dls_params['bs'] = BATCH_SIZE * MU

In [None]:
# DataLoaders
print(f'==> Preparing label dataloaders')

label_dl = DataBlock(**label_ds_params).dataloaders(label_df, **dls_params)

# Calculate sample weights to balance the DataLoader 
from collections import Counter

count = Counter(label_dl.items['Target'])
class_weights = {}
for c in count:
  class_weights[c] = 1/count[c]
wgts = label_dl.items['Target'].map(class_weights).values[:len(train_df)]

# Create weigthed dataloader
weighted_dl = DataBlock(**label_ds_params).dataloaders(label_df, **dls_params, dl_type=WeightedDL, wgts=wgts)
label_dl.train = weighted_dl.train

print(f'==> Preparing unlabel dataloaders')

unlabel_dl = DataBlock(
    **unlabel_ds_params,
).dataloaders(unlabel_df, **unlabel_dls_params)

weak_transform_dl = DataBlock(
    **unlabel_ds_params,
    batch_tfms = weak_transform
).dataloaders(unlabel_df, **unlabel_dls_params)

strong_transform_dl = DataBlock(
    **unlabel_ds_params,
    batch_tfms = strong_transform
).dataloaders(unlabel_df, **unlabel_dls_params)

print(f'==> Preparing MixMatch callback')

fix_match_cb = FixMatchCallback(unlabel_dl, weak_transform_dl, strong_transform_dl)
if cbs is None:
    cbs = [fix_match_cb]
else:
    cbs.append(fix_match_cb)

In [None]:
# Scheduling
sched = {'lr': SchedCos(LR, LR*math.cos(7*math.pi/16))}
cbs.append(ParamScheduler(sched))
moms = (MOMENTUM) # 0.9 according to FixMatch paper

In [None]:
# Model
print("==> creating model")

classes = label_df['Target'].unique()
n_out = len(classes)

model = create_model(MODEL, n_out, pretrained=True, n_in=1)

cbs.append(EMAModel(alpha=EMA_DECAY))

In [None]:
# Loss
print("==> defining loss")

# DO NOT USE CLASS WEIGHT IF WEIGHTED SAMPLER IS BEING USED
# class_weight = compute_class_weight(class_weight='balanced', classes=classes, y=train_df['Target'])
# class_weight = torch.as_tensor(class_weight).float()
# if torch.cuda.is_available():
#     class_weight = class_weight.cuda()
class_weight = None

train_criterion = FixMatchLoss(unlabel_dl=unlabel_dl, n_out=n_out, bs=BATCH_SIZE, mu=MU, lambda_u=LAMBDA_U, label_threshold=LABEL_THRESHOLD, weight=class_weight)
criterion = train_criterion.Lx_criterion

In [None]:
# Learner
print("==> defining learner")

Lx_metric = AvgMetric(func=criterion)
Lu_metric = AvgMetric(func=train_criterion.Lu_criterion)

f1_score = F1ScoreMulti(average='macro')
precision = PrecisionMulti(average='macro')
recall = RecallMulti(average='macro')
fastai_metrics = [
    # Lx_metric, Lu_metric, 
    f1_score,
    precision, recall
]

learn = Learner(label_dl, model, loss_func=train_criterion, opt_func=SGD, lr=LR, metrics=fastai_metrics, cbs=cbs)

In [None]:
learn.freeze()
learn.lr_find()

In [None]:
learn.unfreeze()
learn.lr_find()

In [None]:
learn.fine_tune(5, 0.001, freeze_epochs=1)

In [None]:
from fastai.interpret import ClassificationInterpretation

interp = ClassificationInterpretation.from_learner(learn, ds_idx=1)
interp.plot_top_losses(k=8)