In [None]:
# === CELL 1 (Complete Experiment 5a Code) ===
import os
import json
import random
import time
from pathlib import Path
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T
from torchvision.models import resnet18
import matplotlib.pyplot as plt
import seaborn as sns

# ==========================================
# 1. KONFIGURASJON & OPPSETT
# ==========================================
EXPERIMENT_NAME = "08_exp5a_capped"
DATA_DIR = Path("./data")
BASE_ARTIFACTS_DIR = Path(f"./artifacts/{EXPERIMENT_NAME}")
CKPT_DIR = Path("./checkpoints")

# Startpunkt: Vi bruker Logits lært av Proben (fra 01_setup_probe)
TARGETS_PATH = Path("./artifacts/01_probe_baseline/probe_targets.pt")

# Opprett mapper
BASE_ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
CKPT_DIR.mkdir(parents=True, exist_ok=True)

# --- Hyperparametere for Eksperimentet ---
# Loop 1: Treningslengde
RUNS_EPOCHS = [50, 300]

# Loop 2: Probability Caps (Hvor sikker får modellen lov til å være på seg selv?)
# 0.60 = Veldig usikker (myk), 0.80 = Balansert (original exp 5), 0.95 = Nesten One-Hot
CAP_SETTINGS = [0.60, 0.80, 0.95] 

# Generelle treningsparametre
BATCH_SIZE = 128
LR = 1e-3
WEIGHT_DECAY = 1e-4
SEED = 42
NUM_WORKERS = 2
USE_AMP = True
GRAD_CLIP_NORM = 1.0

# Parametere for Dynamikk (Arvet fra Exp 4b)
MIN_ACC_DIFF = 0.05         # Min forskjell for autoritet
AUTH_BETA = 10.0            # Sigmoid skarphet
TARGET_UPDATE_LR = 0.10     # Steglengde for targets
KAPPA_PUSHBACK = 0.10       # Repulsjon
TARGET_ANCHOR_GAMMA = 0.0   # Ingen anker, la targets flyte (men begrenses av Cap)
TARGET_CLAMP_M = 10.0       # Numerisk stabilitet

# Parametere for Vekting & Boost (Arvet fra Exp 4b)
WEIGHT_GAMMA = 1.0
WEIGHT_EPS = 1e-6
TRUE_CLASS_BOOST = 8.0      # Korrekt klasse vektes 8x mer i Loss
CE_ANCHOR = 0.05            # Lite anker for stabilitet
LOGIT_L2 = 0.001            # L2 på logits

# Data & Klasser
NUM_CLASSES = 10
CIFAR10_CLASSES = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']

print(f"Konfigurasjon satt for {EXPERIMENT_NAME}.")
print(f"Artefakter lagres til: {BASE_ARTIFACTS_DIR}")

if not TARGETS_PATH.exists():
    raise FileNotFoundError(f"Fant ikke teacher targets på {TARGETS_PATH}. Har du kjørt 01_setup_probe.ipynb?")

# ==========================================
# 2. REPRODUSERBARHET
# ==========================================
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

set_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Kjører på: {device}")

# ==========================================
# 3. DATA LOADING
# ==========================================
train_tf = T.Compose([
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

eval_tf = T.Compose([
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

train_ds_aug = torchvision.datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=train_tf)
# Eval-loader for training set (uten aug) for å måle presisjon ("Autoritet")
train_ds_eval = torchvision.datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=eval_tf)
test_ds = torchvision.datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=eval_tf)

train_loader = DataLoader(train_ds_aug, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
train_eval_loader = DataLoader(train_ds_eval, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

# ==========================================
# 4. MODELL DEFINISJON
# ==========================================
def make_cifar_resnet18(num_classes=10):
    m = resnet18(weights=None)
    m.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    m.maxpool = nn.Identity()
    m.fc = nn.Linear(m.fc.in_features, num_classes)
    return m

# ==========================================
# 5. KJERNE-LOGIKK: CAPPING OG DYNAMIKK
# ==========================================

@torch.no_grad()
def apply_probability_cap(logits_matrix, max_p=0.80):
    """
    Tvinger diagonalen (seg selv) til å være maks 'max_p'.
    Overskytende sannsynlighet fordeles proporsjonalt på de andre klassene (off-diagonal).
    Returnerer logits (samme skala/magnitude som input).
    """
    C = logits_matrix.shape[0]
    probs = torch.softmax(logits_matrix, dim=1)
    
    for i in range(C):
        p_self = probs[i, i]
        
        if p_self > max_p:
            # 1. Sett taket
            probs[i, i] = max_p
            
            # 2. Finn resten (off-diagonal)
            others_mask = torch.ones(C, dtype=torch.bool, device=logits_matrix.device)
            others_mask[i] = False
            
            current_others_sum = probs[i, others_mask].sum()
            target_others_sum = 1.0 - max_p
            
            # 3. Skaler resten opp proporsjonalt
            if current_others_sum > 1e-9:
                scale = target_others_sum / current_others_sum
                probs[i, others_mask] *= scale
            else:
                # Hvis de andre var 0, fordel uniformt
                fill_val = target_others_sum / (C - 1)
                probs[i, others_mask] = fill_val

    # Konverter tilbake til logits, bevar dynamisk range
    new_logits = torch.log(probs + 1e-9)
    old_max_vals = logits_matrix.max(dim=1, keepdim=True).values
    new_max_vals = new_logits.max(dim=1, keepdim=True).values
    final_logits = new_logits - new_max_vals + old_max_vals
    
    return final_logits

@torch.no_grad()
def update_target_logits(
    T, acc, T0=None, beta=10.0, lr=0.10, kappa=0.10, 
    min_diff=0.05, anchor_gamma=0.0, clamp_M=10.0
):
    """Dynamisk oppdatering av logits basert på presisjon."""
    C = T.shape[0]
    T_new = T.clone()
    updates_made = 0

    for i in range(C):          # i = novise
        for j in range(C):      # j = ekspert
            if i == j: continue

            diff = acc[j] - acc[i]
            if diff <= min_diff: continue

            p = torch.sigmoid(beta * diff)
            expert_view = T[j, i]
            novice_view = T[i, j]

            change = lr * p * (expert_view - novice_view)
            T_new[i, j] += change

            if kappa > 0:
                T_new[j, i] -= kappa * change # Pushback

            updates_made += 1

    if T0 is not None and anchor_gamma > 0:
        T_new = (1.0 - anchor_gamma) * T_new + anchor_gamma * T0

    if clamp_M is not None and clamp_M > 0:
        T_new = T_new.clamp(min=-clamp_M, max=clamp_M)

    return T_new, updates_made

# ==========================================
# 6. CUSTOM LOSS
# ==========================================
class DynamicTargetWMSE(nn.Module):
    def __init__(self, ce_anchor=0.0, logit_l2=0.0, true_class_boost=1.0):
        super().__init__()
        self.ce_anchor = ce_anchor
        self.logit_l2 = logit_l2
        self.true_class_boost = float(true_class_boost)
        self.ce = nn.CrossEntropyLoss()

    def forward(self, logits, y, current_targets, current_weights):
        # Hent targets
        batch_targets = current_targets[y]
        
        # WMSE
        diff = (logits - batch_targets) ** 2
        B, C = logits.shape
        w_batch = current_weights.view(1, -1).expand(B, -1).clone()

        # True Class Boost
        if self.true_class_boost != 1.0:
            idx = torch.arange(B, device=logits.device)
            w_batch[idx, y] *= self.true_class_boost

        weighted_mse = (diff * w_batch).mean()
        loss = weighted_mse

        # Stabilisering
        if self.ce_anchor > 0:
            loss += self.ce_anchor * self.ce(logits, y)
        if self.logit_l2 > 0:
            loss += self.logit_l2 * (logits ** 2).mean()

        return loss

# ==========================================
# 7. HJELPEFUNKSJONER (Metrics & Viz)
# ==========================================
@torch.no_grad()
def per_class_accuracy(model, loader, num_classes=10):
    model.eval()
    correct = torch.zeros(num_classes, device=device)
    counts = torch.zeros(num_classes, device=device)
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        pred = model(x).argmax(dim=1)
        for c in range(num_classes):
            mask = (y == c)
            if mask.any():
                counts[c] += mask.sum()
                correct[c] += (pred[mask] == c).sum()
    acc = correct / counts.clamp_min(1)
    return acc

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        logits = model(x)
        correct += (logits.argmax(dim=1) == y).sum().item()
        total += x.size(0)
    return correct / total

@torch.no_grad()
def weights_from_acc(acc, gamma=1.0, eps=1e-6):
    w = (acc.clamp_min(0.0) + eps).pow(gamma)
    w = w / w.mean().clamp_min(1e-8)
    return w

def save_target_heatmap(logits_matrix, classes, filename, title_suffix=""):
    probs = torch.softmax(logits_matrix, dim=1).cpu().numpy()
    plt.figure(figsize=(12, 10))
    sns.heatmap(
        probs, xticklabels=classes, yticklabels=classes,
        annot=True, fmt=".2f", cmap="viridis",
        cbar_kws={'label': 'Target Probability'}
    )
    plt.title(f"Final Targets {title_suffix}")
    plt.xlabel("Predicted Class")
    plt.ylabel("True Class")
    plt.tight_layout()
    plt.savefig(filename)
    plt.close()
    print(f"Heatmap lagret til: {filename}")

# ==========================================
# 8. HOVEDLØKKE (RUNNER)
# ==========================================
# Vi itererer over både Epoker og Probability Caps
for max_epochs in RUNS_EPOCHS:
    for cap in CAP_SETTINGS:
        
        run_name = f"run_{max_epochs}ep_cap{cap}"
        print(f"\n{'='*60}")
        print(f"  STARTER KJØRING: {max_epochs} EPOKER | CAP: {cap} (Max Self-Prob)")
        print(f"{'='*60}")
        
        # Initier mappe for denne spesifikke kombinasjonen
        run_dir = BASE_ARTIFACTS_DIR / run_name
        run_dir.mkdir(exist_ok=True, parents=True)
        
        # 1. Last inn start-targets (Probe)
        payload = torch.load(TARGETS_PATH, map_location=device)
        target_logits_matrix = payload["class_avg_logits"].to(device)
        
        # 2. PÅFØR CAP PÅ START-MATRISEN
        # Dette er viktig for at utgangspunktet skal være "gyldig" iht reglene i eksperimentet
        print(f"Applierer Probability Cap ({cap}) på start-matrisen...")
        target_logits_matrix = apply_probability_cap(target_logits_matrix, max_p=cap)
        target_logits_matrix_0 = target_logits_matrix.clone() # Anker (hvis aktivert, her 0.0)

        # Modell & Optimizer
        model = make_cifar_resnet18(NUM_CLASSES).to(device)
        optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
        scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)
        
        # Kriterium
        criterion = DynamicTargetWMSE(
            ce_anchor=CE_ANCHOR, 
            logit_l2=LOGIT_L2, 
            true_class_boost=TRUE_CLASS_BOOST
        )
        
        best_acc = 0.0
        history = {"train_loss": [], "test_acc": [], "target_updates": []}
        start_time = time.time()
        
        for ep in range(1, max_epochs + 1):
            # --- FØR EPOKE: Evaluer & Oppdater Mål ---
            if ep == 1:
                current_acc = torch.ones(NUM_CLASSES, device=device) * 0.5
                updates_count = 0
            else:
                current_acc = per_class_accuracy(model, train_eval_loader, NUM_CLASSES)
                
                # A. Dynamisk Oppdatering (Basert på presisjon)
                target_logits_matrix, updates_count = update_target_logits(
                    T=target_logits_matrix,
                    acc=current_acc,
                    T0=target_logits_matrix_0,
                    beta=AUTH_BETA,
                    lr=TARGET_UPDATE_LR,
                    kappa=KAPPA_PUSHBACK,
                    min_diff=MIN_ACC_DIFF,
                    anchor_gamma=TARGET_ANCHOR_GAMMA,
                    clamp_M=TARGET_CLAMP_M
                )
                
                # B. Tvungen Capping (Etter dynamikken kan ting ha endret seg, tving cap på nytt)
                target_logits_matrix = apply_probability_cap(target_logits_matrix, max_p=cap)
            
            history["target_updates"].append(updates_count)
            
            # Beregn vekter for WMSE
            weights = weights_from_acc(current_acc, gamma=WEIGHT_GAMMA, eps=WEIGHT_EPS)
            
            # --- TRENING ---
            model.train()
            total_loss = 0
            count = 0
            
            for x, y in train_loader:
                x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
                optimizer.zero_grad(set_to_none=True)
                
                with torch.autocast(device_type="cuda" if device.type=="cuda" else "cpu", enabled=USE_AMP):
                    logits = model(x)
                    loss = criterion(logits, y, target_logits_matrix, weights)
                
                scaler.scale(loss).backward()
                
                if GRAD_CLIP_NORM > 0:
                    scaler.unscale_(optimizer)
                    nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_NORM)
                
                scaler.step(optimizer)
                scaler.update()
                
                total_loss += loss.item()
                count += 1
            
            avg_loss = total_loss / count
            
            # --- EVALUERING ---
            val_acc = evaluate(model, test_loader)
            history["train_loss"].append(avg_loss)
            history["test_acc"].append(val_acc)
            
            if val_acc > best_acc:
                best_acc = val_acc
                torch.save(model.state_dict(), CKPT_DIR / f"{EXPERIMENT_NAME}_{run_name}_best.pth")
            
            # Logging
            if ep % 10 == 0 or ep == 1:
                elapsed = time.time() - start_time
                print(f"Ep {ep:03d} | Loss: {avg_loss:.4f} | Test Acc: {val_acc:.4f} | Upd: {updates_count} | T: {elapsed:.0f}s")
                
                # Sjekk at Capping faktisk virker (print max self-prob)
                if ep % 20 == 0:
                    probs = torch.softmax(target_logits_matrix, dim=1)
                    max_observed = probs.diagonal().max().item()
                    print(f"   -> Max Observed Self-Prob: {max_observed:.4f} (Cap: {cap})")

        # --- ETTER KJØRING: LAGRE ALT ---
        
        # 1. Resultater JSON
        res_file = run_dir / "results.json"
        with open(res_file, "w") as f:
            json.dump({
                "config": {
                    "epochs": max_epochs,
                    "max_self_prob": cap,
                    "auth_beta": AUTH_BETA,
                    "true_class_boost": TRUE_CLASS_BOOST,
                    "lr": LR
                },
                "best_acc": best_acc,
                "history": history
            }, f, indent=2)
        
        # 2. Heatmap av slutt-distribusjon
        save_target_heatmap(
            target_logits_matrix, 
            CIFAR10_CLASSES, 
            run_dir / "final_learned_distribution.png",
            title_suffix=f"(Run {max_epochs}ep, Cap {cap})"
        )
        
        # 3. Lagre selve matrisen
        torch.save(target_logits_matrix, run_dir / "final_target_matrix.pt")
        
        # 4. Treningsplot
        plt.figure(figsize=(10, 5))
        plt.plot(history["test_acc"], label=f"Test Acc (Cap {cap})")
        plt.title(f"Training Progress ({max_epochs} epochs)")
        plt.xlabel("Epoch")
        plt.ylabel("Accuracy")
        plt.grid(True)
        plt.legend()
        plt.savefig(run_dir / "training_metrics.png")
        plt.close()
        
        print(f"Ferdig med {run_name}. Beste Acc: {best_acc:.4f}. Resultater i {run_dir}")

print(f"\nAlle kjøringer fullført for {EXPERIMENT_NAME}!")