In [None]:
import os
import glob
import torch
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# ------------------------------------------------------------------------------------------------
model_name = "ResNet18"         # IMPORTANT: CHANGE WITH THE MODEL WHOSE RESULTS YOU WANT TO SHOW
# ------------------------------------------------------------------------------------------------

# Directory of the results for the current model
checkpoint_dir = f"../checkpoints/{model_name}"
checkpoint_files = sorted(glob.glob(os.path.join(checkpoint_dir, "log_epoch_*.pt")))

In [None]:
# ---------------------------------------------------------------------------
# 1. LOADING DATA
# ---------------------------------------------------------------------------

# Loading data and sorted
logs = []
for file in checkpoint_files:
    ckpt = torch.load(file, map_location="cpu")
    logs.append({
        "epoch": ckpt["epoch"] + 1,

        # ACCURACY
        "val_acc_species" : ckpt["current_acc_species"],
        "val_acc_disease" : ckpt["current_acc_disease"],
        "val_acc_avg"     : ckpt["current_acc_avg"],

        "best_val_acc_species" : ckpt["best_acc_species"],
        "best_val_acc_disease" : ckpt["best_acc_disease"],
        "best_val_acc_avg"     : ckpt["best_acc_avg"],

        # F1 MACRO
        "val_f1_species" : ckpt["current_f1_species"],
        "val_f1_disease" : ckpt["current_f1_disease"],
        "val_f1_avg"     : ckpt["current_f1_macro"],

        "best_val_f1_species" : ckpt["best_f1_species"],
        "best_val_f1_disease" : ckpt["best_f1_disease"],
        "best_val_f1_avg"     : ckpt["best_f1_macro"],

        # LOSS
        "loss_species" : ckpt["current_loss_species"],
        "loss_disease" : ckpt["current_loss_disease"],
        "loss_avg"     : ckpt["current_loss"],
    })

df = pd.DataFrame(logs).sort_values("epoch").reset_index(drop=True)

# Visualise raw data
df.head()


In [None]:
# ---------------------------------------------------------------------------
# 2a. ACCURACY PLOT
# ---------------------------------------------------------------------------

plt.figure(figsize=(10, 6))

plt.plot(df["epoch"], df["val_acc_species"], label="Species Accuracy")
plt.plot(df["epoch"], df["val_acc_disease"], label="Disease Accuracy")
plt.plot(df["epoch"], df["val_acc_avg"], label="Average Accuracy", linestyle='--')

plt.xlabel("Epoch")
plt.ylabel("Validation Accuracy")
plt.title(f"Validation Accuracy Over Epochs - {model_name}")
plt.legend()
plt.grid(True)

plt.show()

In [None]:
# ---------------------------------------------------------------------------
# 2b. F1 PLOT
# ---------------------------------------------------------------------------
import matplotlib.pyplot as plt

plt.figure(figsize=(10,6))
plt.plot(df["epoch"], df["val_f1_species"], label="F1 Species")
plt.plot(df["epoch"], df["val_f1_disease"], label="F1 Disease")
plt.plot(df["epoch"], df["val_f1_avg"], label="F1 Average", linestyle='--')

plt.xlabel("Epoch");  plt.ylabel("F1 Macro")
plt.title(f"Validation F1 Macro Over Epochs - {model_name}")
plt.ylim(0.0, 1.01)
plt.legend(); plt.grid(True); plt.show()

In [None]:
# ---------------------------------------------------------------------------
# 2c. LOSS PLOT
# ---------------------------------------------------------------------------
plt.figure(figsize=(10,6))
plt.plot(df["epoch"], df["loss_species"], label="Loss Species")
plt.plot(df["epoch"], df["loss_disease"], label="Loss Disease")
plt.plot(df["epoch"], df["loss_avg"], label="Loss Average", linestyle='--')

plt.xlabel("Epoch");  plt.ylabel("Loss")
plt.title(f"Validation Loss Over Epochs - {model_name}")
plt.legend(); plt.grid(True); plt.show()

In [None]:
# ---------------------------------------------------------------------------
# BONUS F1 vs LOSS CORRELATION
# ---------------------------------------------------------------------------
plt.figure(figsize=(6,6))
plt.scatter(df["loss_avg"], df["val_f1_avg"], c=df["epoch"], cmap="viridis", s=60)
plt.colorbar(label="Epoch")
plt.xlabel("Validation Loss (avg)");  plt.ylabel("F1 Macro (avg)")
plt.title("Loss vs F1 Macro")
plt.grid(True); plt.show()

In [None]:
# ---------------------------------------------------------------------------
# 3. AVERAGE ACCURACY AND BEST ACCURACY COMPARISON
# ---------------------------------------------------------------------------

# Find the best accuracy row
best_row = df.loc[df["best_val_acc_avg"].idxmax()]
best_epoch = best_row["epoch"]
best_acc = best_row["best_val_acc_avg"]

plt.figure(figsize=(10, 6))
plt.plot(df["epoch"], df["val_acc_avg"], label="Validation Accuracy (Avg)", marker='o')
plt.plot(df["epoch"], df["best_val_acc_avg"], label="Best Accuracy (Avg so far)", linestyle='--', marker='x')

# Mark the maximum point
plt.axvline(x=best_epoch, color='red', linestyle='--', alpha=0.7, label=f"Max Accuracy at Epoch {int(best_epoch)}")
plt.scatter(best_epoch, best_acc, color='red', zorder=5)

plt.title("Average accuracy and best accuracy comparison")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.ylim(0.5, 1.01)
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()

print(f"Best Accuracy Mean: {best_acc:.4f} at Epoch {int(best_epoch)}")

In [None]:
# ---------------------------------------------------------------------------
# 4a. MEAN and STD Accuracy
# ---------------------------------------------------------------------------

metrics = ["val_acc_species", "val_acc_disease", "val_acc_avg"]
means = [df[m].mean() for m in metrics]
stds = [df[m].std() for m in metrics]

plt.figure(figsize=(8, 4))
plt.barh(metrics, means, xerr=stds, color=["#4c72b0", "#55a868", "#c44e52"], capsize=8)
plt.xlabel("Accuracy")
plt.title("Mean and standard deviation of accuracy")
plt.xlim(0.5, 1.01)
plt.grid(axis='x', linestyle='--', alpha=0.5)

for i, (m, s) in enumerate(zip(means, stds)):
    plt.text(m + 0.005, i, f"{m:.3f} ± {s:.3f}", va="bottom")

plt.show()

In [None]:
# ---------------------------------------------------------------------------
# 4b. MEAN & STD F1
# ---------------------------------------------------------------------------
metrics_f1 = ["val_f1_species", "val_f1_disease", "val_f1_avg"]
means_f1   = [df[m].mean() for m in metrics_f1]
stds_f1    = [df[m].std()  for m in metrics_f1]

plt.figure(figsize=(8,4))
plt.barh(metrics_f1, means_f1, xerr=stds_f1,
         color=["#4c72b0","#55a868","#c44e52"], capsize=8)
plt.xlabel("F1 Macro"); plt.title("Mean ± Std of F1 Macro")
plt.xlim(0.0, 1.01); plt.grid(axis='x', linestyle='--', alpha=.4)

for i,(m,s) in enumerate(zip(means_f1,stds_f1)):
    plt.text(m+0.01, i, f"{m:.3f} ± {s:.3f}", va="center")
plt.show()

In [None]:
# ---------------------------------------------------------------------------
# 5a. ACCURACY IMPROVEMENT OVER EPOCHS
# ---------------------------------------------------------------------------

# Improvement computation epoch by epoch
df["delta_avg_acc"] = df["val_acc_avg"].diff()  # differenza con epoca precedente

# Deleting first raw (NaN difference)
delta_df = df.dropna(subset=["delta_avg_acc"])

# Descriptive statistics
mean_improvement = delta_df["delta_avg_acc"].mean()
std_improvement = delta_df["delta_avg_acc"].std()

print(f"Average improvement over epochs: {mean_improvement:.4f}")
print(f"Standard deviation: {std_improvement:.4f}")

plt.figure(figsize=(8, 4))
plt.plot(delta_df["epoch"], delta_df["delta_avg_acc"], marker="o", linestyle="-", color="steelblue")
plt.axhline(mean_improvement, color="red", linestyle="--", label=f"Average improvement: ({mean_improvement:.4f})")
plt.axhline(0, color="gray", linestyle=":")
plt.title("Accuracy improvement over epochs")
plt.xlabel("EpoEpochca")
plt.ylabel("Average accuracy")
plt.legend()
plt.grid(True, linestyle="--", alpha=0.5)
plt.show()

In [None]:
# ---------------------------------------------------------------------------
# 5b. F1 MACRO IMPROVEMENT OVER EPOCHS
# ---------------------------------------------------------------------------
df["delta_f1_avg"] = df["val_f1_avg"].diff()
delta_f1_df = df.dropna(subset=["delta_f1_avg"])

mean_imp = delta_f1_df["delta_f1_avg"].mean()
std_imp  = delta_f1_df["delta_f1_avg"].std()

print(f"Average F1 improvement: {mean_imp:.4f}  (std {std_imp:.4f})")

plt.figure(figsize=(8,4))
plt.plot(delta_f1_df["epoch"], delta_f1_df["delta_f1_avg"], marker="o")
plt.axhline(mean_imp, color="red", linestyle="--", label=f"Mean Δ {mean_imp:.4f}")
plt.axhline(0, color="gray", linestyle=":")
plt.title("F1 Macro improvement over epochs")
plt.xlabel("Epoch"); plt.ylabel("Δ F1 Macro")
plt.legend(); plt.grid(True, linestyle="--", alpha=.4); plt.show()

In [None]:
# ---------------------------------------------------------------------------
# 6. COMPARISON BETWEEN ALL MODELS - LOAD
# ---------------------------------------------------------------------------

model_names = ["ResNet18","ViT","CLIPResNet","CLIPViT","DINOv2"]
all_logs = []

for mdl in model_names:
    ckpt_dir = f"../checkpoints/{mdl}"
    files = sorted(glob.glob(os.path.join(ckpt_dir, "log_epoch_*.pt")))
    for f in files:
        ck = torch.load(f, map_location="cpu")
        all_logs.append({
            "model"   : mdl,
            "epoch"   : ck["epoch"]+1,
            "val_f1_avg" : ck["current_f1_macro"],
            "loss_avg"   : ck["current_loss"],
            "val_acc_avg": ck["current_acc_avg"]
        })

df_all = pd.DataFrame(all_logs).sort_values(["model","epoch"]).reset_index(drop=True)

In [None]:
# ---------------------------------------------------------------------------
# 7a. COMPARISON BETWEEN ALL MODELS - PLOT Accuracy
# ---------------------------------------------------------------------------

plt.figure(figsize=(10, 6))
sns.lineplot(data=df_all, x="epoch", y="val_acc_avg", hue="model", marker="o")
plt.title("Models comparison: average accuracy over time")
plt.xlabel("Epoch")
plt.ylabel("Average accuracy")
plt.ylim(0.5, 1.01)
plt.grid(True, linestyle="--", alpha=0.5)
plt.legend(title="Model")
plt.show()

In [None]:
# ---------------------------------------------------------------------------
# 7b. COMPARISON BETWEEN ALL MODELS – PLOT F1
# ---------------------------------------------------------------------------
import seaborn as sns

plt.figure(figsize=(10,6))
sns.lineplot(data=df_all, x="epoch", y="val_f1_avg", hue="model", marker="o")
plt.title("Models comparison: F1 Macro average over time")
plt.xlabel("Epoch"); plt.ylabel("F1 Macro (avg)")
plt.ylim(0.0, 1.01); plt.grid(True, linestyle="--", alpha=.4); plt.legend(title="Model")
plt.show()

In [None]:
# ---------------------------------------------------------------------------
# 8a. SUMMARY FOR ALL MODELS Accuracy
# ---------------------------------------------------------------------------

summary = df_all.groupby("model")["val_acc_avg"].agg(["max", "mean", "std"]).sort_values("max", ascending=False)
summary.rename(columns={"max": "Max Accuracy", "mean": "Mean Accuracy", "std": "Std Dev"}, inplace=True)
summary

In [None]:
# ---------------------------------------------------------------------------
# 8b. SUMMARY FOR ALL MODELS F1
# ---------------------------------------------------------------------------
summary_f1 = (df_all.groupby("model")["val_f1_avg"]
              .agg(["max","mean","std"])
              .sort_values("max", ascending=False)
              .rename(columns={"max":"Max F1","mean":"Mean F1","std":"Std F1"}))
summary_f1