# Data setup

In [None]:
# Paths
root = '' # full path to the root of the project (i.e. /media/MultiCaRe/)
data_root = root + '' # folder containing the dataset (i.e. data/)
images_folder = data_root + '' # folder containing the images (i.e. images/)
table_path = data_root + '' # table with image labels (i.e. table.csv)
save_path = root + '' # folder where to save files (i.e. models/)

taxonomy_node = '' # name of the column of the table that will be used to train a submodel (i.e. image_type:radiology~anatomical_region:axial_region or image_type:endoscopy)

In [None]:
import albumentations, os, sys
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedGroupKFold
from fastai.vision.all import *

In [None]:
# Hyperparameters

h, w = 224, 224

hyperparameters = {
    'MODEL DESCRIPTION': taxonomy_node,
    'BS': 16,
    'EPOCHS': 30,
    'IMG_SIZE': (h, w),      # (height, width)
    'WD': 0.0,
    'TRANSFORMS': [
        albumentations.HorizontalFlip(p=0.5),
        albumentations.VerticalFlip(p=0.5),
        albumentations.Rotate(p=0.5),
        albumentations.Sharpen(p=0.5),
        albumentations.ColorJitter(brightness=0.3, contrast=0.5, saturation=0.5, hue=0.0, p=0.5),
        albumentations.RGBShift(p=0.5),
        albumentations.GaussianBlur(p=0.5),
        albumentations.GaussNoise(p=0.5),
        albumentations.RandomSizedCrop((int(0.75*h),h), h, w, p=1.0)
        ],
    'ARCH': 'convnext_tiny_in22k',
    'LOSS_FUNC': 'LabelSmoothingCrossEntropyFlat',
    'OPT_FUNC': 'Adam',
    'USE_OVERSAMPLING': True,
    'SEED': 18,
}

In [None]:
def get_data(table_path, column):
    # Read data
    data = pd.read_csv(table_path)
    # Filter data    
    data = data[data[column].notnull()]
    # Get relevant info
    image_files = data['file'].apply(lambda x: os.path.join(images_folder, x[:4], x[:5], x)).values
    labels = data[column].values
    groups = data['pmcid'].values
    return image_files, labels, groups

def create_df(image_files, labels, groups, n_splits=10, n_valid=2):
    # Initiate dataframe
    df = pd.DataFrame()
    df['file_path'] = image_files
    df['label'] = labels
    df['groups'] = groups
    df['fold'] = -1
    # Make folds
    cv = StratifiedGroupKFold(n_splits=n_splits)
    for i, (train_idxs, valid_idxs) in enumerate(cv.split(image_files, labels, groups)):
        df.loc[valid_idxs, ['fold']] = i
    # Assign folds for validation
    df['split'] = 'train'
    for i in range (n_valid):
        df.loc[df.fold == i, ['split']] = 'valid'
    del df['fold']
    df.split.value_counts()
    # Add a binary column to the dataframe
    df['is_valid'] = df.split == 'valid'
    del df['split']
    return df

In [None]:
# Dataframe
image_files, labels, groups = get_data(table_path, hyperparameters['MODEL DESCRIPTION'])
df = create_df(image_files, labels, groups)

df['label'].value_counts()

In [None]:
set_seed(hyperparameters['SEED'], True)

class AlbumentationsTransform(DisplayedTransform):
    '''
    Class that allows the use of Albumentations transforms in FastAI.
    '''
    split_idx,order=0,2
    def __init__(self, train_aug): store_attr()

    def encodes(self, img: PILImage):
        aug_img = self.train_aug(image=np.array(img))['image']
        return PILImage.create(aug_img)

# Determine the number of tasks
n_tasks = len(df.label[0].split(', '))

# Datablock
block = DataBlock(
    blocks=(ImageBlock,) + (CategoryBlock,) * n_tasks,
    n_inp=1,
    get_x=ColReader('file_path'),
    get_y=[lambda x, i=i: x['label'].split(', ')[i] for i in range(n_tasks)],
    splitter=ColSplitter(col='is_valid'),
    item_tfms=[
        Resize(hyperparameters['IMG_SIZE'], method='squish'), 
        AlbumentationsTransform(albumentations.Compose(hyperparameters['TRANSFORMS']))])

# Dataloaders
dls = block.dataloaders(df, bs=hyperparameters['BS'], shuffle=True)
dls.rng.seed(hyperparameters['SEED'])

# Sanity check
n_classes = dls.c if n_tasks>1 else [dls.c] 
classes = dls.vocab if n_tasks>1 else [dls.vocab]
print('Number of clases: ', n_classes)
print('Names of classes: ', classes)

In [None]:
# Show batch
dls.train.show_batch(max_n=16, figsize=(15,12))

In [None]:
# Show transforms
dls.train.show_batch(max_n=16, unique=True, figsize=(15,12))

In [None]:
loss = getattr(sys.modules[__name__], hyperparameters['LOSS_FUNC'])(weight=None)
metrics = [accuracy, F1Score(average='macro')]
callbacks = [SaveModelCallback(monitor='f1_score', with_opt=True), ShowGraphCallback]

# Learner
learn = vision_learner(dls,
                        hyperparameters['ARCH'],
                        normalize=True,
                        pretrained=True,
                        n_out=sum(n_classes),
                        loss_func=loss,
                        opt_func=getattr(sys.modules[__name__], hyperparameters['OPT_FUNC']),
                        metrics=metrics,
                        wd=hyperparameters['WD']).to_fp16()

# Fix issue with pickling while calling learn.export
import typing, functools
learn.loss_func.func.__annotations__ = typing.get_type_hints(learn.loss_func.func, globalns=globals(), localns=locals())
functools.update_wrapper(learn.loss_func, learn.loss_func.func)

In [None]:
def oversampled_epoch(self, class_weights = None):
    item_weights = self.items.label.apply(lambda x: class_weights[str(x)])
    oversampled_idxs = self.items.sample(n=self.n, weights=item_weights, replace=True).index
    return [np.where(self.items.index == i)[0][0] for i in oversampled_idxs]

# Oversampling
if hyperparameters['USE_OVERSAMPLING']:
    class_weights = pd.DataFrame(1 / np.sqrt(learn.dls.items.label.value_counts())).rename(index=lambda x: str(x)).to_dict()['count']
    learn.dls.train.get_idxs = types.MethodType(partial(oversampled_epoch, class_weights=class_weights), learn.dls.train)

# Training

In [None]:
# Find LR
learn.lr_find()

In [17]:
# Set LR
hyperparameters['LR'] = 3e-3

In [None]:
# Train
learn.fine_tune(hyperparameters['EPOCHS'], base_lr=hyperparameters['LR'], cbs=callbacks)

# Results

In [None]:
import seaborn as sns
from matplotlib import pyplot as plt
from sklearn.metrics import confusion_matrix

def plot_confusion_matrix(ground_truths, predictions, classes, path, figsize=(16,16), num_size=12, order_by_classes=False):
    '''
    Creates and plots a confusion matrix given the ground truths and the predictions of the classification model.

    Args:
        ground_truths (torch.tensor): ground truth (correct) target values.
        predictions (torch.tensor): estimated targets as returned by the model.
        classes (list): list of the classes labels.

    Returns:
        fig (matplotlib.figure.Figure): figure object.
    '''

    labels = classes if order_by_classes else None
    cm = confusion_matrix(ground_truths, predictions, labels=labels)
    cm_norm = confusion_matrix(ground_truths, predictions, labels=labels, normalize='true')

    df = pd.DataFrame(cm, index=classes, columns=classes)
    df_norm = pd.DataFrame(cm_norm, index=classes, columns=classes)

    plt.figure(figsize = figsize)
    ax = sns.heatmap(df_norm, annot=df, fmt='d', linewidths=0.5, linecolor='black', cmap='YlGn', vmin=0, vmax=1, annot_kws={"color": "black", "size": num_size})

    for _, spine in ax.spines.items():
        spine.set_visible(True)
        spine.set_color('black')
        spine.set_linewidth(2)

    cbar = ax.collections[0].colorbar
    cbar.outline.set_edgecolor('black')
    cbar.outline.set_linewidth(1.5)

    ax.set_title('Confusion Matrix', fontdict={'fontsize': 32, 'fontweight': 'medium'})
    ax.set_xlabel('Predicted class', fontsize=18)
    ax.set_ylabel('True class', fontsize=18)
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right', fontsize=12)
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0, ha='right', fontsize=12)
    fig = ax.get_figure()
    plt.savefig(f'{path}confusion.png', bbox_inches='tight')

    return fig


def plot_losses(learn, path):
    '''
    Creates and plots a figure with the training and validation losses curves.

    Args:
        learn (fastai.learner.Learner): trained learner object.

    Returns:
        fig (matplotlib.figure.Figure): figure object.
    '''

    rec = learn.recorder
    train_losses = np.array(rec.losses)
    train_iters = np.linspace(0, learn.n_epoch, len(train_losses))
    valid_losses = [v[1] for v in rec.values]
    valid_iters = np.arange(1, learn.n_epoch+1)

    plt.figure()
    sns.set(style="whitegrid")
    plot = sns.lineplot(x=train_iters, y=train_losses, label='Train', linestyle='-')
    sns.lineplot(x=valid_iters, y=valid_losses, label='Valid', marker='o', linestyle='--', color='green')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    fig = plot.figure
    plt.savefig(f'{path}losses.png', bbox_inches='tight')
    
    return fig


def plot_metrics(learn, path):
    '''
    Creates and plots a figure with the curves of all metrics.

    Args:
        learn (fastai.learner.Learner): trained learner object.

    Returns:
        fig (matplotlib.figure.Figure): figure object.
    '''

    valid_iters = np.arange(1, learn.n_epoch+1)
    met = np.array([v[2:] for v in learn.recorder.values])
    try:
        metrics_names = [m.func.__name__ for m in learn.metrics]
    except:
        metrics_names = [m.__name__ for m in learn.metrics]

    plt.figure()
    sns.set(style="whitegrid")
    for i in np.arange(len(metrics_names)):
        plot = sns.lineplot(x=valid_iters, y=met[:,i], label=metrics_names[i], linestyle='-')
    plt.xlabel('Epochs')
    plt.ylabel('Metrics')
    plt.legend()   
    fig = plot.figure
    plt.savefig(f'{path}metrics.png', bbox_inches='tight')

    return fig


def get_predictions_table(learn, dl):
    '''
    Creates a table containing a row for each image stored in 'dl'.

    Args:
        learn (fastai.learner.Learner): trained learner object.
        dl (fastai.data.core.TfmdDL): dataloader with the images for making predictions.

    Returns:
        df (pandas.core.frame.DataFrame): table with columns=['file_name', 'ground_truth', 'prediction', 'loss', 'confidence'], sorted by 'loss' value in descending order.
    '''

    labels = dl.vocab
    file_paths = dl.dataset.items.file_path.values
    probs, ground_truths, losses = learn.get_preds(dl=dl, with_loss=True)
    predictions = np.argmax(probs, axis=1)
    data = np.array([file_paths, np.array(labels[ground_truths]), np.array(labels[predictions]), np.array(losses), np.max(probs.numpy(),axis=1)]).T
    table = pd.DataFrame(data=data, columns=["file_name", "ground_truth", "prediction", "loss", "confidence"])
    
    return table.sort_values(by='loss', ascending=False)

def get_metrics(learn):
    '''
    Returns a dictionary with the names and values of the metrics.
    '''
    
    try:
        names = [m.func.__name__ for m in learn.metrics]
    except:
        names = [m.__name__ for m in learn.metrics]
    values = learn.validate()[1:]        
    metrics = dict(zip(names, values))

    return metrics

In [None]:
import dill, itertools

# Export model
learn.export(f'{save_path}/model.pkl', pickle_module=dill)
learn.save(f'{save_path}/model')

# Plot losses and metrics across training
_ = plot_losses(learn, save_path)
_ = plot_metrics(learn, save_path)

# Get confusion matrix
probs, ground_truths = learn.get_preds(ds_idx=1)        # DO NOT PREDICT BEFORE PLOTTING LOSSES AND METRICS
ground_truths = ground_truths if n_tasks>1 else [ground_truths]
predictions = [np.argmax(probs[:,sum(n_classes[:i]):sum(n_classes[:i+1])], axis=1) for i in range(n_tasks)]
decoded_preds = [' '.join([classes[i][p] for i, p in enumerate(tensor(g))]) for g in zip(*predictions)]
decoded_gts = [' '.join([classes[i][p] for i, p in enumerate(tensor(g))]) for g in zip(*ground_truths)]
new_vocab = [' '.join(i) for i in list(itertools.product(*classes))]
_ = plot_confusion_matrix(decoded_gts, decoded_preds, new_vocab, save_path)

# Get tables of predictions
train_table = get_predictions_table(learn, learn.dls.train)
valid_table = get_predictions_table(learn, learn.dls.valid)
train_table.to_csv(f'{save_path}train_table.csv', index=False)
valid_table.to_csv(f'{save_path}valid_table.csv', index=False)

# Get final metrics
results = get_metrics(learn)
print(results)