In [None]:
import pandas as pd
import os
import glob
import scipy.stats as stats
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import seaborn as sns
import json
import metrics

round_precision = 4

In [None]:
def calculate_correlation(csv_files, column_name):
    # Store performance and robustness metrics
    metric_results = []

    # Loop through each CSV and calculate metrics
    for path in csv_files:
        df = pd.read_csv(path)
        
        if column_name in df.columns:
            # Top-1 Accuracy
            if column_name == "test_accuracy":
                plot_file_path = "plots/rq4/image.pdf"
                file = os.path.basename(path)
                model = file.split("-")[0]
                dataset = file.split("-")[1]

                if dataset == "cifar10":
                    dataset = "CIFAR-10"
                elif dataset == "cifar100":
                    dataset = "CIFAR-100"
                elif dataset == "oxford_flowers102":
                    dataset = "Oxford Flowers"
                elif dataset == "uc_merced":
                    dataset = "UC Merced"

                # Check if the file is pretrained or randomly initialized
                if "pretrained" in file:
                    dataset = dataset + " (P)"
                else:
                    dataset = dataset + " (R)"
                
                metric_results.append({
                    "file": file,
                    "model": model,
                    "dataset": dataset,
                    "mean": (1 - metrics.mean(df[column_name].to_list())) * 100,
                    "range": metrics.range(df[column_name].to_list()) * 100,
                })

                title = "Correlation Plot for Mean Top-1 Error Rate\nand Range for Image Classification"
                x_label = "Mean Top-1 Error Percentage"
            # MAE
            elif column_name == "mae":
                plot_file_path = "plots/rq4/time_series.pdf"
                file = os.path.basename(path)
                model = file.split("_")[0]
                dataset = file.split("_")[1]

                metric_results.append({
                    "file": file,
                    "model": model,
                    "dataset": dataset,
                    "mean": metrics.mean(df[column_name].to_list()),
                    "range": metrics.range(df[column_name].to_list()),
                })

                title = "Correlation Plot for Mean MAE\nand Range for Time Series Forecasting"
                x_label = "Mean MAE"


    # Create DataFrame
    metrics_df = pd.DataFrame(metric_results)

    # Compute Pearson correlation and p-value
    r_value, p_value = metrics.correlation(metrics_df["mean"], metrics_df["range"])

    r_value = metrics.safe_round(r_value, round_precision)
    p_value = metrics.safe_round(p_value, round_precision)

    # Display results
    print("Correlation between performance (mean) and robustness (range):")
    print("Pearson r =", r_value)
    print("p-value   =", p_value)

    # Show metrics for transparency
    #print("\nPer-file performance and robustness metrics:")
    #print(metrics_df)

   
    # Create DataFrame
    #metrics_df = pd.DataFrame(metric_results)

    # Set up color palette and marker style
    model_markers = {model: marker for model, marker in zip(metrics_df["model"].unique(), ['o', 's', '^', 'D', 'P', 'X', 'v', '<', '>', 'h', 'H', '*', '+', 'x'])}
    dataset_palette = sns.color_palette("bright", len(metrics_df["dataset"].unique()))
    dataset_colors = dict(zip(metrics_df["dataset"].unique(), dataset_palette))

    # Plot correlation
    plt.figure(figsize=(8, 6))
    for _, row in metrics_df.iterrows():
        plt.scatter(
            row["mean"], row["range"],
            color=dataset_colors[row["dataset"]],
            marker=model_markers[row["model"]],
            label=f'{row["model"]}/{row["dataset"]}',
            s=200,
        )

    
    model_legend = [
        Line2D([0], [0], marker=marker, color='black', linestyle='', label=model) 
        for model, marker in sorted(model_markers.items())
    ]

    dataset_legend = [
        Line2D([0], [0], marker='o', color=color, linestyle='', label=dataset) 
        for dataset, color in sorted(dataset_colors.items())
    ]
    
    plt.legend(handles=model_legend + dataset_legend, title='Model / Dataset', fontsize=12, bbox_to_anchor=(1.05, 1), loc='upper left')

    plt.title(title, fontsize=16)
    plt.xlabel(x_label, fontsize=13)
    plt.ylabel("Range", fontsize=13)
    plt.tick_params(labelsize=13)
    plt.grid(True)
    plt.tight_layout()
    #plt.show()
    # Save the plot
    
    plt.savefig(plot_file_path, dpi=600, pad_inches=0.1, bbox_inches='tight')  # Save the plot with proper bounding


In [None]:
# Set your data directory here
data_folder = "../results/image_classification/"

# Automatically find all CSV files in the directory that contain 'PyTorch' in the filename
csv_files = [f for f in glob.glob(os.path.join(data_folder, "*.csv")) if "PyTorch" in os.path.basename(f)]

calculate_correlation(csv_files, "test_accuracy")

In [None]:
# Set your data directory here
data_folder = "../results/time_series/"

# Automatically find all CSV files in the directory that contain '96_192' in the filename
csv_files = [f for f in glob.glob(os.path.join(data_folder, "*.csv"))]

calculate_correlation(csv_files, "mae")