## Options

In [2]:
# Paths
root = 'path to the root folder'
images_folder = root + 'subpath to the folder with the images'
save_path = root + 'models/'
table_path = root + 'subpath to multicare_multiplex.csv'

# Settings
use_oversampling = False # select True or False

# Data setup

In [None]:
from fastai.vision.all import *
import albumentations
from DLOlympus.training.transforms import AlbumentationsTransform
from DLOlympus.training.utils import get_model
from DLOlympus.training.unbalanced import oversampled_epoch

In [4]:
from sklearn.metrics import f1_score

radiology_ids = torch.Tensor([7, 9, 22, 24, 28])
angiography_ids = torch.Tensor([2, 23])

mapping = {
    (2, 7): 29,
    (7, 23): 30,
    (2, 9): 31,
    (9, 23): 32,
    (2, 22): 33,
    (22, 23): 34,
    (2, 24): 35,
    (23, 24): 36,
    (2, 28): 37,
    (23, 28): 38
}

def new_classes(x):
    # Convert input tensor to a tuple of integers
    key = tuple(x.int().tolist())
    # Return the mapped value or default to the first element of x
    return mapping.get(key, int(x[0].item()))

def multilabel2multiclass(probs, ground_truths):
    probs = probs.cpu()
    ground_truths = ground_truths.cpu()
    new_preds = []
    new_gts = []
    for p, gt in zip(probs, ground_truths):
        # Get the id of the top prediction
        pred = p.argmax()
        # Check if the id corresponds to any of the dual classes
        is_ang = pred in angiography_ids
        is_rad = pred in radiology_ids
        # If angiography type is predicted, get radiology type prediction
        if is_ang:
            pred = torch.stack((pred, radiology_ids[p[radiology_ids.int()].argmax()]))
        # If radiology type is predicted, get angiography type prediction
        elif is_rad:
            pred = torch.stack((pred, angiography_ids[p[angiography_ids.int()].argmax()]))
        else:
            pred = pred[None]
        # Convert indices to the new format
        pred, _ = torch.sort(pred)
        pred = new_classes(pred)
        gt, _ = torch.sort(gt.nonzero())
        gt = new_classes(gt[:,0])
        new_preds.append(pred)
        new_gts.append(gt)
    # print(np.array(new_preds), np.array(new_gts))
    return torch.Tensor(new_preds), torch.Tensor(new_gts)

def new_accuracy(probs, ground_truths):
    predictions, ground_truths = multilabel2multiclass(probs, ground_truths)
    return (predictions == ground_truths).float().mean()

def _accumulate(self, learn):
    m = nn.Sigmoid()
    pred = learn.pred
    targ = learn.y
    pred,targ = to_detach(pred),to_detach(targ)
    pred, targ = multilabel2multiclass(pred, targ)
    self.preds.append(pred)
    self.targs.append(targ)
AccumMetric.accumulate = _accumulate
def NewF1Score():
    return skm_to_fastai(f1_score, average='macro')

In [5]:
# Hyperparameters

h, w = 224, 224

hyperparameters = {
    'model_description': 'multilabel',
    'BS': 16,
    'EPOCHS': 30,
    'IMG_SIZE': (h, w),      # (height, width)
    'WD': 0.0,
    'TRANSFORMS': [
        albumentations.HorizontalFlip(p=0.5),
        albumentations.VerticalFlip(p=0.5),
        albumentations.Rotate(p=0.5),
        albumentations.Sharpen(p=0.5),
        albumentations.ColorJitter(brightness=0.3, contrast=0.5, saturation=0.5, hue=0.0, p=0.5),
        albumentations.RGBShift(p=0.5),
        albumentations.GaussianBlur(p=0.5),
        albumentations.GaussNoise(p=0.5),
        albumentations.RandomSizedCrop((int(0.75*h),h), h, w, p=1.0)
        ],
    'ARCH': 'resnet50',
    'ARCH_TYPE': 'torchvision',
    'LOSS_FUNC': 'BCEWithLogitsLossFlat',
    'OPT_FUNC': 'Adam',
    'USE_OVERSAMPLING': use_oversampling,
    'SEED': 18,
}

# Metrics and callbacks
metrics = [new_accuracy, NewF1Score()]
callbacks = [SaveModelCallback(monitor='f1_score', with_opt=True), ShowGraphCallback]

In [6]:
import pandas as pd
from sklearn.model_selection import StratifiedGroupKFold

def get_gt(x):
    # Convert string to list: discard extra characters and delete ultrasound label
    return [s for s in x.split("'") if len(s)>1 and ' ' not in s and s != 'ultrasound']

def get_data(table_path):
    # Read data
    data = pd.read_csv(table_path)
    # Get relevant info
    image_files = np.array([f'{images_folder}{s[:4]}/{s[:6]}/{s}' for s in (data['file'].values)])
    labels = data['label_list_with_negative_classes'].apply(get_gt)
    groups = data['patient_id'].values
    return image_files, labels, groups

def create_df(image_files, labels, groups, n_splits=10, n_valid=2):
    # Initiate dataframe
    df = pd.DataFrame()
    df['file_path'] = image_files
    df['label'] = labels.values
    df['groups'] = groups
    df['fold'] = -1
    # Make folds
    cv = StratifiedGroupKFold(n_splits=n_splits)
    for i, (train_idxs, valid_idxs) in enumerate(cv.split(image_files, labels.apply(str), groups)):
        df.loc[valid_idxs, ['fold']] = i
    # Assign folds for validation
    df['split'] = 'train'
    for i in range (n_valid):
        df.loc[df.fold == i, ['split']] = 'valid'
    del df['fold']
    df.split.value_counts()
    # Add a binary column to the dataframe
    df['is_valid'] = df.split == 'valid'
    del df['split']
    return df

In [None]:
# Dataframe
image_files, labels, groups = get_data(table_path)
df = create_df(image_files, labels, groups)

df['label'].value_counts()

In [None]:
set_seed(hyperparameters['SEED'], True)

# Datablock
block = DataBlock(
    blocks=(ImageBlock, MultiCategoryBlock),
    get_x=ColReader('file_path'),
    get_y=ColReader('label'),
    splitter=ColSplitter(col='is_valid'),
    item_tfms=[
        Resize(hyperparameters['IMG_SIZE'], method='squish'), 
        AlbumentationsTransform(albumentations.Compose(hyperparameters['TRANSFORMS']))])

# Dataloaders
dls = block.dataloaders(df, bs=hyperparameters['BS'], shuffle=True)
dls.rng.seed(hyperparameters['SEED'])

# Sanity check
num_classes = dls.c
classes = dls.vocab
print('Number of clases: ', num_classes)
print('Names of classes: ', classes)

In [None]:
# Show batch
dls.train.show_batch(max_n=16, figsize=(15,12))

In [None]:
# Show transforms
dls.train.show_batch(max_n=16, unique=True, figsize=(15,12))

In [None]:
# Learner
learn = vision_learner(dls,
                        get_model(hyperparameters),
                        normalize=True,
                        pretrained=True,
                        loss_func=getattr(sys.modules[__name__], hyperparameters['LOSS_FUNC'])(),
                        opt_func=getattr(sys.modules[__name__], hyperparameters['OPT_FUNC']),
                        metrics=metrics,
                        wd=hyperparameters['WD']).to_fp16()

# Fix issue with pickling while calling learn.export
import typing, functools
learn.loss_func.func.__annotations__ = typing.get_type_hints(learn.loss_func.func, globalns=globals(), localns=locals())
functools.update_wrapper(learn.loss_func, learn.loss_func.func)

In [16]:
# Oversampling
if hyperparameters['USE_OVERSAMPLING']:
    class_weights = pd.DataFrame(1 / np.sqrt(learn.dls.items.label.value_counts())).rename(index=lambda x: str(x)).to_dict()['count']
    learn.dls.train.get_idxs = types.MethodType(partial(oversampled_epoch, class_weights=class_weights), learn.dls.train)

# Training

In [None]:
# Find LR
learn.lr_find()

In [18]:
# Set LR
hyperparameters['LR'] = 3e-3

In [None]:
# Train
learn.fine_tune(hyperparameters['EPOCHS'], base_lr=hyperparameters['LR'], cbs=callbacks)

# Results and logs

In [16]:
def new_vocab(vocab): 
    vocab = list(copy(vocab))
    # Add composed classes
    for v in ['ct + angiography', 'ct + not_angiography', 'echocardiogram + angiography', 'echocardiogram + not_angiography', 'mri + angiography', 'mri + not_angiography', 'other_ultrasound + angiography', 'other_ultrasound + not_angiography', 'x_ray + angiography', 'x_ray + not_angiography']:
        vocab.append(v)
    # Delete single classes that are not needed
    for index in sorted(list(np.array(np.concatenate((angiography_ids, radiology_ids)), dtype=int)), reverse=True):
        del vocab[index]
    return CategoryMap(vocab, sort=False)

In [None]:
learn.export(f'{save_path}/model.pkl')
learn.save(f'{save_path}/model')

from DLOlympus.training.plots import plot_confusion_matrix, plot_losses, plot_metrics
_ = plot_losses(learn, save_path)
_ = plot_metrics(learn, save_path)
probs, ground_truths = learn.get_preds(ds_idx=1)        # DO NOT PREDICT BEFORE PLOTTING LOSSES AND METRICS
predictions, ground_truths_trans = multilabel2multiclass(probs, ground_truths)
vocab = new_vocab(learn.dls.vocab)
_ = plot_confusion_matrix(ground_truths_trans, predictions, vocab, save_path, figsize=(22,16))

from DLOlympus.training.utils import get_metrics
results = get_metrics(learn, with_tta=False)