In [None]:
# === CELL 1 (Complete Experiment 3b Code) ===
import os
import json
import random
import time
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T
from torchvision.models import resnet18
import matplotlib.pyplot as plt
import seaborn as sns

# ==========================================
# 1. KONFIGURASJON & OPPSETT
# ==========================================
EXPERIMENT_NAME = "05_exp3b_static_logits"
DATA_DIR = Path("./data")
BASE_ARTIFACTS_DIR = Path(f"./artifacts/{EXPERIMENT_NAME}")
CKPT_DIR = Path("./checkpoints")

# Sti til Teacher Targets (fra 01_setup_probe)
# Denne filen inneholder "class_avg_logits" som vi trenger her
TARGETS_PATH = Path("./artifacts/01_probe_baseline/probe_targets.pt")

# Opprett mapper
BASE_ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
CKPT_DIR.mkdir(parents=True, exist_ok=True)

# Hyperparametere
RUNS = [50, 300]        # Kjører både 50 og 300 epoker
BATCH_SIZE = 128
LR = 1e-3
WEIGHT_DECAY = 1e-4
SEED = 42
NUM_WORKERS = 2
USE_AMP = True
GRAD_CLIP_NORM = 1.0    # Hindrer gradient explosion ved WMSE

# Exp 3b Spesifikke parametere (fra gammel kode)
# WMSE weighting from per-class accuracy
WEIGHT_GAMMA = 1.0      # >1 emphasizes strong classes even more
WEIGHT_EPS = 1e-6       # numerical floor before normalization
REWEIGHT_EVERY_EPOCH = True # update weights once per epoch based on validation acc

# Stabilizers
CE_ANCHOR = 0.05        # small CE anchor (5%) for stability
LOGIT_L2 = 0.001        # small L2 on logits to avoid value drift/explosion
CENTER_INPUTS = False   # Keep False as per original best runs

# Data & Klasser
NUM_CLASSES = 10
CIFAR10_CLASSES = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']

print(f"Konfigurasjon satt for {EXPERIMENT_NAME}.")
print(f"Artefakter lagres til: {BASE_ARTIFACTS_DIR}")

if not TARGETS_PATH.exists():
    raise FileNotFoundError(f"Fant ikke teacher targets på {TARGETS_PATH}. Har du kjørt 01_setup_probe.ipynb?")

# ==========================================
# 2. REPRODUSERBARHET
# ==========================================
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

set_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Kjører på: {device}")

# ==========================================
# 3. DATA LOADING
# ==========================================
train_tf = T.Compose([
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

eval_tf = T.Compose([
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

train_ds_aug = torchvision.datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=train_tf)
# Vi trenger train_ds_eval (uten aug) for å måle per-class accuracy presist for vekting
train_ds_eval = torchvision.datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=eval_tf)
test_ds = torchvision.datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=eval_tf)

train_loader = DataLoader(train_ds_aug, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
train_eval_loader = DataLoader(train_ds_eval, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

# ==========================================
# 4. MODELL DEFINISJON
# ==========================================
def make_cifar_resnet18(num_classes=10):
    m = resnet18(weights=None)
    m.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    m.maxpool = nn.Identity()
    m.fc = nn.Linear(m.fc.in_features, num_classes)
    return m

# ==========================================
# 5. CUSTOM LOSS: LOGIT WMSE (Exp 3b)
# ==========================================
class LogitTargetWMSENLoss(nn.Module):
    """
    Weighted MSE between model logits and class-conditional average logit targets.
    
    Given:
      - target_table: (C, C) average logits from probe (row = true class, col = logit-dim)
      - weights: (C,) per-dimension weights (one per class/logit dimension), updated externally
    """
    def __init__(self, target_table: torch.Tensor, 
                 center_inputs: bool = False, 
                 ce_anchor: float = 0.0, 
                 logit_l2: float = 0.0):
        super().__init__()
        self.register_buffer("T", target_table)  # C x C
        self.center_inputs = bool(center_inputs)
        self.ce_anchor = float(ce_anchor)
        self.logit_l2 = float(logit_l2)
        self._weights = None  # device tensor (C,)
        self.ce_loss_fn = nn.CrossEntropyLoss()

    def set_weights(self, w: torch.Tensor):
        # Normalize weights to mean 1 (keeps scale consistent across epochs)
        w = w / w.mean().clamp_min(1e-8)
        self._weights = w.detach()

    def forward(self, logits, y):
        assert self._weights is not None, "Call set_weights(w) before forward."
        
        # 1. Hent targets for batchen
        targets = self.T[y]  # (B, C)
        x = logits
        
        if self.center_inputs:
            x = x - x.mean(dim=1, keepdim=True)

        # 2. Weighted MSE across dimensions (classes)
        # loss_i = mean_c w_c * (x_ic - t_ic)^2
        diff2 = (x - targets) ** 2
        weighted = diff2 * self._weights.view(1, -1)
        wmse = weighted.mean()

        # 3. Logit L2 Regularization (hindrer eksplosjon)
        if self.logit_l2 > 0.0:
            wmse = wmse + self.logit_l2 * (x.pow(2).mean())

        # 4. CE Anchor (Stabilisering)
        if self.ce_anchor > 0.0:
            wmse = wmse + self.ce_anchor * self.ce_loss_fn(logits, y)

        return wmse

# ==========================================
# 6. HJELPEFUNKSJONER (Vekting & Metrics)
# ==========================================
@torch.no_grad()
def per_class_accuracy(model, loader, num_classes=NUM_CLASSES):
    """Beregner accuracy per klasse. Brukes for å vekte WMSE."""
    model.eval()
    correct = torch.zeros(num_classes, device=device)
    counts  = torch.zeros(num_classes, device=device)
    
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        logits = model(x)
        pred = logits.argmax(dim=1)
        
        for c in range(num_classes):
            mask = (y == c)
            if mask.any():
                counts[c] += mask.sum()
                correct[c] += (pred[mask] == c).sum()
                
    acc = correct.float() / counts.clamp_min(1)
    return acc

@torch.no_grad()
def weights_from_acc(acc: torch.Tensor, gamma: float = 1.0, eps: float = 1e-6):
    """
    Konverterer accuracy til vekter. Høyere accuracy -> Høyere vekt.
    Dette gjør at vi stoler mer på logit-verdier for klasser modellen mestrer.
    """
    base = (acc.clamp_min(0.0) + eps).pow(gamma)
    w = base / base.mean().clamp_min(1e-8)
    return w

@torch.no_grad()
def uniform_weights(num_classes=NUM_CLASSES):
    return torch.ones(num_classes, dtype=torch.float32, device=device)

# ==========================================
# 7. LAST INN OG VISUALISER TARGETS
# ==========================================
payload = torch.load(TARGETS_PATH, map_location=device)
# VIKTIG: Exp 3b bruker LOGITS, ikke PROBS
class_avg_logits = payload["class_avg_logits"] 

print(f"Lastet logit-targets med form: {class_avg_logits.shape}")

# Visualiser Logits (OBS: Logits kan være negative)
plt.figure(figsize=(12, 10))
sns.heatmap(
    class_avg_logits.cpu().numpy(),
    xticklabels=CIFAR10_CLASSES,
    yticklabels=CIFAR10_CLASSES,
    annot=True,
    fmt=".2f",
    cmap="coolwarm", # Coolwarm er bra for positive/negative verdier
    center=0.0,
    cbar_kws={'label': 'Target Logits (Teacher Avg)'}
)
plt.title(f"Experiment 3b Targets: Static Logits")
plt.xlabel("Predicted Logit Dimension")
plt.ylabel("True Class")
plt.tight_layout()

heatmap_path = BASE_ARTIFACTS_DIR / "target_logits_heatmap.png"
plt.savefig(heatmap_path)
print(f"Heatmap lagret til: {heatmap_path}")
plt.show()

# ==========================================
# 8. TRENINGSFUNKSJONER
# ==========================================
def train_one_epoch(model, loader, optimizer, scaler, criterion, grad_clip=0.0):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        
        with torch.autocast(device_type="cuda" if device.type == "cuda" else "cpu", enabled=USE_AMP):
            logits = model(x)
            loss = criterion(logits, y)
            
        scaler.scale(loss).backward()
        
        if grad_clip > 0:
            scaler.unscale_(optimizer)
            nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item() * x.size(0)
        correct += (logits.argmax(dim=1) == y).sum().item()
        total += x.size(0)
        
    return total_loss / total, correct / total

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    correct_top1 = 0
    correct_top2 = 0
    ranks_sum = 0.0
    total = 0
    
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        logits = model(x)
        probs = F.softmax(logits, dim=1)
        
        # Top-k
        top2 = probs.topk(2, dim=1).indices
        pred_top1 = top2[:, 0]
        correct_top1 += (pred_top1 == y).sum().item()
        correct_top2 += ((top2[:, 0] == y) | (top2[:, 1] == y)).sum().item()
        
        # Mean Rank
        sorted_indices = logits.argsort(dim=1, descending=True)
        ranks = (sorted_indices == y.view(-1, 1)).nonzero()[:, 1] + 1
        ranks_sum += ranks.float().sum().item()
        
        total += x.size(0)
        
    return {
        "acc1": correct_top1 / total,
        "acc2": correct_top2 / total,
        "mean_rank": ranks_sum / total
    }

# ==========================================
# 9. HOVEDLØKKE (RUNNER)
# ==========================================
for max_epochs in RUNS:
    print(f"\n{'='*40}")
    print(f"  STARTER KJØRING: {max_epochs} EPOKER")
    print(f"{'='*40}")
    
    # Initier for denne kjøringen
    run_dir = BASE_ARTIFACTS_DIR / f"run_{max_epochs}ep"
    run_dir.mkdir(exist_ok=True)
    
    model = make_cifar_resnet18(NUM_CLASSES).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)
    
    # Initialiser Criterion med target-tabellen
    criterion = LogitTargetWMSENLoss(
        target_table=class_avg_logits,
        center_inputs=CENTER_INPUTS,
        ce_anchor=CE_ANCHOR,
        logit_l2=LOGIT_L2
    )
    
    best_acc = 0.0
    history = {"train_loss": [], "train_acc": [], "test_acc1": [], "test_acc2": [], "mean_rank": []}
    weights_history = []
    
    start_time = time.time()
    
    for epoch in range(1, max_epochs + 1):
        # 1. Oppdater vekter basert på presisjon
        if REWEIGHT_EVERY_EPOCH and epoch > 1:
            # Vi bruker train_eval_loader for å se hva modellen faktisk kan på treningsdataene uten støy
            acc_c = per_class_accuracy(model, train_eval_loader, NUM_CLASSES)
            w = weights_from_acc(acc_c, gamma=WEIGHT_GAMMA, eps=WEIGHT_EPS)
        else:
            w = uniform_weights(NUM_CLASSES)
            
        criterion.set_weights(w)
        weights_history.append(w.detach().cpu().tolist())
        
        # 2. Tren
        t_loss, t_acc = train_one_epoch(model, train_loader, optimizer, scaler, criterion, grad_clip=GRAD_CLIP_NORM)
        
        # 3. Evaluer
        metrics = evaluate(model, test_loader)
        
        history["train_loss"].append(t_loss)
        history["train_acc"].append(t_acc)
        history["test_acc1"].append(metrics["acc1"])
        history["test_acc2"].append(metrics["acc2"])
        history["mean_rank"].append(metrics["mean_rank"])
        
        if metrics["acc1"] > best_acc:
            best_acc = metrics["acc1"]
            torch.save(model.state_dict(), CKPT_DIR / f"{EXPERIMENT_NAME}_{max_epochs}ep_best.pth")
            
        if epoch % 10 == 0 or epoch == 1:
            elapsed = time.time() - start_time
            # Logg også gjennomsnittsvekten for å se at systemet lever
            mean_w = w.mean().item()
            print(f"Ep {epoch:03d} | Loss: {t_loss:.4f} | TrAcc: {t_acc:.3f} | "
                  f"TeAcc1: {metrics['acc1']:.3f} | Rank: {metrics['mean_rank']:.2f} | "
                  f"W_mean: {mean_w:.2f} | T: {elapsed:.0f}s")
            
    # Lagre resultater
    res_file = run_dir / "results.json"
    with open(res_file, "w") as f:
        json.dump({
            "config": {
                "epochs": max_epochs,
                "weight_gamma": WEIGHT_GAMMA,
                "ce_anchor": CE_ANCHOR,
                "logit_l2": LOGIT_L2
            },
            "best_acc": best_acc,
            "history": history,
            "final_weights": weights_history[-1] # Lagre siste vekting
        }, f, indent=2)
        
    # Lagre vekthistorikk separat
    with open(run_dir / "weights_history.json", "w") as f:
        json.dump(weights_history, f)
        
    # Lagre plott
    plt.figure(figsize=(10, 5))
    plt.plot(history["train_loss"], label="Train Loss (WMSE)")
    plt.title(f"Training Loss ({max_epochs} epochs) - Exp 3b")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True)
    plt.savefig(run_dir / "loss_curve.png")
    plt.close()
    
    print(f"Ferdig med {max_epochs} epoker. Beste Acc: {best_acc:.4f}. Data lagret i {run_dir}")

print(f"\nAlle kjøringer fullført for {EXPERIMENT_NAME}!")