In [1]:

import os
import json
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pathlib import Path

# Get current working directory
cwd = Path(os.getcwd())
# Set project root by going one level up
PROJECT_ROOT = cwd.parents[1]  # Adjust this as necessary

print("PROJECT_ROOT:", PROJECT_ROOT)

SEED = 242  # Default seed for reproducibility
base_path = PROJECT_ROOT / "models/gast_ablation"

def summarize_ablation(dataset, base_dir=base_path, seed=SEED):
    variants = ["full", "no_spectral", "no_spatial", "concat"]
    metrics = ["Overall Accuracy", "Average Accuracy", "Kappa"]
    variant_labels = {
        "full": "Gate",
        "no_spectral": "No Spectral",
        "no_spatial": "No Spatial",
        "concat": "Concat"
    }

    results = []
    for variant in variants:
        metrics_path = Path(base_dir) / dataset / f"{variant}_{dataset}" / "test_results" / f"metrics_seed_{seed}.json"
        if not metrics_path.exists():
            print(f"Warning: {metrics_path} not found, skipping.")
            continue
        with open(metrics_path, "r") as f:
            data = json.load(f)
        row = {
            "Variant": variant_labels[variant],
            "OA": data.get("Overall Accuracy", np.nan) * 100,
            "AA": data.get("Average Accuracy", np.nan) * 100,
            "Kappa": data.get("Kappa", np.nan) * 100
        }
        results.append(row)

    if not results:
        print(f"No results found for {dataset}.")
        return

    # Save summary table
    out_dir = Path(base_dir) / dataset / "ablation_figures"
    out_dir.mkdir(parents=True, exist_ok=True)
    df = pd.DataFrame(results)
    df.to_csv(out_dir / "ablation_summary.csv", index=False, float_format="%.2f")
    with open(out_dir / "ablation_summary.txt", "w") as f:
        f.write(df.to_string(index=False, float_format="%.2f"))

    # Plot grouped bar chart
    x = np.arange(len(df["Variant"]))
    width = 0.25

    fig, ax = plt.subplots(figsize=(8, 5))
    ax.bar(x - width, df["OA"], width, label="OA")
    ax.bar(x, df["AA"], width, label="AA")
    ax.bar(x + width, df["Kappa"], width, label="Kappa")

    ax.set_xticks(x)
    ax.set_xticklabels(df["Variant"])
    ax.set_ylabel("Score (%)")
    ax.set_title(f"GAST Ablation Results: {dataset.replace('_', ' ')}")
    ax.legend()
    plt.tight_layout()
    plt.savefig(out_dir / "ablation_barplot.png", dpi=200)
    plt.close()

    print(f"Summary and plot saved to {out_dir}")

if __name__ == "__main__":
    # Example: summarize for Indian_Pines
    # summarize_ablation("Indian_Pines")
    # To run for all datasets, uncomment below:
    for d in ["Botswana", "Houston13", "Indian_Pines", "Kennedy_Space_Center", "Pavia_Centre", "Pavia_University", "Salinas", "SalinasA"]:
        summarize_ablation(d)

PROJECT_ROOT: /home/fesih/Desktop/ubuntu_projects/GAST2
No results found for Botswana.
No results found for Houston13.
No results found for Indian_Pines.
No results found for Kennedy_Space_Center.
No results found for Pavia_Centre.
No results found for Pavia_University.
No results found for Salinas.
No results found for SalinasA.
