# CIFAR-100 & MNIST Results Analysis

In [None]:
import glob
import os

import matplotlib.pyplot as plt
import pandas as pd

In [None]:
files = glob.glob("../results/*.csv")

In [None]:
data = {}

for file in files:
    if "gadam" in file:
        data[file] = pd.read_csv(
            file, usecols=["Epoch", "r", "Train_Loss", "Train_Accuracy", "Test_Accuracy", "Test_Loss"]
        )
    else:
        data[file] = pd.read_csv(file, usecols=["Epoch", "Train_Loss", "Train_Accuracy", "Test_Accuracy", "Test_Loss"])

In [None]:
model_mapping = {"lenet": "MNIST", "resnet18": "CIFAR-100"}
metrics = ["Train_Accuracy", "Test_Accuracy", "Train_Loss", "Test_Loss"]
metric_titles = ["Training Accuracy", "Testing Accuracy", "Training Loss", "Testing Loss"]

In [None]:
# Process and plot data for each model
for model_key, model_name in model_mapping.items():
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    axes = axes.flatten()  # Flatten for easier indexing

    for i, metric in enumerate(metrics):
        ax = axes[i]

        for file, df in data.items():
            if model_key in file:  # Filter files based on the model
                label = file.split("/")[-1].split("_")[0]  # Extract model name from filename, remove "../"
                if "gadam" in file:
                    for r_value in df["r"].unique():
                        subset = df[df["r"] == r_value]
                        ax.plot(subset["Epoch"], subset[metric], label=f"{label} (r={r_value})")
                else:
                    ax.plot(df["Epoch"], df[metric], label=label)

        # Set plot details
        ax.set_title(f"{metric_titles[i]} for {model_name}")
        ax.set_xlabel("Epoch")
        ax.set_ylabel(metric_titles[i])
        ax.grid()

        # Adjust axis scaling using relevant data
        model_data = [df[metric] for file, df in data.items() if model_key in file]
        all_values = pd.concat(model_data)
        ax.set_ylim(all_values.min() * 0.95, all_values.max() * 1.05)

        ax.legend()

    # Save and display the plot
    plt.suptitle(f"Metrics for {model_name}", fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.savefig(f"output/{model_name}_metrics.png")  # Save each plot as an image
    plt.show()