In [None]:
#@title Define if we are on Colab and mount drive { display-mode: "form" }
try:
  from google.colab import drive
  drive.mount('/content/gdrive')
  IN_COLAB = True
except:
  IN_COLAB = False

In [None]:
#@title (COLAB ONLY) Clone GitHub repo { display-mode: "form" }

if IN_COLAB:
  !git clone https://github.com/lluissalord/radiology_ai.git

  %cd radiology_ai

In [None]:
#@title Setup environment and Colab general variables { display-mode: "form" }
# %%capture
%run colab_setup.ipynb

In [None]:
#@title Move images from Drive to temporary folder here to be able to train models { display-mode: "form" }
# %%capture
%run move_raw_preprocess.ipynb

In [None]:
import os

import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight

from fastai.basics import *
from fastai.callback.all import *
from fastai.data.block import *
from fastai.data.transforms import *
from fastai.vision import models
from fastai.vision.augment import *
from fastai.vision.core import PILImageBW, PILImage
from fastai.vision.data import *

In [None]:
SELF_SUPERVISED = True

if SELF_SUPERVISED:
    import pl_bolts
    from pl_bolts.models.self_supervised import SimCLR
    from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform, SimCLREvalDataTransform
    from pytorch_lightning import Trainer

In [None]:
from preprocessing import init_bins, HistScaled

from utils import seed_everything, concat_templates, create_model, TestColSplitter

In [None]:
SSL_MIX_MATCH = 'MixMatch'
SSL_FIX_MATCH = 'FixMatch'

SSL = SSL_FIX_MATCH

if SSL == SSL_FIX_MATCH:
    from semisupervised.fixmatch.losses import FixMatchLoss as SSLLoss
    from semisupervised.fixmatch.callback import FixMatchCallback as SSLCallback
elif SSL == SSL_MIX_MATCH:
    from semisupervised.mixmatch.losses import MixMatchLoss as SSLLoss
    from semisupervised.mixmatch.callback import MixMatchCallback as SSLCallback

from semisupervised.ema import EMAModel

In [None]:
TEST_SIZE = 0.15
VALID_SIZE = 0.15

In [None]:
# Hyperparameters

HIST_SCALED = False
HIST_SCALED_SELF = True

CLASS_WEIGHT = False
WEIGTHED_SAMPLER = True

LR = 0.002

RESIZE = 128
RANDOM_RESIZE_CROP = 256

SELF_SUPERVISED_BATCH_SIZE = 2

if SSL == SSL_FIX_MATCH:
    BATCH_SIZE = 8
    MOMENTUM = 0.9
    LAMBDA_U = 1
    MU = 5
    LABEL_THRESHOLD = 0.95

    cb_params = {}

    loss_params = {
        'bs': BATCH_SIZE,
        'mu': MU,
        'lambda_u': LAMBDA_U,
        'label_threshold': LABEL_THRESHOLD
    }
elif SSL == SSL_MIX_MATCH:
    BATCH_SIZE = 16
    LAMBDA_U = 75
    T = 0.5
    ALPHA = 0.75

    cb_params = {
        'T': T
    }

    loss_params = {
        'bs': BATCH_SIZE,
        'lambda_u': LAMBDA_U,
    }


EMA_DECAY = 0.999

MODEL = models.resnet18
# MODEL = 'efficientnet-b0'

In [None]:
SEED = 42

seed_everything(SEED)

In [None]:
# Transformations

item_tfms = [
    Resize(RESIZE, method=ResizeMethod.Pad, pad_mode=PadMode.Zeros),
    # RandomResizedCrop(RANDOM_RESIZE_CROP),
]

label_transform = [
    RandomResizedCropGPU(RANDOM_RESIZE_CROP),
    Flip(),
    # Normalize()
]

# class Multiply_255(Transform):
#     def encodes(self, o): return o * 255

unlabel_batch_tfms = [None]
if SSL == SSL_FIX_MATCH:

    weak_transform = [
        RandomResizedCropGPU(RANDOM_RESIZE_CROP),
        Flip(),
        # Multiply_255(),
        # Normalize()
    ]
    unlabel_batch_tfms.append(weak_transform)

    strong_transform = [
        RandomResizedCropGPU(RANDOM_RESIZE_CROP),
        Flip(),
        Rotate(90),
        Brightness(),
        Contrast(),
        RandomErasing(),
        # Multiply_255(),
        # Normalize()
    ]
    unlabel_batch_tfms.append(strong_transform)

elif SSL == SSL_MIX_MATCH:

    unlabel_transform = [
        RandomResizedCropGPU(RANDOM_RESIZE_CROP),
        Flip(),
        Rotate(180, p=1),
        # Multiply_255(),
        # Normalize()
    ]
    unlabel_batch_tfms.append(unlabel_transform)

In [None]:
# Callbacks
from fastai.callback.tensorboard import TensorBoardCallback

cbs = None
cbs = [
    TensorBoardCallback(),
]

In [None]:
# df = concat_templates(organize_folder, excel=True)
# df.to_excel(
#     os.path.join(PATH_PREFIX, 'all.xlsx'),
#     index=False
# )
df = pd.read_excel(os.path.join(PATH_PREFIX, 'all.xlsx'), dtype={'ID':'string','Target':'string'})

In [None]:
# Data

# Load DataFrame of relation between Original Filename and ID (IMG_XXX)
relation_df = pd.read_csv(os.path.join(PATH_PREFIX, 'relation.csv'))
relation_df = relation_df.set_index('Filename')

# Merge data to be able to load directly from preprocessed PNG file
final_df = df.set_index('ID').merge(relation_df, left_index=True, right_index=True)
final_df['ID'] = final_df.index.values
final_df = final_df.reset_index(drop=True)
final_df['Raw_preprocess'] = final_df['Original_Filename'].apply(lambda filename: os.path.join(raw_preprocess_folder, filename + '.png'))

# Load DataFrame containing labels of OOS classifier ('ap', 'other')
metadata_labels_path = os.path.join(PATH_PREFIX, 'metadata_labels.csv')
metadata_labels = pd.read_csv(metadata_labels_path)
metadata_labels = metadata_labels.set_index('Path')

# Merge all the data we have with the labelling in order to split correctly according to OOS classifier
unlabel_all_df = metadata_labels.merge(final_df.set_index('Raw_preprocess'), how='left', left_index=True, right_index=True)
unlabel_all_df = unlabel_all_df[unlabel_all_df.Target.isnull()]
unlabel_all_df['Raw_preprocess'] = unlabel_all_df.index.values

# Define which column to use as the prediction
if 'Final_pred' in unlabel_all_df.columns:
    pred_col = 'Final_pred'
else:
    pred_col = 'Pred'

# Conditions for AP radiographies on unlabel data
ap_match = (unlabel_all_df[pred_col] == 'ap') & (unlabel_all_df.Incorrect_image.isnull())

# Split between label_df (labelled data), `unlabel_df` (containing only AP) and `unlabel_other_df` (with the rest of unlabel data)
label_df = final_df[final_df['Target'].notnull()].reset_index(drop=True)
unlabel_df = unlabel_all_df[ap_match].reset_index(drop=True)
unlabel_other_df = unlabel_all_df[~ap_match].reset_index(drop=True)

print(f'Currently {len(label_df.index)} data have been labelled')
print(f'Remaining {len(unlabel_df.index)} data to be labelled')
print(f'Discarded {len(unlabel_other_df.index)} data')

# Split between train, valid and test
try:
  train_df, test_df = train_test_split(label_df, test_size=TEST_SIZE, shuffle=True, stratify=label_df['Target'], random_state=SEED)
except ValueError:
  train_df, test_df = train_test_split(label_df, test_size=TEST_SIZE, shuffle=True, random_state=SEED)

try:
  train_df, val_df = train_test_split(train_df, test_size=VALID_SIZE/(1-TEST_SIZE), shuffle=True, stratify=train_df['Target'], random_state=SEED)
except ValueError:
  train_df, val_df = train_test_split(train_df, test_size=VALID_SIZE/(1-TEST_SIZE), shuffle=True, random_state=SEED)

label_df.loc[train_df.index, 'Dataset'] = 'train'
label_df.loc[val_df.index, 'Dataset'] = 'valid'
label_df.loc[test_df.index, 'Dataset'] = 'test'

print('\nSplit of labelled data is:')
display(label_df['Dataset'].value_counts())

sort_dataset = {'train': 0, 'valid': 1, 'test': 2}
label_df = label_df.sort_values('Dataset', key=lambda x: x.map(sort_dataset)).reset_index(drop=True)

In [None]:
# Histogram scaling DICOM on the fly

if HIST_SCALED:
    if HIST_SCALED_SELF:
        bins = None
    else:
        # bins = init_bins(fnames=L(list(final_df['Original'].values)), n_samples=100)
        all_valid_raw_preprocess = pd.concat([pd.Series(unlabel_all_df.index), label_df['Raw_preprocess']])
        bins = init_bins(fnames=L(list(all_valid_raw_preprocess.values)), n_samples=100, isDCM=False)
    # item_tfms.append(HistScaled(bins))
    item_tfms.append(HistScaled_all(bins))

In [None]:
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
import PIL
from tqdm import tqdm

class MyDataset(Dataset):
    def __init__(self, df, validation = False, transform=None, src_folder=raw_preprocess_folder+'/'):
                    
        suffix = '.png'
        self.transform = transform

        #use sklearn's module to return training data and test data
        if validation:
            _, self.df = train_test_split(df, test_size=0.20, random_state=42)

        else:
            self.df, _ = train_test_split(df, test_size=0.20, random_state=42)

        self.image_pairs = []

        for idx, d in tqdm(enumerate(self.df['Original_Filename']), total=len(self.df.index)):
          
            im = PIL.Image.open(src_folder + d + suffix).convert('RGB')

            if self.transform:
                sample = self.transform(im) #applies the SIMCLR transform required, including new rotation
            else:
                sample = im

            self.image_pairs.append(sample)
          
    def __len__(self):
        return len(self.df.index)

    def __getitem__(self, idx):
        #doing the PIL.image.open and transform stuff here is quite slow
        return (self.image_pairs[idx], 0)

In [None]:
# Self supervised DataLoader from Fastai DataBlock/Dataloader

# self_sup_ds_params = unlabel_ds_params.copy()
# self_sup_ds_params['splitter'] = RandomSplitter(0.2)
# self_sup_ds_params['batch_tfms'] = SimCLRTrainDataTransform(RESIZE)

# self_sup_dls_params = dls_params.copy()
# self_sup_dls_params['bs'] = SELF_SUPERVISED_BATCH_SIZE

# # Create initial Fastai Dataloader for loading images
# init_self_sup_dl = DataBlock(
#     **self_sup_ds_params,
# ).dataloaders(final_df, **self_sup_dls_params)

# Pytorch Lightning Bolt Dataloaders
# data_loader = init_self_sup_dl.train
# val_loader = init_self_sup_dl.valid

dataset = MyDataset(final_df[:100], validation = False, transform = SimCLRTrainDataTransform(RESIZE))
val_dataset = MyDataset(final_df[:100], validation = True, transform = SimCLREvalDataTransform(RESIZE))

data_loader = torch.utils.data.DataLoader(dataset,
                                          batch_size=SELF_SUPERVISED_BATCH_SIZE,
                                          num_workers=0)

val_loader = torch.utils.data.DataLoader(val_dataset,
                                          batch_size=SELF_SUPERVISED_BATCH_SIZE,
                                          num_workers=0)

In [None]:
num_samples = len(dataset)

#init model with batch size, num_samples (len of data), epochs to train, and autofinds learning rate
model_self_sup = SimCLR(gpus = 1, max_epochs=1, arch='resnet50', dataset='', batch_size = SELF_SUPERVISED_BATCH_SIZE, num_samples = num_samples)

trainer = Trainer(gpus = 1)
try:
  trainer.fit(model_self_sup, data_loader, val_loader)
except IndexError as e:
  print('Finish traininig due to IndexError: ', e)

In [None]:
base_ds_params = {
    # 'get_x': ColReader('Original_Filename', pref=raw_preprocess_folder+'/', suff='.png'),
    'get_x': ColReader('Raw_preprocess'),
    # 'get_x': ColReader('Original'),
    'item_tfms': item_tfms
}

label_ds_params = base_ds_params.copy()
# label_ds_params['blocks'] = (ImageBlock(cls=PILImageBW), MultiCategoryBlock if label_df['Target'].nunique() > 2 else CategoryBlock)
# label_ds_params['blocks'] = (ImageBlock(cls=PILImageBW), CategoryBlock)
label_ds_params['blocks'] = (ImageBlock(cls=PILImage), CategoryBlock)
# label_ds_params['blocks'] = (ImageBlock(cls=PILDicom_scaled), MultiCategoryBlock)
label_ds_params['get_y'] = ColReader('Target')
label_ds_params['splitter'] = TestColSplitter(col='Dataset')
label_ds_params['batch_tfms'] = label_transform

unlabel_ds_params = base_ds_params.copy()
# unlabel_ds_params['blocks'] = (ImageBlock(cls=PILImageBW))
unlabel_ds_params['blocks'] = (ImageBlock(cls=PILImage))
# unlabel_ds_params['blocks'] = (ImageBlock(cls=PILDicom_scaled))
unlabel_ds_params['splitter'] = RandomSplitter(0)

dls_params = {
    'bs': BATCH_SIZE,
    'num_workers': 0,
    'shuffle_train': True,
    'drop_last': True
}

unlabel_dls_params = dls_params.copy()
if SSL == SSL_FIX_MATCH:
    unlabel_dls_params['bs'] = BATCH_SIZE * MU

In [None]:
# DataLoaders
print(f'==> Preparing label dataloaders')

label_dl = DataBlock(**label_ds_params).dataloaders(label_df, **dls_params)

if WEIGTHED_SAMPLER:
    # Calculate sample weights to balance the DataLoader 
    from collections import Counter

    count = Counter(label_dl.items['Target'])
    class_weights = {}
    for c in count:
        class_weights[c] = 1/count[c]
    wgts = label_dl.items['Target'].map(class_weights).values[:len(train_df)]

    # Create weigthed dataloader
    weighted_dl = DataBlock(**label_ds_params).dataloaders(label_df, **dls_params, dl_type=WeightedDL, wgts=wgts)
    label_dl.train = weighted_dl.train

print(f'==> Preparing unlabel dataloaders')

unlabel_dls = [
    DataBlock(
        **unlabel_ds_params,
        batch_tfms = batch_tfms
    ).dataloaders(unlabel_df, **unlabel_dls_params) 
    for batch_tfms in unlabel_batch_tfms
]
print(f'==> Preparing SSL callback')

ssl_cb = SSLCallback(*unlabel_dls, **cb_params)
if cbs is None:
    cbs = [ssl_cb]
else:
    cbs.append(ssl_cb)

if SSL == SSL_MIX_MATCH:
    cbs.append(MixUp(alpha=ALPHA))

In [None]:
# Scheduling
if SSL == SSL_FIX_MATCH:
    sched = {'lr': SchedCos(LR, LR*math.cos(7*math.pi/16))}
    cbs.append(ParamScheduler(sched))
    moms = (MOMENTUM) # 0.9 according to FixMatch paper
    opt_func = SGD
else:
    opt_func = Adam

In [None]:
# Model
# from fastai.vision.learner import create_head
from fastai.layers import *

print("==> creating model")

classes = label_df['Target'].unique()
n_out = len(classes)

if SELF_SUPERVISED:
    concat_pool = True
    for i, layer_block in enumerate(model_self_sup.children()):
      if i == 1:
        for layer in layer_block.children():
          for j, layer_ in enumerate(layer.children()):
            if j == 3:
              nf = layer_.out_features
    # nf = num_features_model(nn.Sequential(*model_self_sup.children())) * (2 if concat_pool else 1)
    # head = create_head(nf, n_out, lin_ftrs=[512], ps=0.5, concat_pool=concat_pool, bn_final=True)
    layers = [
        nn.Dropout(p=0.5),
        nn.Linear(512, n_out),
        nn.BatchNorm1d(n_out, momentum=0.01)
    ]
    head = nn.Sequential(*layers)
    model = nn.Sequential(model_self_sup, head)
else:
    model = create_model(MODEL, n_out, pretrained=True, n_in=1, bn_final=True)

# Initialize last BatchNorm bias with values reflecting the current probabilities with Softmax
with torch.no_grad():
    for name, param in model[-1][-1].named_parameters():
        if 'bias' in name:
            param.copy_(torch.as_tensor([np.log(p) for p in train_df['Target'].value_counts(normalize=True).values]))

if SSL == SSL_MIX_MATCH:
    loss_params['model'] = model

cbs.append(EMAModel(alpha=EMA_DECAY))

In [None]:
# Loss
print("==> defining loss")

if CLASS_WEIGHT:
    class_weight = compute_class_weight(class_weight='balanced', classes=classes, y=train_df['Target'])
    class_weight = torch.as_tensor(class_weight).float()
    if torch.cuda.is_available():
        class_weight = class_weight.cuda()
else:
    class_weight = None

train_criterion = SSLLoss(unlabel_dl=unlabel_dls[0], n_out=n_out, weight=class_weight, **loss_params)
criterion = train_criterion.Lx_criterion

In [None]:
# Learner
print("==> defining learner")

Lx_metric = AvgMetric(func=criterion)
Lu_metric = AvgMetric(func=train_criterion.Lu_criterion)

# Adapt metrics depending on the number of labels
if n_out == 2:
    average = 'binary'
    roc_auc = RocAucBinary()
else:
    average = 'macro'
    roc_auc = RocAuc()

metrics = [
    error_rate,
    BalancedAccuracy(),
    # roc_auc,
    FBeta(0.5, average=average),
    F1Score(average=average),
    FBeta(2, average=average),
    Precision(average=average),
    Recall(average=average)
]

learn = Learner(label_dl, model, loss_func=train_criterion, opt_func=opt_func, lr=LR, metrics=metrics, cbs=cbs)
learn.recorder.train_metrics = True

In [None]:
learn.freeze()
learn.lr_find()

In [None]:
learn.unfreeze()
learn.lr_find()

In [None]:
learn.fine_tune(10, 0.01, freeze_epochs=1)

In [None]:
learn.show_results(1)

In [None]:
inputs,preds,targs,decoded,losses = learn.get_preds(dl=learn.dls[1], with_input=True, with_loss=True, with_decoded=True, act=None)

In [None]:
from fastai.interpret import ClassificationInterpretation

interp = ClassificationInterpretation.from_learner(learn, ds_idx=1)
interp.plot_top_losses(k=8)

In [None]:
k=8
largest=True
losses,idx = interp.top_losses(k=8, largest=True)
if not isinstance(interp.inputs, tuple): interp.inputs = (interp.inputs,)
if isinstance(interp.inputs[0], Tensor): inps = tuple(o[idx] for o in interp.inputs)
else: inps = interp.dl.create_batch(interp.dl.before_batch([tuple(o[i] for o in interp.inputs) for i in idx]))
b = inps + tuple(o[idx] for o in (interp.targs if is_listy(interp.targs) else (interp.targs,)))
x,y,its = interp.dl._pre_show_batch(b, max_n=k)
# b_out = inps + tuple(o[idx] for o in (interp.decoded if is_listy(interp.decoded) else (interp.decoded,)))
b_out = b
x1,y1,outs = interp.dl._pre_show_batch(b_out, max_n=k)
if its is not None:
    plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), interp.preds[idx], losses)

In [None]:
b_out[1]

In [None]:
idx

In [None]:
a = (interp.decoded if is_listy(interp.decoded) else (interp.decoded,))
for o in a:
    print(o[idx])

In [None]:
b

In [None]:
b_out