In [None]:
# === CELL 1 (Complete Experiment 4b Code) ===
import os
import json
import random
import time
from pathlib import Path
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T
from torchvision.models import resnet18
import matplotlib.pyplot as plt
import seaborn as sns

# ==========================================
# 1. KONFIGURASJON & OPPSETT
# ==========================================
EXPERIMENT_NAME = "07_exp4b_dynamic_boost"
DATA_DIR = Path("./data")
BASE_ARTIFACTS_DIR = Path(f"./artifacts/{EXPERIMENT_NAME}")
CKPT_DIR = Path("./checkpoints")

# Startpunkt: Vi bruker Logits lært av Proben (fra 01_setup_probe)
TARGETS_PATH = Path("./artifacts/01_probe_baseline/probe_targets.pt")

# Opprett mapper
BASE_ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
CKPT_DIR.mkdir(parents=True, exist_ok=True)

# Hyperparametere (Generelt)
RUNS = [50, 300]        # Kjører både 50 og 300 epoker
BATCH_SIZE = 128
LR = 1e-3
WEIGHT_DECAY = 1e-4
SEED = 42
NUM_WORKERS = 2
USE_AMP = True
GRAD_CLIP_NORM = 1.0

# Hyperparametere (Dynamikk - Fra eks4_3)
MIN_ACC_DIFF = 0.05         # Minimum forskjell i presisjon for at dynamikken skal slå inn
AUTH_BETA = 10.0            # Sigmoid-skarphet for autoritet
TARGET_UPDATE_LR = 0.10     # Steglengde for target-oppdatering
KAPPA_PUSHBACK = 0.10       # "Push" på eksperten (repulsjon)
TARGET_ANCHOR_GAMMA = 0.0   # VIKTIG: 0.0 i eks4_3 (ingen anker mot probe, lar targets flyte)
TARGET_CLAMP_M = 10.0       # Clamp targets til [-M, M]

# Hyperparametere (Vekting & Boost)
WEIGHT_GAMMA = 1.0
WEIGHT_EPS = 1e-6
TRUE_CLASS_BOOST = 8.0      # VIKTIG: Korrekt klasse vektes 8x mer i Loss
CE_ANCHOR = 0.05            # Lite anker for stabilitet
LOGIT_L2 = 0.001            # L2 på logits for å hindre eksplosjon

# Data & Klasser
NUM_CLASSES = 10
CIFAR10_CLASSES = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']

print(f"Konfigurasjon satt for {EXPERIMENT_NAME}.")
print(f"Artefakter lagres til: {BASE_ARTIFACTS_DIR}")

if not TARGETS_PATH.exists():
    raise FileNotFoundError(f"Fant ikke teacher targets på {TARGETS_PATH}. Har du kjørt 01_setup_probe.ipynb?")

# ==========================================
# 2. REPRODUSERBARHET
# ==========================================
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

set_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Kjører på: {device}")

# ==========================================
# 3. DATA LOADING
# ==========================================
train_tf = T.Compose([
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

eval_tf = T.Compose([
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

train_ds_aug = torchvision.datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=train_tf)
# Eval-loader for training set (uten aug) for å måle presisjon ("Autoritet")
train_ds_eval = torchvision.datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=eval_tf)
test_ds = torchvision.datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=eval_tf)

train_loader = DataLoader(train_ds_aug, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
train_eval_loader = DataLoader(train_ds_eval, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

# ==========================================
# 4. MODELL DEFINISJON
# ==========================================
def make_cifar_resnet18(num_classes=10):
    m = resnet18(weights=None)
    m.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    m.maxpool = nn.Identity()
    m.fc = nn.Linear(m.fc.in_features, num_classes)
    return m

# ==========================================
# 5. KJERNE-LOGIKK: DYNAMISK OPPDATERING
# ==========================================
@torch.no_grad()
def update_target_logits(
    T, acc, T0=None, beta=10.0, lr=0.10, kappa=0.10, 
    min_diff=0.05, anchor_gamma=0.0, clamp_M=10.0
):
    """
    Oppdaterer mål-matrisen T basert på asymmetrisk autoritet.
    Identisk logikk som i 06_exp4a / eks4_2 / eks4_3.
    """
    C = T.shape[0]
    T_new = T.clone()
    updates_made = 0

    for i in range(C):          # i = novise
        for j in range(C):      # j = ekspert
            if i == j: continue

            diff = acc[j] - acc[i]
            if diff <= min_diff: continue

            # Autoritet
            p = torch.sigmoid(beta * diff)

            # Ekspertens syn på relasjonen
            expert_view = T[j, i]
            novice_view = T[i, j]

            # Trekk novisen mot eksperten
            change = lr * p * (expert_view - novice_view)
            T_new[i, j] += change

            # Pushback (repulsjon) på eksperten
            if kappa > 0:
                T_new[j, i] -= kappa * change

            updates_made += 1

    # Anker (Her satt til 0.0 i dette eksperimentet for å la ting flyte)
    if T0 is not None and anchor_gamma > 0:
        T_new = (1.0 - anchor_gamma) * T_new + anchor_gamma * T0

    # Clamp
    if clamp_M is not None and clamp_M > 0:
        T_new = T_new.clamp(min=-clamp_M, max=clamp_M)

    return T_new, updates_made

# ==========================================
# 6. CUSTOM LOSS: BOOSTED DYNAMIC WMSE
# ==========================================
class DynamicTargetWMSE(nn.Module):
    def __init__(self, ce_anchor=0.0, logit_l2=0.0, true_class_boost=1.0):
        super().__init__()
        self.ce_anchor = ce_anchor
        self.logit_l2 = logit_l2
        self.true_class_boost = float(true_class_boost)
        self.ce = nn.CrossEntropyLoss()

    def forward(self, logits, y, current_targets, current_weights):
        # 1. Hent targets for batchen basert på y
        batch_targets = current_targets[y] # (B, C)
        
        # 2. WMSE Loss
        diff = (logits - batch_targets) ** 2
        
        # Expand weights to batch: (1, C) -> (B, C)
        B, C = logits.shape
        w_batch = current_weights.view(1, -1).expand(B, -1).clone()
        
        # --- TRUE CLASS BOOST LOGIC ---
        # Hvis boost er aktivert, øk vekten for den sanne klassen i hver sample
        if self.true_class_boost != 1.0:
            idx = torch.arange(B, device=logits.device)
            w_batch[idx, y] *= self.true_class_boost
        
        weighted_mse = (diff * w_batch).mean()
        loss = weighted_mse
        
        # 3. Stabilisering
        if self.ce_anchor > 0:
            loss += self.ce_anchor * self.ce(logits, y)
            
        if self.logit_l2 > 0:
            loss += self.logit_l2 * (logits ** 2).mean()
            
        return loss

# ==========================================
# 7. HJELPEFUNKSJONER (Metrics & Viz)
# ==========================================
@torch.no_grad()
def per_class_accuracy(model, loader, num_classes=10):
    model.eval()
    correct = torch.zeros(num_classes, device=device)
    counts = torch.zeros(num_classes, device=device)
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        pred = model(x).argmax(dim=1)
        for c in range(num_classes):
            mask = (y == c)
            if mask.any():
                counts[c] += mask.sum()
                correct[c] += (pred[mask] == c).sum()
    acc = correct / counts.clamp_min(1)
    return acc

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        logits = model(x)
        correct += (logits.argmax(dim=1) == y).sum().item()
        total += x.size(0)
    return correct / total

@torch.no_grad()
def weights_from_acc(acc, gamma=1.0, eps=1e-6):
    w = (acc.clamp_min(0.0) + eps).pow(gamma)
    w = w / w.mean().clamp_min(1e-8)
    return w

def save_target_heatmap(logits_matrix, classes, filename, title_suffix=""):
    probs = torch.softmax(logits_matrix, dim=1).cpu().numpy()
    plt.figure(figsize=(12, 10))
    sns.heatmap(
        probs, xticklabels=classes, yticklabels=classes,
        annot=True, fmt=".2f", cmap="viridis",
        cbar_kws={'label': 'Target Probability'}
    )
    plt.title(f"Dynamic Targets {title_suffix}")
    plt.xlabel("Predicted Class (Target)")
    plt.ylabel("True Class (Source)")
    plt.tight_layout()
    plt.savefig(filename)
    plt.close()
    print(f"Heatmap lagret til: {filename}")

# ==========================================
# 8. HOVEDLØKKE (RUNNER)
# ==========================================
for max_epochs in RUNS:
    print(f"\n{'='*40}")
    print(f"  STARTER KJØRING: {max_epochs} EPOKER (Exp 4b - Boosted)")
    print(f"{'='*40}")
    
    # Initier for denne kjøringen
    run_dir = BASE_ARTIFACTS_DIR / f"run_{max_epochs}ep"
    run_dir.mkdir(exist_ok=True)
    
    # Last inn start-punktene (Probe)
    payload = torch.load(TARGETS_PATH, map_location=device)
    target_logits_matrix = payload["class_avg_logits"].to(device)
    target_logits_matrix_0 = target_logits_matrix.clone()
    
    print(f"Start-targets lastet. Form: {target_logits_matrix.shape}")

    # Modell
    model = make_cifar_resnet18(NUM_CLASSES).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)
    
    # Kriterium med TRUE CLASS BOOST
    criterion = DynamicTargetWMSE(
        ce_anchor=CE_ANCHOR, 
        logit_l2=LOGIT_L2, 
        true_class_boost=TRUE_CLASS_BOOST
    )
    
    best_acc = 0.0
    history = {"train_loss": [], "test_acc": [], "target_updates": []}
    start_time = time.time()
    
    for ep in range(1, max_epochs + 1):
        # --- STEG 1: Evaluer Presisjon & Oppdater Mål ---
        if ep == 1:
            current_acc = torch.ones(NUM_CLASSES, device=device) * 0.5
            updates_count = 0
        else:
            current_acc = per_class_accuracy(model, train_eval_loader, NUM_CLASSES)
            
            # Oppdater Mål-matrisen
            target_logits_matrix, updates_count = update_target_logits(
                T=target_logits_matrix,
                acc=current_acc,
                T0=target_logits_matrix_0,
                beta=AUTH_BETA,
                lr=TARGET_UPDATE_LR,
                kappa=KAPPA_PUSHBACK,
                min_diff=MIN_ACC_DIFF,
                anchor_gamma=TARGET_ANCHOR_GAMMA, # Her er denne 0.0
                clamp_M=TARGET_CLAMP_M
            )
        
        history["target_updates"].append(updates_count)
        
        # Beregn vekter
        weights = weights_from_acc(current_acc, gamma=WEIGHT_GAMMA, eps=WEIGHT_EPS)
        
        # --- STEG 2: Trening ---
        model.train()
        total_loss = 0
        count = 0
        
        for x, y in train_loader:
            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
            optimizer.zero_grad(set_to_none=True)
            
            with torch.autocast(device_type="cuda" if device.type=="cuda" else "cpu", enabled=USE_AMP):
                logits = model(x)
                # Loss med boost
                loss = criterion(logits, y, target_logits_matrix, weights)
                
            scaler.scale(loss).backward()
            
            if GRAD_CLIP_NORM > 0:
                scaler.unscale_(optimizer)
                nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_NORM)
                
            scaler.step(optimizer)
            scaler.update()
            
            total_loss += loss.item()
            count += 1
            
        avg_loss = total_loss / count
        
        # --- STEG 3: Evaluering ---
        val_acc = evaluate(model, test_loader)
        
        history["train_loss"].append(avg_loss)
        history["test_acc"].append(val_acc)
        
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), CKPT_DIR / f"{EXPERIMENT_NAME}_{max_epochs}ep_best.pth")
            
        # Logging
        if ep % 10 == 0 or ep == 1:
            elapsed = time.time() - start_time
            print(f"Ep {ep:03d} | Loss: {avg_loss:.4f} | Test Acc: {val_acc:.4f} | "
                  f"Updates: {updates_count} | T: {elapsed:.0f}s")
            
            # Vis dynamikk (Top vs Bunn)
            if ep > 1:
                top_cls = CIFAR10_CLASSES[current_acc.argmax()]
                bot_cls = CIFAR10_CLASSES[current_acc.argmin()]
                print(f"   -> Top: {top_cls} ({current_acc.max():.2f}), Bot: {bot_cls} ({current_acc.min():.2f})")

    # Lagre resultater
    res_file = run_dir / "results.json"
    with open(res_file, "w") as f:
        json.dump({
            "config": {
                "epochs": max_epochs,
                "auth_beta": AUTH_BETA,
                "target_update_lr": TARGET_UPDATE_LR,
                "true_class_boost": TRUE_CLASS_BOOST,
                "anchor_gamma": TARGET_ANCHOR_GAMMA
            },
            "best_acc": best_acc,
            "history": history
        }, f, indent=2)
        
    # Heatmap av slutt-tilstand
    save_target_heatmap(
        target_logits_matrix, 
        CIFAR10_CLASSES, 
        run_dir / "final_learned_distribution.png",
        title_suffix=f"(Run {max_epochs} eps - Boost {TRUE_CLASS_BOOST}x)"
    )
    
    # Lagre matrisen
    torch.save(target_logits_matrix, run_dir / "final_target_matrix.pt")
    
    # Plot
    plt.figure(figsize=(10, 5))
    plt.plot(history["test_acc"], label="Test Acc")
    plt.title(f"Training Progress (Boost={TRUE_CLASS_BOOST})")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.grid(True)
    plt.savefig(run_dir / "training_metrics.png")
    plt.close()

    print(f"Ferdig med {max_epochs} epoker. Beste Acc: {best_acc:.4f}. Data lagret i {run_dir}")

print(f"\nAlle kjøringer fullført for {EXPERIMENT_NAME}!")