In [None]:
# === CELL 1 (Imports) ===
import os
import json
import random
import time
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T
from torchvision.models import resnet18
import matplotlib.pyplot as plt
import seaborn as sns

# === CELL 2 (Configuration) ===
# --- Konfigurasjon ---
EXPERIMENT_NAME = "04_exp3a_static_probs"
DATA_DIR = Path("./data")
BASE_ARTIFACTS_DIR = Path(f"./artifacts/{EXPERIMENT_NAME}")
CKPT_DIR = Path("./checkpoints")

# Sti til Teacher Targets (fra 01_setup_probe)
TARGETS_PATH = Path("./artifacts/01_probe_baseline/probe_targets.pt")

# Opprett mapper
BASE_ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
CKPT_DIR.mkdir(parents=True, exist_ok=True)

# Hyperparametere
RUNS = [50, 300]  # Vi skal kjøre to varianter: 50 epoker og 300 epoker
BATCH_SIZE = 128
LR = 1e-3
WEIGHT_DECAY = 1e-4
SEED = 42
NUM_WORKERS = 2
USE_AMP = True

# Exp 3a Spesifikke parametere (fra gammel kode)
KL_EPS = 1e-8           # Gulv for numerisk stabilitet
ONEHOT_MIX_EPS = 0.10   # Hvor mye "hard" one-hot vi blander inn i soft-targets (0.10 = 10%)
CE_ANCHOR = 0.0         # Lite anker mot standard CE loss (0.0 = kun KL)

# Data & Klasser
NUM_CLASSES = 10
CIFAR10_CLASSES = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']

print(f"Konfigurasjon satt for {EXPERIMENT_NAME}.")
print(f"Artefakter lagres til: {BASE_ARTIFACTS_DIR}")

if not TARGETS_PATH.exists():
    raise FileNotFoundError(f"Fant ikke teacher targets på {TARGETS_PATH}. Har du kjørt 01_setup_probe.ipynb?")

# === CELL 3 (Reproducibility) ===
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

set_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Kjører på: {device}")

# === CELL 4 (Data Loading) ===
train_tf = T.Compose([
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

eval_tf = T.Compose([
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

train_ds_aug = torchvision.datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=train_tf)
test_ds = torchvision.datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=eval_tf)

train_loader = DataLoader(train_ds_aug, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

# === CELL 5 (Model Definition) ===
def make_cifar_resnet18(num_classes=10):
    m = resnet18(weights=None)
    m.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    m.maxpool = nn.Identity()
    m.fc = nn.Linear(m.fc.in_features, num_classes)
    return m

# === CELL 6 (Custom Loss Function - KL Divergence) ===
class ProbabilityTargetLoss(nn.Module):
    """
    KLDivLoss mellom lærte klasse-gjennomsnittlige soft targets (T) og modellens prediksjoner (P).
    Inkluderer mulighet for å blande inn one-hot (mix_eps) og et CE-anker.
    """
    def __init__(self, class_avg_probs: torch.Tensor, onehot_mix_eps: float = 0.0, ce_anchor: float = 0.0):
        super().__init__()
        # T er en matrise [NumClasses, NumClasses] hvor rad 'c' er target-fordelingen for klasse 'c'.
        self.register_buffer("T", class_avg_probs)
        self.onehot_mix_eps = float(onehot_mix_eps)
        self.ce_anchor = float(ce_anchor)
        self.kldiv = nn.KLDivLoss(reduction="batchmean")
        self.ce = nn.CrossEntropyLoss()

    def forward(self, logits, y):
        # 1. Hent targets basert på sann klasse y
        targets = self.T[y]  # (B, C)

        # 2. Bland inn One-Hot hvis ønsket (Stabilisering)
        if self.onehot_mix_eps > 0.0:
            one_hot = torch.zeros_like(targets)
            one_hot.scatter_(1, y.view(-1, 1), 1.0)
            targets = (1.0 - self.onehot_mix_eps) * targets + self.onehot_mix_eps * one_hot
            
            # Normaliser og clamp
            targets = torch.clamp(targets, min=KL_EPS)
            targets = targets / targets.sum(dim=1, keepdim=True)

        # 3. Beregn KL Divergens
        # PyTorch KLDiv krever log-probabilities som input, og probabilities som target
        log_probs = F.log_softmax(logits, dim=1)
        loss = self.kldiv(log_probs, targets)

        # 4. Legg til CE Anker hvis ønsket
        if self.ce_anchor > 0.0:
            loss += self.ce_anchor * self.ce(logits, y)

        return loss

# === CELL 7 (Load & Visualize Targets) ===
# Last inn data generert av Sonde-modellen
payload = torch.load(TARGETS_PATH, map_location=device)
class_avg_probs = payload["class_avg_probs"] # Dette er soft targets

print(f"Lastet targets med form: {class_avg_probs.shape}")

# Visualiser hva modellen skal prøve å lære (Target Heatmap)
plt.figure(figsize=(12, 10))
sns.heatmap(
    class_avg_probs.cpu().numpy(),
    xticklabels=CIFAR10_CLASSES,
    yticklabels=CIFAR10_CLASSES,
    annot=True,
    fmt=".2f",
    cmap="viridis",
    cbar_kws={'label': 'Target Probability'}
)
plt.title(f"Experiment 3a Targets: Soft Probabilities (Probe-Learned)")
plt.xlabel("Predicted Class")
plt.ylabel("True Class")
plt.tight_layout()

heatmap_path = BASE_ARTIFACTS_DIR / "target_distribution_heatmap.png"
plt.savefig(heatmap_path)
print(f"Heatmap lagret til: {heatmap_path}")
plt.show()

# === CELL 8 (Training & Eval Functions) ===
def train_one_epoch(model, loader, optimizer, scaler, criterion):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        
        with torch.autocast(device_type="cuda" if device.type == "cuda" else "cpu", enabled=USE_AMP):
            logits = model(x)
            loss = criterion(logits, y)
            
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item() * x.size(0)
        correct += (logits.argmax(dim=1) == y).sum().item()
        total += x.size(0)
        
    return total_loss / total, correct / total

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    correct_top1 = 0
    correct_top2 = 0
    ranks_sum = 0.0
    total = 0
    
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        logits = model(x)
        probs = F.softmax(logits, dim=1)
        
        # Top-k metrics
        top2 = probs.topk(2, dim=1).indices
        pred_top1 = top2[:, 0]
        correct_top1 += (pred_top1 == y).sum().item()
        correct_top2 += ((top2[:, 0] == y) | (top2[:, 1] == y)).sum().item()
        
        # Mean Rank
        sorted_indices = logits.argsort(dim=1, descending=True)
        ranks = (sorted_indices == y.view(-1, 1)).nonzero()[:, 1] + 1
        ranks_sum += ranks.float().sum().item()
        
        total += x.size(0)
        
    return {
        "acc1": correct_top1 / total,
        "acc2": correct_top2 / total,
        "mean_rank": ranks_sum / total
    }

# === CELL 9 (Main Experiment Runner) ===
# Vi kjører loopen for både 50 og 300 epoker som forespurt
for max_epochs in RUNS:
    print(f"\n{'='*40}")
    print(f"  STARTER KJØRING: {max_epochs} EPOKER")
    print(f"{'='*40}")
    
    # Initier for denne kjøringen
    run_dir = BASE_ARTIFACTS_DIR / f"run_{max_epochs}ep"
    run_dir.mkdir(exist_ok=True)
    
    model = make_cifar_resnet18(NUM_CLASSES).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)
    
    # Sett opp KL-tapsfunksjonen med våre loaded targets
    criterion = ProbabilityTargetLoss(
        class_avg_probs=class_avg_probs, 
        onehot_mix_eps=ONEHOT_MIX_EPS, 
        ce_anchor=CE_ANCHOR
    )
    
    best_acc = 0.0
    history = {"train_loss": [], "train_acc": [], "test_acc1": [], "test_acc2": [], "mean_rank": []}
    start_time = time.time()
    
    for epoch in range(1, max_epochs + 1):
        t_loss, t_acc = train_one_epoch(model, train_loader, optimizer, scaler, criterion)
        metrics = evaluate(model, test_loader)
        
        history["train_loss"].append(t_loss)
        history["train_acc"].append(t_acc)
        history["test_acc1"].append(metrics["acc1"])
        history["test_acc2"].append(metrics["acc2"])
        history["mean_rank"].append(metrics["mean_rank"])
        
        if metrics["acc1"] > best_acc:
            best_acc = metrics["acc1"]
            torch.save(model.state_dict(), CKPT_DIR / f"{EXPERIMENT_NAME}_{max_epochs}ep_best.pth")
            
        if epoch % 10 == 0 or epoch == 1:
            elapsed = time.time() - start_time
            print(f"Ep {epoch:03d} | Loss: {t_loss:.4f} | TrAcc: {t_acc:.3f} | "
                  f"TeAcc1: {metrics['acc1']:.3f} | Rank: {metrics['mean_rank']:.2f} | T: {elapsed:.0f}s")
            
    # Lagre resultater for denne kjøringen
    res_file = run_dir / "results.json"
    with open(res_file, "w") as f:
        json.dump({
            "config": {
                "epochs": max_epochs,
                "onehot_mix_eps": ONEHOT_MIX_EPS,
                "ce_anchor": CE_ANCHOR
            },
            "best_acc": best_acc,
            "history": history
        }, f, indent=2)
        
    # Lagre plott av trening
    plt.figure(figsize=(10, 5))
    plt.plot(history["train_loss"], label="Train Loss")
    plt.title(f"Training Loss ({max_epochs} epochs)")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True)
    plt.savefig(run_dir / "loss_curve.png")
    plt.close()
    
    print(f"Ferdig med {max_epochs} epoker. Beste Acc: {best_acc:.4f}. Data lagret i {run_dir}")

print(f"\nAlle kjøringer fullført for {EXPERIMENT_NAME}!")