In [None]:
# %% 00_core.ipynb - Unified Framework for CIFAR-10 Experiments
import os
import time
import json
import random
import copy
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T
from torchvision.models import resnet18
import numpy as np

# ==========================================
# 1. GLOBAL KONFIGURASJON & UTILS
# ==========================================
SEED = 42
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CIFAR10_CLASSES = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
NUM_CLASSES = 10

def set_seed(seed=SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

set_seed()

# ==========================================
# 2. MODELL (ResNet18 for CIFAR-10)
# ==========================================
def make_cifar_resnet18(num_classes=10):
    """Lager en ResNet18 tilpasset CIFAR-10 (mindre kernel i start, ingen maxpool)."""
    m = resnet18(weights=None)
    m.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    m.maxpool = nn.Identity()
    m.fc = nn.Linear(m.fc.in_features, num_classes)
    return m

# ==========================================
# 3. SPESIALISERT LOGIKK (SBLS, DYNAMIC ETC.)
# ==========================================

# --- For Exp 2: SBLS ---
def build_similarity_matrix_cifar10():
    """Håndlaget likhetsmatrise for SBLS."""
    C = len(CIFAR10_CLASSES)
    S = torch.zeros(C, C)
    animals = {2,3,4,5,6,7}
    vehicles = {0,1,8,9}
    for i in range(C):
        for j in range(C):
            if i == j: 
                S[i,j] = 0.0
            elif (i in animals and j in animals) or (i in vehicles and j in vehicles):
                S[i,j] = 1.0
            else:
                S[i,j] = 0.0
    row_sums = S.sum(dim=1, keepdim=True).clamp_min(1e-8)
    return S / row_sums

def make_soft_targets_sbls(y, S, alpha=0.2, num_classes=10):
    """Lager soft targets on-the-fly basert på likhetsmatrise S."""
    B = y.shape[0]
    T = torch.zeros(B, num_classes, device=y.device)
    T.scatter_(1, y.view(-1,1), 1.0)
    if alpha <= 0: return T
    T = T * (1.0 - alpha)
    rows = S[y]
    row_sums = rows.sum(dim=1, keepdim=True)
    fallback = (row_sums < 1e-7).float() # Hvis ingen naboer, uniform smoothing
    if fallback.any():
        uniform = torch.ones_like(rows) / (num_classes - 1)
        uniform.scatter_(1, y.view(-1,1), 0.0)
        rows = torch.where(fallback.bool(), uniform, rows)
    T = T + alpha * rows
    return T

# --- For Exp 4/5: Dynamisk Target Update ---
@torch.no_grad()
def update_target_logits_dynamic(T, acc, T0=None, beta=10.0, lr=0.10, kappa=0.10, 
                                 min_diff=0.05, anchor_gamma=0.0, clamp_M=10.0):
    """
    Oppdaterer mål-matrisen T basert på forskjell i presisjon (acc).
    Novise (i) lærer av Ekspert (j) hvis acc[j] > acc[i].
    """
    C = T.shape[0]
    T_new = T.clone()
    updates_made = 0
    
    for i in range(C): # Novise
        for j in range(C): # Ekspert
            if i == j: continue
            
            diff = acc[j] - acc[i]
            if diff <= min_diff: continue
            
            # Autoritet sigmoid
            p = torch.sigmoid(torch.tensor(beta * diff, device=T.device))
            
            expert_view = T[j, i]
            novice_view = T[i, j]
            
            # Trekk novisen mot eksperten
            change = lr * p * (expert_view - novice_view)
            T_new[i, j] += change
            
            # Pushback (frastøting) på eksperten
            if kappa > 0:
                T_new[j, i] -= kappa * change
            
            updates_made += 1
            
    # Anker mot start-targets (T0)
    if T0 is not None and anchor_gamma > 0:
        T_new = (1.0 - anchor_gamma) * T_new + anchor_gamma * T0
        
    # Clamp for stabilitet
    if clamp_M > 0:
        T_new = T_new.clamp(min=-clamp_M, max=clamp_M)
        
    return T_new, updates_made

@torch.no_grad()
def enforce_target_dominance(logits_matrix):
    """Exp 5b: Swap hvis en annen klasse har høyere verdi enn 'sann' klasse."""
    C = logits_matrix.shape[0]
    swaps_count = 0
    for i in range(C):
        row = logits_matrix[i]
        val_max, idx_max = torch.max(row, dim=0)
        if idx_max != i:
            val_self = row[i].clone()
            logits_matrix[i, i] = val_max
            logits_matrix[i, idx_max] = val_self
            swaps_count += 1
    return logits_matrix, swaps_count

@torch.no_grad()
def apply_probability_cap(logits_matrix, max_p=0.80):
    """Exp 5: Tving maks sannsynlighet for diagonalen, redistribuer resten."""
    C = logits_matrix.shape[0]
    probs = torch.softmax(logits_matrix, dim=1)
    
    for i in range(C):
        p_self = probs[i, i]
        if p_self > max_p:
            probs[i, i] = max_p
            others_mask = torch.ones(C, dtype=torch.bool, device=logits_matrix.device)
            others_mask[i] = False
            
            current_others_sum = probs[i, others_mask].sum()
            if current_others_sum > 1e-9:
                scale = (1.0 - max_p) / current_others_sum
                probs[i, others_mask] *= scale
            else:
                probs[i, others_mask] = (1.0 - max_p) / (C - 1)
                
    new_logits = torch.log(probs + 1e-9)
    # Bevar magnitude range for stabilitet
    old_max = logits_matrix.max(dim=1, keepdim=True).values
    new_max = new_logits.max(dim=1, keepdim=True).values
    final_logits = new_logits - new_max + old_max
    return final_logits

@torch.no_grad()
def weights_from_acc(acc, gamma=1.0, eps=1e-6):
    """Exp 3b/4/5: Vekter klasser basert på presisjon. Høyere acc = høyere vekt."""
    w = (acc.clamp_min(0.0) + eps).pow(gamma)
    w = w / w.mean().clamp_min(1e-8)
    return w

# ==========================================
# 4. LOSS FUNCTIONS
# ==========================================
class SoftTargetKLLoss(nn.Module):
    """Exp 3a: KL Divergence mot statiske sannsynlighetsmål."""
    def __init__(self, targets_probs, ce_anchor=0.0):
        super().__init__()
        self.register_buffer("T", targets_probs)
        self.ce_anchor = ce_anchor
        self.kl = nn.KLDivLoss(reduction="batchmean")
        self.ce = nn.CrossEntropyLoss()

    def forward(self, logits, y):
        # targets: (B, C)
        targets = self.T[y] 
        log_probs = F.log_softmax(logits, dim=1)
        loss = self.kl(log_probs, targets)
        if self.ce_anchor > 0:
            loss += self.ce_anchor * self.ce(logits, y)
        return loss

class LogitTargetWMSENLoss(nn.Module):
    """Exp 3b/4/5: Weighted MSE mot logit-mål."""
    def __init__(self, ce_anchor=0.0, logit_l2=0.0, true_class_boost=1.0):
        super().__init__()
        self.ce_anchor = ce_anchor
        self.logit_l2 = logit_l2
        self.true_class_boost = true_class_boost
        self.ce = nn.CrossEntropyLoss()

    def forward(self, logits, y, current_targets, current_weights):
        # current_targets: (C, C)
        # current_weights: (C)
        
        batch_targets = current_targets[y] # (B, C)
        diff = (logits - batch_targets) ** 2
        
        B, C = logits.shape
        w_batch = current_weights.view(1, -1).expand(B, -1).clone()
        
        # Boost true class weight
        if self.true_class_boost != 1.0:
            idx = torch.arange(B, device=logits.device)
            w_batch[idx, y] *= self.true_class_boost
            
        weighted_mse = (diff * w_batch).mean()
        loss = weighted_mse
        
        if self.ce_anchor > 0:
            loss += self.ce_anchor * self.ce(logits, y)
        if self.logit_l2 > 0:
            loss += self.logit_l2 * (logits ** 2).mean()
            
        return loss

# ==========================================
# 5. UNIFIED TRAINER CLASS
# ==========================================
class UnifiedTrainer:
    def __init__(self, config, run_name):
        self.cfg = config
        self.run_name = run_name
        
        # Paths
        self.artifact_dir = Path(f"./artifacts/{run_name}")
        self.ckpt_dir = Path("./checkpoints")
        self.artifact_dir.mkdir(parents=True, exist_ok=True)
        self.ckpt_dir.mkdir(parents=True, exist_ok=True)
        
        # Setup Data
        self.setup_data()
        
        # Setup Model
        self.model = make_cifar_resnet18(NUM_CLASSES).to(DEVICE)
        
        # Optimizer & Scaler
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.cfg['lr'], weight_decay=1e-4)
        self.scaler = torch.cuda.amp.GradScaler(enabled=self.cfg['use_amp'])
        
        # Loss Init
        self.ce_loss = nn.CrossEntropyLoss()
        
        # Experiment Specific State
        self.sbls_matrix = None
        self.targets_matrix = None # Kan være probs (3a) eller logits (3b/4/5)
        self.targets_matrix_0 = None # Anker for dynamisk
        self.weights = torch.ones(NUM_CLASSES, device=DEVICE) # For WMSE
        
        # History
        self.history = {"loss": [], "train_acc": [], "test_acc": [], "epoch_times": []}
        
        self._init_experiment_logic()

    def setup_data(self):
        train_tf = T.Compose([
            T.RandomCrop(32, padding=4), T.RandomHorizontalFlip(), T.ToTensor(),
            T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
        eval_tf = T.Compose([
            T.ToTensor(), T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
        
        # Dataset
        train_ds_aug = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=train_tf)
        train_ds_eval = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=eval_tf)
        test_ds = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=eval_tf)
        
        self.train_loader = DataLoader(train_ds_aug, batch_size=self.cfg['batch_size'], shuffle=True, num_workers=2, pin_memory=True)
        self.train_eval_loader = DataLoader(train_ds_eval, batch_size=self.cfg['batch_size'], shuffle=False, num_workers=2, pin_memory=True)
        self.test_loader = DataLoader(test_ds, batch_size=self.cfg['batch_size'], shuffle=False, num_workers=2, pin_memory=True)

    def _init_experiment_logic(self):
        """Laster targets eller initierer matriser basert på mode."""
        mode = self.cfg['mode']
        
        if mode == 'sbls':
            self.sbls_matrix = build_similarity_matrix_cifar10().to(DEVICE)
            
        elif mode == 'kl': # Exp 3a
            if 'targets_path' in self.cfg:
                d = torch.load(self.cfg['targets_path'], map_location=DEVICE)
                self.targets_matrix = d['class_avg_probs']
                self.criterion_kl = SoftTargetKLLoss(self.targets_matrix, ce_anchor=self.cfg.get('ce_anchor', 0))
            else:
                raise ValueError("Mode 'kl' requires 'targets_path'")
                
        elif mode in ['wmse_static', 'wmse_dynamic']: # Exp 3b, 4, 5
            if 'targets_path' in self.cfg:
                d = torch.load(self.cfg['targets_path'], map_location=DEVICE)
                self.targets_matrix = d['class_avg_logits']
                
                # Hvis vi skal bruke cap/swap initialt (Exp 5 starttilstand)
                if self.cfg.get('swap_enabled', False):
                    self.targets_matrix, _ = enforce_target_dominance(self.targets_matrix)
                if self.cfg.get('cap_enabled', False):
                    self.targets_matrix = apply_probability_cap(self.targets_matrix, max_p=self.cfg['max_prob'])
                
                self.targets_matrix_0 = self.targets_matrix.clone() # Anker
                
                self.criterion_wmse = LogitTargetWMSENLoss(
                    ce_anchor=self.cfg.get('ce_anchor', 0.05),
                    logit_l2=self.cfg.get('logit_l2', 0.0),
                    true_class_boost=self.cfg.get('true_class_boost', 1.0)
                )
            else:
                raise ValueError("WMSE modes require 'targets_path'")

    def train_epoch(self, epoch):
        self.model.train()
        total_loss, correct, total = 0.0, 0, 0
        
        for x, y in self.train_loader:
            x, y = x.to(DEVICE, non_blocking=True), y.to(DEVICE, non_blocking=True)
            self.optimizer.zero_grad(set_to_none=True)
            
            with torch.cuda.amp.autocast(enabled=self.cfg['use_amp']):
                logits = self.model(x)
                
                # --- VELG LOSS BASERT PÅ MODE ---
                if self.cfg['mode'] == 'baseline':
                    loss = self.ce_loss(logits, y)
                    
                elif self.cfg['mode'] == 'sbls':
                    soft_T = make_soft_targets_sbls(y, self.sbls_matrix, alpha=self.cfg['sbls_alpha'])
                    log_probs = F.log_softmax(logits, dim=1)
                    loss = -(soft_T * log_probs).sum(dim=1).mean()
                    
                elif self.cfg['mode'] == 'kl':
                    loss = self.criterion_kl(logits, y)
                    
                elif self.cfg['mode'] in ['wmse_static', 'wmse_dynamic']:
                    loss = self.criterion_wmse(logits, y, self.targets_matrix, self.weights)
            
            self.scaler.scale(loss).backward()
            
            if self.cfg.get('grad_clip', 0) > 0:
                self.scaler.unscale_(self.optimizer)
                nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg['grad_clip'])
                
            self.scaler.step(self.optimizer)
            self.scaler.update()
            
            total_loss += loss.item() * x.size(0)
            correct += (logits.argmax(1) == y).sum().item()
            total += x.size(0)
            
        return total_loss / total, correct / total

    @torch.no_grad()
    def evaluate(self, loader):
        self.model.eval()
        correct, total = 0, 0
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            logits = self.model(x)
            correct += (logits.argmax(1) == y).sum().item()
            total += x.size(0)
        return correct / total

    @torch.no_grad()
    def get_per_class_accuracy(self):
        self.model.eval()
        correct = torch.zeros(NUM_CLASSES, device=DEVICE)
        counts = torch.zeros(NUM_CLASSES, device=DEVICE)
        for x, y in self.train_eval_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            pred = self.model(x).argmax(1)
            for c in range(NUM_CLASSES):
                mask = (y == c)
                if mask.any():
                    counts[c] += mask.sum()
                    correct[c] += (pred[mask] == c).sum()
        return correct / counts.clamp_min(1)

    def run(self):
        print(f"Starter trening: {self.run_name} ({self.cfg['epochs']} epoker)")
        print(f"Mode: {self.cfg['mode']}")
        
        best_acc = 0.0
        
        for ep in range(1, self.cfg['epochs'] + 1):
            t0 = time.time()
            
            # --- PRE-EPOCH DYNAMICS (Exp 4/5) ---
            if self.cfg['mode'] == 'wmse_dynamic':
                if ep == 1:
                    current_acc = torch.ones(NUM_CLASSES, device=DEVICE) * 0.5
                else:
                    current_acc = self.get_per_class_accuracy()
                    
                    # 1. Update based on acc
                    self.targets_matrix, changes = update_target_logits_dynamic(
                        self.targets_matrix, current_acc, 
                        T0=self.targets_matrix_0,
                        beta=self.cfg.get('auth_beta', 10.0),
                        lr=self.cfg.get('target_update_lr', 0.1),
                        kappa=self.cfg.get('kappa', 0.1),
                        anchor_gamma=self.cfg.get('anchor_gamma', 0.0),
                        clamp_M=self.cfg.get('clamp_m', 10.0)
                    )
                    
                    # 2. Swap (Optional)
                    if self.cfg.get('swap_enabled', False):
                        self.targets_matrix, swaps = enforce_target_dominance(self.targets_matrix)
                        
                    # 3. Cap (Optional)
                    if self.cfg.get('cap_enabled', False):
                        self.targets_matrix = apply_probability_cap(self.targets_matrix, max_p=self.cfg['max_prob'])
                
                # Oppdater vekter for WMSE
                self.weights = weights_from_acc(current_acc, gamma=self.cfg.get('weight_gamma', 1.0))

            # --- TRAIN ---
            tr_loss, tr_acc = self.train_epoch(ep)
            
            # --- EVAL ---
            te_acc = self.evaluate(self.test_loader)
            dt = time.time() - t0
            
            self.history["loss"].append(tr_loss)
            self.history["train_acc"].append(tr_acc)
            self.history["test_acc"].append(te_acc)
            self.history["epoch_times"].append(dt)
            
            print(f"Ep {ep:03d} | Loss: {tr_loss:.4f} | TrAcc: {tr_acc:.3f} | TeAcc: {te_acc:.3f} | T: {dt:.1f}s")
            
            # Lagre best
            if te_acc > best_acc:
                best_acc = te_acc
                torch.save(self.model.state_dict(), self.ckpt_dir / f"{self.run_name}_best.pth")
                
        # Lagre sluttresultater
        self._save_results()
        print(f"Ferdig! Best Test Acc: {best_acc:.4f}")
        return best_acc

    def _save_results(self):
        # Lagre history json
        with open(self.artifact_dir / "history.json", "w") as f:
            json.dump(self.history, f, indent=2)
            
        # Lagre targets matrise hvis relevant
        if self.targets_matrix is not None:
            torch.save(self.targets_matrix, self.artifact_dir / "final_targets.pt")
            # Lag også en heatmap hvis mulig (lagres som pt, plotting skjer i utils)