In [None]:
import math
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import skew, kurtosis, variation

In [None]:
base_directory = "../results/image_classification/"
results_column = "test_accuracy"
round_digits = 4
cvar_alpha = 0.05

In [None]:
experiments_by_model_and_dataset = [
    # EX1 - Small Image Classification
    [
        "ResNet20-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "ResNet56-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "ResNet110-cifar10-idun-A100-PyTorch-ngc2312.csv",
    ],
    [
        "ViTS8-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "ViTB8-cifar10-idun-A100-PyTorch-ngc2312.csv",
    ],
    [
        "ResNet20-cifar100-idun-A100-PyTorch-ngc2312.csv",
        "ResNet56-cifar100-idun-A100-PyTorch-ngc2312.csv",
        "ResNet110-cifar100-idun-A100-PyTorch-ngc2312.csv",
    ],
    [
        "ViTS8-cifar100-idun-A100-PyTorch-ngc2312.csv",
        "ViTB8-cifar100-idun-A100-PyTorch-ngc2312.csv",
    ],
    # EX2 - Large Image Classification
    [
        "ResNet18-oxford_flowers102-idun-A100-PyTorch-ngc2312.csv",
        "ResNet50-oxford_flowers102-idun-A100-PyTorch-ngc2312.csv",
        "ResNet101-oxford_flowers102-idun-A100-PyTorch-ngc2312.csv",
    ],
    [
        "ViTTiny16-oxford_flowers102-idun-A100-PyTorch-ngc2312.csv",
        "ViTS16-oxford_flowers102-idun-A100-PyTorch-ngc2312.csv",
        "ViTB16-oxford_flowers102-idun-A100-PyTorch-ngc2312.csv",
    ],
    [
        "ViTTiny16-oxford_flowers102-idun-A100-PyTorch-ngc2312-pretrained.csv",
        "ViTS16-oxford_flowers102-idun-A100-PyTorch-ngc2312-pretrained.csv",
        "ViTB16-oxford_flowers102-idun-A100-PyTorch-ngc2312-pretrained.csv",
    ],

    [
        "ResNet18-uc_merced-idun-A100-PyTorch-ngc2312.csv",
        "ResNet50-uc_merced-idun-A100-PyTorch-ngc2312.csv",
        "ResNet101-uc_merced-idun-A100-PyTorch-ngc2312.csv",
    ],
    [
        "ViTTiny16-uc_merced-idun-A100-PyTorch-ngc2312.csv",
        "ViTS16-uc_merced-idun-A100-PyTorch-ngc2312.csv",
        "ViTB16-uc_merced-idun-A100-PyTorch-ngc2312.csv",
    ],
    [
        "ViTTiny16-uc_merced-idun-A100-PyTorch-ngc2312-pretrained.csv",
        "ViTS16-uc_merced-idun-A100-PyTorch-ngc2312-pretrained.csv",
        "ViTB16-uc_merced-idun-A100-PyTorch-ngc2312-pretrained.csv",
    ],
    # EX3 - Learning Rate Warmup Comparison
    [
        "ResNet20-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "learning_rate/ResNet20LR-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "ResNet56-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "learning_rate/ResNet56LR-cifar10-idun-A100-PyTorch-ngc2312.csv",
    ],
    [
        "ResNet20-cifar100-idun-A100-PyTorch-ngc2312.csv",
        "learning_rate/ResNet20LR-cifar100-idun-A100-PyTorch-ngc2312.csv",
        "ResNet56-cifar100-idun-A100-PyTorch-ngc2312.csv",
        "learning_rate/ResNet56LR-cifar100-idun-A100-PyTorch-ngc2312.csv",
    ],
    # EX4 - Random train/val/test splits
    [
        "ResNet20-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "dataset_splits/ResNet20_01-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "dataset_splits/ResNet20_02-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "dataset_splits/ResNet20_03-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "dataset_splits/ResNet20_04-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "dataset_splits/ResNet20_05-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "dataset_splits/ResNet20_06-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "dataset_splits/ResNet20_07-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "dataset_splits/ResNet20_08-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "dataset_splits/ResNet20_09-cifar10-idun-A100-PyTorch-ngc2312.csv",
    ],
    [
        "ResNet50-oxford_flowers102-idun-A100-PyTorch-ngc2312.csv",
        "dataset_splits/ResNet50_01-oxford_flowers102-idun-A100-PyTorch-ngc2312.csv",
        "dataset_splits/ResNet50_02-oxford_flowers102-idun-A100-PyTorch-ngc2312.csv",
        "dataset_splits/ResNet50_03-oxford_flowers102-idun-A100-PyTorch-ngc2312.csv",
        "dataset_splits/ResNet50_04-oxford_flowers102-idun-A100-PyTorch-ngc2312.csv",
        "dataset_splits/ResNet50_05-oxford_flowers102-idun-A100-PyTorch-ngc2312.csv",
        "dataset_splits/ResNet50_06-oxford_flowers102-idun-A100-PyTorch-ngc2312.csv",
        "dataset_splits/ResNet50_07-oxford_flowers102-idun-A100-PyTorch-ngc2312.csv",
        "dataset_splits/ResNet50_08-oxford_flowers102-idun-A100-PyTorch-ngc2312.csv",
        "dataset_splits/ResNet50_09-oxford_flowers102-idun-A100-PyTorch-ngc2312.csv",
    ],
    # EX5 - PyTorch vs TensorFlow
    [
        "ResNet20-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "ResNet56-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "ResNet110-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "ResNet20-cifar10-idun-A100-TensorFlow-ngc2312.csv",
        "ResNet56-cifar10-idun-A100-TensorFlow-ngc2312.csv",
        "ResNet110-cifar10-idun-A100-TensorFlow-ngc2312.csv",
    ],
    [
        "ResNet20-cifar100-idun-A100-PyTorch-ngc2312.csv",
        "ResNet56-cifar100-idun-A100-PyTorch-ngc2312.csv",
        "ResNet110-cifar100-idun-A100-PyTorch-ngc2312.csv",
        "ResNet20-cifar100-idun-A100-TensorFlow-ngc2312.csv",
        "ResNet56-cifar100-idun-A100-TensorFlow-ngc2312.csv",
        "ResNet110-cifar100-idun-A100-TensorFlow-ngc2312.csv",
    ],
    # EX6 - Increasing Epochs
    [
        "epochs/ResNet20_e010-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "epochs/ResNet20_e020-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "epochs/ResNet20_e030-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "epochs/ResNet20_e040-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "epochs/ResNet20_e050-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "epochs/ResNet20_e075-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "epochs/ResNet20_e100-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "epochs/ResNet20_e125-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "epochs/ResNet20_e150-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "epochs/ResNet20_e175-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "epochs/ResNet20_e200-cifar10-idun-A100-PyTorch-ngc2312.csv",
    ],
    [
        "epochs/ViTS8_e010-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "epochs/ViTS8_e020-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "epochs/ViTS8_e030-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "epochs/ViTS8_e040-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "epochs/ViTS8_e050-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "epochs/ViTS8_e075-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "epochs/ViTS8_e100-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "epochs/ViTS8_e125-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "epochs/ViTS8_e150-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "epochs/ViTS8_e175-cifar10-idun-A100-PyTorch-ngc2312.csv",
        "epochs/ViTS8_e200-cifar10-idun-A100-PyTorch-ngc2312.csv",
    ]
]

In [None]:
def calculate_cvar(dataset, alpha):
    # alpha = 0.1 = 90% 
    # alpha = 0.05 = 95%
    # alpha = 0.01 = 99%    

    dataset.sort()
    var = np.quantile(dataset, alpha)
    cvar = dataset[dataset <= var].mean().round(round_digits)
    return(cvar)

In [None]:
def save_kde(data, title):
    fig, ax = plt.subplots(figsize=(20, 10))

    sns.histplot(
        data,
        kde=True,
        stat="proportion",
        kde_kws=dict(cut=3),
        legend=True,
    )

    plt.title(title, fontsize=25)

    # Increase the legend font size
    plt.setp(ax.get_legend().get_texts(), fontsize="25") 
    plt.xlabel("Top-1 Accuracy", fontsize=25)
    plt.ylabel("Proportion", fontsize=25)
    plt.xticks(fontsize=25)
    plt.yticks(fontsize=25)

    plt.show()

    fig.savefig(title.replace(":", "").replace(" ",  "_").replace("(", "").replace(")", "") + "_kde.png", pad_inches=0.1, bbox_inches='tight')

In [None]:
def save_boxplot(data, title):
    fig, ax = plt.subplots(figsize=(15, 7))

    sns.boxplot(pd.DataFrame.from_dict(data))

    plt.title(title, fontsize=16)
    plt.ylabel("Accuracy", fontsize=16)
    plt.tick_params(labelsize=14)

    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.set_xticklabels(ax.get_xticklabels(),rotation=90)

    plt.show()

    fig.savefig(title.replace(":", "").replace(" ",  "_").replace("(", "").replace(")", "") + "_box.png", pad_inches=0.1, bbox_inches='tight')

In [None]:
def save_cvar(data, title):
    # Calculate the min and max values for the x-axis
    min = 0
    max = 0

    for run_name in data:
        run_min = data[run_name].min()
        run_max = data[run_name].max()
    
        if run_min < min or min == 0:
            min = run_min
        if run_max > max or max == 100:
            max = run_max

    min = min - (min * 0.01)
    max = max + (max * 0.01)

    # Set the number of columns and rows for the subplots
    # 2 columns and as many rows as needed to fit all the data
    ncols = 2
    nrows = math.ceil((len(data) / ncols))

    if nrows == 1:
        nrows = 2

    fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(20, 10))
    fig.tight_layout(pad=7)

    fig.suptitle(title, fontsize=25)
    plt.legend(fontsize='x-large', title_fontsize='40')

    for idx, run_name in enumerate(data):
        # Find the row and column for the current plot
        column = idx % ncols
        row = idx // ncols

        # Calculate the summary statistics
        dataset_mean = data[run_name].mean()
        cvar = calculate_cvar(data[run_name], cvar_alpha)

        sns.histplot(
            data[run_name],
            kde=True,
            stat="proportion",
            kde_kws=dict(cut=3),
            legend=True,
            ax=ax[row, column],
        ) 

        # Add the mean and CVaR to the plot
        ax[row, column].axvline(dataset_mean, color='red', linestyle='solid', label="Mean: %.2f%% " % (dataset_mean *100) )
        ax[row, column].axvline(cvar, color='red', linestyle='dashed', label="CVaR: %.2f%% " % (cvar *100)  )

        ax[row, column].legend(fontsize=25)

        ax[row, column].set_title(run_name, fontsize=25)

        ax[row, column].set_xlabel("Top-1 Accuracy", fontsize=25)
        ax[row, column].set_ylabel("proportion", fontsize=25)
        
        ax[row, column].tick_params(labelsize=20)

        ax[row, column].set_xlim([min, max])
        ax[row, column].set_ylim([0, .30])
        
    plt.show()

    fig.savefig(title.replace(":", "").replace(" ",  "_").replace("(", "").replace(")", "") + "_cvar.png", pad_inches=0.1, bbox_inches='tight')

In [None]:
# Save the summary statistics so they can be saved to a CSV file
summary_statistics = []

for idx, experiments in enumerate(experiments_by_model_and_dataset):

    # Set title
    # EX1 is the first 4 experiments, EX2 is the next 6, EX3 is the last 2
    if idx in [0, 1, 2, 3]:
        title = "EX1:"
    elif idx in [4, 5, 6, 7, 8, 9]:
        title = "EX2:"
    elif idx in [10, 11]:
        title = "EX3:"
    elif idx in [12, 13]:
        title = "EX4:"
    elif idx in [14, 15]:
        title = "EX5:"
    elif idx in [16, 17]:
        title = "EX6:"

    # Include model in title
    if "resnet" in experiments[0].lower():
        title += " ResNet"
    elif "vit" in experiments[0].lower():
        title += " ViT"

    # Save the individual results for the plots
    results_values = {}

    # Loop through the individual experiments
    for experiment in experiments:
        df = pd.read_csv(base_directory + experiment)

        # Get the 100 results
        results = df[results_column].values
        # Calculate the summary statistics
        mean = np.mean(results).round(round_digits)
        median = np.median(results).round(round_digits)
        min = np.min(results).round(round_digits)
        max = np.max(results).round(round_digits)
        q_1 = np.quantile(results, 0.25).round(round_digits)
        q_3 = np.quantile(results, 0.75).round(round_digits)
        std = np.std(results).round(round_digits)
        c_v = variation(results).round(round_digits)
        cvar = calculate_cvar(results, cvar_alpha)
        skewness = skew(results).round(round_digits)
        k_value = kurtosis(results).round(round_digits)

        # Get the model and dataset from the experiment name
        dataset = experiment.split("-")[1]
        model = experiment.split("-")[0]

        # Add the dataset to the title if it is not already there
        if dataset not in title:
            title = title + " " + dataset

        # Add whether the model is pretrained or not for the ViT 16 models
        if "vit" in model.lower() and "16" in model.lower():
            if "pretrained" in experiment:
                model += " (Pretrained)"
                if "pretrained" not in title.lower():
                    title += " (Pretrained)"
            else:
                model += " (Random)"
                if "random" not in title.lower():
                    title += " (Random)"

        # Remove the subdirectory name for the EX3 models
        if "ex3" in title.lower():
            model = model.replace("learning_rate/", "")

        # Remove the subdirectory name for the EX4 models        
        if "ex4" in title.lower():
            model = model.replace("dataset_splits/", "")

        # Add whether the model is TensorFlow or PyTorch for the EX5 experiments
        if "ex5" in title.lower():
            if "tensorflow" in experiment.lower():
                model += " (TensorFlow)"
            else:
                model += " (PyTorch)"

        # Save the results for the plot
        results_values[model] = results

        # Save the summary statistics
        summary_statistics.append([dataset, model, mean, median, min, max, q_1, q_3, std, c_v, cvar, skewness, k_value])

    # Plot the results for the experiment as a KDE histogram
    save_kde(results_values, title)
    # Plot the results for the experiment as a boxplot
    save_boxplot(results_values, title)
    # Plot the results for the experiment as a CVaR histogram
    save_cvar(results_values, title)

In [None]:
summary_statistics_df = pd.DataFrame(summary_statistics, columns=["Dataset", "Model", "Mean", "Median", "Min", "Max", "25th_percentile", "75th_percentile", "Std", "Coefficient of Variation", "CVaR 95%", "Skewness", "Kertosis"])
summary_statistics_df.to_csv("image_summary_statistics.csv", index=False)