In [None]:
DATA_PATH = '/data/datasets'
OUTPUT_PATH = '/data/tlbn/output'
MAX_EPOCHS = 500
REPETITIONS = 3

In [None]:
import datetime
import json
import shutil
from functools import partial
from pathlib import Path

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow import keras
from tensorflow.keras import layers

In [None]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        print(e)

### Datasets

### Chexpert 

In [None]:
def load_chexpert(**kwargs):
    """
    Download chexport-small from https://stanfordmlgroup.github.io/competitions/chexpert/
    """
    class params:
        batch_size = 32
        input_shape = (224, 224, 3)
        labels = ['No Finding', 'Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Pleural Effusion']
        mode = 'multi-binary'
        dropout = 0
        loss = 'binary_crossentropy'
        flipping = "horizontal"

    labels = params.labels
    df = pd.read_csv(f'{DATA_PATH}/chexpert/downloads/manual/CheXpert-v1.0-small/train.csv', index_col=0)
    all_labels = df.columns[4:]
    # drop lateral
    df = df[df['Frontal/Lateral'] == 'Frontal']

    # drop uncertain
    for label in labels:
        df = df[(df[label] != -1)]
    df = df[labels]
    df['Patient'] = df.index
    df['Patient'] = df['Patient'].apply(lambda x: x.split('/')[2])
    patient_ids = pd.Series(df['Patient'].unique())
    test_patients = patient_ids.sample(frac=0.3, random_state=0)
    train_patients = patient_ids.drop(test_patients.index)
    val_patients = train_patients.sample(frac=0.3, random_state=0)
    train_patients = train_patients.drop(val_patients.index)
    train = df[df['Patient'].isin(train_patients)]
    val = df[df['Patient'].isin(train_patients)]
    test = df[df['Patient'].isin(test_patients)]

    def subsample(df, labels, num_per_label=1000):
        data = []
        for i, label in enumerate(labels):
            data.append(df[df[label] == 1].sample(num_per_label, random_state=i))
        data.append(df[df['No Finding'] == 1].sample(num_per_label, random_state=i))
        return pd.concat(data).drop_duplicates()

    test = subsample(test, labels, num_per_label=200)
    train = subsample(train, labels, num_per_label=2000)
    val = subsample(val, labels, num_per_label=150)

    params.labels = params.labels[1:]
    train_gen = tf.keras.preprocessing.image.ImageDataGenerator(horizontal_flip=True)
    train_ds = train_gen.flow_from_dataframe(train.reset_index(),
                                             directory=f'{DATA_PATH}/chexpert/downloads/manual/',
                                             x_col='Path',
                                             class_mode='raw',
                                             y_col=params.labels,
                                             batch_size=params.batch_size,
                                             target_size=params.input_shape[:2]
                                             )

    val_gen = tf.keras.preprocessing.image.ImageDataGenerator()
    val_ds = val_gen.flow_from_dataframe(val.reset_index(),
                                         directory=f'{DATA_PATH}/chexpert/downloads/manual/',
                                         x_col='Path',
                                         class_mode='raw',
                                         y_col=params.labels,
                                         batch_size=params.batch_size,
                                         shuffle=False,
                                         target_size=params.input_shape[:2]
                                         )

    test_gen = tf.keras.preprocessing.image.ImageDataGenerator()
    test_ds = test_gen.flow_from_dataframe(test.reset_index(),
                                           directory=f'{DATA_PATH}/chexpert/downloads/manual/',
                                           x_col='Path',
                                           class_mode='raw',
                                           y_col=params.labels,
                                           batch_size=params.batch_size,
                                           shuffle=False,
                                           target_size=params.input_shape[:2]
                                           )
                                           
    return train_ds, val_ds, test_ds, params

### Camelyon 17

In [None]:
def process_patch_camelyon17_download():
    """
    download manually from https://wilds.stanford.edu/datasets/#camelyon17
    """
    path = Path(f'{DATA_PATH}/camelyon17_v1.0')
    df = pd.read_csv(f'{str(path)}/metadata.csv', index_col=0, dtype={'patient': 'str'})
    (path / 'splits').mkdir(parents=True, exist_ok=True)
    
    for hosp_idx in range(5):
        hosp_path = path / 'splits' / f'hosp_{hosp_idx}'
        hosp_path.mkdir(parents=True, exist_ok=True)
        for l in {0, 1}:
            subset_path = hosp_path / str(l)
            subset_path.mkdir(parents=True, exist_ok=True)

    for patient, node, x, y, h, label in df.loc[:,
                                         ['patient', 'node', 'x_coord', 'y_coord', 'center', 'tumor']].itertuples(
        index=False, name=None):
        src = path / Path(f'patches/patient_{patient}_node_{node}/patch_patient_{patient}_node_{node}_x_{x}_y_{y}.png')
        dest = path / Path(f'splits/hosp_{h}/{label}/patch_patient_{patient}_node_{node}_x_{x}_y_{y}.png')
        try:
            shutil.move(src, dest)
        except:
            pass

In [None]:
def load_patch_camelyon17(mode='full'):
    """
    Hosp 0,3,4: original training, validation, IID
    Hosp 2: original test set
    Hosp 1: validation OOD
    use only hosp 4 as training, hosp 1 as validation, and then
    the rest as 3 separate test sets
    """
    class params:
        batch_size = 64
        input_shape = (96, 96, 3)
        num_classes = 2
        mode = 'binary'
        dropout = 0.5
        labels = ['0', '1']
        loss = tf.keras.losses.binary_crossentropy
        flipping = "horizontal_and_vertical"

    root_path = Path(f'{DATA_PATH}/camelyon17_v1.0/splits')
    train_ds = tf.keras.preprocessing.image_dataset_from_directory(
        root_path / 'hosp_4',
        seed=1337,
        image_size=params.input_shape[:2],
        batch_size=params.batch_size,
    )
    val_ds = tf.keras.preprocessing.image_dataset_from_directory(
        root_path / 'hosp_1',
        seed=1338,
        image_size=params.input_shape[:2],
        batch_size=params.batch_size,
        validation_split=0.5,
        subset='validation'
    )

    if mode == 'small':
        train_ds = train_ds.take(100)
        val_ds = val_ds.take(15)
    test_ds = {}
    for name in ['hosp_2', 'hosp_3', 'hosp_0']:
        test_ds[name] = tf.keras.preprocessing.image_dataset_from_directory(
            root_path / name,
            seed=1339,
            image_size=params.input_shape[:2],
            batch_size=params.batch_size,
        )
        
    return train_ds, val_ds, test_ds, params

### Camelyon 16

In [None]:
def load_patch_camelyon16(mode='full'):
    """
    https://www.tensorflow.org/datasets/catalog/patch_camelyon
    """
    class params:
        batch_size = 64
        input_shape = (96, 96, 3)
        num_classes = 2
        mode = 'binary'
        dropout = 0.5
        labels = ['0', '1']
        loss = 'binary_crossentropy'
        flipping = "horizontal_and_vertical"

    if mode == 'full':
        train_ds, ds_info = tfds.load('patch_camelyon', split='train',
                                      shuffle_files=True, with_info=True,
                                      data_dir=DATA_PATH,
                                      read_config=tfds.ReadConfig(shuffle_seed=0,
                                                                  shuffle_reshuffle_each_iteration=True))
        val_ds, ds_info = tfds.load('patch_camelyon', split='validation',
                                    shuffle_files=True, with_info=True,
                                    data_dir=DATA_PATH,
                                    read_config=tfds.ReadConfig(shuffle_seed=0, shuffle_reshuffle_each_iteration=True))
    elif mode == 'small':
        train_ds, ds_info = tfds.load('patch_camelyon', split='train[:5000]',
                                      shuffle_files=True, with_info=True,
                                      data_dir=DATA_PATH,
                                      read_config=tfds.ReadConfig(shuffle_seed=0,
                                                                  shuffle_reshuffle_each_iteration=True))
        val_ds, ds_info = tfds.load('patch_camelyon', split='validation[:1000]',
                                    shuffle_files=True, with_info=True,
                                    data_dir=DATA_PATH,
                                    read_config=tfds.ReadConfig(shuffle_seed=0, shuffle_reshuffle_each_iteration=True))
    test_ds, ds_info = tfds.load('patch_camelyon', split='test',
                                 shuffle_files=True, with_info=True,
                                 data_dir=DATA_PATH)

    print(f"total of training samples: {tf.data.experimental.cardinality(train_ds)}")
    print(f"total of validation samples: {tf.data.experimental.cardinality(val_ds)}")
    print(f"total of test samples: {tf.data.experimental.cardinality(test_ds)}")
    tfds.visualization.show_examples(train_ds, ds_info)

    def load_image(image):
        return image['image'], image['label']
    train_ds = train_ds.map(load_image, num_parallel_calls=tf.data.AUTOTUNE).batch(params.batch_size).prefetch(
        tf.data.AUTOTUNE)
    val_ds = val_ds.map(load_image, num_parallel_calls=tf.data.AUTOTUNE).batch(params.batch_size).prefetch(
        tf.data.AUTOTUNE)
    test_ds = test_ds.map(load_image, num_parallel_calls=tf.data.AUTOTUNE).batch(params.batch_size).prefetch(
        tf.data.AUTOTUNE)
        
    print(f"Number of training batches: {tf.data.experimental.cardinality(train_ds)}")
    print(f"Number of validation batches: {tf.data.experimental.cardinality(val_ds)}")
    print(f"Number of test batches: {tf.data.experimental.cardinality(test_ds)}")
    return train_ds, val_ds, test_ds, params

### Chest x-ray

In [None]:
def load_chestxray(**kwargs):
    """
    https://www.cell.com/cell/fulltext/S0092-8674(18)30154-5
    TF records from https://keras.io/examples/vision/xray_classification_with_tpus/
    """
    class params:
        batch_size = 64
        input_shape = (180, 180, 3)
        num_classes = 2
        mode = 'binary'
        dropout = 0.5
        loss = 'binary_crossentropy'
        flipping = "horizontal"
        labels = ["NORMAL", "PNEUMONIA"]

    AUTOTUNE = tf.data.experimental.AUTOTUNE
    train_images = tf.data.TFRecordDataset(
        f'{DATA_PATH}/ChestXRay2017/train/images.tfrec'
    )
    train_paths = tf.data.TFRecordDataset(
        f'{DATA_PATH}/ChestXRay2017/train/paths.tfrec'
    )
    ds = tf.data.Dataset.zip((train_images, train_paths))
    COUNT_NORMAL = len(
        [
            filename
            for filename in train_paths
            if "NORMAL" in filename.numpy().decode("utf-8")
        ]
    )
    print("Normal images count in training set: " + str(COUNT_NORMAL))
    COUNT_PNEUMONIA = len(
        [
            filename
            for filename in train_paths
            if "PNEUMONIA" in filename.numpy().decode("utf-8")
        ]
    )
    print("Pneumonia images count in training set: " + str(COUNT_PNEUMONIA))
    def get_label(file_path):
        # convert the path to a list of path components
        parts = tf.strings.split(file_path, "/")
        # The second to last is the class-directory
        return parts[-2] == "PNEUMONIA"

    def decode_img(img):
        # convert the compressed string to a 3D uint8 tensor
        img = tf.image.decode_jpeg(img, channels=3)

        # resize the image to the desired size.
        return tf.image.resize(img, params.input_shape[:2])

    def process_path(image, path):
        label = get_label(path)
        # load the raw data from the file as a string
        img = decode_img(image)
        return img, tf.cast(label, tf.float32)

    def prepare(ds, cache=False):
        # This is a small dataset, only load it once, and keep it in memory.
        # use `.cache(filename)` to cache preprocessing work for datasets that don't
        # fit in memory.
        if cache:
            if isinstance(cache, str):
                ds = ds.cache(cache)
            else:
                ds = ds.cache()

        # `prefetch` lets the dataset fetch batches in the background while the model
        # is training.
        ds = ds.prefetch(buffer_size=AUTOTUNE)
        return ds

    def show_batch(image_batch, label_batch):
        plt.figure(figsize=(10, 10))
        for n in range(25):
            ax = plt.subplot(5, 5, n + 1)
            plt.imshow(image_batch[n] / 255)
            if label_batch[n]:
                plt.title("PNEUMONIA")
            else:
                plt.title("NORMAL")
            plt.axis("off")

    ds = ds.map(process_path, num_parallel_calls=AUTOTUNE)
    ds = ds.shuffle(10000, reshuffle_each_iteration=True, seed=0)
    ds = ds.batch(params.batch_size)
    train_ds = ds
    val_ds = prepare(train_ds.take(5))
    train_ds = prepare(train_ds.skip(5))

    #
    test_images = tf.data.TFRecordDataset(
        f'{DATA_PATH}/ChestXRay2017/test/images.tfrec'
    )
    test_paths = tf.data.TFRecordDataset(
        f'{DATA_PATH}/ChestXRay2017/test/paths.tfrec'
    )
    test_ds = tf.data.Dataset.zip((test_images, test_paths))
    test_ds = test_ds.map(process_path, num_parallel_calls=AUTOTUNE)
    test_ds = test_ds.batch(params.batch_size)

    print(f"Number of training batches: {tf.data.experimental.cardinality(train_ds)}")
    print(f"Number of validation batches: {tf.data.experimental.cardinality(val_ds)}")
    print(f"Number of test batches: {tf.data.experimental.cardinality(test_ds)}")
    
    return train_ds, val_ds, test_ds, params

### Malaria<br>


In [None]:
def load_malaria(**kwargs):
    """
    available from tensorflow datasets
    https://www.tensorflow.org/datasets/catalog/malaria
    """
    (train_ds, val_ds, test_ds), ds_info = tfds.load('malaria',
                                                     split=['train[:70%]',
                                                            'train[70%:85%]',
                                                            'train[85%:]'],
                                                     shuffle_files=True,
                                                     with_info=True,
                                                     read_config=tfds.ReadConfig(shuffle_seed=0,
                                                                                 shuffle_reshuffle_each_iteration=True)
                                                     )

    class params:
        input_shape = (120, 120, 3)
        batch_size = 32
        num_classes = 2
        mode = 'binary'
        dropout = 0.5
        labels = [i for i in range(num_classes)]
        loss = 'binary_crossentropy'
        flipping = "horizontal_and_vertical"

    tfds.visualization.show_examples(train_ds, ds_info)
    print(f"total of training samples: {tf.data.experimental.cardinality(train_ds)}")
    print(f"total of validation samples: {tf.data.experimental.cardinality(val_ds)}")
    print(f"total of test samples: {tf.data.experimental.cardinality(test_ds)}")
    
    def load_image(image):
        return tf.image.resize(image['image'], params.input_shape[:2]), image['label']

    train_ds = train_ds.map(load_image).batch(params.batch_size).prefetch(tf.data.AUTOTUNE)
    val_ds = val_ds.map(load_image).batch(params.batch_size).prefetch(tf.data.AUTOTUNE)
    test_ds = test_ds.map(load_image).batch(params.batch_size).prefetch(tf.data.AUTOTUNE)

    print(f"Number of training batches: {tf.data.experimental.cardinality(train_ds)}")
    print(f"Number of validation batches: {tf.data.experimental.cardinality(val_ds)}")
    print(f"Number of test batches: {tf.data.experimental.cardinality(test_ds)}")
    return train_ds, val_ds, test_ds, params

### Colorectal histology

In [None]:
def load_colorectal_histology(**kwargs):
    """
    https://www.tensorflow.org/datasets/catalog/colorectal_histology
    """
    (train_ds, val_ds, test_ds), ds_info = tfds.load('colorectal_histology',
                                                     split=['train[:70%]',
                                                            'train[70%:85%]',
                                                            'train[85%:]'],
                                                     shuffle_files=True,
                                                     read_config=tfds.ReadConfig(shuffle_seed=0,
                                                                                 shuffle_reshuffle_each_iteration=True),
                                                     with_info=True)
    class params:
        input_shape = (150, 150, 3)
        batch_size = 32
        num_classes = 8
        mode = 'categorical'
        dropout = 0.5
        labels = ["tumor",
                  "stroma",
                  "complex",
                  "lympho",
                  "debris",
                  "mucosa",
                  "adipose",
                  "empty"]

        loss = tf.keras.losses.sparse_categorical_crossentropy
        flipping = "horizontal_and_vertical"
    tfds.visualization.show_examples(train_ds, ds_info)

    print(f"total of training samples: {tf.data.experimental.cardinality(train_ds)}")
    print(f"total of validation samples: {tf.data.experimental.cardinality(val_ds)}")
    print(f"total of test samples: {tf.data.experimental.cardinality(test_ds)}")

    def load_image(image):
        print(image.keys())
        return image['image'], image['label']

    train_ds = train_ds.map(load_image).batch(params.batch_size).prefetch(tf.data.AUTOTUNE)
    val_ds = val_ds.map(load_image).batch(params.batch_size).prefetch(tf.data.AUTOTUNE)
    test_ds = test_ds.map(load_image).batch(params.batch_size).prefetch(tf.data.AUTOTUNE)

    print(f"Number of training batches: {tf.data.experimental.cardinality(train_ds)}")
    print(f"Number of validation batches: {tf.data.experimental.cardinality(val_ds)}")
    print(f"Number of test batches: {tf.data.experimental.cardinality(test_ds)}")
    
    return train_ds, val_ds, test_ds, params

### OCT

In [None]:
def load_oct(mode='full'):
    """
    https://data.mendeley.com/datasets/rscbjbr9sj/3
    http://dx.doi.org/10.17632/rscbjbr9sj.3
    """
    class params:
        batch_size = 32
        input_shape = (496 // 3, 1024 // 3, 3)
        num_classes = 4
        mode = 'categorical'
        dropout = 0.5
        labels = ['0', '1', '2', '3']
        loss = tf.keras.losses.sparse_categorical_crossentropy
        flipping = None

    root_path = Path(f'{DATA_PATH}/OCT')
    train_ds = tf.keras.preprocessing.image_dataset_from_directory(
        root_path / 'train',
        seed=1337,
        image_size=params.input_shape[:2],
        batch_size=params.batch_size,
        validation_split=0.2,
        subset='training'
    )

    val_ds = tf.keras.preprocessing.image_dataset_from_directory(
        root_path / 'train',
        seed=1337,
        image_size=params.input_shape[:2],
        batch_size=params.batch_size,
        validation_split=0.2,
        subset='validation'
    )

    if mode == 'small':
        train_ds = train_ds.take(100)
        val_ds = val_ds.take(15)
    test_ds = tf.keras.preprocessing.image_dataset_from_directory(
        root_path / 'test',
        seed=1339,
        image_size=params.input_shape[:2],
        batch_size=params.batch_size,
    )
    
    return train_ds, val_ds, test_ds, params

## Functions<br>


In [None]:
class MulticlassAUC(tf.keras.metrics.AUC):
    """AUC for a single class in a muliticlass problem.
    Parameters
    ----------
    pos_label : int
        Label of the positive class (the one whose AUC is being computed).
    from_logits : bool, optional (default: False)
        If True, assume predictions are not standardized to be between 0 and 1.
        In this case, predictions will be squeezed into probabilities using the
        softmax function.
    sparse : bool, optional (default: True)
        If True, ground truth labels should be encoded as integer indices in the
        range [0, n_classes-1]. Otherwise, ground truth labels should be one-hot
        encoded indicator vectors (with a 1 in the true label position and 0
        elsewhere).
    **kwargs : keyword arguments
        Keyword arguments for tf.keras.metrics.AUC.__init__(). For example, the
        curve type (curve='ROC' or curve='PR').
    """
    def __init__(self, pos_label, from_logits=False, sparse=False, **kwargs):
        super().__init__(**kwargs)
        self.pos_label = pos_label
        self.from_logits = from_logits
        self.sparse = sparse
    def update_state(self, y_true, y_pred, **kwargs):
        """Accumulates confusion matrix statistics.
        Parameters
        ----------
        y_true : tf.Tensor
            The ground truth values. Either an integer tensor of shape
            (n_examples,) (if sparse=True) or a one-hot tensor of shape
            (n_examples, n_classes) (if sparse=False).
        y_pred : tf.Tensor
            The predicted values, a tensor of shape (n_examples, n_classes).
        **kwargs : keyword arguments
            Extra keyword arguments for tf.keras.metrics.AUC.update_state
            (e.g., sample_weight).
        """
        if self.sparse:
            y_true = tf.math.equal(y_true, self.pos_label)
            y_true = tf.squeeze(y_true)
        else:
            y_true = y_true[..., self.pos_label]
        if self.from_logits:
            y_pred = tf.nn.softmax(y_pred, axis=-1)
        y_pred = y_pred[..., self.pos_label]
        super().update_state(y_true, y_pred, **kwargs)

In [None]:
def load_dataset(name, **kwargs):
    if name == 'chexpert':
        return load_chexpert(**kwargs)
    elif name == 'chexpert_small':
        return load_chexpert(**kwargs, mode='small')
    elif name == 'chest_xray':
        return load_chestxray(**kwargs)
    elif name == 'patch_camelyon17':
        return load_patch_camelyon17(**kwargs)
    elif name == 'patch_camelyon17_small':
        return load_patch_camelyon17(mode='small')
    elif name == 'patch_camelyon16':
        return load_patch_camelyon16(**kwargs)
    elif name == 'patch_camelyon16_small':
        return load_patch_camelyon16(**kwargs, mode='small')
    elif name == 'colorectal_histology':
        return load_colorectal_histology(**kwargs)
    elif name == 'malaria':
        return load_malaria(**kwargs)
    elif name == 'oct_small':
        return load_oct(mode='small')
    else:
        raise NotImplementedError('invalid dataset')

In [None]:
def build_model(input_shape, bn_train=False, dropout=0.5, labels=('0', '1'),
                weights='imagenet',
                mode='binary', model_name='efficientnetb3',
                flipping="horizontal_and_vertical", do_augmentation=True):

    data_augmentation = tf.keras.Sequential([
        layers.experimental.preprocessing.RandomFlip(flipping),
        layers.experimental.preprocessing.RandomTranslation(0.1, 0.1),
        layers.experimental.preprocessing.RandomZoom(0.1),
    ], name='data_augmentation')

    bn_train = None if bn_train == True else bn_train

    arch = {'densenet121': dict(model=tf.keras.applications.densenet.DenseNet121,
                                preprocess=tf.keras.applications.densenet.preprocess_input),
            'densenet169': dict(model=tf.keras.applications.densenet.DenseNet169,
                                preprocess=tf.keras.applications.densenet.preprocess_input),
            'efficientnetb1': dict(model=tf.keras.applications.efficientnet.EfficientNetB1,
                                   preprocess=tf.keras.applications.efficientnet.preprocess_input),
            'efficientnetb3': dict(model=tf.keras.applications.efficientnet.EfficientNetB3,
                                   preprocess=tf.keras.applications.efficientnet.preprocess_input),
            'efficientnetb5': dict(model=tf.keras.applications.efficientnet.EfficientNetB5,
                                   preprocess=tf.keras.applications.efficientnet.preprocess_input),
            'resnet50v2': dict(model=tf.keras.applications.ResNet50V2,
                               preprocess=tf.keras.applications.resnet_v2.preprocess_input),
            'inceptionv3': dict(model=tf.keras.applications.inception_v3.InceptionV3,
                                preprocess=tf.keras.applications.inception_v3.preprocess_input)
            }

    base_model = arch[model_name]['model'](include_top=False, weights=weights)
    preprocess_input = arch[model_name]['preprocess']
    inputs = keras.Input(shape=input_shape)

    # Image augmentation block
    if do_augmentation:
        x = data_augmentation(inputs)
    else:
        x = inputs
    x = layers.Lambda(lambda x: preprocess_input(x))(x)
    x = base_model(x, training=bn_train)
    x = layers.GlobalAveragePooling2D()(x)
   
    if mode == 'binary':
        activation = "sigmoid"
        units = 1
    elif mode == 'categorical':
        activation = "softmax"
        units = len(labels)
    elif mode == 'multi-binary':
        activation = "sigmoid"
        units = len(labels)
    else:
        raise NotImplementedError()
    x = layers.Dropout(dropout)(x)
  
    outputs = layers.Dense(units, activation=activation)(x)

    model = keras.Model(inputs, outputs)
    return model, base_model

In [None]:
def get_callbacks(path):
    early_stop = tf.keras.callbacks.EarlyStopping(
        monitor="val_loss",
        min_delta=0,
        patience=6,
        verbose=True,
        mode="auto",
        restore_best_weights=True,
    )
    reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2,
                                                     patience=1, min_delta=0.0001, min_lr=1e-6, verbose=True)
    csv_logger = tf.keras.callbacks.CSVLogger(filename=f'{path}/log.csv')
    # tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=f'{path}/tboard_log')
    # checkpoint = tf.keras.callbacks.ModelCheckpoint(f'{path}/checkpoints', save_best_only=True)
    return [early_stop, reduce_lr, csv_logger]

In [None]:
def get_metrics(labels=None, mode='binary'):
    metrics = ['accuracy']
    if mode == 'categorical' and len(labels) > 2:
        for i, label in enumerate(labels):
            metrics += [MulticlassAUC(pos_label=i, sparse=True,
                                      name=f"AUC_{label.replace(' ', '_')}")]
    else:
        multi_label = mode == 'multi-binary'
        metrics.append(tf.keras.metrics.AUC(multi_label=multi_label))
    if mode == 'multi-binary':
        for i, label in enumerate(labels):
            class_id = i
            metrics += [MulticlassAUC(pos_label=i,
                                      name=f"AUC_{label.replace(' ', '_')}")]
    return metrics

In [None]:
def make_bn_gamma_bias_trainable(base_model, momentum=0.99):
    base_model.trainable = True
    for layer in base_model.layers:
        layer.trainable = False
        if isinstance(layer, keras.layers.BatchNormalization):
            layer.trainable = True
            layer.momentum = momentum

In [None]:
def make_bn_moving_stats_trainable_only(base_model, momentum=0.99):
    base_model.trainable = True
    for layer in base_model.layers:
        layer.trainable = False
        if isinstance(layer, keras.layers.BatchNormalization):
            layer.trainable = True
            layer.momentum = momentum
            layer._non_trainable_weights += layer._trainable_weights
            layer._trainable_weights = []

In [None]:
def evaluate_and_save_log(path, model, test_ds, history, callbacks,
                          method, name='run_log'):
    def convert(o):
        if isinstance(o, np.float32): return float(o)
    logs = {'method': method}
    if history:
        logs.update({'train': history.history})
        logs['best_epoch'] = int(np.argmin(history.history['val_loss']) + 1)
    if callbacks:
        logs['stopped_epoch'] = callbacks[0].stopped_epoch
    log_path = path / f'{name}.json'
    if isinstance(test_ds, dict):
        for name, ds in test_ds.items():
            result = model.evaluate(ds)
            logs[f'test_{name}'] = dict(zip(model.metrics_names, result))
    else:
        result = model.evaluate(test_ds)
        logs['test'] = dict(zip(model.metrics_names, result))
    with log_path.open('w') as f:
        json.dump(logs, f, default=convert)

In [None]:
def save_params_to_json(params, path, name='params'):
    data = {k: v for k, v in params.__dict__.items() if not k.startswith('__') and k not in {'loss'}}
    param_path = path / f'{name}.json'
    with param_path.open('w') as f:
        json.dump(data, f)

In [None]:
def load_results():
    errs = []
    rows = []
    for f in Path(f'{OUTPUT_PATH}/results/').rglob('*log*.json'):
        method = f.parent.parent.parent.stem
        model_name = f.parent.parent.parent.parent.stem
        dataset_name = f.parent.parent.parent.parent.parent.stem
        with f.open() as c:
            log_data = json.load(c)
        try:
            del log_data['train']
        except:
            errs.append(f)
        try:
            if 'test_hosp_0' in log_data:
                log_data['test'] = {}
                for test_hosp in ['test_hosp_0', 'test_hosp_2', 'test_hosp_3']:
                    for k in log_data[test_hosp].keys():
                        log_data['test'][f"{k.replace('_1', '')}_{test_hosp}"] = log_data[test_hosp][k]
        except:
            pass
        for metric in ['auc', 'accuracy', 'loss']:
            s, n = 0, 0
            for k, v in log_data['test'].items():
                if k.lower().startswith(metric):
                    s += v
                    n += 1
            if n > 0:
                log_data['test'][metric] = s / n
        for met in ['auc_1', 'TPR', 'TNR']:
            try:
                del log_data['test'][met]
            except:
                pass
        log_data.update(log_data['test'])
        del log_data['test']
        if f.stem == 'run_log_fc':
            log_data['method'] = 'fc'
        log_data['model'] = model_name
        log_data['dataset'] = dataset_name
        log_data['path'] = f
        try:
            rows.append(log_data)
        except:
            pass
    df = pd.DataFrame(rows)
    try:
        df_g = df.groupby(['dataset',
                           'model',
                           'method']).agg(['mean', 'std', 'count']).round(3)
    except:
        df_g = None
    return df, df_g

## Experiments

### BN gamma and bias only + last layer

In [None]:
def run_bn_and_fc(expParams):
    method = expParams.__dict__.get('method', 'bn_and_fc')
    tf.keras.backend.clear_session()
    datenow = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    path = Path(f'{OUTPUT_PATH}/results/{expParams.dataset_name}/{expParams.model_name}/{method}/runs/{datenow}')
    print(path)
    path.mkdir(exist_ok=True, parents=True)

    train_ds, val_ds, test_ds, params = load_dataset(expParams.dataset_name)
    
    save_params_to_json(params, path, name='params')

    model, base_model = build_model(params.input_shape,
                                    model_name=expParams.model_name,
                                    bn_train=False, dropout=params.dropout,
                                    mode=params.mode,
                                    do_augmentation=expParams.do_augmentation,
                                    labels=params.labels)
    make_bn_gamma_bias_trainable(base_model)
    model.compile(loss=params.loss,
                  optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
                  metrics=get_metrics(params.labels, mode=params.mode))
    callbacks = get_callbacks(str(path))
    print(model.summary())
    history = model.fit(
        train_ds, epochs=MAX_EPOCHS,
        initial_epoch=0,
        validation_data=val_ds,
        callbacks=callbacks
    )
    if expParams.save_model:
        model.save(f'{path}/checkpoint')
    evaluate_and_save_log(path, model, test_ds, history, callbacks, method=method, name='run_log')

### FC then BN

In [None]:
def run_fc_then_bn(expParams, bn_training=False, fc_epochs=None, learning_rate_2=1e-3):
    method = expParams.__dict__.get('method', 'fc_then_bn')
    tf.keras.backend.clear_session()
    datenow = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    path = Path(f'{OUTPUT_PATH}/results/{expParams.dataset_name}/{expParams.model_name}/{method}/runs/{datenow}')
    print(path)
    path.mkdir(exist_ok=True, parents=True)

    train_ds, val_ds, test_ds, params = load_dataset(expParams.dataset_name)

    save_params_to_json(params, path, name='params')
    
    model, base_model = build_model(params.input_shape,
                                    model_name=expParams.model_name,
                                    bn_train=bn_training, dropout=params.dropout,
                                    mode=params.mode,
                                    do_augmentation=expParams.do_augmentation,
                                    labels=params.labels)

    # Train FC first
    ## freeze base
    base_model.trainable = False
    model.compile(loss=params.loss,
                  optimizer=keras.optimizers.Adam(learning_rate=1e-3),
                  metrics=get_metrics(params.labels, mode=params.mode))
    callbacks = get_callbacks(str(path))
    print(model.summary())
    history = model.fit(
        train_ds, epochs=100 if fc_epochs is None else fc_epochs,
        initial_epoch=0,
        validation_data=val_ds,
        callbacks=callbacks
    )
    if fc_epochs is None:
        stopped_epoch = callbacks[0].stopped_epoch
    else:
        stopped_epoch = fc_epochs
    base_model.trainable = True
    make_bn_gamma_bias_trainable(base_model)
    model.compile(loss=params.loss,
                  optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate_2),
                  metrics=get_metrics(params.labels, mode=params.mode))
    callbacks = get_callbacks(str(path))
    print(model.summary())
    history = model.fit(
        train_ds, epochs=MAX_EPOCHS,
        initial_epoch=stopped_epoch,
        validation_data=val_ds,
        callbacks=callbacks
    )
    if expParams.save_model:
        model.save(f'{path}/checkpoint')
    evaluate_and_save_log(path, model, test_ds, history, callbacks, method=method, name='run_log')

### FC, then full base with lower lr

In [None]:
def run_fc_then_full(expParams, bn_training=False):
    method = expParams.__dict__.get('method', 'fc_then_full')
    tf.keras.backend.clear_session()
    datenow = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    path = Path(f'{OUTPUT_PATH}/results/{expParams.dataset_name}/{expParams.model_name}/{method}/runs/{datenow}')

    print(path)
    path.mkdir(exist_ok=True, parents=True)

    train_ds, val_ds, test_ds, params = load_dataset(expParams.dataset_name)
    
    save_params_to_json(params, path, name='params')
    model, base_model = build_model(params.input_shape,
                                    model_name=expParams.model_name,
                                    bn_train=bn_training, dropout=params.dropout,
                                    mode=params.mode,
                                    do_augmentation=expParams.do_augmentation,
                                    labels=params.labels)

    # Train FC first
    ## freeze base
    base_model.trainable = False
    model.compile(loss=params.loss,
                  optimizer=keras.optimizers.Adam(learning_rate=1e-3),
                  metrics=get_metrics(params.labels, mode=params.mode))
    callbacks = get_callbacks(str(path))
    print(model.summary())
    history = model.fit(
        train_ds, epochs=100,
        initial_epoch=0,
        validation_data=val_ds,
        callbacks=callbacks
    )
    if expParams.save_model:
        model.save(f'{path}/checkpoint_fc')
    evaluate_and_save_log(path, model, test_ds, history, callbacks, method='fc',
                          name='run_log_fc')
    stopped_epoch = callbacks[0].stopped_epoch
    # unfreeze rest of model
    base_model.trainable = True
    model.compile(loss=params.loss,
                  optimizer=keras.optimizers.Adam(learning_rate=1e-5),
                  metrics=get_metrics(params.labels, mode=params.mode))
    callbacks = get_callbacks(str(path))
    print(model.summary())
    history = model.fit(
        train_ds, epochs=MAX_EPOCHS,
        initial_epoch=stopped_epoch,
        validation_data=val_ds,
        callbacks=callbacks
    )
    if expParams.save_model:
        model.save(f'{path}/checkpoint_fc_then_full')
    evaluate_and_save_log(path, model, test_ds, history,
                          callbacks, method=method, name='run_log_fc_then_full')

### Moving BN + FC

In [None]:
def run_moving_bn_and_fc(expParams):
    method = expParams.__dict__.get('method', 'moving_bn_and_fc')
    tf.keras.backend.clear_session()
    datenow = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    path = Path(f'{OUTPUT_PATH}/results/{expParams.dataset_name}/{expParams.model_name}/{method}/runs/{datenow}')
    print(path)
    path.mkdir(exist_ok=True, parents=True)
    train_ds, val_ds, test_ds, params = load_dataset(expParams.dataset_name)
    save_params_to_json(params, path, name='params')
    model, base_model = build_model(params.input_shape,
                                    model_name=expParams.model_name,
                                    bn_train=None, dropout=params.dropout,
                                    mode=params.mode,
                                    do_augmentation=expParams.do_augmentation,
                                    labels=params.labels)
    make_bn_moving_stats_trainable_only(base_model, momentum=0.95)
    model.compile(loss=params.loss,
                  optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
                  metrics=get_metrics(params.labels, mode=params.mode))
    callbacks = get_callbacks(str(path))
    print(model.summary())
    history = model.fit(
        train_ds, epochs=MAX_EPOCHS,
        initial_epoch=0,
        validation_data=val_ds,
        callbacks=callbacks
    )
    if expParams.save_model:
        model.save(f'{path}/checkpoint')
    evaluate_and_save_log(path, model, test_ds, history, callbacks, method=method, name='run_log')

### Random weights

In [None]:
def run_random_weights_bn_and_fc(expParams, bn_train=None):
    method = expParams.__dict__.get('method', 'random_weights_bn_and_fc')
    tf.keras.backend.clear_session()
    datenow = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    path = Path(f'{OUTPUT_PATH}/results/{expParams.dataset_name}/{expParams.model_name}/{method}/runs/{datenow}')
    print(path)
    path.mkdir(exist_ok=True, parents=True)
    train_ds, val_ds, test_ds, params = load_dataset(expParams.dataset_name)
    save_params_to_json(params, path, name='params')
    model, base_model = build_model(params.input_shape,
                                    model_name=expParams.model_name,
                                    bn_train=bn_train, dropout=params.dropout,
                                    weights=None,
                                    mode=params.mode,
                                    do_augmentation=expParams.do_augmentation,
                                    labels=params.labels)
    make_bn_gamma_bias_trainable(base_model, momentum=0.99)
    metrics = get_metrics(params.labels, mode=params.mode)
    model.compile(loss=params.loss,
                  optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
                  metrics=metrics)
    callbacks = get_callbacks(str(path))
    print(model.summary())
    history = model.fit(
        train_ds, epochs=MAX_EPOCHS,
        initial_epoch=0,
        validation_data=val_ds,
        callbacks=callbacks
    )
    for layer in base_model.layers:
        if isinstance(layer, keras.layers.BatchNormalization):
            print(layer.moving_mean)
            break
    if expParams.save_model:
        model.save(f'{path}/checkpoint')
    evaluate_and_save_log(path, model, test_ds, history, callbacks, method=method, name='run_log')

## Run

In [None]:
def should_run(dataset_name, model_name, method, df_g, max_count=3):
    try:
        count = df_g.loc[(dataset_name, model_name, method)][('loss', 'count')]
        print(count)
        if count < max_count:
            return True
        else:
            return False
    except Exception as e:
        print(e)
        return True

In [None]:
methods_map = {'fc_then_full': 'FC-then-full',
               'fc_then_bn': 'FC-then-BN',
               'bn_and_fc': 'FC-BN',
               'fc_then_bn_lr': 'FC-then-BN-lr',
               'fc': 'FC',
               'moving_bn_and_fc': 'FC-BMA',
               'fc_then_full_bnt': 'FC-then-full-BMA',
               'fc_then_bn_bnt': 'FC-then-BN-BMA',
               'random_weights_bn_and_fc': 'RND-FC-BN',
               'random_weights_bn_training_and_fc': 'RND-FC-BN-BMA'
               }
models_map = {'densenet121': 'DenseNet121',
              'resnet50v2': 'ResNet50v2',
              'inceptionv3': 'InceptionV3',
              'efficientnetb3': 'EfficientNetB3'}
datasets_map = {'chest_xray': 'Chest X-ray',
                'chexpert': 'CheXpert',
                'malaria': 'Malaria',
                'oct_small': 'OCT small',
                'colorectal_histology': 'Colorectal Histology',
                'patch_camelyon16_small': 'Patch Camelyon 16 small',
                'patch_camelyon16': 'Patch Camelyon16',
                'patch_camelyon17_small': 'Patch Camelyon 17 small',
                'patch_camelyon17': 'Patch Camelyon 17'}

In [None]:
method_dict = {
    'bn_and_fc': run_bn_and_fc,
    'fc_then_bn': partial(run_fc_then_bn, fc_epochs=1),
    'fc_then_bn_bnt': partial(run_fc_then_bn, fc_epochs=1, bn_training=None),
    'fc_then_bn_lr': partial(run_fc_then_bn, learning_rate_2=1e-5),
    'fc_then_full': run_fc_then_full,
    'fc_then_full_bnt': partial(run_fc_then_full, bn_training=None),
    'moving_bn_and_fc': run_moving_bn_and_fc,
    'random_weights_bn_and_fc': partial(run_random_weights_bn_and_fc, bn_train=False),
    'random_weights_bn_training_and_fc': partial(run_random_weights_bn_and_fc, bn_train=None)
}

In [None]:
df, df_g = load_results()

In [None]:
for run_id in range(REPETITIONS):
    for method, func in method_dict.items():
        for model_name in models_map.keys():
            for dataset_name in datasets_map.keys():
                class expParams:
                    dataset_name = dataset_name
                    model_name = model_name
                    save_model = False
                    do_augmentation = True
                    method = method

                if should_run(dataset_name, model_name, method, df_g max_count=REPETITIONS):
                    print(dataset_name, model_name, method)
                    try:
                        func(expParams)
                    except Exception as e:
                        print(e)
                    df, df_g = load_results()
                else:
                    print('already exists')

## Process results

load from file

In [None]:
def load_results():
    df = pd.read_csv('data/results.csv')
    df_g = df.groupby(['dataset', 'model', 'method']).agg(['mean', 'std', 'count']).round(3)
    return df, df_g

In [None]:
pd.set_option('display.max_rows', 2000)

In [None]:
df, df_g = load_results()
df_g = df_g[
    [('loss', 'mean'),
     ('loss', 'std'),
     ('loss', 'count'),
     ('auc', 'mean'),
     ('auc', 'std'),
     ]
]
df_g

### Tables

In [None]:
df, df_g = load_results()
df = df.replace(datasets_map).replace(models_map).replace(methods_map)

methods = [methods_map[m] for m in ['fc_then_full', 'fc_then_bn', 'fc', 'moving_bn_and_fc', 'random_weights_bn_and_fc']]
models = ['DenseNet121']
datasets = [datasets_map[m] for m in ['chexpert']]
df = df[df['method'].isin(methods)]
df = df[df['model'].isin(models)]
df = df[df['dataset'].isin(datasets)]
print(df.columns)
df = df[['method', 'AUC_Atelectasis',
         'AUC_Cardiomegaly', 'AUC_Consolidation', 'AUC_Edema',
         'AUC_Pleural_Effusion']]

mean = (df.groupby(['method']).agg(['mean']) * 100).round(1).astype('str')
mean.columns = mean.columns.droplevel(1)
std = (df.groupby(['method']).agg(['std']) * 100).round(1).astype('str')
std.columns = std.columns.droplevel(1)
for c in std.columns:
    std[c] = std[c].apply(lambda x: f' ({x})')
df = mean + std
df.columns = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Pleural Effusion']

In [None]:
print(df.to_latex())

In [None]:
df, df_g = load_results()
df = df.replace(datasets_map).replace(models_map).replace(methods_map)

methods = [methods_map[m] for m in ['fc_then_full', 'fc_then_bn', 'fc', 'moving_bn_and_fc', 'random_weights_bn_and_fc']]
models = ['DenseNet121']
datasets = [datasets_map[m] for m in ['oct_small']]
df = df[df['method'].isin(methods)]
df = df[df['model'].isin(models)]
df = df[df['dataset'].isin(datasets)]
print(df.columns)
df = df[['method', 'AUC_0',
         'AUC_1', 'AUC_2', 'AUC_3']]

mean = (df.groupby(['method']).agg(['mean']) * 100).round(1).astype('str')
mean.columns = mean.columns.droplevel(1)
std = (df.groupby(['method']).agg(['std']) * 100).round(1).astype('str')
std.columns = std.columns.droplevel(1)
for c in std.columns:
    std[c] = std[c].apply(lambda x: f' ({x})')
df = mean + std
df.columns = ['CNV', 'DME', 'DRUSEN', 'NORMAL']

print(df.to_latex())

In [None]:
df, df_g = load_results()
df = df.replace(datasets_map).replace(models_map).replace(methods_map)

methods = [methods_map[m] for m in ['fc_then_full', 'fc_then_bn', 'fc', 'moving_bn_and_fc', 'random_weights_bn_and_fc']]
models = ['DenseNet121']
datasets = [datasets_map[m] for m in ['colorectal_histology']]
df = df[df['method'].isin(methods)]
df = df[df['model'].isin(models)]
df = df[df['dataset'].isin(datasets)]
print(df.columns)
df = df[['method', 'AUC_tumor', 'AUC_stroma', 'AUC_complex', 'AUC_lympho', 'AUC_debris',
         'AUC_mucosa', 'AUC_adipose', 'AUC_empty']]

mean = (df.groupby(['method']).agg(['mean']) * 100).round(1).astype('str')
mean.columns = mean.columns.droplevel(1)
std = (df.groupby(['method']).agg(['std']) * 100).round(1).astype('str')
std.columns = std.columns.droplevel(1)
for c in std.columns:
    std[c] = std[c].apply(lambda x: f' ({x})')
df = mean + std
df.columns = [s.split('_')[1].capitalize() for s in
              ['AUC_tumor', 'AUC_stroma', 'AUC_complex', 'AUC_lympho', 'AUC_debris',
               'AUC_mucosa', 'AUC_adipose', 'AUC_empty']]

print(df.to_latex())

In [None]:
df, df_g = load_results()
df = df.replace(datasets_map).replace(models_map).replace(methods_map)

methods = [methods_map[m] for m in ['fc_then_full', 'fc_then_bn', 'fc', 'moving_bn_and_fc', 'random_weights_bn_and_fc']]
models = ['DenseNet121']
datasets = [datasets_map[m] for m in ['patch_camelyon17', 'patch_camelyon17_small']]
df = df[df['method'].isin(methods)]
df = df[df['model'].isin(models)]
df = df[df['dataset'].isin(datasets)]

df = df[['method', 'dataset', 'auc_test_hosp_2', 'auc_test_hosp_3', 'auc_test_hosp_0']]

mean = (df.groupby(['dataset', 'method']).agg(['mean']) * 100).round(1).astype('str')
mean.columns = mean.columns.droplevel(1)
std = (df.groupby(['dataset', 'method']).agg(['std']) * 100).round(1).astype('str')
std.columns = std.columns.droplevel(1)
for c in std.columns:
    std[c] = std[c].apply(lambda x: f' ({x})')
df = mean + std

df.columns = ['Hosp. 0', 'Hosp. 2', 'Hosp. 3']
print(df.to_latex())

### Figures

In [None]:
df, df_g = load_results()

df = df.replace(datasets_map).replace(models_map).replace(methods_map)
methods = [methods_map[m] for m in ['fc_then_full', 'fc_then_bn', 'fc', 'moving_bn_and_fc', 'random_weights_bn_and_fc']]
models = models_map.values()
datasets = datasets_map.values()
df = df[df['method'].isin(methods)]
sns.set(font_scale=2, rc={'text.usetex': True})

g = sns.catplot(x='model', y='auc', hue='method',
                hue_order=methods,
                order=models,
                data=df, col='dataset', col_order=datasets, ci="sd", kind='bar', col_wrap=3,
                height=5, aspect=1.5,
                legend=False, legend_out=False
                )

(g.set_axis_labels("", "AUC")
 .set_titles("{col_name}")
 .set(ylim=(0.45, 1))
 .despine(left=True).set_xticklabels(rotation=90))

handles = g._legend_data.values()
labels = g._legend_data.keys()
g.fig.legend(handles=handles, labels=labels, loc='upper center', ncol=5)
g.fig.subplots_adjust(top=0.92, bottom=0.08)
# g.savefig(f'{OUTPUT_PATH}/auc_all.png')

In [None]:
df, df_g = load_results()
df = df.replace(datasets_map).replace(models_map).replace(methods_map)

methods = [methods_map[m] for m in
           ['fc_then_full', 'fc_then_full_bnt', 'fc_then_bn', 'fc_then_bn_bnt', 'fc', 'moving_bn_and_fc']]
models = ['DenseNet121']
datasets = datasets_map.values()
df = df[df['method'].isin(methods)]

sns.set(font_scale=2, rc={'text.usetex': True})
g = sns.catplot(x='model', y='auc', hue='method',
                hue_order=methods,
                order=models,
                data=df, col='dataset', col_order=datasets, ci="sd", kind='bar', col_wrap=3,
                height=6, aspect=1.5, legend=False)

(g.set_axis_labels("", "AUC")
 .set_titles("{col_name}")
 .set(ylim=(0.45, 1))
 .despine(left=True).set_xticklabels(rotation=90))
handles = g._legend_data.values()
labels = g._legend_data.keys()
g.fig.legend(handles=handles, labels=labels, loc='upper center', ncol=6)
g.fig.subplots_adjust(top=0.92, bottom=0.08)

In [None]:
df, df_g = load_results()
df = df.replace(datasets_map).replace(models_map).replace(methods_map)
methods = [methods_map[m] for m in ['random_weights_bn_and_fc', 'random_weights_bn_training_and_fc']]
models = [models_map[m] for m in ['densenet121', 'efficientnetb3']]
datasets = datasets_map.values()
df = df[df['method'].isin(methods)]

sns.set(font_scale=2)
g = sns.catplot(x='model', y='auc', hue='method',
                hue_order=methods,
                order=models,
                data=df, col='dataset', col_order=datasets, ci="sd", kind='bar', col_wrap=3,
                height=5, aspect=1.5, legend=False)

(g.set_axis_labels("", "AUC")
 .set_titles("{col_name}")
 .set(ylim=(0.45, 1))
 .despine(left=True).set_xticklabels(rotation=90))
handles = g._legend_data.values()
labels = g._legend_data.keys()
g.fig.legend(handles=handles, labels=labels, loc='upper center', ncol=6)
g.fig.subplots_adjust(top=0.92, bottom=0.08)

In [None]:
df, df_g = load_results()
df = df.replace(datasets_map).replace(models_map).replace(methods_map)
methods = [methods_map[m] for m in ['fc_then_bn', 'bn_and_fc', 'fc_then_bn_lr']]
models = [models_map[m] for m in ['densenet121', 'efficientnetb3']]
datasets = datasets_map.values()
df = df[df['method'].isin(methods)]

sns.set(font_scale=2)
g = sns.catplot(x='model', y='auc', hue='method',
                hue_order=methods,
                order=models,
                data=df, col='dataset', col_order=datasets, ci="sd", kind='bar', col_wrap=3,
                height=5, aspect=1.5, legend=False)

(g.set_axis_labels("", "AUC")
 .set_titles("{col_name}")
 .set(ylim=(0.45, 1))
 .despine(left=True).set_xticklabels(rotation=90))
handles = g._legend_data.values()
labels = g._legend_data.keys()
g.fig.legend(handles=handles, labels=labels, loc='upper center', ncol=6)
g.fig.subplots_adjust(top=0.92, bottom=0.08)