In [None]:
import os
import json
import re
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader
from torchvision.models import resnet18
import pandas as pd
import numpy as np
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

# ==========================================
# 1. KONFIGURASJON
# ==========================================
DATA_DIR = Path("./data")
ARTIFACTS_DIR = Path("./artifacts/analysis")
CKPT_DIR = Path("./checkpoints")
JSON_PATH = Path("human_similarity_ranking.json")

ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 128
NUM_CLASSES = 10
CIFAR10_CLASSES = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']

# Liste over ALLE modeller (50ep og 300ep)
# N√∏klene er formatert som "Navn - Epoker" for enklere parsing senere
POSSIBLE_MODELS = {
    "Probe (Teacher)":                    "probe_best.pth",
    
    "Baseline (Exp 1) - 50ep":            "02_exp1_baseline_epoch50.pth",
    "Baseline (Exp 1) - 300ep":           "02_exp1_baseline_best.pth",
    
    "SBLS (Exp 2) - 50ep":                "03_exp2_sbls_50ep_best.pth",
    "SBLS (Exp 2) - 300ep":               "03_exp2_sbls_300ep_best.pth",
    
    "Static KL (Exp 3a) - 50ep":          "04_exp3a_static_probs_50ep_best.pth",
    "Static KL (Exp 3a) - 300ep":         "04_exp3a_static_probs_300ep_best.pth",
    
    "Static WMSE (Exp 3b) - 50ep":        "05_exp3b_static_logits_50ep_best.pth",
    "Static WMSE (Exp 3b) - 300ep":       "05_exp3b_static_logits_300ep_best.pth",
    
    "Dynamic Basic (Exp 4a) - 50ep":      "06_exp4a_dynamic_basic_50ep_best.pth",
    "Dynamic Basic (Exp 4a) - 300ep":     "06_exp4a_dynamic_basic_300ep_best.pth",
    
    "Dynamic Boost (Exp 4b) - 50ep":      "07_exp4b_dynamic_boost_50ep_best.pth",
    "Dynamic Boost (Exp 4b) - 300ep":     "07_exp4b_dynamic_boost_300ep_best.pth",
    
    "Capped 0.8 (Exp 5a) - 50ep":         "08_exp5a_capped_run_50ep_cap0.8_best.pth",
    "Capped 0.8 (Exp 5a) - 300ep":        "08_exp5a_capped_run_300ep_cap0.8_best.pth",
    
    "Swap+Capped 0.8 (Exp 5b) - 50ep":    "09_exp5b_capped_swap_run_50ep_cap0.8_best.pth",
    "Swap+Capped 0.8 (Exp 5b) - 300ep":   "09_exp5b_capped_swap_run_300ep_cap0.8_best.pth"
}

# ==========================================
# 2. DATA & MODELL DEFINISJONER
# ==========================================
def make_cifar_resnet18(num_classes=10):
    m = resnet18(weights=None)
    m.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    m.maxpool = nn.Identity()
    m.fc = nn.Linear(m.fc.in_features, num_classes)
    return m

def get_test_loader():
    test_tf = T.Compose([
        T.ToTensor(),
        T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    test_ds = torchvision.datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=test_tf)
    # Shuffle=False er KRITISK for "Fair Comparison" (indeks-matching)
    return DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# ==========================================
# 3. HUMAN SIMILARITY MATRIX LOADER
# ==========================================
def load_human_similarity_matrix(json_path):
    if not json_path.exists():
        print(f"‚ö†Ô∏è  ADVARSEL: Fant ikke {json_path}. Lager dummy-matrise (0 poeng).")
        return torch.zeros((NUM_CLASSES, NUM_CLASSES), device=DEVICE)
    
    with open(json_path, "r") as f:
        data = json.load(f)
    
    wins_data = data.get("wins", {})
    score_matrix = torch.zeros((NUM_CLASSES, NUM_CLASSES), device=DEVICE)
    
    for true_idx, true_name in enumerate(CIFAR10_CLASSES):
        if true_name not in wins_data: continue
        for pred_idx, pred_name in enumerate(CIFAR10_CLASSES):
            if true_idx == pred_idx: continue # Ignorer diagonalen (korrekt svar)
            
            # Hent poeng
            points = wins_data[true_name].get(pred_name, 0)
            score_matrix[true_idx, pred_idx] = points

    print(f"Human Similarity Matrix lastet fra {json_path}.")
    return score_matrix

# ==========================================
# 4. PREDICTION COLLECTOR
# ==========================================
def collect_predictions(models_dict, loader):
    all_preds = {}
    targets = []
    
    # Hent targets f√∏rst
    print("Laster testdata (fasit)...")
    for _, y in loader:
        targets.append(y.to(DEVICE))
    targets = torch.cat(targets)
    
    # Iterer over modeller
    print("\nStarter inferens p√• alle modeller...")
    for name, filename in tqdm(models_dict.items(), desc="Modeller"):
        ckpt_path = CKPT_DIR / filename
        
        if not ckpt_path.exists():
            # Fallback hvis stien er litt annerledes eller bruker la inn full path
            if Path(filename).exists():
                ckpt_path = Path(filename)
            else:
                print(f"‚ö†Ô∏è  Hopper over {name}: Fant ikke {ckpt_path}")
                continue
        
        # Last modell
        model = make_cifar_resnet18(NUM_CLASSES).to(DEVICE)
        try:
            model.load_state_dict(torch.load(ckpt_path, map_location=DEVICE))
        except Exception as e:
            print(f"‚ö†Ô∏è  Kunne ikke laste {name}: {e}")
            continue
            
        model.eval()
        preds_list = []
        
        with torch.no_grad():
            for x, _ in loader:
                x = x.to(DEVICE)
                logits = model(x)
                preds_list.append(logits.argmax(dim=1))
        
        all_preds[name] = torch.cat(preds_list)
        
    return all_preds, targets

# ==========================================
# 5. ANALYSE-LOGIKK
# ==========================================
def run_full_analysis():
    # 1. Oppsett
    score_matrix = load_human_similarity_matrix(JSON_PATH)
    loader = get_test_loader()
    
    # 2. Samle data
    model_preds, targets = collect_predictions(POSSIBLE_MODELS, loader)
    
    if not model_preds:
        print("Ingen modeller ble evaluert. Sjekk filstier.")
        return

    # --- 3. BEREGN FAIR COMPARISON MASK (FELLESMENGDE AV FEIL) ---
    # Vi finner bildene der ALLE modellene tok feil.
    # Dette sikrer at vi sammenligner "epler med epler" p√• vanskelige bilder.
    n_samples = targets.size(0)
    common_failure_mask = torch.ones(n_samples, dtype=torch.bool, device=DEVICE)
    
    for name, preds in model_preds.items():
        incorrect_mask = (preds != targets)
        common_failure_mask = common_failure_mask & incorrect_mask
        
    num_common = common_failure_mask.sum().item()
    print(f"\nAntall bilder der ALLE {len(model_preds)} modeller tok feil: {num_common}")
    
    hard_targets = None
    if num_common > 0:
        hard_targets = targets[common_failure_mask]
    else:
        print("‚ö†Ô∏è Ingen felles feil funnet! 'Fair Score' vil v√¶re tom.")

    # --- 4. SAMLE RESULTATER ---
    results_list = []

    for name, preds in model_preds.items():
        # A. Generell Accuracy
        acc = (preds == targets).float().mean().item() * 100
        
        # B. Fair Human Score (Kun p√• felles feil)
        fair_score = 0.0
        if num_common > 0:
            hard_preds = preds[common_failure_mask]
            # Hent poeng fra matrisen for disse vanskelige tilfellene
            scores = score_matrix[hard_targets, hard_preds]
            fair_score = scores.mean().item()
            
        # Parse navn for plotting (Family vs Epochs)
        if "Probe" in name:
            family = "Probe"
            epochs = "Teacher"
        else:
            # Eks: "Baseline (Exp 1) - 300ep"
            parts = name.split(" - ")
            family = parts[0]
            epochs = parts[1] if len(parts) > 1 else "Unknown"

        results_list.append({
            "Full Name": name,
            "Family": family,
            "Epochs": epochs,
            "Accuracy": acc,
            "Fair Human Score": fair_score
        })

    df = pd.DataFrame(results_list)
    
    # Sorter etter Fair Human Score (h√∏yest er best)
    df_sorted = df.sort_values(by="Fair Human Score", ascending=False)

    # --- 5. UTSKRIFT ---
    print("\n" + "="*80)
    print("RESULTAT: ACCURACY VS FAIR HUMAN SCORE")
    print("-" * 80)
    # Vis relevante kolonner
    print(df_sorted[["Full Name", "Accuracy", "Fair Human Score"]].to_string(index=False, float_format="%.2f"))
    print("="*80)
    
    # Lagre til CSV
    df_sorted.to_csv(ARTIFACTS_DIR / "analysis_full_comparison.csv", index=False)
    print(f"Data lagret til {ARTIFACTS_DIR / 'analysis_full_comparison.csv'}")

    # --- 6. PLOTTING: ACCURACY VS FAIR HUMAN SCORE ---
    if not df.empty and num_common > 0:
        plt.figure(figsize=(12, 8))
        
        # Vi bruker Seaborn for √• h√•ndtere farger (Hue) og former (Style) automatisk
        sns.scatterplot(
            data=df, 
            x="Accuracy", 
            y="Fair Human Score", 
            hue="Family",       # Farge basert p√• modell-familie (Baseline, SBLS, etc.)
            style="Epochs",     # Form basert p√• 50ep vs 300ep
            s=150,              # St√∏rrelse p√• punkter
            alpha=0.9,
            palette="bright"
        )
        
        # Annoter Probe spesielt hvis den er med
        probe_row = df[df["Family"] == "Probe"]
        if not probe_row.empty:
            plt.text(
                probe_row["Accuracy"].values[0], 
                probe_row["Fair Human Score"].values[0] + 0.01, 
                "TEACHER", 
                fontweight='bold'
            )

        plt.title(f"Trade-off: Accuracy vs Human Alignment\n(Fair Score based on {num_common} common failures)", fontsize=14)
        plt.xlabel("Top-1 Accuracy (%)", fontsize=12)
        plt.ylabel("Fair Human Score (Avg on Common Failures)", fontsize=12)
        plt.grid(True, linestyle="--", alpha=0.5)
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
        plt.tight_layout()
        
        plot_path = ARTIFACTS_DIR / "tradeoff_acc_vs_fairscore.png"
        plt.savefig(plot_path, dpi=150)
        print(f"Plot lagret til: {plot_path}")
        plt.show()
    else:
        print("Kan ikke plotte: Mangler data eller ingen felles feil.")

if __name__ == "__main__":
    run_full_analysis()

In [None]:
# ==========================================
# üßπ OPPRYDDING & RESTART üßπ
# ==========================================
import torch
import gc
from IPython import get_ipython

# 1. Pr√∏v √• slette store objekter manuelt f√∏rst
try:
    del model, optimizer, scaler, train_loader, test_loader
except NameError:
    pass

# 2. Kj√∏r Garbage Collection og t√∏m GPU-cache
gc.collect()
torch.cuda.empty_cache()

# 3. Nuke everything! (Dette sletter ALLE variabler i minnet)
# Dette sikrer at neste celle starter med blanke ark.
get_ipython().run_line_magic('reset', '-sf')

print("‚úÖ Minne t√∏mt og variabler nullstilt. Klar for neste eksperiment.")

In [None]:
import os
import json
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader
from torchvision.models import resnet18
import pandas as pd
import numpy as np
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

# ==========================================
# 1. KONFIGURASJON
# ==========================================
DATA_DIR = Path("./data")
ARTIFACTS_DIR = Path("./artifacts/analysis")
CKPT_DIR = Path("./checkpoints")
JSON_PATH = Path("human_similarity_ranking.json")

ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 128
NUM_CLASSES = 10
CIFAR10_CLASSES = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']

# ==========================================
# DEFINISJON AV ALLE MODELLER
# ==========================================
# Format: "Visningsnavn - Epoker" : "Filnavn i checkpoints-mappen"
POSSIBLE_MODELS = {
    # --- Referanser ---
    "Probe (Teacher)":                    "probe_best.pth",
    
    # --- Exp 1: Baseline ---
    "Baseline (Exp 1) - 50ep":            "02_exp1_baseline_epoch50.pth",
    "Baseline (Exp 1) - 300ep":           "02_exp1_baseline_best.pth",
    
    # --- Exp 2: SBLS ---
    "SBLS (Exp 2) - 50ep":                "03_exp2_sbls_50ep_best.pth",
    "SBLS (Exp 2) - 300ep":               "03_exp2_sbls_300ep_best.pth",
    
    # --- Exp 3: Static Targets ---
    "Static KL (Exp 3a) - 50ep":          "04_exp3a_static_probs_50ep_best.pth",
    "Static KL (Exp 3a) - 300ep":         "04_exp3a_static_probs_300ep_best.pth",
    "Static WMSE (Exp 3b) - 50ep":        "05_exp3b_static_logits_50ep_best.pth",
    "Static WMSE (Exp 3b) - 300ep":       "05_exp3b_static_logits_300ep_best.pth",
    
    # --- Exp 4: Dynamic Targets ---
    "Dynamic Basic (Exp 4a) - 50ep":      "06_exp4a_dynamic_basic_50ep_best.pth",
    "Dynamic Basic (Exp 4a) - 300ep":     "06_exp4a_dynamic_basic_300ep_best.pth",
    "Dynamic Boost (Exp 4b) - 50ep":      "07_exp4b_dynamic_boost_50ep_best.pth",
    "Dynamic Boost (Exp 4b) - 300ep":     "07_exp4b_dynamic_boost_300ep_best.pth",
    
    # --- Exp 5a: Capped (Alle caps) ---
    "Capped 0.6 (Exp 5a) - 50ep":         "08_exp5a_capped_run_50ep_cap0.6_best.pth",
    "Capped 0.6 (Exp 5a) - 300ep":        "08_exp5a_capped_run_300ep_cap0.6_best.pth",
    
    "Capped 0.8 (Exp 5a) - 50ep":         "08_exp5a_capped_run_50ep_cap0.8_best.pth",
    "Capped 0.8 (Exp 5a) - 300ep":        "08_exp5a_capped_run_300ep_cap0.8_best.pth",
    
    "Capped 0.95 (Exp 5a) - 50ep":        "08_exp5a_capped_run_50ep_cap0.95_best.pth",
    "Capped 0.95 (Exp 5a) - 300ep":       "08_exp5a_capped_run_300ep_cap0.95_best.pth",
    
    # --- Exp 5b: Capped + Swap (Alle caps) ---
    "Swap 0.6 (Exp 5b) - 50ep":           "09_exp5b_capped_swap_run_50ep_cap0.6_best.pth",
    "Swap 0.6 (Exp 5b) - 300ep":          "09_exp5b_capped_swap_run_300ep_cap0.6_best.pth",
    
    "Swap 0.8 (Exp 5b) - 50ep":           "09_exp5b_capped_swap_run_50ep_cap0.8_best.pth",
    "Swap 0.8 (Exp 5b) - 300ep":          "09_exp5b_capped_swap_run_300ep_cap0.8_best.pth",
    
    "Swap 0.95 (Exp 5b) - 50ep":          "09_exp5b_capped_swap_run_50ep_cap0.95_best.pth",
    "Swap 0.95 (Exp 5b) - 300ep":         "09_exp5b_capped_swap_run_300ep_cap0.95_best.pth"
}

# ==========================================
# 2. DATA & MODELL DEFINISJONER
# ==========================================
def make_cifar_resnet18(num_classes=10):
    m = resnet18(weights=None)
    m.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    m.maxpool = nn.Identity()
    m.fc = nn.Linear(m.fc.in_features, num_classes)
    return m

def get_test_loader():
    test_tf = T.Compose([
        T.ToTensor(),
        T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    test_ds = torchvision.datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=test_tf)
    # Shuffle=False er KRITISK for "Fair Comparison" (indeks-matching)
    return DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# ==========================================
# 3. HUMAN SIMILARITY MATRIX LOADER
# ==========================================
def load_human_similarity_matrix(json_path):
    if not json_path.exists():
        print(f"‚ö†Ô∏è  ADVARSEL: Fant ikke {json_path}. Lager dummy-matrise (0 poeng).")
        return torch.zeros((NUM_CLASSES, NUM_CLASSES), device=DEVICE)
    
    with open(json_path, "r") as f:
        data = json.load(f)
    
    wins_data = data.get("wins", {})
    score_matrix = torch.zeros((NUM_CLASSES, NUM_CLASSES), device=DEVICE)
    
    for true_idx, true_name in enumerate(CIFAR10_CLASSES):
        if true_name not in wins_data: continue
        for pred_idx, pred_name in enumerate(CIFAR10_CLASSES):
            if true_idx == pred_idx: continue # Ignorer diagonalen (korrekt svar)
            
            # Hent poeng
            points = wins_data[true_name].get(pred_name, 0)
            score_matrix[true_idx, pred_idx] = points

    print(f"Human Similarity Matrix lastet fra {json_path}.")
    return score_matrix

# ==========================================
# 4. PREDICTION COLLECTOR
# ==========================================
def collect_predictions(models_dict, loader):
    all_preds = {}
    targets = []
    
    # Hent targets f√∏rst
    print("Laster testdata (fasit)...")
    for _, y in loader:
        targets.append(y.to(DEVICE))
    targets = torch.cat(targets)
    
    # Iterer over modeller
    print("\nStarter inferens p√• alle modeller...")
    for name, filename in tqdm(models_dict.items(), desc="Modeller"):
        ckpt_path = CKPT_DIR / filename
        
        if not ckpt_path.exists():
            # Fallback hvis stien er litt annerledes eller bruker la inn full path
            if Path(filename).exists():
                ckpt_path = Path(filename)
            else:
                print(f"‚ö†Ô∏è  Hopper over {name}: Fant ikke {ckpt_path}")
                continue
        
        # Last modell
        model = make_cifar_resnet18(NUM_CLASSES).to(DEVICE)
        try:
            model.load_state_dict(torch.load(ckpt_path, map_location=DEVICE))
        except Exception as e:
            print(f"‚ö†Ô∏è  Kunne ikke laste {name}: {e}")
            continue
            
        model.eval()
        preds_list = []
        
        with torch.no_grad():
            for x, _ in loader:
                x = x.to(DEVICE)
                logits = model(x)
                preds_list.append(logits.argmax(dim=1))
        
        all_preds[name] = torch.cat(preds_list)
        
    return all_preds, targets

# ==========================================
# 5. ANALYSE-LOGIKK
# ==========================================
def run_full_analysis():
    # 1. Oppsett
    score_matrix = load_human_similarity_matrix(JSON_PATH)
    loader = get_test_loader()
    
    # 2. Samle data
    model_preds, targets = collect_predictions(POSSIBLE_MODELS, loader)
    
    if not model_preds:
        print("Ingen modeller ble evaluert. Sjekk filstier.")
        return

    # --- 3. BEREGN FAIR COMPARISON MASK (FELLESMENGDE AV FEIL) ---
    # Vi finner bildene der ALLE modellene tok feil.
    n_samples = targets.size(0)
    common_failure_mask = torch.ones(n_samples, dtype=torch.bool, device=DEVICE)
    
    for name, preds in model_preds.items():
        incorrect_mask = (preds != targets)
        common_failure_mask = common_failure_mask & incorrect_mask
        
    num_common = common_failure_mask.sum().item()
    print(f"\nAntall bilder der ALLE {len(model_preds)} modeller tok feil: {num_common}")
    
    hard_targets = None
    if num_common > 0:
        hard_targets = targets[common_failure_mask]
    else:
        print("‚ö†Ô∏è Ingen felles feil funnet! 'Fair Score' vil v√¶re tom.")

    # --- 4. SAMLE RESULTATER ---
    results_list = []

    for name, preds in model_preds.items():
        # A. Generell Accuracy
        acc = (preds == targets).float().mean().item() * 100
        
        # B. Fair Human Score (Kun p√• felles feil)
        fair_score = 0.0
        if num_common > 0:
            hard_preds = preds[common_failure_mask]
            # Hent poeng fra matrisen for disse vanskelige tilfellene
            scores = score_matrix[hard_targets, hard_preds]
            fair_score = scores.mean().item()
            
        # Parse navn for plotting (Family vs Epochs)
        # Eks: "Capped 0.8 (Exp 5a) - 300ep"
        if "Probe" in name:
            family = "Probe"
            epochs = "Teacher"
        else:
            parts = name.split(" - ")
            family = parts[0]
            epochs = parts[1] if len(parts) > 1 else "Unknown"

        results_list.append({
            "Full Name": name,
            "Family": family,
            "Epochs": epochs,
            "Accuracy": acc,
            "Fair Human Score": fair_score
        })

    df = pd.DataFrame(results_list)
    
    # Sorter etter Fair Human Score (h√∏yest er best)
    df_sorted = df.sort_values(by="Fair Human Score", ascending=False)

    # --- 5. UTSKRIFT ---
    print("\n" + "="*80)
    print("RESULTAT: ACCURACY VS FAIR HUMAN SCORE")
    print(f"Basert p√• {num_common} vanskelige bilder som ingen modeller klarte.")
    print("-" * 80)
    print(df_sorted[["Full Name", "Accuracy", "Fair Human Score"]].to_string(index=False, float_format="%.2f"))
    print("="*80)
    
    # Lagre til CSV
    df_sorted.to_csv(ARTIFACTS_DIR / "analysis_full_comparison.csv", index=False)
    print(f"Data lagret til {ARTIFACTS_DIR / 'analysis_full_comparison.csv'}")

    # --- 6. PLOTTING: ACCURACY VS FAIR HUMAN SCORE ---
    if not df.empty and num_common > 0:
        plt.figure(figsize=(14, 10)) # St√∏rre figur for √• f√• plass til alt
        
        sns.scatterplot(
            data=df, 
            x="Accuracy", 
            y="Fair Human Score", 
            hue="Family",       # Farge = Modelltype
            style="Epochs",     # Form = 50ep vs 300ep
            s=200,              # St√∏rre punkter
            alpha=0.8,
            palette="tab20"     # Fargepalett med mange farger
        )
        
        # Annoter Probe
        probe_row = df[df["Family"] == "Probe"]
        if not probe_row.empty:
            plt.text(
                probe_row["Accuracy"].values[0], 
                probe_row["Fair Human Score"].values[0] + 0.01, 
                "TEACHER", 
                fontweight='bold',
                color='black'
            )

        plt.title(f"Trade-off: Accuracy vs Human Alignment\n(Fair Score based on {num_common} common failures)", fontsize=16)
        plt.xlabel("Top-1 Accuracy (%)", fontsize=14)
        plt.ylabel("Fair Human Score (Avg on Common Failures)", fontsize=14)
        plt.grid(True, linestyle="--", alpha=0.5)
        
        # Flytt legenden ut av plottet
        plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0., title="Models")
        plt.tight_layout()
        
        plot_path = ARTIFACTS_DIR / "tradeoff_acc_vs_fairscore_ALL.png"
        plt.savefig(plot_path, dpi=150)
        print(f"Plot lagret til: {plot_path}")
        plt.show()
    else:
        print("Kan ikke plotte: Mangler data eller ingen felles feil.")

if __name__ == "__main__":
    run_full_analysis()

In [None]:
# ==========================================
# üßπ OPPRYDDING & RESTART üßπ
# ==========================================
import torch
import gc
from IPython import get_ipython

# 1. Pr√∏v √• slette store objekter manuelt f√∏rst
try:
    del model, optimizer, scaler, train_loader, test_loader
except NameError:
    pass

# 2. Kj√∏r Garbage Collection og t√∏m GPU-cache
gc.collect()
torch.cuda.empty_cache()

# 3. Nuke everything! (Dette sletter ALLE variabler i minnet)
# Dette sikrer at neste celle starter med blanke ark.
get_ipython().run_line_magic('reset', '-sf')

print("‚úÖ Minne t√∏mt og variabler nullstilt. Klar for neste eksperiment.")

In [None]:
import os
import json
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader
from torchvision.models import resnet18
import pandas as pd
import numpy as np
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

# ==========================================
# 1. KONFIGURASJON
# ==========================================
DATA_DIR = Path("./data")
# Ny mappe for dette eksperimentet (9 poeng)
ARTIFACTS_DIR = Path("./artifacts/human_weighted_accuracy_9p")
CKPT_DIR = Path("./checkpoints")
JSON_PATH = Path("human_similarity_ranking.json")

ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 128
NUM_CLASSES = 10
CIFAR10_CLASSES = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']

# ==========================================
# 2. DEFINISJON AV ALLE MODELLER
# ==========================================
POSSIBLE_MODELS = {
    # --- Referanser ---
    "Probe (Teacher)":                    "probe_best.pth",
    
    # --- Exp 1: Baseline ---
    "Baseline (Exp 1) - 50ep":            "02_exp1_baseline_epoch50.pth",
    "Baseline (Exp 1) - 300ep":           "02_exp1_baseline_best.pth",
    
    # --- Exp 2: SBLS ---
    "SBLS (Exp 2) - 50ep":                "03_exp2_sbls_50ep_best.pth",
    "SBLS (Exp 2) - 300ep":               "03_exp2_sbls_300ep_best.pth",
    
    # --- Exp 3: Static Targets ---
    "Static KL (Exp 3a) - 50ep":          "04_exp3a_static_probs_50ep_best.pth",
    "Static KL (Exp 3a) - 300ep":         "04_exp3a_static_probs_300ep_best.pth",
    "Static WMSE (Exp 3b) - 50ep":        "05_exp3b_static_logits_50ep_best.pth",
    "Static WMSE (Exp 3b) - 300ep":       "05_exp3b_static_logits_300ep_best.pth",
    
    # --- Exp 4: Dynamic Targets ---
    "Dynamic Basic (Exp 4a) - 50ep":      "06_exp4a_dynamic_basic_50ep_best.pth",
    "Dynamic Basic (Exp 4a) - 300ep":     "06_exp4a_dynamic_basic_300ep_best.pth",
    "Dynamic Boost (Exp 4b) - 50ep":      "07_exp4b_dynamic_boost_50ep_best.pth",
    "Dynamic Boost (Exp 4b) - 300ep":     "07_exp4b_dynamic_boost_300ep_best.pth",
    
    # --- Exp 5a: Capped (Alle caps) ---
    "Capped 0.6 (Exp 5a) - 50ep":         "08_exp5a_capped_run_50ep_cap0.6_best.pth",
    "Capped 0.6 (Exp 5a) - 300ep":        "08_exp5a_capped_run_300ep_cap0.6_best.pth",
    
    "Capped 0.8 (Exp 5a) - 50ep":         "08_exp5a_capped_run_50ep_cap0.8_best.pth",
    "Capped 0.8 (Exp 5a) - 300ep":        "08_exp5a_capped_run_300ep_cap0.8_best.pth",
    
    "Capped 0.95 (Exp 5a) - 50ep":        "08_exp5a_capped_run_50ep_cap0.95_best.pth",
    "Capped 0.95 (Exp 5a) - 300ep":       "08_exp5a_capped_run_300ep_cap0.95_best.pth",
    
    # --- Exp 5b: Capped + Swap (Alle caps) ---
    "Swap 0.6 (Exp 5b) - 50ep":           "09_exp5b_capped_swap_run_50ep_cap0.6_best.pth",
    "Swap 0.6 (Exp 5b) - 300ep":          "09_exp5b_capped_swap_run_300ep_cap0.6_best.pth",
    
    "Swap 0.8 (Exp 5b) - 50ep":           "09_exp5b_capped_swap_run_50ep_cap0.8_best.pth",
    "Swap 0.8 (Exp 5b) - 300ep":          "09_exp5b_capped_swap_run_300ep_cap0.8_best.pth",
    
    "Swap 0.95 (Exp 5b) - 50ep":          "09_exp5b_capped_swap_run_50ep_cap0.95_best.pth",
    "Swap 0.95 (Exp 5b) - 300ep":         "09_exp5b_capped_swap_run_300ep_cap0.95_best.pth"
}

# ==========================================
# 3. OPPSETT AV DATA & MODELL
# ==========================================
def make_cifar_resnet18(num_classes=10):
    m = resnet18(weights=None)
    m.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    m.maxpool = nn.Identity()
    m.fc = nn.Linear(m.fc.in_features, num_classes)
    return m

def get_test_loader():
    test_tf = T.Compose([
        T.ToTensor(),
        T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    test_ds = torchvision.datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=test_tf)
    return DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# ==========================================
# 4. LASTE SCORINGSMATRISE (MED 9 POENG FOR RETT SVAR)
# ==========================================
def load_weighted_score_matrix(json_path):
    """
    Laster human_similarity_ranking.json.
    - Diagonalen (Rett svar) settes til 9.0 poeng.
    - Feil svar f√•r poeng fra JSON-filen (0-10, men i praksis lavere).
    """
    # Start med nuller
    score_matrix = torch.zeros((NUM_CLASSES, NUM_CLASSES), device=DEVICE)
    
    # 1. Sett diagonalen til 9 (Ny Max score for rett svar)
    for i in range(NUM_CLASSES):
        score_matrix[i, i] = 9.0  # <--- ENDRING HER
    
    # 2. Fyll inn poeng for feil svar fra JSON
    if not json_path.exists():
        print(f"‚ö†Ô∏è  ADVARSEL: Fant ikke {json_path}. Bruker kun diagonal (Acc).")
        return score_matrix

    with open(json_path, "r") as f:
        data = json.load(f)
    
    wins_data = data.get("wins", {})
    
    for true_idx, true_name in enumerate(CIFAR10_CLASSES):
        if true_name not in wins_data: continue
        for pred_idx, pred_name in enumerate(CIFAR10_CLASSES):
            if true_idx == pred_idx: continue # Allerede satt til 9
            
            # Hent poeng for feilen
            points = wins_data[true_name].get(pred_name, 0)
            score_matrix[true_idx, pred_idx] = points

    print(f"‚úÖ Weighted Score Matrix lastet. Rett svar = 9p, Feil svar = JSON poeng.")
    return score_matrix

# ==========================================
# 5. EVALUERINGSL√òKKE
# ==========================================
def evaluate_models(models_dict, loader, score_matrix):
    results = []
    
    print(f"\nStarter evaluering av {len(models_dict)} modeller...")
    print("Metrikk: Total Weighted Score (Avg per bilde, Max=9)")
    
    # Hent targets (fasit) en gang for alle
    all_targets = []
    for _, y in loader:
        all_targets.append(y.to(DEVICE))
    all_targets = torch.cat(all_targets)
    
    for name, filename in tqdm(models_dict.items(), desc="Evaluerer"):
        ckpt_path = CKPT_DIR / filename
        
        # Sjekk om fil finnes
        if not ckpt_path.exists():
            if Path(filename).exists(): ckpt_path = Path(filename)
            else: continue
            
        # Last modell
        model = make_cifar_resnet18(NUM_CLASSES).to(DEVICE)
        try:
            model.load_state_dict(torch.load(ckpt_path, map_location=DEVICE))
        except Exception:
            continue
            
        model.eval()
        all_preds = []
        
        with torch.no_grad():
            for x, _ in loader:
                x = x.to(DEVICE)
                logits = model(x)
                all_preds.append(logits.argmax(dim=1))
        
        all_preds = torch.cat(all_preds)
        
        # --- BEREGN SCORE ---
        # score_matrix[true_class, pred_class] gir poengene for hvert bilde
        scores = score_matrix[all_targets, all_preds]
        
        # Gjennomsnittlig score (Weighted Accuracy)
        avg_score = scores.mean().item()
        
        # Standard Accuracy (for referanse)
        acc = (all_preds == all_targets).float().mean().item() * 100
        
        # Parse navn for gruppering
        if "Probe" in name:
            family = "Probe"
            epochs = "Teacher"
        else:
            parts = name.split(" - ")
            family = parts[0]
            epochs = parts[1] if len(parts) > 1 else "Unknown"
            
        results.append({
            "Full Name": name,
            "Family": family,
            "Epochs": epochs,
            "Standard Accuracy (%)": acc,
            "Weighted Human Score (Max 9)": avg_score
        })
        
    return pd.DataFrame(results)

# ==========================================
# 6. PLOTTING OG LAGRING
# ==========================================
def run_weighted_analysis():
    # Last data
    score_matrix = load_weighted_score_matrix(JSON_PATH)
    loader = get_test_loader()
    
    # Evaluer
    df = evaluate_models(POSSIBLE_MODELS, loader, score_matrix)
    
    if df.empty:
        print("Ingen resultater √• vise.")
        return

    # Sorter
    df = df.sort_values(by="Weighted Human Score (Max 9)", ascending=False)
    
    # Lagre CSV
    csv_path = ARTIFACTS_DIR / "weighted_human_accuracy_9p.csv"
    df.to_csv(csv_path, index=False)
    print(f"\nResultater lagret til: {csv_path}")
    
    # Print tabell
    print("\n" + "="*80)
    print("RESULTAT: WEIGHTED HUMAN ACCURACY (Rett=9p, Feil=Likhetspoeng)")
    print("-" * 80)
    print(df[["Full Name", "Standard Accuracy (%)", "Weighted Human Score (Max 9)"]].to_string(index=False, float_format="%.3f"))
    print("="*80)

    # --- PLOT 1: GROUPED BAR CHART (Epoch Comparison) ---
    plt.figure(figsize=(14, 8))
    
    # Filtrer vekk Probe for dette plottet for √• sammenligne treningsmetoder renere
    plot_df = df[df["Family"] != "Probe"]
    
    sns.barplot(
        data=plot_df,
        x="Family",
        y="Weighted Human Score (Max 9)",
        hue="Epochs",
        palette="viridis",
        edgecolor="black"
    )
    
    # Legg til linje for Probe (Teacher) Score hvis den finnes
    probe_row = df[df["Family"] == "Probe"]
    if not probe_row.empty:
        probe_score = probe_row["Weighted Human Score (Max 9)"].values[0]
        plt.axhline(y=probe_score, color='red', linestyle='--', label=f'Teacher Score ({probe_score:.2f})')
    
    plt.title("Weighted Human Accuracy (Correct=9p)", fontsize=16)
    plt.ylabel("Average Weighted Score (Max 9)", fontsize=14)
    plt.xlabel("Model Variant", fontsize=14)
    plt.xticks(rotation=45, ha='right')
    plt.legend(title="Training Epochs")
    plt.grid(axis='y', linestyle='--', alpha=0.5)
    
    # Juster Y-akse for √• zoome inn p√• toppen (men med tak p√• 9.1)
    min_score = df["Weighted Human Score (Max 9)"].min()
    plt.ylim(bottom=min_score - 0.5, top=9.1) 
    plt.tight_layout()
    
    plot_path_bar = ARTIFACTS_DIR / "weighted_score_comparison_bar.png"
    plt.savefig(plot_path_bar, dpi=150)
    print(f"Barplot lagret til: {plot_path_bar}")
    plt.show()

    # --- PLOT 2: SCATTER (Acc vs Weighted Score) ---
    plt.figure(figsize=(12, 8))
    sns.scatterplot(
        data=df,
        x="Standard Accuracy (%)",
        y="Weighted Human Score (Max 9)",
        hue="Family",
        style="Epochs",
        s=150,
        palette="tab20",
        alpha=0.9
    )
    
    plt.title("Correlation: Standard Accuracy vs Weighted Score (Correct=9p)", fontsize=16)
    plt.ylabel("Weighted Human Score (Max 9)", fontsize=12)
    plt.grid(True, linestyle="--", alpha=0.5)
    plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0.)
    plt.tight_layout()
    
    plot_path_scatter = ARTIFACTS_DIR / "acc_vs_weighted_score_scatter.png"
    plt.savefig(plot_path_scatter, dpi=150)
    print(f"Scatterplot lagret til: {plot_path_scatter}")
    plt.show()

if __name__ == "__main__":
    run_weighted_analysis()