In [None]:
# === CELL 1 (Imports) ===
import os
import json
import random
import time
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T
from torchvision.models import resnet18
import matplotlib.pyplot as plt
import seaborn as sns

# === CELL 2 (Configuration) ===
# --- Konfigurasjon ---
EXPERIMENT_NAME = "01_probe_baseline"
DATA_DIR = Path("./data")
ARTIFACTS_DIR = Path(f"./artifacts/{EXPERIMENT_NAME}")
CKPT_DIR = Path("./checkpoints")

# Opprett mapper
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
CKPT_DIR.mkdir(parents=True, exist_ok=True)

# Hyperparametere
# Vi bruker 50 epoker for å sikre en veldig sterk "Lærer" (Probe).
# Dette tilsvarer den "sterke sonden" (v2) fra det gamle repoet.
EPOCHS = 50 
BATCH_SIZE = 128
LR = 1e-3
WEIGHT_DECAY = 1e-4
SEED = 42
NUM_WORKERS = 2
USE_AMP = True

# Data & Klasser
NUM_CLASSES = 10
CIFAR10_CLASSES = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']

print(f"Konfigurasjon satt. Artefakter lagres til: {ARTIFACTS_DIR}")

# === CELL 3 (Reproducibility) ===
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

set_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Kjører på: {device}")

# === CELL 4 (Data Loading) ===
# Samme transformasjoner som originalt repo
train_tf = T.Compose([
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Eval transform (ingen augmentering) for validering OG ekstraksjon av targets
eval_tf = T.Compose([
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Datasets
train_ds = torchvision.datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=train_tf)
# Vi trenger treningssettet uten augmentering for å hente ut "sannheten" modellen har lært
train_ds_eval = torchvision.datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=eval_tf)
test_ds = torchvision.datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=eval_tf)

# Loaders
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
train_eval_loader = DataLoader(train_ds_eval, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

print(f"Data lastet. Train size: {len(train_ds)}, Test size: {len(test_ds)}")

# === CELL 5 (Model Definition) ===
def make_cifar_resnet18(num_classes=10):
    """
    Standard ResNet18 tilpasset CIFAR-10 (slik som i originalt repo).
    Fjerner maxpool og endrer første conv til 3x3 kernel.
    """
    m = resnet18(weights=None)
    m.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    m.maxpool = nn.Identity()
    m.fc = nn.Linear(m.fc.in_features, num_classes)
    return m

model = make_cifar_resnet18(NUM_CLASSES).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)
criterion = nn.CrossEntropyLoss()

# === CELL 6 (Training & Eval Functions) ===
def train_one_epoch(model, loader, optimizer, scaler):
    model.train()
    total_loss, correct, total = 0, 0, 0
    
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        
        with torch.autocast(device_type="cuda" if device.type == "cuda" else "cpu", enabled=USE_AMP):
            logits = model(x)
            loss = criterion(logits, y)
            
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item() * x.size(0)
        correct += (logits.argmax(dim=1) == y).sum().item()
        total += x.size(0)
        
    return total_loss / total, correct / total

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        logits = model(x)
        correct += (logits.argmax(dim=1) == y).sum().item()
        total += x.size(0)
    return correct / total

# === CELL 7 (Main Training Loop) ===
print(f"Starter trening av Probe i {EPOCHS} epoker...")
best_acc = 0.0
stats = {"train_loss": [], "train_acc": [], "test_acc": []}
start_time = time.time()

for epoch in range(1, EPOCHS + 1):
    t_loss, t_acc = train_one_epoch(model, train_loader, optimizer, scaler)
    val_acc = evaluate(model, test_loader)
    
    stats["train_loss"].append(t_loss)
    stats["train_acc"].append(t_acc)
    stats["test_acc"].append(val_acc)
    
    # Lagre beste modell
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), CKPT_DIR / "probe_best.pth")
        
    print(f"Epoke {epoch:02d} | Loss: {t_loss:.4f} | Train Acc: {t_acc:.4f} | Test Acc: {val_acc:.4f}")

total_time = time.time() - start_time
print(f"Ferdig! Beste Test Acc: {best_acc:.4f}. Tid: {total_time/60:.1f} min.")

# Lagre slutt-statistikk
with open(ARTIFACTS_DIR / "training_stats.json", "w") as f:
    json.dump(stats, f, indent=2)

# === CELL 8 (Extract Stats / Targets) ===
# Her genererer vi "Sannheten" som de andre eksperimentene skal bruke.
# Vi gjør det nøyaktig som i originalt repo:
# - avg_probs: Gjennomsnitt av softmax (for Exp 3a)
# - avg_logits: Gjennomsnitt av rå logits (for Exp 3b, 4, 5)

print("Laster beste modell for å generere targets...")
model.load_state_dict(torch.load(CKPT_DIR / "probe_best.pth", map_location=device))
model.eval()

sum_probs = torch.zeros(NUM_CLASSES, NUM_CLASSES, dtype=torch.float64, device=device)
sum_logits = torch.zeros(NUM_CLASSES, NUM_CLASSES, dtype=torch.float64, device=device)
counts = torch.zeros(NUM_CLASSES, dtype=torch.long, device=device)

print("Analyserer treningssettet (eval-modus)...")
with torch.no_grad():
    for x, y in train_eval_loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        probs = F.softmax(logits, dim=1)
        
        # Akkumuler per klasse
        for c in range(NUM_CLASSES):
            mask = (y == c)
            if mask.any():
                # sum_probs: brukes av Exp 3a (KL Div)
                sum_probs[c] += probs[mask].sum(dim=0)
                # sum_logits: brukes av Exp 3b, 4, 5 (WMSE)
                sum_logits[c] += logits[mask].sum(dim=0)
                counts[c] += mask.sum()

# Beregn gjennomsnitt
avg_probs = (sum_probs / counts.view(-1, 1).clamp_min(1)).float().cpu()
avg_logits = (sum_logits / counts.view(-1, 1).clamp_min(1)).float().cpu()

# Lagre filene
targets_path = ARTIFACTS_DIR / "probe_targets.pt"
torch.save({
    "class_avg_probs": avg_probs,
    "class_avg_logits": avg_logits,
    "counts": counts.cpu(),
    "classes": CIFAR10_CLASSES
}, targets_path)

print(f"Targets (både logits og probs) lagret til: {targets_path}")

# === CELL 9 (Visualization) ===
# Generer heatmap av sannsynlighetsfordelingen (Final Target Distribution)
# Dette viser hva Læreren mener hver klasse "ligner på".

plt.figure(figsize=(12, 10))
sns.heatmap(
    avg_probs.numpy(),
    xticklabels=CIFAR10_CLASSES,
    yticklabels=CIFAR10_CLASSES,
    annot=True,
    fmt=".2f",
    cmap="viridis",
    cbar_kws={'label': 'Teacher Confidence (Softmax Avg)'}
)
plt.title(f"Probe Learned Distribution (Acc: {best_acc:.2%})")
plt.xlabel("Predicted Class")
plt.ylabel("True Class")
plt.tight_layout()

heatmap_path = ARTIFACTS_DIR / "probe_distribution_heatmap.png"
plt.savefig(heatmap_path)
print(f"Heatmap lagret til: {heatmap_path}")
plt.show()