In [None]:
# === CELL 1 (Imports) ===
import os
import json
import random
import time
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T
from torchvision.models import resnet18
import matplotlib.pyplot as plt
import seaborn as sns

# === CELL 2 (Configuration) ===
# --- Konfigurasjon ---
EXPERIMENT_NAME = "02_exp1_baseline"
DATA_DIR = Path("./data")
ARTIFACTS_DIR = Path(f"./artifacts/{EXPERIMENT_NAME}")
CKPT_DIR = Path("./checkpoints")

# Opprett mapper
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
CKPT_DIR.mkdir(parents=True, exist_ok=True)

# Hyperparametere
MAIN_EPOCHS = 300        # Kjører 300 (dekker også behovet for 50)
BATCH_SIZE = 128
LR = 1e-3
WEIGHT_DECAY = 1e-4
SEED = 42
NUM_WORKERS = 2
USE_AMP = True

# Data & Klasser
NUM_CLASSES = 10
CIFAR10_CLASSES = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']

print(f"Konfigurasjon satt for {EXPERIMENT_NAME}.")
print(f"Artefakter lagres til: {ARTIFACTS_DIR}")

# === CELL 3 (Reproducibility) ===
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

set_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Kjører på: {device}")

# === CELL 4 (Data Loading) ===
# Samme transformasjoner som i det gamle repoet (og probe setup)
train_tf = T.Compose([
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

eval_tf = T.Compose([
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Datasets
train_ds_aug = torchvision.datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=train_tf)
test_ds = torchvision.datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=eval_tf)

# Loaders
train_loader = DataLoader(train_ds_aug, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

print(f"Data lastet. Train size: {len(train_ds_aug)}, Test size: {len(test_ds)}")

# === CELL 5 (Model Definition) ===
def make_cifar_resnet18(num_classes=10):
    """
    Standard ResNet18 tilpasset CIFAR-10.
    Identisk arkitektur som brukt i Probe og gamle eksperimenter.
    """
    m = resnet18(weights=None)
    m.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    m.maxpool = nn.Identity()
    m.fc = nn.Linear(m.fc.in_features, num_classes)
    return m

model = make_cifar_resnet18(NUM_CLASSES).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)
criterion = nn.CrossEntropyLoss()

# === CELL 6 (Evaluation Utilities) ===
@torch.no_grad()
def evaluate(model, loader):
    """
    Evaluerer modellen og returnerer:
    - Top-1 Accuracy
    - Top-2 Accuracy
    - Mean Rank (Gjennomsnittlig rangering av sann klasse)
    """
    model.eval()
    total = 0
    correct_top1 = 0
    correct_top2 = 0
    ranks_sum = 0.0
    
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        logits = model(x)
        
        # Sannsynligheter
        probs = F.softmax(logits, dim=1)
        
        # Top-k
        top2 = probs.topk(2, dim=1).indices
        pred_top1 = top2[:, 0]
        
        correct_top1 += (pred_top1 == y).sum().item()
        correct_top2 += ((top2[:, 0] == y) | (top2[:, 1] == y)).sum().item()
        
        # Mean Rank calculation
        # argsort descending gir indeksene sortert etter confidence.
        # Vi finner hvor sann klasse (y) befinner seg i denne sorteringen.
        # .nonzero()[:, 1] gir kolonne-indeksen (rangeringen, 0-indeksert)
        sorted_indices = logits.argsort(dim=1, descending=True)
        ranks = (sorted_indices == y.view(-1, 1)).nonzero()[:, 1] + 1
        ranks_sum += ranks.float().sum().item()
        
        total += x.size(0)
        
    return {
        "acc1": correct_top1 / total,
        "acc2": correct_top2 / total,
        "mean_rank": ranks_sum / total
    }

# === CELL 7 (Training Loop) ===
def train_one_epoch(model, loader, optimizer, scaler):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        
        with torch.autocast(device_type="cuda" if device.type == "cuda" else "cpu", enabled=USE_AMP):
            logits = model(x)
            loss = criterion(logits, y)
            
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item() * x.size(0)
        correct += (logits.argmax(dim=1) == y).sum().item()
        total += x.size(0)
        
    return total_loss / total, correct / total

# === CELL 8 (Run Experiment) ===
print(f"Starter Eksperiment 1: Baseline (Cross-Entropy)")
print(f"Total Epoker: {MAIN_EPOCHS}")

best_acc = 0.0
history = {"train_loss": [], "train_acc": [], "test_acc1": [], "test_acc2": [], "mean_rank": []}
start_time = time.time()

for epoch in range(1, MAIN_EPOCHS + 1):
    t_loss, t_acc = train_one_epoch(model, train_loader, optimizer, scaler)
    val_metrics = evaluate(model, test_loader)
    
    # Lagre historikk
    history["train_loss"].append(t_loss)
    history["train_acc"].append(t_acc)
    history["test_acc1"].append(val_metrics["acc1"])
    history["test_acc2"].append(val_metrics["acc2"])
    history["mean_rank"].append(val_metrics["mean_rank"])
    
    # Lagre beste modell
    if val_metrics["acc1"] > best_acc:
        best_acc = val_metrics["acc1"]
        torch.save(model.state_dict(), CKPT_DIR / f"{EXPERIMENT_NAME}_best.pth")
    
    # Lagre sjekkpunkt ved 50 epoker (for sammenligning)
    if epoch == 50:
        torch.save(model.state_dict(), CKPT_DIR / f"{EXPERIMENT_NAME}_epoch50.pth")
        print(f"   -> Sjekkpunkt lagret ved epoke 50.")

    # Logging
    if epoch % 10 == 0 or epoch == 1:
        dt = time.time() - start_time
        print(f"Ep {epoch:03d} | Loss: {t_loss:.4f} | TrAcc: {t_acc:.3f} | "
              f"TeAcc1: {val_metrics['acc1']:.3f} | Rank: {val_metrics['mean_rank']:.2f} | T: {dt:.0f}s")

total_time = time.time() - start_time
print(f"\nFerdig! Beste Test Acc: {best_acc:.4f}. Tid: {total_time/60:.1f} min.")

# === CELL 9 (Save Results & Artifacts) ===
# 1. Lagre treningshistorikk
results_file = ARTIFACTS_DIR / "results.json"
with open(results_file, "w") as f:
    json.dump({
        "config": {
            "epochs": MAIN_EPOCHS,
            "lr": LR,
            "batch_size": BATCH_SIZE,
            "seed": SEED
        },
        "best_acc": best_acc,
        "history": history
    }, f, indent=2)
print(f"Resultater lagret til {results_file}")

# 2. Generer Target Distribution Heatmap
# Siden dette er Baseline (Cross Entropy), er målet One-Hot.
# Det betyr at 'Airplane' skal være 100% 'Airplane' og 0% alt annet.
# Target Distribution er dermed en Identitetsmatrise.

print("Genererer Heatmap for Target Distribution...")
target_matrix = torch.eye(NUM_CLASSES).numpy()

plt.figure(figsize=(12, 10))
sns.heatmap(
    target_matrix,
    xticklabels=CIFAR10_CLASSES,
    yticklabels=CIFAR10_CLASSES,
    annot=True,
    fmt=".1f", # Vis 1.0 eller 0.0
    cmap="Blues",
    cbar_kws={'label': 'Target Probability (One-Hot)'}
)
plt.title(f"Baseline Targets (One-Hot / Hard Labels)")
plt.xlabel("Target Class")
plt.ylabel("Source Class")
plt.tight_layout()

heatmap_path = ARTIFACTS_DIR / "target_distribution_heatmap.png"
plt.savefig(heatmap_path)
print(f"Heatmap lagret til: {heatmap_path}")
plt.show()