In [None]:
# === CELL 1 (Complete Experiment Code) ===
import os
import json
import random
import time
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T
from torchvision.models import resnet18
import matplotlib.pyplot as plt
import seaborn as sns

# ==========================================
# 1. KONFIGURASJON & OPPSETT
# ==========================================
EXPERIMENT_NAME = "03_exp2_sbls"
DATA_DIR = Path("./data")
BASE_ARTIFACTS_DIR = Path(f"./artifacts/{EXPERIMENT_NAME}")
CKPT_DIR = Path("./checkpoints")

# Opprett mapper
BASE_ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
CKPT_DIR.mkdir(parents=True, exist_ok=True)

# Hyperparametere
RUNS = [50, 300]        # Vi kjører både kort (50) og lang (300) trening
BATCH_SIZE = 128
LR = 1e-3
WEIGHT_DECAY = 1e-4
SEED = 42
NUM_WORKERS = 2
USE_AMP = True

# SBLS Spesifikke parametere
ALPHA = 0.2             # Smoothing mass (Hvor mye sannsynlighet vi tar fra true class)
NUM_CLASSES = 10
CIFAR10_CLASSES = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']

print(f"Konfigurasjon satt for {EXPERIMENT_NAME}.")
print(f"Artefakter lagres til: {BASE_ARTIFACTS_DIR}")

# ==========================================
# 2. REPRODUSERBARHET
# ==========================================
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

set_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Kjører på: {device}")

# ==========================================
# 3. DATA LOADING
# ==========================================
train_tf = T.Compose([
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

eval_tf = T.Compose([
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

train_ds_aug = torchvision.datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=train_tf)
test_ds = torchvision.datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=eval_tf)

train_loader = DataLoader(train_ds_aug, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

# ==========================================
# 4. MODELL DEFINISJON
# ==========================================
def make_cifar_resnet18(num_classes=10):
    m = resnet18(weights=None)
    m.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    m.maxpool = nn.Identity()
    m.fc = nn.Linear(m.fc.in_features, num_classes)
    return m

# ==========================================
# 5. SBLS LOGIKK (Likhetsmatrise & Targets)
# ==========================================
def build_similarity_matrix_cifar10():
    """
    Bygger en manuell likhetsmatrise basert på superklasser i CIFAR-10.
    Animals: bird, cat, deer, dog, frog, horse (indices 2,3,4,5,6,7)
    Vehicles: airplane, automobile, ship, truck (indices 0,1,8,9)
    """
    C = len(CIFAR10_CLASSES)
    S = torch.zeros(C, C)
    
    animals = {2, 3, 4, 5, 6, 7}
    vehicles = {0, 1, 8, 9}
    
    for i in range(C):
        for j in range(C):
            if i == j:
                S[i, j] = 0.0 # Ingen selv-likhet her (håndteres av alpha i target-generering)
            else:
                if (i in animals and j in animals) or (i in vehicles and j in vehicles):
                    S[i, j] = 1.0 # "Lik"
                else:
                    S[i, j] = 0.0 # "Ulik"
                    
    # Normaliser rader slik at summen av naboer blir 1.0
    row_sums = S.sum(dim=1, keepdim=True).clamp_min(1e-8)
    S = S / row_sums
    return S

# Opprett matrisen og legg på device
S_MAT = build_similarity_matrix_cifar10().to(device)

def make_soft_targets(y, S, alpha=0.2, num_classes=10):
    """
    Genererer soft targets on-the-fly for en batch.
    y: (B,) labels
    S: (C, C) similarity matrix
    alpha: Hvor mye masse som skal fordeles til naboer.
    """
    B = y.shape[0]
    # Start med standard one-hot
    T = torch.zeros(B, num_classes, device=y.device)
    T.scatter_(1, y.view(-1, 1), 1.0)
    
    if alpha <= 0:
        return T
    
    # Reduser sannsynligheten for sann klasse
    T = T * (1.0 - alpha)
    
    # Hent nabo-fordeling fra S-matrisen
    # S[y] velger radene som tilsvarer labelene i batchen -> (B, C)
    neighbor_dist = S[y]
    
    # Sjekk om noen klasser ikke har naboer (row sums ~ 0) -> fallback til uniform distribution
    row_sums = neighbor_dist.sum(dim=1, keepdim=True)
    fallback = (row_sums < 1e-7).float()
    
    if fallback.any():
        # Fordel alpha uniformt over alle ANDRE klasser enn seg selv
        uniform = torch.ones_like(neighbor_dist) / (num_classes - 1)
        uniform.scatter_(1, y.view(-1, 1), 0.0)
        neighbor_dist = torch.where(fallback.bool(), uniform, neighbor_dist)
        
    # Legg til alpha-massen fordelt på naboer
    T = T + alpha * neighbor_dist
    return T

def soft_ce_loss(logits, soft_targets):
    """Cross Entropy med soft targets (probabilities)."""
    log_probs = F.log_softmax(logits, dim=1)
    return -(soft_targets * log_probs).sum(dim=1).mean()

# ==========================================
# 6. VISUALISERING AV TARGETS
# ==========================================
# Vi visualiserer hvordan target-distribusjonen ser ut for hver klasse
# Dette tilsvarer "Target Distribution Heatmap" kravet.
print("Genererer Heatmap for SBLS Targets...")

# Lag en dummy batch med en av hver klasse for å visualisere matrisen
dummy_y = torch.arange(NUM_CLASSES, device=device)
dummy_targets = make_soft_targets(dummy_y, S_MAT, alpha=ALPHA, num_classes=NUM_CLASSES)

plt.figure(figsize=(12, 10))
sns.heatmap(
    dummy_targets.cpu().numpy(),
    xticklabels=CIFAR10_CLASSES,
    yticklabels=CIFAR10_CLASSES,
    annot=True,
    fmt=".2f",
    cmap="viridis",
    cbar_kws={'label': f'Target Probability (Alpha={ALPHA})'}
)
plt.title(f"Experiment 2: SBLS Target Distribution (Manual Groups)")
plt.xlabel("Target Class")
plt.ylabel("Source Class")
plt.tight_layout()

heatmap_path = BASE_ARTIFACTS_DIR / "target_distribution_heatmap.png"
plt.savefig(heatmap_path)
print(f"Heatmap lagret til: {heatmap_path}")
plt.show()

# Lagre selve matrisen også
torch.save(dummy_targets.cpu(), BASE_ARTIFACTS_DIR / "sbls_target_matrix.pt")

# ==========================================
# 7. TRENINGS- OG EVALUERINGSFUNKSJONER
# ==========================================
def train_one_epoch(model, loader, optimizer, scaler):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        
        with torch.autocast(device_type="cuda" if device.type == "cuda" else "cpu", enabled=USE_AMP):
            logits = model(x)
            # Generer targets dynamisk basert på SBLS logikk
            soft_targets = make_soft_targets(y, S_MAT, alpha=ALPHA, num_classes=NUM_CLASSES)
            loss = soft_ce_loss(logits, soft_targets)
            
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item() * x.size(0)
        correct += (logits.argmax(dim=1) == y).sum().item()
        total += x.size(0)
        
    return total_loss / total, correct / total

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    correct_top1 = 0
    correct_top2 = 0
    ranks_sum = 0.0
    total = 0
    
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        logits = model(x)
        probs = F.softmax(logits, dim=1)
        
        # Top-k metrics
        top2 = probs.topk(2, dim=1).indices
        pred_top1 = top2[:, 0]
        correct_top1 += (pred_top1 == y).sum().item()
        correct_top2 += ((top2[:, 0] == y) | (top2[:, 1] == y)).sum().item()
        
        # Mean Rank
        sorted_indices = logits.argsort(dim=1, descending=True)
        ranks = (sorted_indices == y.view(-1, 1)).nonzero()[:, 1] + 1
        ranks_sum += ranks.float().sum().item()
        
        total += x.size(0)
        
    return {
        "acc1": correct_top1 / total,
        "acc2": correct_top2 / total,
        "mean_rank": ranks_sum / total
    }

# ==========================================
# 8. HOVEDLØKKE (RUNNER)
# ==========================================
for max_epochs in RUNS:
    print(f"\n{'='*40}")
    print(f"  STARTER KJØRING: {max_epochs} EPOKER")
    print(f"{'='*40}")
    
    # Initier for denne kjøringen
    run_dir = BASE_ARTIFACTS_DIR / f"run_{max_epochs}ep"
    run_dir.mkdir(exist_ok=True)
    
    model = make_cifar_resnet18(NUM_CLASSES).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)
    
    best_acc = 0.0
    history = {"train_loss": [], "train_acc": [], "test_acc1": [], "test_acc2": [], "mean_rank": []}
    start_time = time.time()
    
    for epoch in range(1, max_epochs + 1):
        t_loss, t_acc = train_one_epoch(model, train_loader, optimizer, scaler)
        metrics = evaluate(model, test_loader)
        
        history["train_loss"].append(t_loss)
        history["train_acc"].append(t_acc)
        history["test_acc1"].append(metrics["acc1"])
        history["test_acc2"].append(metrics["acc2"])
        history["mean_rank"].append(metrics["mean_rank"])
        
        if metrics["acc1"] > best_acc:
            best_acc = metrics["acc1"]
            torch.save(model.state_dict(), CKPT_DIR / f"{EXPERIMENT_NAME}_{max_epochs}ep_best.pth")
            
        if epoch % 10 == 0 or epoch == 1:
            elapsed = time.time() - start_time
            print(f"Ep {epoch:03d} | Loss: {t_loss:.4f} | TrAcc: {t_acc:.3f} | "
                  f"TeAcc1: {metrics['acc1']:.3f} | Rank: {metrics['mean_rank']:.2f} | T: {elapsed:.0f}s")
            
    # Lagre resultater for denne kjøringen
    res_file = run_dir / "results.json"
    with open(res_file, "w") as f:
        json.dump({
            "config": {
                "epochs": max_epochs,
                "alpha": ALPHA,
                "lr": LR
            },
            "best_acc": best_acc,
            "history": history
        }, f, indent=2)
        
    # Lagre plott av trening
    plt.figure(figsize=(10, 5))
    plt.plot(history["train_loss"], label="Train Loss")
    plt.title(f"Training Loss ({max_epochs} epochs) - SBLS")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True)
    plt.savefig(run_dir / "loss_curve.png")
    plt.close()
    
    print(f"Ferdig med {max_epochs} epoker. Beste Acc: {best_acc:.4f}. Data lagret i {run_dir}")

print(f"\nAlle kjøringer fullført for {EXPERIMENT_NAME}!")