In [None]:
#export

import os
import warnings
import math
import time
import copy
from tqdm.notebook import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.image as plt_image

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, roc_curve, auc, accuracy_score

from pathlib import Path
Path.ls = lambda p: list(p.iterdir())

import torch
import torch.nn as NN
from torch.optim import Optimizer
from torchvision.io import read_image
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms

In [None]:
#export
NO_FINDING = "No Finding"
PATHOLOGIC = "pathologic"
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
SUBSET_MEAN = 0.50589985
SUBSET_STD = 0.23221017

LABEL_DF = None

In [None]:
#export
def seed_everything(seed=92):
    try: random.seed(seed)
    except: pass
    try: np.random.seed(seed)
    except: pass
    try: torch.manual_seed(seed)
    except: pass

In [None]:
#export
def ignore_warnings():
    warnings.filterwarnings("ignore")

In [None]:
#export
def get_working_dir():
    return Path(f"{os.environ.get('HOME')}/work/crx8")

In [None]:
#export
def get_data_path():
    return Path(f"{os.environ.get('HOME')}/.datasets/CRX8")

In [None]:
#export
def get_model_path():
    return get_working_dir()/"pt_models"

def save_model(model, model_name):
    path = str(get_model_path()/f"{model_name}.pt")
    torch.save(model, path)

def load_model(model_name):
    path = str(get_model_path()/f"{model_name}.pt")
    return torch.load(path)

In [None]:
#export
def FERTIG():
    print("FERTIG! :D")
    with open(get_working_dir()/"fertig.txt", "w") as f:
        f.write("FERTIG! :D")

In [None]:
#export
def get_image_path(): return get_data_path()/"images"

In [None]:
#export
def get_train_val_list(): 
    with open(get_data_path()/"train_val_list.txt") as f:
        train_val_list = f.readlines()
    return [l.replace("\n", "") for l in train_val_list]

In [None]:
#export
def get_labels(reduced=True):
    data = pd.read_csv(get_data_path()/"Data_Entry_2017_v2020.csv")
    tmp = data["Finding Labels"].values
    labels = set()
    for el in tmp:
        for l in el.split("|"): labels.add(l)
    if reduced: labels = [l for l in list(labels) if l != "No Finding"]
    return list(labels)

In [None]:
#export
def get_dataframes(reduced=True, small=False, small_fraction=0.1, include_labels=get_labels()):
    image_dir = get_image_path()
    train_df, test_df = get_label_dfs(reduced=reduced)
    train_df, valid_df = get_train_valid(train_df)
    
    if small:
        train_idx = int(train_df.shape[0] * small_fraction)
        val_idx   = int(valid_df.shape[0] * small_fraction)
        test_idx  = int(test_df.shape[0]  * small_fraction)
        
        train_df = train_df.iloc[:train_idx,:]
        valid_df = valid_df.iloc[:val_idx,:]
        test_df  = test_df.iloc[:test_idx,:]
    
    if len(include_labels) < len(get_labels()):
    
        tmp_train_df = train_df[train_df[include_labels].sum(axis=1, numeric_only=True) > 0]
        tmp_valid_df = valid_df[valid_df[include_labels].sum(axis=1, numeric_only=True) > 0]
        tmp_test_df  = test_df[test_df[include_labels].sum(axis=1, numeric_only=True) > 0]

        negative_df = pd.concat([train_df, valid_df], axis=0)
        #display(negative_df)
        
        negative_df = negative_df[negative_df[get_labels()].sum(axis=1) == 0]
        test_neg_df = test_df[test_df[include_labels].sum(axis=1, numeric_only=True) == 0]
        
        pos_rows = tmp_train_df.shape[0] + tmp_valid_df.shape[0]
        if negative_df.shape[0] < pos_rows:
            train_idx = math.floor(negative_df.shape[0] * tmp_train_df.shape[0] / pos_rows)
            print(train_idx)
            valid_idx = math.floor(train_idx + negative_df.shape[0] * tmp_valid_df.shape[0] / pos_rows)
        else:
            train_idx = tmp_train_df.shape[0]
            valid_idx = train_idx + tmp_valid_df.shape[0]
        
        
        train_df = pd.concat([tmp_train_df, negative_df.iloc[:train_idx, :]], axis=0)
        valid_df = pd.concat([tmp_valid_df, negative_df.iloc[train_idx:valid_idx, :]], axis=0)
        
        if tmp_test_df.shape[0] >= test_neg_df.shape[0]:
            test_df = pd.concat([tmp_test_df, test_neg_df], axis=1)
        else:
            test_df = pd.concat([tmp_test_df, test_neg_df.iloc[:tmp_test_df.shape[0], :]], axis=0)
        
    return train_df, valid_df, test_df

In [None]:
#export
def get_label_dfs(reduced=True):
    data = pd.read_csv(get_data_path()/"Data_Entry_2017_v2020.csv")
    additional_data = data.drop(columns=["Image Index", "Finding Labels"])
    values = data.values[:,:2]
    labels = get_labels(reduced=reduced)
    cols = ["Image Index", *labels]
    col2idx = {c:i for i, c in enumerate(labels)}
    arr = np.zeros((values.shape[0], len(labels)))

    for row in range(arr.shape[0]):
        image_labels = values[row, 1].split("|")
        for col, col_name in enumerate(cols):
            for lbl in image_labels:
                if reduced:
                    if lbl == NO_FINDING: continue
                arr[row, col2idx[lbl]] = 1

    new_data = pd.DataFrame({"Image Index": values[:,0]})
    new_data = pd.concat([new_data, pd.DataFrame(arr, columns=labels), additional_data], axis=1)
    
    train_label_df = new_data[new_data["Image Index"].isin(get_train_val_list())]
    test_label_df = new_data[new_data["Image Index"].isin(get_test_list()) ]
    
    return train_label_df, test_label_df

In [None]:
#export
def check_for_leakage(df1, df2):
    patient_col = "Patient ID"
    df1_patients_unique = set(df1[patient_col].values)
    df2_patients_unique = set(df2[patient_col].values)
    patients_in_both_groups = df1_patients_unique.intersection(df2_patients_unique)
    leakage = len(patients_in_both_groups) > 0
    return leakage

In [None]:
#export
def add_rows(indices, q_d, data_d):
    for r_idx in indices:
        for c_idx, k in enumerate(data_d.keys()): 
            data_d[k].append(q_d[r_idx, c_idx])
    return data_d

def get_train_valid(df, val_size=0.2):
    
    try:
        train_df = pd.read_csv(
            get_working_dir()/"no_overlap_train_df.csv", index_col="Unnamed: 0")
        valid_df = pd.read_csv(
            get_working_dir()/"no_overlap_valid_df.csv", index_col="Unnamed: 0")
        return train_df, valid_df
    except:
        print("No precomputed dataframes found...\nComputing them now!")
    
    patients = Counter(df["Patient ID"].values).most_common()

    train_size = round(df.shape[0] * (1 - val_size))
    valid_size = df.shape[0] - train_size
    patient_id_idx = [i+1 for i, k in enumerate(df.columns) if k == "Patient ID"][0]

    train_df = {k: [] for k in ["tmp_index", *df.columns]}
    valid_df = {k: [] for k in ["tmp_index", *df.columns]}

    df_values = df.values
    tmp = np.zeros((df.shape[0], df.shape[1]+1)).astype(np.object)
    tmp[:, 1:] = df_values
    tmp[:, 0] = df.index.values
    df_values = tmp

    for pid, _ in tqdm(patients):
        train_fill = len(train_df["tmp_index"]) / train_size
        valid_fill = len(valid_df["tmp_index"]) / valid_size
        if train_fill <= valid_fill:
            indices = np.where(df_values[:,patient_id_idx] == pid)[0]
            train_df = add_rows(indices, df_values, train_df)
        else:
            indices = np.where(df_values[:,patient_id_idx] == pid)[0]
            valid_df = add_rows(indices, df_values, valid_df)
    
    train_df = pd.DataFrame(train_df)
    tmp_idx = train_df["tmp_index"].values
    train_df = train_df.drop(columns="tmp_index")
    train_df.index= tmp_idx
    
    valid_df = pd.DataFrame(valid_df)
    tmp_idx = valid_df["tmp_index"].values
    valid_df = valid_df.drop(columns="tmp_index")
    valid_df.index= tmp_idx
    
    train_df.to_csv(get_working_dir()/"no_overlap_train_df.csv", index=True)
    valid_df.to_csv(get_working_dir()/"no_overlap_valid_df.csv", index=True)
    print("Dataframes saved!")
    
    return train_df, valid_df

In [None]:
#export
def old_get_train_valid(data, val_size=0.2, seed=92):
    # Currently with patient overlap!
    warnings.warn("Train-Val-Split with patient overlap!")
    warnings.warn("DEPRECATED")
    labels = get_labels()
    X = data[[c for c in data.columns if c not in labels]]
    y = data[labels]
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=val_size, random_state=seed)
    return X_train, X_test, y_train, y_test

In [1]:
#export
def load_label_df():
    global LABEL_DF
    if LABEL_DF is None: 
        LABEL_DF = pd.read_csv(get_data_path()/"Data_Entry_2017_v2020.csv")
        to_drop = [c for c in LABEL_DF.columns if c not in ["Image Index", "Finding Labels"]]
        LABEL_DF = LABEL_DF.drop(columns=to_drop)
    return LABEL_DF

In [None]:
#export
def image2label(fn):
    label_df = load_label_df()
    return label_df[label_df["Image Index"] == fn]["Finding Labels"]

In [None]:
#export
def translate2label(arr, labels=None):
    return "not implemented"
    try: arr = arr.values[0]
    except: pass
    if labels is None: labels = get_labels()
    idx2lbl = {i: l for i, l in enumerate(labels)}
    pos_lbls = [idx2lbl[idx] for idx, v in enumerate(arr) if v == 1]
    print(pos_lbls)
    if len(pos_lbls) == 0: return "No Finding"
    return "|".join(pos_lbls)

In [None]:
#export
def print_image(x, label="(no value passed)"):
    assert len(x.shape) == 3
    
    x = np.array(x)
    x = np.einsum("cwh -> whc", x)
    
    #c, w, h = x.shape
    x = x * IMAGENET_STD + IMAGENET_MEAN
    plt.title(label)
    plt.imshow(x, cmap="gray")
    #try: plt.imshow(x.view(w, h, c), cmap="bone");
    #except: plt.imshow(x.reshape(w, h, c), cmap="bone");

In [None]:
#export
label_str = lambda v, l: l if v > 0.5 else NO_FINDING 

In [None]:
#export
def translate_one_label(label, values):
    translate_label = lambda v: label if v > 0.5 else NO_FINDING
    return list(map(translate_label, values))

In [None]:
#export
def print_one_label_batch(X, y, label):

    assert X.shape[0] % 2 == 0
    X = X * SUBSET_STD + SUBSET_MEAN
    x_dim = int(X.shape[0] / 2)
    y_dim = int(X.shape[0] / x_dim)
    
    axes = []
    figure = plt.figure(figsize=(2*x_dim, 20*y_dim))
    i = 0
    for y_idx in range(y_dim):
        for x_idx in range(x_dim):
            axes.append(figure.add_subplot(x_dim, y_dim, i+1))
            axes[-1].set_title(translate_one_label(label, y[i]))
            plt.imshow(X[i].reshape(X[i].shape[1], X[i].shape[2], X[i].shape[0]))
            i += 1
    figure.tight_layout()
    plt.show()

In [None]:
#export
def print_batch(X, y, labels=get_labels()):
    assert X.shape[0] % 2 == 0
    X = X * SUBSET_STD + SUBSET_MEAN
    x_dim = int(X.shape[0] / 2)
    y_dim = int(X.shape[0] / x_dim)
    
    axes = []
    figure = plt.figure(figsize=(2*x_dim, 20*y_dim))
    i = 0
    for y_idx in range(y_dim):
        for x_idx in range(x_dim):
            axes.append(figure.add_subplot(x_dim, y_dim, i+1))
            axes[-1].set_title(translate2label(y[i]))
            plt.imshow(X[i].reshape(X[i].shape[1], X[i].shape[2], X[i].shape[0]))
            i += 1
    figure.tight_layout()
    plt.show()

In [None]:
#export
def get_test_list(): 
    with open(get_data_path()/"test_list.txt") as f:
        train_val_list = f.readlines()
    return [l.replace("\n", "") for l in train_val_list]

In [None]:
#export
def get_transforms(image_size=(224, 224)):
    train_tfs = transforms.Compose([
        transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
        transforms.Resize(image_size),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(1)
    ])
    test_tfs = transforms.Compose([
        transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
        transforms.Resize(image_size)
    ])
    return train_tfs, test_tfs

In [None]:
#export
def compute_class_freqs(labels):
    N = labels.shape[0]
    
    positive_frequencies = np.sum(labels, axis=0) / N
    negative_frequencies = 1 - positive_frequencies

    return positive_frequencies, negative_frequencies

In [None]:
#export
def compute_positive_class_weigths(labels):
    n_positives = (labels == 1).sum(axis=0)
    n_negatives = (labels == 0).sum(axis=0)
    return torch.Tensor(n_negatives / n_positives)

In [None]:
#export
def auroc(y_hat, y, model_name="model_name", with_chexnet=True, with_previous=True):
    aurocs = {}
    for l_idx, l in enumerate(get_labels()):
        try:
            v = roc_auc_score(y[:, l_idx], y_hat[:, l_idx])
        except ValueError:
            warnings.warn(f"{l} only has one class. Returning 0!")
            v = 0.
        aurocs[l] = v
    df = pd.DataFrame(aurocs.values(), index=aurocs.keys(), columns=[model_name])
    if with_previous:
        prev = load_results()
        for i, c in enumerate(prev.columns):
            if c not in df.columns:
                df = pd.concat([df, prev.iloc[:, i]], axis=1)
    if with_chexnet: df = add_chexnet(df)
    return df

In [None]:
#export
def threshold_predictions(pred, t=0.5):
    return pred >= t

In [None]:
#export
def chexnet_df():
    values = [
        0.8094, 
        0.9248, 
        0.8638, 
        0.7345, 
        0.8676, 
        0.7802, 
        0.7680, 
        0.8887, 
        0.7901, 
        0.8878, 
        0.9371, 
        0.8047, 
        0.8062, 
        0.9164
    ]
    indices = [
        "Atelectasis", 
        "Cardiomegaly", 
        "Effusion", 
        "Infiltration", 
        "Mass", 
        "Nodule", 
        "Pneumonia", 
        "Pneumothorax", 
        "Consolidation", 
        "Edema", 
        "Emphysema", 
        "Fibrosis", 
        "Pleural_Thickening", 
        "Hernia"
    ]
    return pd.DataFrame(
        values,
        index=indices,
        columns=["CheXNet"]
    )

In [None]:
#export
def add_chexnet(df):
    if "CheXNet" not in df.columns:
        return pd.concat([df, chexnet_df()], axis=1)
    return df

In [None]:
#export
def save_results(df):
    df.to_csv(get_working_dir()/"AUROC_results.csv", index=True)

In [None]:
#export
def load_results():
    return pd.read_csv(get_working_dir()/"AUROC_results.csv", index_col="Unnamed: 0")

In [None]:
#export
def array_info(arr, with_hist=False):
    print("Shape:\t", arr.shape)
    print("Mean:\t", arr.mean())
    print("Std:\t", arr.std())
    print("Max:\t", arr.max())
    print("Min:\t", arr.min())
    if with_hist == True:
        print("Histogram:")
        plt.hist(arr.flatten());

In [None]:
#export
def show_image(fn):
    im = plt_image.imread(fn)
    plt.imshow(im, cmap="bone");

In [None]:
#export
def calc_stat(fn):
        image = plt_image.imread(fn)
        return np.array([image.mean(), image.std()])
def calc_stats(df):
    image_names = [get_image_path()/fn for fn in df.loc[:,"Image Index"]]
    with mp.Pool() as p:
        stats = p.map(calc_stat, image_names)
    stats = np.array(stats)
    return stats[:,0].mean(), stats[:,1].mean()

In [None]:
#export
class EmptyScheduler: 
    def step(self):pass
    def reset(self):pass
    def is_empty(self):return True

In [None]:
#export
# From: https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html
def find_lr(net, dl, optimizer, criterion, init_value = 1e-8, 
            final_value=10., beta = 0.98, device=torch.device('cuda:0')):
    torch.cuda.empty_cache()
    num = len(dl)-1
    mult = (final_value / init_value) ** (1/num)
    lr = init_value
    optimizer.param_groups[0]['lr'] = lr
    avg_loss = 0.
    best_loss = 0.
    batch_num = 0
    losses = []
    log_lrs = []
    
    net.to(device)
    
    for data in tqdm(dl):
        batch_num += 1
        #As before, get the loss for this mini-batch of inputs/outputs
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        #inputs, labels = Variable(inputs), Variable(labels)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        #Compute the smoothed loss
        avg_loss = beta * avg_loss + (1-beta) *loss.item()
        smoothed_loss = avg_loss / (1 - beta**batch_num)
        #Stop if the loss is exploding
        if batch_num > 1 and smoothed_loss > 10 * best_loss:
            return log_lrs, losses
        #Record the best loss
        if smoothed_loss < best_loss or batch_num==1:
            best_loss = smoothed_loss
        #Store the values
        losses.append(smoothed_loss)
        log_lrs.append(math.log10(lr))
        #Do the SGD step
        loss.backward()
        optimizer.step()
        #Update the lr for the next step
        lr *= mult
        optimizer.param_groups[0]['lr'] = lr
    return log_lrs, losses

In [None]:
#export
# Modified from: https://github.com/dkumazaw/onecyclelr/blob/master/onecyclelr.py
class OneCycleLR:
    """ Sets the learing rate of each parameter group by the one cycle learning rate policy
    proposed in https://arxiv.org/pdf/1708.07120.pdf. 
    It is recommended that you set the max_lr to be the learning rate that achieves 
    the lowest loss in the learning rate range test, and set min_lr to be 1/10 th of max_lr.
    So, the learning rate changes like min_lr -> max_lr -> min_lr -> final_lr, 
    where final_lr = min_lr * reduce_factor.
    Note: Currently only supports one parameter group.
    Args:
        optimizer:             (Optimizer) against which we apply this scheduler
        num_steps:             (int) of total number of steps/iterations
        lr_range:              (tuple) of min and max values of learning rate
        momentum_range:        (tuple) of min and max values of momentum
        annihilation_frac:     (float), fracion of steps to annihilate the learning rate
        reduce_factor:         (float), denotes the factor by which we annihilate the learning rate at the end
        last_step:             (int), denotes the last step. Set to -1 to start training from the beginning
    Example:
        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
        >>> scheduler = OneCycleLR(optimizer, num_steps=num_steps, lr_range=(0.1, 1.))
        >>> for epoch in range(epochs):
        >>>     for step in train_dataloader:
        >>>         train(...)
        >>>         scheduler.step()
    Useful resources:
        https://towardsdatascience.com/finding-good-learning-rate-and-the-one-cycle-policy-7159fe1db5d6
        https://medium.com/vitalify-asia/whats-up-with-deep-learning-optimizers-since-adam-5c1d862b9db0
    """

    def __init__(self,
                 optimizer: Optimizer,
                 num_steps: int,
                 lr_range: tuple = (0.1, 1.),
                 momentum_range: tuple = (0.85, 0.95),
                 annihilation_frac: float = 0.1,
                 reduce_factor: float = 0.01,
                 last_step: int = -1):
        # Sanity check
        if not isinstance(optimizer, Optimizer):
            raise TypeError('{} is not an Optimizer'.format(type(optimizer).__name__))
        self.optimizer = optimizer
        
        self.num_steps = num_steps

        self.min_lr, self.max_lr = lr_range[0], lr_range[1]
        assert self.min_lr < self.max_lr, \
            "Argument lr_range must be (min_lr, max_lr), where min_lr < max_lr"

        self.min_momentum, self.max_momentum = momentum_range[0], momentum_range[1]
        assert self.min_momentum < self.max_momentum, \
            "Argument momentum_range must be (min_momentum, max_momentum), where min_momentum < max_momentum"

        self.num_cycle_steps = int(num_steps * (1. - annihilation_frac))  # Total number of steps in the cycle
        self.final_lr = self.min_lr * reduce_factor

        self.last_step = last_step

        if self.last_step == -1:
            self.step()
        
        self.ground_state = self.state_dict()
    
    def reset(self):
        for k, v in self.ground_state.items():
            self.__dict__[k] = v
    
    def is_empty(self): return False

    def state_dict(self):
        """Returns the state of the scheduler as a :class:`dict`.
        It contains an entry for every variable in self.__dict__ which
        is not the optimizer. (Borrowed from _LRScheduler class in torch.optim.lr_scheduler.py)
        """
        return {key: value for key, value in self.__dict__.items() if key not in ['optimizer', 'ground_state']}

    def load_state_dict(self, state_dict):
        """Loads the schedulers state. (Borrowed from _LRScheduler class in torch.optim.lr_scheduler.py)
        Arguments:
            state_dict (dict): scheduler state. Should be an object returned
                from a call to :meth:`state_dict`.
        """
        self.__dict__.update(state_dict)

    def get_lr(self):
        return self.optimizer.param_groups[0]['lr']

    def get_momentum(self):
        return self.optimizer.param_groups[0]['momentum']

    def step(self):
        """Conducts one step of learning rate and momentum update
        """
        current_step = self.last_step + 1
        self.last_step = current_step

        if current_step <= self.num_cycle_steps // 2:
            # Scale up phase
            scale = current_step / (self.num_cycle_steps // 2)
            lr = self.min_lr + (self.max_lr - self.min_lr) * scale
            momentum = self.max_momentum - (self.max_momentum - self.min_momentum) * scale
        elif current_step <= self.num_cycle_steps:
            # Scale down phase
            scale = (current_step - self.num_cycle_steps // 2) / (self.num_cycle_steps - self.num_cycle_steps // 2)
            lr = self.max_lr - (self.max_lr - self.min_lr) * scale
            momentum = self.min_momentum + (self.max_momentum - self.min_momentum) * scale
        elif current_step <= self.num_steps:
            # Annihilation phase: only change lr
            scale = (current_step - self.num_cycle_steps) / (self.num_steps - self.num_cycle_steps)
            lr = self.min_lr - (self.min_lr - self.final_lr) * scale
            momentum = None
        else:
            # Exceeded given num_steps: do nothing
            return

        self.optimizer.param_groups[0]['lr'] = lr
        if momentum:
            self.optimizer.param_groups[0]['momentum'] = momentum

In [None]:
#export
class CRX8_Data(Dataset):
    
    def __init__(self, df, image_path, labels, image_size=None, transforms=None):
        self.df = df
        self.image_size = image_size
        self.image_path = image_path
        self.len = df.shape[0]
        self.labels = self._correct_labels(labels)
        self.transforms = transforms
        self.df["Index_2"] = list(range(self.df.shape[0]))
        
    def _correct_labels(self, lbls):
        if type(lbls) == type([]): return lbls
        return [lbls]
        
    def __len__(self): return self.len
    
    def _resize(self, im):
        return transforms.Resize(self.image_size)(im)

    def __getitem__(self, idx):
        img_path = self._get_image_path(idx)     
        image = read_image(img_path)
        image = self._make3D(image)
        label = self.df.iloc[idx,:].loc[self.labels].values
        if self.transforms: image = self.transforms(image)
        return image.float(), torch.Tensor(label.astype(np.float)).float()
    
    def _correct_dims(self, t):
        t = t.squeeze()
        if len(t.shape) == 2: return t
        if len(t.shape) == 3: return t[0, :, :]
        assert False, "Check dimensions of loaded images!"
    
    def _make3D(self, t):
        t = self._correct_dims(t)
        t = t.detach().numpy()
        t = np.expand_dims(t, axis=2)
        t = np.concatenate((t,t,t), axis=2)
        reshaped = torch.Tensor(np.einsum("whc -> cwh", t))
        reshaped /= 255.
        return reshaped
    
    def _get_image_path(self, idx):
        return str(self.image_path/self.df.iloc[idx].loc["Image Index"]) 

In [None]:
#export
def get_device(verbose=True):
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
        if verbose: print("Using the GPU!")
    else:
        device = torch.device("cpu")
        if verbose: print("Using the CPU!")
    return device

In [None]:
#export
def get_cpu(verbose=True):
    if verbose: print("Using the CPU!")
    return torch.device("cpu")

In [None]:
#export
class Logger:
    def __init__(self):
        self.state = {}
    
    def get_name(self, phase, name): return f"{phase}_{name}"
    
    def add_metric(self, name, phases=["train", "val"]):
        for phase in phases: self.state[self.get_name(phase, name)] = []

    def add_value(self, name, value, phase):
        self.state[self.get_name(phase, name)].append(value)
    
    def _plot_values(self, name):
        train_name = self.get_name("train", name)
        val_name = self.get_name("val", name)
        plt.plot(
            list(range(len(self.state[train_name]))), 
            self.state[train_name]);
        plt.plot(
            list(range(len(self.state[val_name]))), 
            self.state[val_name]);
    
    def plot_losses(self):
        self._plot_values("loss")
    
    def plot_acc(self):
        self._plot_values("acc")
    
    def plot_auroc(self, phase="val"):
        name = self.get_name(phase, "auroc")
        df = self.state[name][0][0]
        for data, _ in self.state[name]:
            df = pd.concat([df, data], axis=1)
        return add_chexnet(df)

In [None]:
#export
def auroc_score(y_hat, y, model_name):
    fpr, tpr, thresholds = roc_curve(y, y_hat)
    auc_value = auc(fpr, tpr)
    label_threshold = thresholds[np.argmax(tpr - fpr)]
    return auc_value, label_threshold

In [None]:
#export
def multi_auroc_score(y_hat, y, model_name, labels=get_labels()):
    if type(labels) != type([]): labels = [labels] 
    score = {l:[] for l in labels}
    label_thresholds = {}
    for l_idx, l in enumerate(labels):
        fpr, tpr, thresholds = roc_curve(y[:, l_idx], y_hat[:, l_idx])
        auc_value = auc(fpr, tpr)
        score[l].append(auc_value)
        label_thresholds[l] = thresholds[np.argmax(tpr - fpr)]
    df = pd.DataFrame(score.values(), index=score.keys(), columns=[model_name])
    return df, label_thresholds

In [None]:
#export
def train_model(model, criterion, optimizer, scheduler, dataloaders, logger, model_name,
                labels=get_labels(), alpha=0.1, num_epochs=25, device=torch.device("cuda:0")):
    sigmoid = NN.Sigmoid()
    since = time.time()
    model.to(device)
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    running_acc = 0
    
    
    epoch_looper = tqdm(range(num_epochs))

    for epoch in epoch_looper:
        for phase in ['train', 'val']:
            epoch_looper.set_description(f"Epoch {epoch} - {phase}")
            if phase == 'train': model.train()  
            else: model.eval()   

            running_loss = 0.0
            running_corrects = 0
            
            y_hat, truth = [], []

            # Iterate over data.
            counter = 0
            datalooper = tqdm(dataloaders[phase])
            for X, y in datalooper:
                datalooper.set_description(f"{phase}_loss: {running_loss:.03f}")
                
                X = X.to(device)
                y = y.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(X)
                    sigmoid_outputs = sigmoid(outputs)
                    thresholded = sigmoid_outputs >= 0.5
                    loss = criterion(outputs, y)
                
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                    
                    y_hat = [*y_hat, *sigmoid_outputs.cpu().detach().numpy()]
                    truth = [*truth, *y.cpu().detach().numpy()]
                
                if counter == 0: 
                    running_loss = loss.item()
                else:
                    running_loss = loss.item() * alpha + (1 - alpha) * running_loss
                
                #running_loss += loss.item() * y.size(0)
                running_corrects += (thresholded == y).sum() / y.numel()
                
                counter += 1
                
            if phase == 'train': scheduler.step()

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            epoch_auroc = auroc_score(np.array(y_hat), np.array(truth), f"{model_name}_e{epoch}", labels=labels)
            
            logger.add_value("loss",  epoch_loss,  phase)
            logger.add_value("acc",   epoch_acc,   phase)
            logger.add_value("auroc", epoch_auroc, phase)

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
        
        scheduler.reset()

    # load best model weights
    model.load_state_dict(best_model_wts)
    
    return model

In [None]:
#export
def get_layers(model):
    pre_layers = []
    for i, data in enumerate(model.named_parameters()):
        name, param = data
        pre_layers.append((i, name.replace(".weight", "").replace(".bias", "")))
    pre_layers = list(set(pre_layers))
    pre_layers = sorted(pre_layers, key=lambda p: p[0])

    layers = []
    already_added = set()
    for i, layer_name in pre_layers:
        if layer_name not in already_added:
            already_added.add(layer_name)
            layers.append((len(layers), layer_name))
            
    return layers
    

In [None]:
#export
def get_batch(idx, dl):
    assert idx < len(dl)
    for i, data in enumerate(dl):
        if i == idx: return data

In [None]:
#export
def plot_history(history, mode="val", metric="loss"):
    train_losses, val_losses = [], []
    for k, v in history.items():
        if "train" in k: train_losses = [*train_losses, *v[metric]]
        if "val" in k: val_losses = [*val_losses, *v[metric]]
    if mode == "val": plt.plot(list(range(len(val_losses))), val_losses)
    else: plt.plot(list(range(len(train_losses))), train_losses)
    

In [None]:
#export
def train_one_cycle(model, 
                    criterion, 
                    optimizer, 
                    scheduler, 
                    dataloader, 
                    model_name, 
                    labels=get_labels(), 
                    device=torch.device("cuda:0")):
    
    sigmoid = NN.Sigmoid()
    model.to(device)
    model.train()
    
    running_loss, running_acc = [], []
    running_predictions_np, running_y_np = [], []
    
    datalooper = tqdm(dataloader)
    for X, y in datalooper:
        if len(running_loss) > 0: info = f"Loss: {running_loss[-1]:.03f}, Acc: {running_acc[-1]:.03f}, lr: {optimizer.param_groups[0]['lr']:.08f}"
        else: info = "no data yet"
        datalooper.set_description(info)
        
        X, y = X.to(device), y.to(device)
        
        optimizer.zero_grad()
        
        logits = model(X)
        predictions = sigmoid(logits)
        loss = criterion(logits, y)
        
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        predictions_np = predictions.cpu().detach().numpy()
        y_np = y.cpu().detach().numpy()
        running_predictions_np = [*running_predictions_np, *predictions_np]
        running_y_np = [*running_y_np, *y_np]
        
        # Loss
        running_loss.append(loss.item())
        
        # Accuracy
        thresholded = predictions_np >= 0.5
        running_acc.append(accuracy_score(y_np, thresholded))
        
        # AUROC
        #mean_auroc = auroc_score(predictions_np, y_np, model_name, labels)[0][model_name].mean()
        #running_auroc.append(0.)#mean_auroc)
    
    running_loss = np.array(running_loss)
    running_acc = np.array(running_acc)
    #running_auroc = np.array(running_auroc)
    epoch_auroc, _ = auroc_score(np.array(running_predictions_np), np.array(running_y_np), model_name)
    epoch_info = f"Loss: {running_loss.mean():.03f}, Acc: {running_acc.mean():.03f}, AUROC: {epoch_auroc:.03f}"
    print("Train:", epoch_info)
    
    return model, {"loss": running_loss, "acc": running_acc, "auroc": epoch_auroc}

In [None]:
#export
def validate(model, 
             criterion, 
             dataloader, 
             model_name,
             device=torch.device("cuda:0")):
    
    model.to(device)
    model.eval()
    sigmoid = NN.Sigmoid()
    
    running_loss, running_y_hat, running_y = [], [], []
    
    with torch.no_grad():
    
        datalooper = tqdm(dataloader)
        for X, y in datalooper:
            X, y = X.to(device), y.to(device)
            logits = model(X)
            predictions = sigmoid(logits)
            loss = criterion(logits, y)
            
            running_loss.append(loss.item())
            running_y_hat = [*running_y_hat, *predictions.cpu().detach().numpy()]
            running_y = [*running_y, *y.cpu().detach().numpy()]
    
    running_loss = np.array(running_loss)
    running_y_hat = np.array(running_y_hat)
    running_y = np.array(running_y)
    
    auroc, threshold = auroc_score(running_y_hat, running_y, model_name)
    acc = accuracy_score(running_y, (running_y_hat>threshold))
    
    epoch_info = f"Loss: {running_loss.mean():.03f}, Acc: {acc:.03f}, AUROC: {auroc:.03f}"
    print("Val:", epoch_info)
    
    return {"loss": running_loss, "acc": acc, "auroc": auroc,"threshold": threshold}

In [None]:
#export
def fit(model, 
        criterion, 
        optimizer, 
        dataloaders, 
        model_name, 
        epochs,
        lr,
        sam=False,
        with_reset=False,
        patience=1,
        scheduler=EmptyScheduler(), 
        labels=get_labels(), 
        metric="loss",
        device=torch.device("cuda:0")):
    
    metric = metric.lower()
    assert metric in ["auroc", "loss", "acc"]
    if metric in ["auroc", "acc"]: best_metric = -1e12
    if metric == "loss": best_metric = 1e12
        
    patience_counter = 0
    patience_metric = 1e12 # monitors val_loss
    curr_lr = lr
        
    history = {}
    
    for e in range(epochs):
        print(f"Epoch {e+1}:")
        if sam:
            model, train_hist = train_SAM(model, criterion, optimizer,
                                          dataloaders["train"], model_name, 
                                          lr=curr_lr, device=device)
        else:
            model, train_hist = train_one_cycle(model, criterion, optimizer, scheduler, 
                                       dataloaders["train"], model_name, device=device)
        val_hist = validate(model, criterion, dataloaders["val"], model_name, device=device)
        
        # Reducing lr on plateau
        if val_hist["loss"].mean() < patience_metric:
            patience_metric = val_hist["loss"].mean()
            patience_counter = 0
        else:
            patience_counter += 1
        if patience_counter >= patience:
            patience_counter = 0
            curr_lr /= 10.
            if not scheduler.is_empty():
                scheduler = get_one_cycle_scheduler(dataloaders["train"], curr_lr, optimizer)
            elif sam:
                sam_optimizer = SAM(model.parameters(), torch.optim.Adam, lr=curr_lr)
            else:
                optimizer.param_groups[0]['lr'] = curr_lr
            if with_reset:
                model = load_model(model_name)
                print("Resetted model to previous best.")
            print(f"Lowered lr to {curr_lr}")
            
        
        # Saving model
        if metric == "acc":
            if best_metric < val_hist["acc"]:
                best_metric = val_hist["acc"]
                save_model(model, model_name)
                print(f"Saved model with acc {best_metric:.04f}")
        elif metric == "auroc":
            if best_metric < val_hist["auroc"]:
                best_metric = val_hist["auroc"]
                save_model(model, model_name)
                print(f"Saved model with auroc {best_metric:.04f}")
        elif metric == "loss":
            if best_metric > val_hist["loss"].mean():
                best_metric = val_hist["loss"].mean()
                save_model(model, model_name)
                print(f"Saved model with loss {best_metric:.04f}")
        
        history[f"e{e+1}_train"] = train_hist
        history[f"e{e+1}_val"] = val_hist
        scheduler.reset()
        
        # Early Stopping
        if curr_lr < 1e-13: 
            print("Learning rate is basically zero. Stopping training.")
            break
    
    model = load_model(model_name)
    return model, history

In [None]:
#export
def get_one_cycle_scheduler(dataloader, lr, optimizer):
    num_steps = len(dataloader)
    lr_range = (lr / 10, lr)
    return OneCycleLR(optimizer, num_steps, lr_range)

In [None]:
#export
def get_binary_df(column, df):
    lbls = get_labels()
    assert column in lbls
    excluding = [lbl for lbl in lbls if lbl != column]
    df = df.drop(columns=excluding)
    return df

In [None]:
#export
def get_preclinic_df(df):
    if NO_FINDING in df.columns:
        df[PATHOLOGIC] = df[NO_FINDING].values
        df = df.drop(columns=[NO_FINDING, *get_labels()])
        return df
    tmp_df = df.drop(columns=["Image Index", "Follow-up #", "Height]", "y]", "Patient ID", "Patient Age", "Patient Gender", "OriginalImage[Width",  "OriginalImagePixelSpacing[x", "View Position"])
    patho_idx = np.clip(tmp_df.values.sum(axis=1), 0, 1)
    df[PATHOLOGIC] = patho_idx
    df = df.drop(columns=get_labels())
    return df

In [None]:
#export
# From: https://pub.towardsai.net/we-dont-need-to-worry-about-overfitting-anymore-9fb31a154c81
import torch

class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho, **kwargs)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)

            for p in group["params"]:
                if p.grad is None: continue
                e_w = p.grad * scale.to(p)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"
                self.state[p]["e_w"] = e_w

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.sub_(self.state[p]["e_w"])  # get back to "w" from "w + e(w)"

        self.base_optimizer.step()  # do the actual "sharpness-aware" update

        if zero_grad: self.zero_grad()


    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
        norm = torch.norm(
                    torch.stack([
                        p.grad.norm(p=2).to(shared_device)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        return norm

In [None]:
#export
def train_SAM(model, 
              criterion, 
              optimizer, 
              dataloader, 
              model_name,
              lr=1e-3,
              scheduler=EmptyScheduler(),
              labels=get_labels(), 
              device=torch.device("cuda:0")):
    
    sigmoid = NN.Sigmoid()
    model.to(device)
    model.train()
    
    running_loss, running_acc = [], []
    running_predictions_np, running_y_np = [], []
    
    datalooper = tqdm(dataloader)
    for X, y in datalooper:
        if len(running_loss) > 0: info = f"Loss: {running_loss[-1]:.03f}, Acc: {running_acc[-1]:.03f}, lr: {lr:.08f}"
        else: info = "no data yet"
        datalooper.set_description(info)
        
        X, y = X.to(device), y.to(device)
        
        optimizer.zero_grad()
        
        logits = model(X)
        predictions = sigmoid(logits)
        loss = criterion(logits, y)
        
        loss.backward()
        optimizer.first_step(zero_grad=True)
        
        criterion(model(X), y).backward() 
        optimizer.second_step(zero_grad=True)
        
        scheduler.step()
        
        predictions_np = predictions.cpu().detach().numpy()
        y_np = y.cpu().detach().numpy()
        running_predictions_np = [*running_predictions_np, *predictions_np]
        running_y_np = [*running_y_np, *y_np]
        
        # Loss
        running_loss.append(loss.item())
        
        # Accuracy
        thresholded = predictions_np >= 0.5
        running_acc.append(accuracy_score(y_np, thresholded))
    
    running_loss = np.array(running_loss)
    running_acc = np.array(running_acc)
    epoch_auroc, _ = auroc_score(np.array(running_predictions_np), np.array(running_y_np), model_name)
    epoch_info = f"Loss: {running_loss.mean():.03f}, Acc: {running_acc.mean():.03f}, AUROC: {epoch_auroc:.03f}"
    print("Train:", epoch_info)
    
    return model, {"loss": running_loss, "acc": running_acc, "auroc": epoch_auroc}

In [None]:
#export
pos_images_name = lambda m: f"{m}_pos_images.npy"
neg_images_name = lambda m: f"{m}_neg_images.npy"

pos_nt_name = lambda m: f"{m}_pos_noisetunnnels.npy"
neg_nt_name = lambda m: f"{m}_neg_noisetunnnels.npy"

pos_probs_name = lambda m: f"{m}_pos_probs.npy"
neg_probs_name = lambda m: f"{m}_neg_probs.npy"

pos_truths_name = lambda m: f"{m}_pos_truths.npy"
neg_truths_name = lambda m: f"{m}_neg_truths.npy"

def get_insights_path(model_name):
    wd = get_working_dir()
    dirs = [f.stem for f in wd.ls() if f.is_dir()]
    if "insights" not in dirs:
        (wd/"insights").mkdir(mode=0o777, parents=False, exist_ok=False)
    d = wd/"insights"
    dirs = [f.stem for f in d.ls() if f.is_dir()]
    if model_name not in dirs:
        (d/model_name).mkdir(mode=0o777, parents=False, exist_ok=False)
    return d/model_name

def save_insights(pos_images, neg_images, 
                  pos_noise_tunnels, neg_noise_tunnels,
                  pos_probs, neg_probs, 
                  pos_truths, neg_truths,
                  model_name):
    ipath = get_insights_path(model_name)
    
    np.save(ipath/pos_images_name(model_name), pos_images)
    np.save(ipath/neg_images_name(model_name), neg_images)
    
    np.save(ipath/pos_nt_name(model_name), pos_noise_tunnels)
    np.save(ipath/neg_nt_name(model_name), neg_noise_tunnels)
    
    np.save(ipath/pos_probs_name(model_name), pos_probs)
    np.save(ipath/neg_probs_name(model_name), neg_probs)
    
    np.save(ipath/pos_truths_name(model_name), pos_truths)
    np.save(ipath/neg_truths_name(model_name), neg_truths)
    
    print(f"Insights saved to '{ipath}'")

In [None]:
#export
model_metrics_name = lambda m: f"{m}_metrics.csv"

def load_model_metrics(model_name):
    model_stem = "_".join(model_name.split("_")[:-1])
    parent_dir = get_insights_path(model_name).parent
    stem_dirs = [d for d in parent_dir.ls() if d.is_dir()]
    for idx, stem_dir in enumerate(stem_dirs):
        extracted_label = stem_dir.name.split("_")[-1]
        complete_model_name = stem_dir/model_metrics_name(f"{model_stem}_{extracted_label}")
        if idx == 0:
            metrics = load_metrics(complete_model_name)
        else:
            metrics = pd.concat([metrics, load_metrics(complete_model_name)], axis=1)
    return metrics

def load_metrics(model_name):
    absolute_path = len(str(model_name).split("/")) > 1
    if not absolute_path:
        ipath = get_insights_path(model_name)
        saved_metrics = pd.read_csv(str(ipath/model_metrics_name(model_name)))
    else:
        saved_metrics = pd.read_csv(str(model_name))
    tmp = saved_metrics["Unnamed: 0"].values
    saved_metrics = saved_metrics.drop(columns="Unnamed: 0")
    saved_metrics.index = tmp
    return saved_metrics

def save_metrics(model_metrics, model_name):
    ipath = get_insights_path(model_name)
    model_metrics.to_csv(str(ipath/model_metrics_name(model_name)))
    print(f"Saved metrics to '{model_metrics_name(model_name)}'")