In [None]:
!pip install -q torch torchvision scikit-learn tqdm pandas matplotlib seaborn


In [1]:
# Single-cell: strict dataset download + full-dataset update evaluation on Kaggle
# Paste this whole cell into a Kaggle notebook and run.


import os, time, math, random, gc, sys
import numpy as np, pandas as pd, matplotlib.pyplot as plt, seaborn as sns
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms, models
from sklearn.metrics import f1_score
from tqdm import tqdm

print("Torch", torch.__version__)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE, torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")


Torch 2.6.0+cu124
Device: cuda Tesla P100-PCIE-16GB


In [13]:

# ----------------------
# Strict download helper ‚Äî HALTS on failure
# ----------------------
def try_download_torchvision(name, root='./data', retries=5, delay=2):
    root = os.path.abspath(root)
    os.makedirs(root, exist_ok=True)
    for attempt in range(1, retries+1):
        try:
            if name == 'mnist':
                datasets.MNIST(root, download=True, train=True, transform=transforms.ToTensor())
                datasets.MNIST(root, download=True, train=False, transform=transforms.ToTensor())
            elif name == 'fashionmnist':
                datasets.FashionMNIST(root, download=True, train=True, transform=transforms.ToTensor())
                datasets.FashionMNIST(root, download=True, train=False, transform=transforms.ToTensor())
            elif name == 'svhn':
                datasets.SVHN(root, download=True, split='train', transform=transforms.ToTensor())
                datasets.SVHN(root, download=True, split='test', transform=transforms.ToTensor())
            elif name == 'cifar10':
                datasets.CIFAR10(root, download=True, train=True, transform=transforms.ToTensor())
                datasets.CIFAR10(root, download=True, train=False, transform=transforms.ToTensor())
            elif name == 'cifar100':
                datasets.CIFAR100(root, download=True, train=True, transform=transforms.ToTensor())
                datasets.CIFAR100(root, download=True, train=False, transform=transforms.ToTensor())
            else:
                raise ValueError(f"Unknown dataset: {name}")
            print(f"[‚úÖ OK] {name} downloaded and saved under {root}")
            return True
        except Exception as e:
            print(f"[‚ö†Ô∏è attempt {attempt}/{retries}] {name} failed: {e}")
            if attempt < retries:
                time.sleep(delay)
    print(f"[‚ùå FAILED] Could not download '{name}' after {retries} attempts.")
    print("üõë Halting execution. Kaggle may block external hosts (e.g., SVHN).")
    sys.exit(1)

# ----------------------
# Loss functions (unchanged)
# ----------------------
class FocalLoss(nn.Module):
    def __init__(self, gamma=1.0, eps=1e-7):
        super().__init__()
        self.gamma = gamma; self.eps = eps
    def forward(self, logits, y):
        p = F.softmax(logits, dim=1)
        pt = p.gather(1, y[:,None]).squeeze().clamp(min=self.eps)
        return (-((1-pt)**self.gamma) * torch.log(pt)).mean()

def effective_num_weights(counts, beta):
    counts = np.array(counts, dtype=np.float64)
    eff = (1.0 - np.power(beta, counts)) / (1.0 - beta + 1e-12)
    w = 1.0 / (eff + 1e-12)
    w = w / w.sum() * len(w)
    return w.astype(np.float32)

class CB_Focal(nn.Module):
    def __init__(self, counts, beta=0.999, gamma=1.0, eps=1e-7):
        super().__init__()
        self.gamma = gamma; self.eps = eps
        self.register_buffer('weights', torch.tensor(effective_num_weights(counts, beta), dtype=torch.float32))
    def forward(self, logits, y):
        p = F.softmax(logits, dim=1)
        pt = p.gather(1, y[:,None]).squeeze().clamp(min=self.eps)
        w = self.weights[y].to(logits.device)
        return (- w * ((1-pt)**self.gamma) * torch.log(pt)).mean()

class CDG_Focal(nn.Module):
    def __init__(self, counts, tau=1.0, k=0.0, gamma_min=0.75, gamma_max=2.5, eps=1e-7):
        super().__init__()
        self.eps = eps
        counts = torch.tensor(counts, dtype=torch.float32)
        p = (counts / counts.sum()).clamp(min=1e-12)

        # -----------------------------------------------------------------
        # ‚úÖ **THE FIX**: Convert the float (1.0 / tau) to a tensor 
        # *before* passing it to torch.log.
        log_val = torch.log(torch.tensor(1.0 / tau, dtype=torch.float32))
        # -----------------------------------------------------------------

        raw = torch.where(p > tau,
                          torch.log(1.0 / p),
                          log_val + k * (p - tau)) # Use the new tensor variable
                          
        gamma = raw.clamp(min=gamma_min, max=gamma_max)
        self.register_buffer('gamma_per_class', gamma)
        
    @staticmethod
    def cosine_warmup_weight(epoch, Ew):
        if Ew <= 0: return 1.0
        e = float(epoch)
        if e <= 0.0: return 0.0
        if e >= Ew: return 1.0
        return 0.5 * (1.0 - math.cos(math.pi * e / Ew))
        
    def forward(self, logits, y, epoch=None, Ew=5):
        p = F.softmax(logits, dim=1)
        pt = p.gather(1, y[:,None]).squeeze().clamp(min=self.eps)
        
        # Ensure gamma is on the same device as logits
        g = self.gamma_per_class.to(logits.device)[y] 
        
        if epoch is not None:
            w = self.cosine_warmup_weight(epoch, Ew)
            g = g * w
        return ( - ((1-pt)**g) * torch.log(pt) ).mean()

# ----------------------
# Long-tail helper
# ----------------------
def make_lt_indices(targets, imb_factor, seed=0):
    np.random.seed(seed)
    targets = np.array(targets)
    C = int(targets.max()) + 1
    cls_counts = np.bincount(targets, minlength=C)
    N_max = cls_counts.max()
    r = 1.0 / float(imb_factor)
    cls_num = [int(max(1, round(N_max * (r ** (i / (C - 1.0)))))) for i in range(C)]
    indices = []
    for c in range(C):
        idxs = np.where(targets == c)[0]
        chosen = np.random.choice(idxs, cls_num[c], replace=False)
        indices.extend(chosen.tolist())
    random.shuffle(indices)
    return indices, cls_num

# ----------------------
# Dataset loader ‚Äî STRICT, REAL, 3-CHANNEL
# ----------------------
def prepare_dataset(name, root='./data', imb_factor=1, seed=0):
    name_l = name.lower()
    try_download_torchvision(name_l, root=root)

    if name_l == 'mnist':
        tr = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.repeat(3, 1, 1))
        ])
        train = datasets.MNIST(root, train=True, download=False, transform=tr)
        test = datasets.MNIST(root, train=False, download=False, transform=tr)
        C = 10

    elif name_l == 'fashionmnist':
        tr = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.repeat(3, 1, 1))
        ])
        train = datasets.FashionMNIST(root, train=True, download=False, transform=tr)
        test = datasets.FashionMNIST(root, train=False, download=False, transform=tr)
        C = 10

    elif name_l == 'svhn':
        tr = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x if x.size(0) == 3 else x.repeat(3, 1, 1))
        ])
        train = datasets.SVHN(root, split='train', download=False, transform=tr)
        test = datasets.SVHN(root, split='test', download=False, transform=tr)
        C = 10

    elif name_l == 'cifar10':
        train_tr = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ])
        test_tr = transforms.ToTensor()
        base_train = datasets.CIFAR10(root, train=True, download=False, transform=train_tr)
        test = datasets.CIFAR10(root, train=False, download=False, transform=test_tr)
        C = 10
        # ‚úÖ Critical fix: assign train = base_train for balanced case
        train = base_train

    elif name_l == 'cifar100':
        train_tr = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ])
        test_tr = transforms.ToTensor()
        base_train = datasets.CIFAR100(root, train=True, download=False, transform=train_tr)
        test = datasets.CIFAR100(root, train=False, download=False, transform=test_tr)
        C = 100
        # ‚úÖ Critical fix: assign train = base_train for balanced case
        train = base_train

    else:
        raise ValueError(f"Unsupported dataset: {name}")

    # Handle imbalance (CIFAR only)
    if name_l in ('cifar10', 'cifar100') and imb_factor > 1:
        targets = np.array(base_train.targets)
        indices, cls_counts = make_lt_indices(targets, imb_factor, seed)
        train = Subset(base_train, indices)
        counts = cls_counts
    else:
        # Extract labels
        if hasattr(train, 'targets'):
            targets = np.array(train.targets)
        elif hasattr(train, 'labels'):  # SVHN
            targets = np.array(train.labels)
        else:
            targets = np.array([train[i][1] for i in range(len(train))])
        counts = np.bincount(targets, minlength=C).tolist()

    print(f"  ‚Üí Train: {len(train)}, Test: {len(test)}, Classes: {C}, Counts (first 10): {counts[:min(10, len(counts))]}")
    return train, test, counts

# ----------------------
# Exact full-dataset update via gradient accumulation
# ----------------------
def train_batchwise_emulate_full(model, loss_fn, loader, opt, epoch=None, Ew=5, loss_tag=None):
    model.train()
    opt.zero_grad(set_to_none=True)
    total_samples = 0
    
    # Use tqdm for progress, as this loop is now the main time sink
    for xb, yb in tqdm(loader, desc=f"Epoch {epoch+1} Train", leave=False):
        xb = xb.to(DEVICE, non_blocking=True)
        yb = yb.to(DEVICE, non_blocking=True)
        
        logits = model(xb)
        
        if loss_tag == 'cdg':
            loss = loss_fn(logits, yb, epoch=epoch, Ew=Ew)
        else:
            loss = loss_fn(logits, yb)
            
        bs = xb.size(0)
        total_samples += bs
        
        (loss * bs).backward()  # accumulate scaled grads
        
        # -----------------------------------------------------------------
        # ‚úÖ **THE FIX**: Explicitly free memory after each backward pass
        # This is crucial to prevent the computation graph from
        # holding onto memory across mini-batches.
        del xb, yb, logits, loss
        torch.cuda.empty_cache()
        # -----------------------------------------------------------------

    # Normalize to full-batch average
    for p in model.parameters():
        if p.grad is not None:
            p.grad.div_(total_samples)
            
    opt.step()

def evaluate_model(model, loader):
    model.eval()
    preds, tg = [], []
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(DEVICE, non_blocking=True)
            out = model(xb)
            preds.extend(out.argmax(1).cpu().numpy())
            tg.extend(yb.numpy())
    preds, tg = np.array(preds), np.array(tg)
    acc = (preds == tg).mean()
    macro_f1 = f1_score(tg, preds, average='macro', zero_division=0)
    return acc, macro_f1

# ----------------------
# Model builder
# ----------------------
def get_model(num_classes):
    model = models.resnet18(weights=None)  # avoids deprecation warning
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model



In [14]:
# ----------------------
# Experiment config
# ----------------------
OUT = "/kaggle/working/loss_eval_results"
os.makedirs(OUT, exist_ok=True)

DATASETS = [
    ('cifar10', 1),
    ('cifar10', 100),
    ('cifar100', 1),
    ('cifar100', 100)
]

LOSSES = {
    'CE': lambda counts: nn.CrossEntropyLoss(),
    'Focal_g1': lambda counts: FocalLoss(gamma=1.0),
    'CBF_b0.999_g1': lambda counts: CB_Focal(counts, beta=0.999, gamma=1.0),
    'CDG': lambda counts: CDG_Focal(counts, tau=1.0, k=0.0, gamma_min=0.75, gamma_max=2.5)
}

EPOCHS = 1
LR = 0.1
TRAIN_BATCH = 256
# larger = faster, still exact
SEED = 0
Ew = 5

random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

summary_rows = []
print("üöÄ Starting FULL-DATASET exact update evaluation (batchwise emulation)...")
print(f"Datasets: {[f'{d} (IF={i})' for d,i in DATASETS]}")

for ds_name, IF in DATASETS:
    print("\n" + "="*65)
    print(f"üìÅ Dataset: {ds_name} | Imbalance Factor: {IF}")
    
    train_ds, test_ds, counts = prepare_dataset(ds_name, root='./data', imb_factor=IF, seed=SEED)
    num_classes = len(counts)
    
    # Dataloaders (exact full-update compatible)
    train_loader = DataLoader(train_ds, batch_size=TRAIN_BATCH, 
                              shuffle=True, num_workers=2, pin_memory=True)
    train_eval_loader = DataLoader(train_ds, batch_size=512, 
                                   shuffle=False, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=512, 
                             shuffle=False, num_workers=2, pin_memory=True)

    for loss_name, loss_ctor in LOSSES.items():
        print(f"\nüîç Loss: {loss_name}")
        torch.manual_seed(SEED); np.random.seed(SEED); random.seed(SEED)
        
        model = get_model(num_classes).to(DEVICE)
        loss_fn = loss_ctor(counts)
        loss_fn = loss_fn.to(DEVICE)
        opt = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS)

        # Save CDG gamma if applicable
        if loss_name == 'CDG' and hasattr(loss_fn, 'gamma_per_class'):
            gamma_path = os.path.join(OUT, f"{ds_name}_IF{IF}_CDG_gamma.npy")
            np.save(gamma_path, loss_fn.gamma_per_class.cpu().numpy())
            print(f"  ‚Üí Saved CDG gamma to {gamma_path}")

        rows = []
        for ep in range(EPOCHS):
            t0 = time.time()
            
            # ‚úÖ EXACT full-dataset update (gradient accumulation)
            train_batchwise_emulate_full(
                model, loss_fn, train_loader, opt,
                epoch=ep, Ew=Ew, loss_tag=('cdg' if loss_name == 'CDG' else None)
            )
            
            # Compute train loss (accurate)
            model.eval()
            train_loss = 0.0
            with torch.no_grad():
                for xb, yb in train_eval_loader:
                    xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                    logits = model(xb)
                    l = loss_fn(logits, yb, epoch=ep, Ew=Ew) if loss_name == 'CDG' else loss_fn(logits, yb)
                    train_loss += l.item() * xb.size(0)
            train_loss /= len(train_ds)
            model.train()
            
            # Evaluate
            val_acc, macrof1 = evaluate_model(model, test_loader)
            elapsed = time.time() - t0
            scheduler.step()
            
            print(f"  Epoch {ep+1}/{EPOCHS} | Loss: {train_loss:.4f} | Val Acc: {val_acc:.4f} | Macro F1: {macrof1:.4f} | {elapsed:.1f}s")
            rows.append({
                'epoch': ep,
                'train_loss': train_loss,
                'val_acc': val_acc,
                'macro_f1': macrof1
            })

        # Save results
        pd.DataFrame(rows).to_csv(os.path.join(OUT, f"{ds_name}_IF{IF}_{loss_name}.csv"), index=False)
        summary_rows.append({
            'dataset': ds_name,
            'IF': IF,
            'loss': loss_name,
            'val_acc': rows[-1]['val_acc'],
            'macro_f1': rows[-1]['macro_f1']
        })

# ----------------------
# Final summary & plots
# ----------------------
summary_df = pd.DataFrame(summary_rows)
summary_path = os.path.join(OUT, "summary_table.csv")
summary_df.to_csv(summary_path, index=False)
print("\n" + "="*65)
print("‚úÖ All experiments completed!")
print(f"üìä Summary saved to: {summary_path}")

# Plot CDG gamma curves
gamma_files = [f for f in os.listdir(OUT) if f.endswith('_CDG_gamma.npy')]
for gf in gamma_files:
    arr = np.load(os.path.join(OUT, gf))
    plt.figure(figsize=(6,2.5))
    plt.plot(arr, marker='o', markersize=3, linewidth=1.2)
    plt.title(f"CDG Gamma ‚Äî {gf.replace('_CDG_gamma.npy', '')}", fontsize=10)
    plt.xlabel("Class"); plt.ylabel("Œ≥")
    plt.grid(True, ls=':', alpha=0.6)
    plt.tight_layout()
    plt.savefig(os.path.join(OUT, gf.replace('.npy', '.png')), dpi=150)
    plt.close()

# Summary plot
plt.figure(figsize=(12,5))
sns.barplot(data=summary_df, x='loss', y='val_acc', hue='dataset')
plt.title("Final Validation Accuracy (Exact Full-Dataset Updates)", fontsize=14)
plt.xticks(rotation=30)
plt.ylabel("Accuracy")
plt.legend(title='Dataset (IF)', bbox_to_anchor=(1.02, 1), loc='upper left')
plt.tight_layout()
plt.savefig(os.path.join(OUT, 'summary_valacc_bar.png'), dpi=150, bbox_inches='tight')
plt.show()

print(f"\nüìÅ All outputs saved in:\n{OUT}")
print("üí° Note: All updates are EXACT full-dataset gradients (via accumulation), not mini-batch SGD.")

NameError: name 'DATASET' is not defined

In [18]:
# ‚úÖ Minimal code to resume training from CIFAR-10 onward
# Assumes MNIST, FashionMNIST, SVHN already completed.
# Uses existing /kaggle/working/data and /kaggle/working/loss_eval_results

import os, sys, time, math, random
import numpy as np, pandas as pd
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms, models
from sklearn.metrics import f1_score

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

# ----------------------
# Required helpers (minimal)
# ----------------------
def make_lt_indices(targets, imb_factor, seed=0):
    np.random.seed(seed)
    targets = np.array(targets)
    C = int(targets.max()) + 1
    cls_counts = np.bincount(targets, minlength=C)
    N_max = cls_counts.max()
    r = 1.0 / float(imb_factor)
    cls_num = [int(max(1, round(N_max * (r ** (i / (C - 1.0)))))) for i in range(C)]
    indices = []
    for c in range(C):
        idxs = np.where(targets == c)[0]
        chosen = np.random.choice(idxs, cls_num[c], replace=False)
        indices.extend(chosen.tolist())
    random.shuffle(indices)
    return indices, cls_num

class FocalLoss(nn.Module):
    def __init__(self, gamma=1.0, eps=1e-7):
        super().__init__()
        self.gamma = gamma
        self.eps = eps
    def forward(self, logits, y):
        p = F.softmax(logits, dim=1)
        pt = p.gather(1, y[:,None]).squeeze().clamp(min=self.eps)
        return (-((1-pt)**self.gamma) * torch.log(pt)).mean()

def effective_num_weights(counts, beta):
    counts = np.array(counts, dtype=np.float64)
    eff = (1.0 - np.power(beta, counts)) / (1.0 - beta + 1e-12)
    w = 1.0 / (eff + 1e-12)
    w = w / w.sum() * len(w)
    return w.astype(np.float32)

class CB_Focal(nn.Module):
    def __init__(self, counts, beta=0.999, gamma=1.0, eps=1e-7):
        super().__init__()
        self.gamma = gamma
        self.eps = eps
        self.register_buffer('weights', torch.tensor(effective_num_weights(counts, beta), dtype=torch.float32))
    
    def forward(self, logits, y):
        p = F.softmax(logits, dim=1)
        pt = p.gather(1, y[:,None]).squeeze().clamp(min=self.eps)
        w = self.weights.to(logits.device)[y]
        return (- w * ((1-pt)**self.gamma) * torch.log(pt)).mean()

class CDG_Focal(nn.Module):
    def __init__(self, counts, tau=1.0, k=0.0, gamma_min=0.75, gamma_max=2.5, eps=1e-7):
        super().__init__()
        self.eps = eps
        counts = torch.tensor(counts, dtype=torch.float32)
        p = (counts / counts.sum()).clamp(min=1e-12)
        # Fix: ensure scalar is a tensor
        log_tau_inv = torch.log(torch.tensor(1.0 / tau, dtype=p.dtype))
        branch2 = log_tau_inv + k * (p - tau)
        raw = torch.where(p > tau, torch.log(1.0 / p), branch2)
        gamma = raw.clamp(min=gamma_min, max=gamma_max)
        self.register_buffer('gamma_per_class', gamma)
    
    @staticmethod
    def cosine_warmup_weight(epoch, Ew):
        if Ew <= 0: return 1.0
        e = float(epoch)
        if e <= 0.0: return 0.0
        if e >= Ew: return 1.0
        return 0.5 * (1.0 - math.cos(math.pi * e / Ew))
    
    def forward(self, logits, y, epoch=None, Ew=5):
        p = F.softmax(logits, dim=1)
        pt = p.gather(1, y[:,None]).squeeze().clamp(min=self.eps)
        g = self.gamma_per_class.to(logits.device)[y]
        if epoch is not None:
            w = self.cosine_warmup_weight(epoch, Ew)
            g = g * w
        return (- ((1-pt)**g) * torch.log(pt)).mean()

def get_model(num_classes):
    model = models.resnet18(weights=None)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

def train_batchwise_emulate_full(model, loss_fn, loader, opt, epoch=None, Ew=5, loss_tag=None):
    model.train()
    opt.zero_grad(set_to_none=True)
    total_samples = 0
    for xb, yb in loader:
        xb = xb.to(DEVICE, non_blocking=True)
        yb = yb.to(DEVICE, non_blocking=True)
        logits = model(xb)
        if loss_tag == 'cdg':
            loss = loss_fn(logits, yb, epoch=epoch, Ew=Ew)
        else:
            loss = loss_fn(logits, yb)
        bs = xb.size(0)
        total_samples += bs
        (loss * bs).backward()
    for p in model.parameters():
        if p.grad is not None:
            p.grad.div_(total_samples)
    opt.step()

def evaluate_model(model, loader):
    model.eval()
    preds, tg = [], []
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(DEVICE, non_blocking=True)
            out = model(xb)
            preds.extend(out.argmax(1).cpu().numpy())
            tg.extend(yb.numpy())
    preds, tg = np.array(preds), np.array(tg)
    acc = (preds == tg).mean()
    macro_f1 = f1_score(tg, preds, average='macro', zero_division=0)
    return acc, macro_f1

def prepare_dataset(name, root='./data', imb_factor=1, seed=0):
    name_l = name.lower()

    if name_l == 'mnist':
        tr = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), lambda x: x.repeat(3,1,1)])
        train = datasets.MNIST(root, train=True, download=False, transform=tr)
        test = datasets.MNIST(root, train=False, download=False, transform=tr)
        C = 10

    elif name_l == 'fashionmnist':
        tr = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), lambda x: x.repeat(3,1,1)])
        train = datasets.FashionMNIST(root, train=True, download=False, transform=tr)
        test = datasets.FashionMNIST(root, train=False, download=False, transform=tr)
        C = 10

    elif name_l == 'svhn':
        tr = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), lambda x: x if x.size(0)==3 else x.repeat(3,1,1)])
        train = datasets.SVHN(root, split='train', download=False, transform=tr)
        test = datasets.SVHN(root, split='test', download=False, transform=tr)
        C = 10

    elif name_l == 'cifar10':
        train_tr = transforms.Compose([transforms.RandomCrop(32,4), transforms.RandomHorizontalFlip(), transforms.ToTensor()])
        test_tr = transforms.ToTensor()
        base_train = datasets.CIFAR10(root, train=True, download=False, transform=train_tr)
        test = datasets.CIFAR10(root, train=False, download=False, transform=test_tr)
        C = 10
        train = base_train

    elif name_l == 'cifar100':
        train_tr = transforms.Compose([transforms.RandomCrop(32,4), transforms.RandomHorizontalFlip(), transforms.ToTensor()])
        test_tr = transforms.ToTensor()
        base_train = datasets.CIFAR100(root, train=True, download=False, transform=train_tr)
        test = datasets.CIFAR100(root, train=False, download=False, transform=test_tr)
        C = 100
        train = base_train

    else:
        raise ValueError(f"Unsupported: {name}")

    if name_l in ('cifar10','cifar100') and imb_factor > 1:
        targets = np.array(base_train.targets)
        indices, cls_counts = make_lt_indices(targets, imb_factor, seed)
        train = Subset(base_train, indices)
        counts = cls_counts
    else:
        if hasattr(train, 'targets'):
            targets = np.array(train.targets)
        elif hasattr(train, 'labels'):
            targets = np.array(train.labels)
        else:
            targets = np.array([train[i][1] for i in range(len(train))])
        counts = np.bincount(targets, minlength=C).tolist()

    print(f"  ‚Üí {name} (IF={imb_factor}): train={len(train)}, test={len(test)}, classes={C}")
    return train, test, counts

# ----------------------
# Config
# ----------------------
OUT = "/kaggle/working/loss_eval_results"
os.makedirs(OUT, exist_ok=True)

LOSSES = {
    'CE': lambda counts: nn.CrossEntropyLoss(),
    'Focal_g1': lambda counts: FocalLoss(gamma=1.0),
    'CBF_b0.999_g1': lambda counts: CB_Focal(counts, beta=0.999, gamma=1.0),
    'CDG': lambda counts: CDG_Focal(counts, tau=1.0, k=0.0, gamma_min=0.75, gamma_max=2.5)
}

EPOCHS = 1
LR = 0.1
TRAIN_BATCH = 512
SEED = 0
Ew = 5

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

DATASETS_TO_RUN = [
    ('cifar10', 1),
    ('cifar10', 100),
    ('cifar100', 1),
    ('cifar100', 100)
]

summary_rows = []

print("üöÄ Resuming training for remaining datasets...")
for ds_name, IF in DATASETS_TO_RUN:
    print("\n" + "="*60)
    print(f"üìÅ {ds_name} | IF={IF}")
    
    try:
        train_ds, test_ds, counts = prepare_dataset(ds_name, root='./data', imb_factor=IF, seed=SEED)
    except Exception as e:
        print(f"‚ùå Failed to load {ds_name}: {e}")
        continue

    num_classes = len(counts)
    train_loader = DataLoader(train_ds, batch_size=TRAIN_BATCH, shuffle=True, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=512, shuffle=False, num_workers=2, pin_memory=True)

    for loss_name, loss_ctor in LOSSES.items():
        print(f"  üîç {loss_name}")
        torch.manual_seed(SEED)
        np.random.seed(SEED)
        random.seed(SEED)
        
        model = get_model(num_classes).to(DEVICE)
        loss_fn = loss_ctor(counts)
        opt = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS)

        if loss_name == 'CDG' and hasattr(loss_fn, 'gamma_per_class'):
            np.save(os.path.join(OUT, f"{ds_name}_IF{IF}_CDG_gamma.npy"), loss_fn.gamma_per_class.cpu().numpy())

        for ep in range(EPOCHS):
            t0 = time.time()
            train_batchwise_emulate_full(model, loss_fn, train_loader, opt, 
                                         epoch=ep, Ew=Ew, loss_tag=('cdg' if loss_name=='CDG' else None))
            
            model.eval()
            train_loss = 0.0
            train_eval_loader = DataLoader(train_ds, batch_size=512, shuffle=False, num_workers=2, pin_memory=True)
            with torch.no_grad():
                for xb, yb in train_eval_loader:
                    xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                    logits = model(xb)
                    if loss_name == 'CDG':
                        l = loss_fn(logits, yb, epoch=ep, Ew=Ew)
                    else:
                        l = loss_fn(logits, yb)
                    train_loss += l.item() * len(xb)
            train_loss /= len(train_ds)
            model.train()
            
            val_acc, macrof1 = evaluate_model(model, test_loader)
            scheduler.step()
            
            print(f"    Ep {ep+1}/{EPOCHS} | Loss: {train_loss:.4f} | Acc: {val_acc:.4f} | F1: {macrof1:.4f} | {time.time()-t0:.1f}s")

            if ep == EPOCHS - 1:
                summary_rows.append({
                    'dataset': ds_name,
                    'IF': IF,
                    'loss': loss_name,
                    'val_acc': val_acc,
                    'macro_f1': macrof1
                })

summary_df = pd.DataFrame(summary_rows)
summary_path = os.path.join(OUT, "summary_cifar_remaining.csv")
summary_df.to_csv(summary_path, index=False)
print(f"\n‚úÖ Remaining datasets completed. Summary saved to:\n{summary_path}")

full_summary_path = os.path.join(OUT, "summary_table.csv")
if os.path.exists(full_summary_path):
    prev = pd.read_csv(full_summary_path)
    full = pd.concat([prev, summary_df], ignore_index=True)
    full.to_csv(full_summary_path, index=False)
    print(f"‚úÖ Appended to full summary: {full_summary_path}")

Device: cuda
üöÄ Resuming training for remaining datasets...

üìÅ cifar10 | IF=1
  ‚Üí cifar10 (IF=1): train=50000, test=10000, classes=10
  üîç CE
    Ep 1/1 | Loss: 4.3650 | Acc: 0.1262 | F1: 0.0580 | 15.0s
  üîç Focal_g1
    Ep 1/1 | Loss: 4.8829 | Acc: 0.1255 | F1: 0.0548 | 14.7s
  üîç CBF_b0.999_g1
    Ep 1/1 | Loss: 4.8829 | Acc: 0.1255 | F1: 0.0548 | 15.5s
  üîç CDG
    Ep 1/1 | Loss: 4.3435 | Acc: 0.1262 | F1: 0.0580 | 15.1s

üìÅ cifar10 | IF=100
  ‚Üí cifar10 (IF=100): train=12408, test=10000, classes=10
  üîç CE
    Ep 1/1 | Loss: 3.2030 | Acc: 0.1000 | F1: 0.0182 | 4.7s
  üîç Focal_g1
    Ep 1/1 | Loss: 3.5244 | Acc: 0.1000 | F1: 0.0182 | 4.6s
  üîç CBF_b0.999_g1
    Ep 1/1 | Loss: 0.5366 | Acc: 0.0983 | F1: 0.0202 | 4.8s
  üîç CDG
    Ep 1/1 | Loss: 3.2030 | Acc: 0.1000 | F1: 0.0182 | 4.9s

üìÅ cifar100 | IF=1
‚ùå Failed to load cifar100: Dataset not found or corrupted. You can use download=True to download it

üìÅ cifar100 | IF=100
‚ùå Failed to load cifar100:

In [22]:
# ‚úÖ Minimal code to resume training from CIFAR-10 onward
# Assumes MNIST, FashionMNIST, SVHN already completed.
# Downloads CIFAR-10/100 if missing.

import os, sys, time, math, random
import numpy as np, pandas as pd
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms, models
from sklearn.metrics import f1_score

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

# ----------------------
# Required helpers (minimal)
# ----------------------
def make_lt_indices(targets, imb_factor, seed=0):
    np.random.seed(seed)
    targets = np.array(targets)
    C = int(targets.max()) + 1
    cls_counts = np.bincount(targets, minlength=C)
    N_max = cls_counts.max()
    r = 1.0 / float(imb_factor)
    cls_num = [int(max(1, round(N_max * (r ** (i / (C - 1.0)))))) for i in range(C)]
    indices = []
    for c in range(C):
        idxs = np.where(targets == c)[0]
        chosen = np.random.choice(idxs, cls_num[c], replace=False)
        indices.extend(chosen.tolist())
    random.shuffle(indices)
    return indices, cls_num

class FocalLoss(nn.Module):
    def __init__(self, gamma=1.0, eps=1e-7):
        super().__init__()
        self.gamma = gamma
        self.eps = eps
    def forward(self, logits, y):
        p = F.softmax(logits, dim=1)
        pt = p.gather(1, y[:,None]).squeeze().clamp(min=self.eps)
        return (-((1-pt)**self.gamma) * torch.log(pt)).mean()

def effective_num_weights(counts, beta):
    counts = np.array(counts, dtype=np.float64)
    eff = (1.0 - np.power(beta, counts)) / (1.0 - beta + 1e-12)
    w = 1.0 / (eff + 1e-12)
    w = w / w.sum() * len(w)
    return w.astype(np.float32)

class CB_Focal(nn.Module):
    def __init__(self, counts, beta=0.999, gamma=1.0, eps=1e-7):
        super().__init__()
        self.gamma = gamma
        self.eps = eps
        self.register_buffer('weights', torch.tensor(effective_num_weights(counts, beta), dtype=torch.float32))
    
    def forward(self, logits, y):
        p = F.softmax(logits, dim=1)
        pt = p.gather(1, y[:,None]).squeeze().clamp(min=self.eps)
        w = self.weights.to(logits.device)[y]  # ‚úÖ Fix: move to device before indexing
        return (- w * ((1-pt)**self.gamma) * torch.log(pt)).mean()

class CDG_Focal(nn.Module):
    def __init__(self, counts, tau=1.0, k=0.0, gamma_min=0.75, gamma_max=2.5, eps=1e-7):
        super().__init__()
        self.eps = eps
        counts = torch.tensor(counts, dtype=torch.float32)
        p = (counts / counts.sum()).clamp(min=1e-12)
        # ‚úÖ Fix: log() requires tensor, not float
        log_tau_inv = torch.log(torch.tensor(1.0 / tau, dtype=p.dtype))
        branch2 = log_tau_inv + k * (p - tau)
        raw = torch.where(p > tau, torch.log(1.0 / p), branch2)
        gamma = raw.clamp(min=gamma_min, max=gamma_max)
        self.register_buffer('gamma_per_class', gamma)
    
    @staticmethod
    def cosine_warmup_weight(epoch, Ew):
        if Ew <= 0: return 1.0
        e = float(epoch)
        if e <= 0.0: return 0.0
        if e >= Ew: return 1.0
        return 0.5 * (1.0 - math.cos(math.pi * e / Ew))
    
    def forward(self, logits, y, epoch=None, Ew=5):
        p = F.softmax(logits, dim=1)
        pt = p.gather(1, y[:,None]).squeeze().clamp(min=self.eps)
        g = self.gamma_per_class.to(logits.device)[y]
        if epoch is not None:
            w = self.cosine_warmup_weight(epoch, Ew)
            g = g * w
        return (- ((1-pt)**g) * torch.log(pt)).mean()

def get_model(num_classes):
    model = models.resnet18(weights=None)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

def train_batchwise_emulate_full(model, loss_fn, loader, opt, epoch=None, Ew=5, loss_tag=None):
    model.train()
    opt.zero_grad(set_to_none=True)
    total_samples = 0
    for xb, yb in loader:
        xb = xb.to(DEVICE, non_blocking=True)
        yb = yb.to(DEVICE, non_blocking=True)
        logits = model(xb)
        if loss_tag == 'cdg':
            loss = loss_fn(logits, yb, epoch=epoch, Ew=Ew)
        else:
            loss = loss_fn(logits, yb)
        bs = xb.size(0)
        total_samples += bs
        (loss * bs).backward()
    for p in model.parameters():
        if p.grad is not None:
            p.grad.div_(total_samples)
    opt.step()

def evaluate_model(model, loader):
    model.eval()
    preds, tg = [], []
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(DEVICE, non_blocking=True)
            out = model(xb)
            preds.extend(out.argmax(1).cpu().numpy())
            tg.extend(yb.numpy())
    preds, tg = np.array(preds), np.array(tg)
    acc = (preds == tg).mean()
    macro_f1 = f1_score(tg, preds, average='macro', zero_division=0)
    return acc, macro_f1

def prepare_dataset(name, root='./data', imb_factor=1, seed=0):
    name_l = name.lower()

    if name_l == 'mnist':
        tr = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), lambda x: x.repeat(3,1,1)])
        train = datasets.MNIST(root, train=True, download=False, transform=tr)
        test = datasets.MNIST(root, train=False, download=False, transform=tr)
        C = 10

    elif name_l == 'fashionmnist':
        tr = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), lambda x: x.repeat(3,1,1)])
        train = datasets.FashionMNIST(root, train=True, download=False, transform=tr)
        test = datasets.FashionMNIST(root, train=False, download=False, transform=tr)
        C = 10

    elif name_l == 'svhn':
        tr = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), lambda x: x if x.size(0)==3 else x.repeat(3,1,1)])
        train = datasets.SVHN(root, split='train', download=False, transform=tr)
        test = datasets.SVHN(root, split='test', download=False, transform=tr)
        C = 10

    elif name_l == 'cifar10':
        train_tr = transforms.Compose([transforms.RandomCrop(32,4), transforms.RandomHorizontalFlip(), transforms.ToTensor()])
        test_tr = transforms.ToTensor()
        # ‚úÖ Enable download
        base_train = datasets.CIFAR10(root, train=True, download=True, transform=train_tr)
        test = datasets.CIFAR10(root, train=False, download=True, transform=test_tr)
        C = 10
        train = base_train

    elif name_l == 'cifar100':
        train_tr = transforms.Compose([transforms.RandomCrop(32,4), transforms.RandomHorizontalFlip(), transforms.ToTensor()])
        test_tr = transforms.ToTensor()
        # ‚úÖ Enable download
        base_train = datasets.CIFAR100(root, train=True, download=True, transform=train_tr)
        test = datasets.CIFAR100(root, train=False, download=True, transform=test_tr)
        C = 100
        train = base_train

    else:
        raise ValueError(f"Unsupported: {name}")

    # Handle imbalance
    if name_l in ('cifar10','cifar100') and imb_factor > 1:
        targets = np.array(base_train.targets)
        indices, cls_counts = make_lt_indices(targets, imb_factor, seed)
        train = Subset(base_train, indices)
        counts = cls_counts
    else:
        if hasattr(train, 'targets'):
            targets = np.array(train.targets)
        elif hasattr(train, 'labels'):
            targets = np.array(train.labels)
        else:
            targets = np.array([train[i][1] for i in range(len(train))])
        counts = np.bincount(targets, minlength=C).tolist()

    print(f"  ‚Üí {name} (IF={imb_factor}): train={len(train)}, test={len(test)}, classes={C}")
    return train, test, counts

# ----------------------
# Config
# ----------------------
OUT = "/kaggle/working/loss_eval_results"
os.makedirs(OUT, exist_ok=True)

LOSSES = {
    'CE': lambda counts: nn.CrossEntropyLoss(),
    'Focal_g1': lambda counts: FocalLoss(gamma=1.0),
    'CBF_b0.999_g1': lambda counts: CB_Focal(counts, beta=0.999, gamma=1.0),
    'CDG': lambda counts: CDG_Focal(counts, tau=1.0, k=0.0, gamma_min=0.75, gamma_max=2.5)
}

EPOCHS = 6
LR = 0.1
TRAIN_BATCH = 512
SEED = 0
Ew = 5

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

DATASETS_TO_RUN = [
   ('mnist', 1),
  ('fashionmnist', 1),
  ('svhn', 1),
  ('cifar10', 1),
  ('cifar10', 100),
  ('cifar100', 1),
  ('cifar100', 100)
]

summary_rows = []

print("üöÄ Resuming training for remaining datasets...")
for ds_name, IF in DATASETS_TO_RUN:
    print("\n" + "="*60)
    print(f"üìÅ {ds_name} | IF={IF}")
    
    try:
        train_ds, test_ds, counts = prepare_dataset(ds_name, root='./data', imb_factor=IF, seed=SEED)
    except Exception as e:
        print(f"‚ùå Failed to load {ds_name}: {e}")
        continue

    num_classes = len(counts)
    train_loader = DataLoader(train_ds, batch_size=TRAIN_BATCH, shuffle=True, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=512, shuffle=False, num_workers=2, pin_memory=True)

    for loss_name, loss_ctor in LOSSES.items():
        print(f"  üîç {loss_name}")
        torch.manual_seed(SEED)
        np.random.seed(SEED)
        random.seed(SEED)
        
        model = get_model(num_classes).to(DEVICE)
        loss_fn = loss_ctor(counts)
        opt = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS)

        if loss_name == 'CDG' and hasattr(loss_fn, 'gamma_per_class'):
            np.save(os.path.join(OUT, f"{ds_name}_IF{IF}_CDG_gamma.npy"), loss_fn.gamma_per_class.cpu().numpy())

        for ep in range(EPOCHS):
            t0 = time.time()
            train_batchwise_emulate_full(model, loss_fn, train_loader, opt, 
                                         epoch=ep, Ew=Ew, loss_tag=('cdg' if loss_name=='CDG' else None))
            
            # Evaluate train loss accurately
            model.eval()
            train_loss = 0.0
            train_eval_loader = DataLoader(train_ds, batch_size=512, shuffle=False, num_workers=2, pin_memory=True)
            with torch.no_grad():
                for xb, yb in train_eval_loader:
                    xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                    logits = model(xb)
                    if loss_name == 'CDG':
                        l = loss_fn(logits, yb, epoch=ep, Ew=Ew)
                    else:
                        l = loss_fn(logits, yb)
                    train_loss += l.item() * len(xb)
            train_loss /= len(train_ds)
            model.train()
            
            val_acc, macrof1 = evaluate_model(model, test_loader)
            scheduler.step()
            
            print(f"    Ep {ep+1}/{EPOCHS} | Loss: {train_loss:.4f} | Acc: {val_acc:.4f} | F1: {macrof1:.4f} | {time.time()-t0:.1f}s")

            if ep == EPOCHS - 1:
                summary_rows.append({
                    'dataset': ds_name,
                    'IF': IF,
                    'loss': loss_name,
                    'val_acc': val_acc,
                    'macro_f1': macrof1
                })

# Save summary
summary_df = pd.DataFrame(summary_rows)
summary_path = os.path.join(OUT, "summary_cifar_remaining.csv")
summary_df.to_csv(summary_path, index=False)
print(f"\n‚úÖ Remaining datasets completed. Summary saved to:\n{summary_path}")

# Optional: merge with full summary
full_summary_path = os.path.join(OUT, "summary_table.csv")
if os.path.exists(full_summary_path):
    prev = pd.read_csv(full_summary_path)
    full = pd.concat([prev, summary_df], ignore_index=True)
    full.to_csv(full_summary_path, index=False)
    print(f"‚úÖ Appended to full summary: {full_summary_path}")

Device: cuda
üöÄ Resuming training for remaining datasets...

üìÅ mnist | IF=1
  ‚Üí mnist (IF=1): train=60000, test=10000, classes=10
  üîç CE
    Ep 1/6 | Loss: 5.4154 | Acc: 0.1845 | F1: 0.0903 | 14.3s
    Ep 2/6 | Loss: 12.8347 | Acc: 0.1773 | F1: 0.0955 | 14.3s
    Ep 3/6 | Loss: 2.1353 | Acc: 0.5089 | F1: 0.5159 | 14.3s
    Ep 4/6 | Loss: 0.4306 | Acc: 0.8829 | F1: 0.8824 | 14.4s


KeyboardInterrupt: 

In [25]:
# ‚úÖ FULL WORKING CODE: All datasets, 100 epochs, early stopping, CDG fixed
import os, sys, time, math, random
import numpy as np, pandas as pd
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms, models
from sklearn.metrics import f1_score

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

# ---------------------- HELPER FUNCTIONS ----------------------
def make_lt_indices(targets, imb_factor, seed=0):
    np.random.seed(seed)
    targets = np.array(targets)
    C = int(targets.max()) + 1
    cls_counts = np.bincount(targets, minlength=C)
    N_max = cls_counts.max()
    r = 1.0 / float(imb_factor)
    cls_num = [int(max(1, round(N_max * (r ** (i / (C - 1.0)))))) for i in range(C)]
    indices = []
    for c in range(C):
        idxs = np.where(targets == c)[0]
        chosen = np.random.choice(idxs, cls_num[c], replace=False)
        indices.extend(chosen.tolist())
    random.shuffle(indices)
    return indices, cls_num

class FocalLoss(nn.Module):
    def __init__(self, gamma=1.0, eps=1e-7):
        super().__init__()
        self.gamma = gamma
        self.eps = eps
    def forward(self, logits, y):
        p = F.softmax(logits, dim=1)
        pt = p.gather(1, y[:,None]).squeeze().clamp(min=self.eps)
        return (-((1-pt)**self.gamma) * torch.log(pt)).mean()

def effective_num_weights(counts, beta):
    counts = np.array(counts, dtype=np.float64)
    eff = (1.0 - np.power(beta, counts)) / (1.0 - beta + 1e-12)
    w = 1.0 / (eff + 1e-12)
    w = w / w.sum() * len(w)
    return w.astype(np.float32)

class CB_Focal(nn.Module):
    def __init__(self, counts, beta=0.999, gamma=1.0, eps=1e-7):
        super().__init__()
        self.gamma = gamma
        self.eps = eps
        self.register_buffer('weights', torch.tensor(effective_num_weights(counts, beta), dtype=torch.float32))
    
    def forward(self, logits, y):
        p = F.softmax(logits, dim=1)
        pt = p.gather(1, y[:,None]).squeeze().clamp(min=self.eps)
        w = self.weights.to(logits.device)[y]
        return (- w * ((1-pt)**self.gamma) * torch.log(pt)).mean()

# ‚úÖ FIXED CDG_Focal: Numerically stable + device-safe
class CDG_Focal(nn.Module):
    def __init__(self, counts, tau=1.0, k=0.0, gamma_min=0.75, gamma_max=2.5, eps=1e-7):
        super().__init__()
        self.eps = eps
        counts = torch.tensor(counts, dtype=torch.float32)
        p = (counts / counts.sum()).clamp(min=1e-12)
        log_tau_inv = torch.log(torch.tensor(1.0 / tau, dtype=p.dtype, device=p.device))
        branch2 = log_tau_inv + k * (p - tau)
        raw = torch.where(p > tau, torch.log(1.0 / p), branch2)
        gamma = raw.clamp(min=gamma_min, max=gamma_max)
        self.register_buffer('gamma_per_class', gamma)
    
    @staticmethod
    def cosine_warmup_weight(epoch, Ew):
        if Ew <= 0: return 1.0
        e = float(epoch)
        if e <= 0.0: return 0.0
        if e >= Ew: return 1.0
        return 0.5 * (1.0 - math.cos(math.pi * e / Ew))
    
    def forward(self, logits, y, epoch=None, Ew=5):
        # ‚úÖ Numerically stable: use log_softmax directly
        log_p = F.log_softmax(logits, dim=1)
        log_pt = log_p.gather(1, y[:, None]).squeeze()
        pt = torch.exp(log_pt).clamp(min=self.eps, max=1.0 - self.eps)
        g = self.gamma_per_class.to(logits.device)[y]
        if epoch is not None:
            w = self.cosine_warmup_weight(epoch, Ew)
            g = g * w
        loss = - ((1.0 - pt) ** g) * log_pt
        return loss.mean()

def get_model(num_classes):
    model = models.resnet18(weights=None)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

# ‚úÖ FIXED: Added gradient clipping to prevent explosion
def train_one_epoch(model, loss_fn, loader, opt, epoch=None, Ew=5, loss_tag=None):
    model.train()
    total_loss = 0.0
    total_samples = 0
    for xb, yb in loader:
        xb = xb.to(DEVICE, non_blocking=True)
        yb = yb.to(DEVICE, non_blocking=True)
        logits = model(xb)
        if loss_tag == 'cdg':
            loss = loss_fn(logits, yb, epoch=epoch, Ew=Ew)
        else:
            loss = loss_fn(logits, yb)
        opt.zero_grad(set_to_none=True)
        loss.backward()
        # üîë CRITICAL: Prevent gradient explosion
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        opt.step()
        total_loss += loss.item() * xb.size(0)
        total_samples += xb.size(0)
    return total_loss / total_samples

def evaluate_model(model, loader):
    model.eval()
    preds, tg = [], []
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(DEVICE, non_blocking=True)
            out = model(xb)
            preds.extend(out.argmax(1).cpu().numpy())
            tg.extend(yb.numpy())
    preds, tg = np.array(preds), np.array(tg)
    acc = (preds == tg).mean()
    macro_f1 = f1_score(tg, preds, average='macro', zero_division=0)
    return acc, macro_f1

def prepare_dataset(name, root='./data', imb_factor=1, seed=0):
    name_l = name.lower()

    if name_l == 'mnist':
        tr = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), lambda x: x.repeat(3,1,1)])
        train = datasets.MNIST(root, train=True, download=True, transform=tr)
        test = datasets.MNIST(root, train=False, download=True, transform=tr)
        C = 10

    elif name_l == 'fashionmnist':
        tr = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), lambda x: x.repeat(3,1,1)])
        train = datasets.FashionMNIST(root, train=True, download=True, transform=tr)
        test = datasets.FashionMNIST(root, train=False, download=True, transform=tr)
        C = 10

    elif name_l == 'svhn':
        tr = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), lambda x: x if x.size(0)==3 else x.repeat(3,1,1)])
        train = datasets.SVHN(root, split='train', download=True, transform=tr)
        test = datasets.SVHN(root, split='test', download=True, transform=tr)
        C = 10

    elif name_l == 'cifar10':
        train_tr = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
        test_tr = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
        base_train = datasets.CIFAR10(root, train=True, download=True, transform=train_tr)
        test = datasets.CIFAR10(root, train=False, download=True, transform=test_tr)
        C = 10
        train = base_train

    elif name_l == 'cifar100':
        train_tr = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
        ])
        test_tr = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
        ])
        base_train = datasets.CIFAR100(root, train=True, download=True, transform=train_tr)
        test = datasets.CIFAR100(root, train=False, download=True, transform=test_tr)
        C = 100
        train = base_train

    else:
        raise ValueError(f"Unsupported: {name}")

    if name_l in ('cifar10','cifar100') and imb_factor > 1:
        targets = np.array(base_train.targets)
        indices, cls_counts = make_lt_indices(targets, imb_factor, seed)
        train = Subset(base_train, indices)
        counts = cls_counts
    else:
        if hasattr(train, 'targets'):
            targets = np.array(train.targets)
        elif hasattr(train, 'labels'):
            targets = np.array(train.labels)
        else:
            targets = np.array([train[i][1] for i in range(len(train))])
        counts = np.bincount(targets, minlength=C).tolist()

    print(f"  ‚Üí {name} (IF={imb_factor}): train={len(train)}, test={len(test)}, classes={C}")
    return train, test, counts

# ---------------------- CONFIG ----------------------
OUT = "/kaggle/working/loss_eval_results"
os.makedirs(OUT, exist_ok=True)

LOSSES = {
    'CE': lambda counts: nn.CrossEntropyLoss(),
    'Focal_g1': lambda counts: FocalLoss(gamma=1.0),
    'CBF_b0.999_g1': lambda counts: CB_Focal(counts, beta=0.999, gamma=1.0),
    'CDG': lambda counts: CDG_Focal(counts, tau=1.0, k=0.0, gamma_min=0.75, gamma_max=2.5)
}

EPOCHS = 100
PATIENCE = 5
LR = 0.01  # Stable LR for all losses
TRAIN_BATCH = 256
SEED = 0
Ew = 5

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

# ‚úÖ ALL DATASETS (as per your current run)
DATASETS_TO_RUN = [
    ('mnist', 1),
    ('fashionmnist', 1),
    ('svhn', 1),
    ('cifar10', 1),
    ('cifar10', 100),
    ('cifar100', 1),
    ('cifar100', 100)
]

summary_rows = []
print("üöÄ Starting full training with early stopping (100 epochs, patience=5)...")

# ---------------------- MAIN LOOP ----------------------
for ds_name, IF in DATASETS_TO_RUN:
    print("\n" + "="*60)
    print(f"üìÅ {ds_name} | IF={IF}")
    
    try:
        train_ds, test_ds, counts = prepare_dataset(ds_name, root='./data', imb_factor=IF, seed=SEED)
    except Exception as e:
        print(f"‚ùå Failed to load {ds_name}: {e}")
        continue

    num_classes = len(counts)
    train_loader = DataLoader(train_ds, batch_size=TRAIN_BATCH, shuffle=True, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=256, shuffle=False, num_workers=2, pin_memory=True)

    for loss_name, loss_ctor in LOSSES.items():
        print(f"\n  üîç Training with {loss_name}")
        torch.manual_seed(SEED)
        np.random.seed(SEED)
        random.seed(SEED)
        
        model = get_model(num_classes).to(DEVICE)
        loss_fn = loss_ctor(counts)
        opt = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS)

        best_val_acc = 0.0
        epochs_no_improve = 0
        best_model_path = os.path.join(OUT, f"{ds_name}_IF{IF}_{loss_name}_best.pth")

        # Save CDG gamma if applicable
        if loss_name == 'CDG' and hasattr(loss_fn, 'gamma_per_class'):
            np.save(os.path.join(OUT, f"{ds_name}_IF{IF}_CDG_gamma.npy"), loss_fn.gamma_per_class.cpu().numpy())

        for ep in range(EPOCHS):
            t0 = time.time()
            train_loss = train_one_epoch(
                model, loss_fn, train_loader, opt,
                epoch=ep, Ew=Ew, loss_tag=('cdg' if loss_name=='CDG' else None)
            )
            
            val_acc, macrof1 = evaluate_model(model, test_loader)
            scheduler.step()
            
            # Early stopping
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_macrof1 = macrof1
                best_epoch = ep
                epochs_no_improve = 0
                torch.save(model.state_dict(), best_model_path)
                improved = "‚úÖ"
            else:
                epochs_no_improve += 1
                improved = "  "
            
            print(f"    Ep {ep+1:3d}/{EPOCHS} | Loss: {train_loss:.4f} | Acc: {val_acc:.4f} | F1: {macrof1:.4f} | {time.time()-t0:.1f}s {improved}")
            
            if epochs_no_improve >= PATIENCE:
                print(f"    ‚èπÔ∏è  Early stopping at epoch {ep+1}. Best: {best_epoch+1}")
                break

        # Final evaluation with best model
        model.load_state_dict(torch.load(best_model_path))
        final_acc, final_f1 = evaluate_model(model, test_loader)
        print(f"    üèÜ Final (best) | Acc: {final_acc:.4f} | F1: {final_f1:.4f}")

        summary_rows.append({
            'dataset': ds_name,
            'IF': IF,
            'loss': loss_name,
            'val_acc': final_acc,
            'macro_f1': final_f1,
            'best_epoch': best_epoch + 1
        })

# ---------------------- SAVE RESULTS ----------------------
summary_df = pd.DataFrame(summary_rows)
summary_path = os.path.join(OUT, "summary_full_results.csv")
summary_df.to_csv(summary_path, index=False)
print(f"\n‚úÖ Training completed! Results saved to:\n{summary_path}")

# Append to master summary if exists
full_summary_path = os.path.join(OUT, "summary_table.csv")
if os.path.exists(full_summary_path):
    prev = pd.read_csv(full_summary_path)
    full = pd.concat([prev, summary_df], ignore_index=True)
    full.to_csv(full_summary_path, index=False)
    print(f"‚úÖ Appended to master summary: {full_summary_path}")

Device: cuda
üöÄ Starting full training with early stopping (100 epochs, patience=5)...

üìÅ mnist | IF=1
  ‚Üí mnist (IF=1): train=60000, test=10000, classes=10

  üîç Training with CE
    Ep   1/100 | Loss: 0.3073 | Acc: 0.9743 | F1: 0.9742 | 11.0s ‚úÖ
    Ep   2/100 | Loss: 0.0452 | Acc: 0.9852 | F1: 0.9851 | 10.8s ‚úÖ
    Ep   3/100 | Loss: 0.0276 | Acc: 0.9903 | F1: 0.9903 | 10.8s ‚úÖ
    Ep   4/100 | Loss: 0.0166 | Acc: 0.9899 | F1: 0.9899 | 10.7s   
    Ep   5/100 | Loss: 0.0122 | Acc: 0.9908 | F1: 0.9908 | 10.9s ‚úÖ
    Ep   6/100 | Loss: 0.0086 | Acc: 0.9908 | F1: 0.9907 | 10.7s   
    Ep   7/100 | Loss: 0.0058 | Acc: 0.9917 | F1: 0.9916 | 10.9s ‚úÖ
    Ep   8/100 | Loss: 0.0045 | Acc: 0.9846 | F1: 0.9844 | 10.7s   
    Ep   9/100 | Loss: 0.0028 | Acc: 0.9924 | F1: 0.9923 | 10.9s ‚úÖ
    Ep  10/100 | Loss: 0.0027 | Acc: 0.9917 | F1: 0.9916 | 10.7s   
    Ep  11/100 | Loss: 0.0011 | Acc: 0.9925 | F1: 0.9924 | 10.8s ‚úÖ
    Ep  12/100 | Loss: 0.0009 | Acc: 0.9924 | F1: 0.9923

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40> 
 Traceback (most recent call last):
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
       self._shutdown_workers() 
^  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
^    ^if w.is_alive():^^
^ ^^ ^ ^ ^ ^
   File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
 ^    ^assert self._parent_pid == os.getpid(), 'can only test a child process'^
^ ^ ^ ^ ^ ^ ^ ^ ^ 
   File "/usr/

    Ep   3/100 | Loss: 0.2406 | Acc: 0.8681 | F1: 0.8653 | 12.0s ‚úÖ


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
     Exception ignored in:  <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40> 
 Traceback (most recent call last):
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
      ^self._shutdown_workers()^
^  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
^    ^if w.is_alive():^^
^ ^^ ^  ^ ^  ^^^^^^^^^^^^

    Ep   4/100 | Loss: 0.1870 | Acc: 0.8788 | F1: 0.8792 | 11.8s ‚úÖ


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
Exception ignored in:   File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
<function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>    
assert self._parent_pid == os.getpid(), 'can only test a child process'Traceback (most recent call last):

  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
       self._shutdown_workers() 
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
      if w.is_alive():  
     ^ ^^ ^ ^ ^^^^^^^^^^^^^^^^^^^^

    Ep   5/100 | Loss: 0.1480 | Acc: 0.8881 | F1: 0.8874 | 11.7s ‚úÖ


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
Exception ignored in:   File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
<function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>    assert self._parent_pid == os.getpid(), 'can only test a child process'
Traceback (most recent call last):

  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
      self._shutdown_workers()  
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
       if w.is_alive(): 
    ^ ^ ^ ^^  ^^^^^^^^^^^^^^^^^^^^

    Ep   6/100 | Loss: 0.1223 | Acc: 0.8927 | F1: 0.8925 | 11.7s ‚úÖ


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
 Exception ignored in:  <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40> 
 Traceback (most recent call last):
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
       self._shutdown_workers() 
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
     ^if w.is_alive():^
^ ^ ^ ^ ^ ^  ^^^^^^^^^^^^^^^^^

    Ep   7/100 | Loss: 0.1047 | Acc: 0.8947 | F1: 0.8931 | 11.7s ‚úÖ


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
Exception ignored in:   File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
<function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>    
assert self._parent_pid == os.getpid(), 'can only test a child process'Traceback (most recent call last):

  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
       self._shutdown_workers() 
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
      if w.is_alive():  
     ^ ^ ^ ^^ ^^^^^^^^^^^^^^^^^^^^

    Ep   8/100 | Loss: 0.0913 | Acc: 0.8947 | F1: 0.8947 | 11.7s   


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>  
Traceback (most recent call last):
    File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
      self._shutdown_workers()  
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
      if w.is_alive():^
^ ^^ ^ ^ ^ ^ ^^ ^^^^^^^^^^^^^^^

    Ep   9/100 | Loss: 0.0798 | Acc: 0.8921 | F1: 0.8922 | 11.6s   


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
Exception ignored in:  <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40> 
 Traceback (most recent call last):
    File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
       self._shutdown_workers() 
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
     ^if w.is_alive():^
^ ^^ ^ ^ ^ ^ ^ ^^^^^^^^^^^^^^^

    Ep  10/100 | Loss: 0.0716 | Acc: 0.8954 | F1: 0.8940 | 11.7s ‚úÖ


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^Exception ignored in: ^
<function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive

    assert self._parent_pid == os.getpid(), 'can only test a child process'Traceback (most recent call last):

  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
      self._shutdown_workers()  
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
     if w.is_alive(): 
        ^^  ^ ^^^^^^^^^^^^^^^^^^^^^^

    Ep  11/100 | Loss: 0.0631 | Acc: 0.8970 | F1: 0.8961 | 11.7s ‚úÖ


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
Exception ignored in:   <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>
  Traceback (most recent call last):
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
      self._shutdown_workers()  
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
      if w.is_alive():^
^ ^ ^ ^ ^ ^^  ^^^^^^^^^^^^^^^^^

    Ep  12/100 | Loss: 0.0558 | Acc: 0.8921 | F1: 0.8898 | 11.7s   


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    Exception ignored in: assert self._parent_pid == os.getpid(), 'can only test a child process'<function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>

  Traceback (most recent call last):
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
       self._shutdown_workers() 
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
      if w.is_alive(): 
^ ^ ^^  ^^ ^ ^ ^^^^^^^^^^^^^^^^^

    Ep  13/100 | Loss: 0.0508 | Acc: 0.8947 | F1: 0.8945 | 11.6s   


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    Exception ignored in: assert self._parent_pid == os.getpid(), 'can only test a child process'<function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>

 Traceback (most recent call last):
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
      self._shutdown_workers()  
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
      if w.is_alive():  
  ^ ^ ^ ^ ^ ^ ^^^^^^^^^^^^^^^^^^^

    Ep  14/100 | Loss: 0.0479 | Acc: 0.8960 | F1: 0.8957 | 11.6s   


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
Exception ignored in:   File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
<function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>    
assert self._parent_pid == os.getpid(), 'can only test a child process'Traceback (most recent call last):

   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
      self._shutdown_workers() 
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
      if w.is_alive(): 
     ^ ^ ^ ^ ^ ^^^^^^^^^^^^^^^^^^^^

    Ep  15/100 | Loss: 0.0422 | Acc: 0.8976 | F1: 0.8975 | 11.9s ‚úÖ


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
  Exception ignored in:  <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>
 Traceback (most recent call last):
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
      self._shutdown_workers() 
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
       if w.is_alive():^
^^ ^ ^ ^ ^ ^ ^ ^^^^^^^^^^^^^^^^

    Ep  16/100 | Loss: 0.0386 | Acc: 0.8942 | F1: 0.8951 | 11.7s   


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'Exception ignored in: 
 <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40> 
 Traceback (most recent call last):
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
       self._shutdown_workers() 
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
      ^if w.is_alive():
^ ^ ^ ^^  ^ ^ ^^^^^^^^^^^^^^^^^

    Ep  17/100 | Loss: 0.0351 | Acc: 0.9001 | F1: 0.9003 | 11.8s ‚úÖ


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    Exception ignored in: assert self._parent_pid == os.getpid(), 'can only test a child process'
<function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>  
Traceback (most recent call last):
    File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
      self._shutdown_workers() 
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
       if w.is_alive():
^ ^^ ^ ^  ^ ^ ^^^^^^^^^^^^^^^^^^

    Ep  18/100 | Loss: 0.0326 | Acc: 0.8999 | F1: 0.9002 | 11.7s   


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 16

    Ep  19/100 | Loss: 0.0311 | Acc: 0.8987 | F1: 0.8983 | 11.9s   


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
Exception ignored in:   File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
<function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>    
assert self._parent_pid == os.getpid(), 'can only test a child process'
Traceback (most recent call last):
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
      self._shutdown_workers() 
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
      if w.is_alive():
       ^ ^ ^ ^ ^^^^^^^^^^^^^^^^^^^^^

    Ep  20/100 | Loss: 0.0279 | Acc: 0.9011 | F1: 0.9010 | 11.6s ‚úÖ


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
  Exception ignored in:  <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40> 
  Traceback (most recent call last):
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
      self._shutdown_workers() 
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
^    ^if w.is_alive():^
^ ^ ^ ^ ^ ^ ^ ^^^^^^^^^^^^^^^

    Ep  21/100 | Loss: 0.0272 | Acc: 0.8982 | F1: 0.8976 | 12.0s   


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 16

    Ep  22/100 | Loss: 0.0256 | Acc: 0.8988 | F1: 0.8986 | 12.0s   


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
Exception ignored in:  <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40> 
 Traceback (most recent call last):
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
       self._shutdown_workers()
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
      if w.is_alive(): 
^  ^ ^  ^ ^^ ^^^^^^^^^^^^^^^^^^^

    Ep  23/100 | Loss: 0.0227 | Acc: 0.8939 | F1: 0.8928 | 11.6s   


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40> 
 Traceback (most recent call last):
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
       self._shutdown_workers() 
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
       if w.is_alive():^
^^ ^ ^ ^  ^ ^ ^^^^^^^^^^^^^^^^^

    Ep  24/100 | Loss: 0.0213 | Acc: 0.8957 | F1: 0.8953 | 12.2s   


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
          Exception ignored in:  <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>^^
^Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
^^    ^self._shutdown_workers()^
^  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
^    ^if w.is_alive():^
^ ^ ^  ^^  ^^ ^^^^^^^

    Ep  25/100 | Loss: 0.0216 | Acc: 0.8898 | F1: 0.8880 | 12.1s   
    ‚èπÔ∏è  Early stopping at epoch 25. Best: 20


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7aaa59286d40>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 16

    üèÜ Final (best) | Acc: 0.9011 | F1: 0.9010

üìÅ svhn | IF=1
  ‚Üí svhn (IF=1): train=73257, test=26032, classes=10

  üîç Training with CE
    Ep   1/100 | Loss: 1.4814 | Acc: 0.7157 | F1: 0.6838 | 14.5s ‚úÖ
    Ep   2/100 | Loss: 0.5905 | Acc: 0.7800 | F1: 0.7566 | 14.4s ‚úÖ
    Ep   3/100 | Loss: 0.4290 | Acc: 0.8430 | F1: 0.8261 | 14.4s ‚úÖ
    Ep   4/100 | Loss: 0.3465 | Acc: 0.8537 | F1: 0.8384 | 14.4s ‚úÖ
    Ep   5/100 | Loss: 0.2876 | Acc: 0.8632 | F1: 0.8497 | 14.5s ‚úÖ
    Ep   6/100 | Loss: 0.2392 | Acc: 0.8630 | F1: 0.8478 | 14.4s   
    Ep   7/100 | Loss: 0.1973 | Acc: 0.8733 | F1: 0.8617 | 14.5s ‚úÖ
    Ep   8/100 | Loss: 0.1633 | Acc: 0.8627 | F1: 0.8488 | 14.4s   
    Ep   9/100 | Loss: 0.1330 | Acc: 0.8752 | F1: 0.8638 | 14.4s ‚úÖ
    Ep  10/100 | Loss: 0.1093 | Acc: 0.8726 | F1: 0.8619 | 14.4s   
    Ep  11/100 | Loss: 0.0901 | Acc: 0.8763 | F1: 0.8633 | 14.4s ‚úÖ
    Ep  12/100 | Loss: 0.0746 | Acc: 0.8747 | F1: 0.8641 | 14.4s   
    Ep  13/100 | Loss: 0.0625

In [26]:
# ‚úÖ CDG with Head/Mid/Tail accuracy for CIFAR-10-LT
import os, sys, time, math, random
import numpy as np, pandas as pd
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, Subset, WeightedRandomSampler
from torchvision import datasets, transforms, models
from sklearn.metrics import f1_score

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

# ---------------------- CDG LOSS (TUNED) ----------------------
class CDG_Focal(nn.Module):
    def __init__(self, counts, tau=0.01, k=1.0, gamma_min=0.5, gamma_max=4.0, eps=1e-7):
        super().__init__()
        self.eps = eps
        counts = torch.tensor(counts, dtype=torch.float32)
        p = (counts / counts.sum()).clamp(min=1e-12)
        log_tau_inv = torch.log(torch.tensor(1.0 / tau, dtype=p.dtype, device=p.device))
        branch2 = log_tau_inv + k * (p - tau)
        raw = torch.where(p > tau, torch.log(1.0 / p), branch2)
        gamma = raw.clamp(min=gamma_min, max=gamma_max)
        self.register_buffer('gamma_per_class', gamma)
    
    @staticmethod
    def cosine_warmup_weight(epoch, Ew):
        if Ew <= 0: return 1.0
        e = float(epoch)
        if e <= 0.0: return 0.0
        if e >= Ew: return 1.0
        return 0.5 * (1.0 - math.cos(math.pi * e / Ew))
    
    def forward(self, logits, y, epoch=None, Ew=10):
        log_p = F.log_softmax(logits, dim=1)
        log_pt = log_p.gather(1, y[:, None]).squeeze()
        pt = torch.exp(log_pt).clamp(min=self.eps, max=1.0 - self.eps)
        g = self.gamma_per_class.to(logits.device)[y]
        if epoch is not None:
            w = self.cosine_warmup_weight(epoch, Ew)
            g = g * w
        loss = - ((1.0 - pt) ** g) * log_pt
        return loss.mean()

# ---------------------- SPLIT ACCURACY ----------------------
def compute_head_mid_tail_acc(preds, targets, dataset_name, imb_factor):
    """
    Compute Head/Mid/Tail accuracy for CIFAR-10-LT (IF=100)
    Standard split: Head=0-2, Mid=3-5, Tail=6-9
    """
    if dataset_name.lower() != 'cifar10' or imb_factor != 100:
        return None, None, None

    preds, targets = np.array(preds), np.array(targets)
    
    head_classes = [0, 1, 2]
    mid_classes = [3, 4, 5]
    tail_classes = [6, 7, 8, 9]
    
    def acc_for_classes(classes):
        mask = np.isin(targets, classes)
        if mask.sum() == 0:
            return np.nan
        return (preds[mask] == targets[mask]).mean()
    
    head_acc = acc_for_classes(head_classes)
    mid_acc = acc_for_classes(mid_classes)
    tail_acc = acc_for_classes(tail_classes)
    
    return head_acc, mid_acc, tail_acc

# ---------------------- DATASET PREP ----------------------
def make_lt_indices(targets, imb_factor, seed=0):
    np.random.seed(seed)
    targets = np.array(targets)
    C = int(targets.max()) + 1
    cls_counts = np.bincount(targets, minlength=C)
    N_max = cls_counts.max()
    r = 1.0 / float(imb_factor)
    cls_num = [int(max(1, round(N_max * (r ** (i / (C - 1.0)))))) for i in range(C)]
    indices = []
    for c in range(C):
        idxs = np.where(targets == c)[0]
        chosen = np.random.choice(idxs, cls_num[c], replace=False)
        indices.extend(chosen.tolist())
    random.shuffle(indices)
    return indices, cls_num

def get_sampler_for_imbalance(targets, num_classes):
    counts = np.bincount(targets, minlength=num_classes)
    weight_per_class = 1.0 / (counts + 1e-6)
    weights = weight_per_class[targets]
    return WeightedRandomSampler(weights, len(weights), replacement=True)

def prepare_dataset(name, root='./data', imb_factor=1, seed=0):
    name_l = name.lower()
    if name_l in ('cifar10', 'cifar100'):
        mean = (0.4914, 0.4822, 0.4465) if name_l == 'cifar10' else (0.5071, 0.4867, 0.4408)
        std = (0.2023, 0.1994, 0.2010) if name_l == 'cifar10' else (0.2675, 0.2565, 0.2761)
        train_tr = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
        test_tr = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
    else:
        train_tr = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            lambda x: x.repeat(3,1,1) if x.size(0) == 1 else x
        ])
        test_tr = train_tr

    if name_l == 'mnist':
        base_train = datasets.MNIST(root, train=True, download=True, transform=train_tr)
        test = datasets.MNIST(root, train=False, download=True, transform=test_tr)
        C = 10
    elif name_l == 'fashionmnist':
        base_train = datasets.FashionMNIST(root, train=True, download=True, transform=train_tr)
        test = datasets.FashionMNIST(root, train=False, download=True, transform=test_tr)
        C = 10
    elif name_l == 'svhn':
        base_train = datasets.SVHN(root, split='train', download=True, transform=train_tr)
        test = datasets.SVHN(root, split='test', download=True, transform=test_tr)
        C = 10
    elif name_l == 'cifar10':
        base_train = datasets.CIFAR10(root, train=True, download=True, transform=train_tr)
        test = datasets.CIFAR10(root, train=False, download=True, transform=test_tr)
        C = 10
    elif name_l == 'cifar100':
        base_train = datasets.CIFAR100(root, train=True, download=True, transform=train_tr)
        test = datasets.CIFAR100(root, train=False, download=True, transform=test_tr)
        C = 100
    else:
        raise ValueError(f"Unsupported: {name}")

    if name_l in ('cifar10','cifar100') and imb_factor > 1:
        targets = np.array(base_train.targets)
        indices, cls_counts = make_lt_indices(targets, imb_factor, seed)
        train = Subset(base_train, indices)
        train_targets = np.array([base_train.targets[i] for i in indices])
        counts = cls_counts
    else:
        train = base_train
        if hasattr(base_train, 'targets'):
            train_targets = np.array(base_train.targets)
        elif hasattr(base_train, 'labels'):
            train_targets = np.array(base_train.labels)
        else:
            train_targets = np.array([base_train[i][1] for i in range(len(base_train))])
        counts = np.bincount(train_targets, minlength=C).tolist()

    print(f"  ‚Üí {name} (IF={imb_factor}): train={len(train)}, test={len(test)}, classes={C}")
    return train, test, counts, train_targets

# ---------------------- MODEL & TRAINING ----------------------
def get_model(num_classes):
    model = models.resnet18(weights=None)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

def train_one_epoch(model, loss_fn, loader, opt, epoch=None, Ew=10):
    model.train()
    total_loss = 0.0
    for xb, yb in loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        opt.zero_grad()
        logits = model(xb)
        loss = loss_fn(logits, yb, epoch=epoch, Ew=Ew)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        opt.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate_model_with_preds(model, loader):
    """Returns predictions and targets for detailed analysis"""
    model.eval()
    preds, tg = [], []
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(DEVICE)
            out = model(xb)
            preds.extend(out.argmax(1).cpu().numpy())
            tg.extend(yb.numpy())
    return np.array(preds), np.array(tg)

def evaluate_model(preds, tg):
    acc = (preds == tg).mean()
    macro_f1 = f1_score(tg, preds, average='macro', zero_division=0)
    return acc, macro_f1

# ---------------------- CONFIG ----------------------
OUT = "/kaggle/working/loss_eval_results"
os.makedirs(OUT, exist_ok=True)

EPOCHS = 100
PATIENCE = 10
LR = 0.005
BATCH = 128
SEED = 0

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

# ‚úÖ Focus: Include CIFAR-10-LT (IF=100) for Head/Mid/Tail
DATASETS_TO_RUN = [
    ('cifar10', 100),  # ‚Üê Main focus
    # Add others if needed, but Head/Mid/Tail only computed for this
]

summary_rows = []
print("üöÄ CDG training with Head/Mid/Tail accuracy for CIFAR-10-LT...")

# ---------------------- MAIN LOOP ----------------------
for ds_name, IF in DATASETS_TO_RUN:
    print("\n" + "="*60)
    print(f"üìÅ {ds_name} | IF={IF}")
    
    try:
        train_ds, test_ds, counts, train_targets = prepare_dataset(ds_name, root='./data', imb_factor=IF, seed=SEED)
    except Exception as e:
        print(f"‚ùå Failed to load {ds_name}: {e}")
        continue

    num_classes = len(counts)
    
    if IF > 1:
        sampler = get_sampler_for_imbalance(train_targets, num_classes)
        train_loader = DataLoader(train_ds, batch_size=BATCH, sampler=sampler, num_workers=2, pin_memory=True)
    else:
        train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True, num_workers=2, pin_memory=True)
    
    test_loader = DataLoader(test_ds, batch_size=256, shuffle=False, num_workers=2, pin_memory=True)

    print(f"  üîç Training with CDG (tau=0.01, k=1.0, gamma‚àà[0.5,4.0])")
    torch.manual_seed(SEED)
    
    model = get_model(num_classes).to(DEVICE)
    loss_fn = CDG_Focal(counts, tau=0.01, k=1.0, gamma_min=0.5, gamma_max=4.0)
    opt = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS)

    best_val_acc = 0.0
    epochs_no_improve = 0
    best_model_path = os.path.join(OUT, f"{ds_name}_IF{IF}_CDG_tuned_best.pth")
    np.save(os.path.join(OUT, f"{ds_name}_IF{IF}_CDG_tuned_gamma.npy"), loss_fn.gamma_per_class.cpu().numpy())

    for ep in range(EPOCHS):
        t0 = time.time()
        train_loss = train_one_epoch(model, loss_fn, train_loader, opt, epoch=ep, Ew=10)
        
        # Evaluate with predictions
        preds, tg = evaluate_model_with_preds(model, test_loader)
        val_acc, macrof1 = evaluate_model(preds, tg)
        
        # Compute Head/Mid/Tail only for CIFAR-10-LT
        head_acc, mid_acc, tail_acc = compute_head_mid_tail_acc(preds, tg, ds_name, IF)
        
        scheduler.step()

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_macrof1 = macrof1
            best_head, best_mid, best_tail = head_acc, mid_acc, tail_acc
            best_epoch = ep
            epochs_no_improve = 0
            torch.save(model.state_dict(), best_model_path)
            improved = "‚úÖ"
        else:
            epochs_no_improve += 1
            improved = "  "
        
        # Log Head/Mid/Tail if available
        hmt_log = ""
        if head_acc is not None:
            hmt_log = f" | H:{head_acc:.3f} M:{mid_acc:.3f} T:{tail_acc:.3f}"
        
        print(f"    Ep {ep+1:3d}/{EPOCHS} | Loss: {train_loss:.4f} | Acc: {val_acc:.4f} | F1: {macrof1:.4f}{hmt_log} | {time.time()-t0:.1f}s {improved}")
        
        if epochs_no_improve >= PATIENCE:
            print(f"    ‚èπÔ∏è Early stopping at epoch {ep+1}. Best: {best_epoch+1}")
            break

    # Final evaluation
    model.load_state_dict(torch.load(best_model_path))
    final_preds, final_tg = evaluate_model_with_preds(model, test_loader)
    final_acc, final_f1 = evaluate_model(final_preds, final_tg)
    head_acc, mid_acc, tail_acc = compute_head_mid_tail_acc(final_preds, final_tg, ds_name, IF)
    
    print(f"    üèÜ Final (best) | Acc: {final_acc:.4f} | F1: {final_f1:.4f} | H:{head_acc:.3f} M:{mid_acc:.3f} T:{tail_acc:.3f}")

    summary_rows.append({
        'dataset': ds_name,
        'IF': IF,
        'loss': 'CDG_tuned',
        'val_acc': final_acc,
        'macro_f1': final_f1,
        'head_acc': head_acc,
        'mid_acc': mid_acc,
        'tail_acc': tail_acc,
        'best_epoch': best_epoch + 1
    })

# ---------------------- SAVE ----------------------
summary_df = pd.DataFrame(summary_rows)
summary_path = os.path.join(OUT, "summary_cifar10_lt_headmidtail.csv")
summary_df.to_csv(summary_path, index=False)
print(f"\n‚úÖ Head/Mid/Tail results saved to:\n{summary_path}")

Device: cuda
üöÄ CDG training with Head/Mid/Tail accuracy for CIFAR-10-LT...

üìÅ cifar10 | IF=100
  ‚Üí cifar10 (IF=100): train=12408, test=10000, classes=10
  üîç Training with CDG (tau=0.01, k=1.0, gamma‚àà[0.5,4.0])
    Ep   1/100 | Loss: 2.2315 | Acc: 0.2588 | F1: 0.2502 | H:0.258 M:0.193 T:0.309 | 7.7s ‚úÖ
    Ep   2/100 | Loss: 1.8858 | Acc: 0.3062 | F1: 0.3011 | H:0.412 M:0.208 T:0.300 | 7.6s ‚úÖ
    Ep   3/100 | Loss: 1.6772 | Acc: 0.3528 | F1: 0.3492 | H:0.399 M:0.288 T:0.367 | 7.6s ‚úÖ
    Ep   4/100 | Loss: 1.4726 | Acc: 0.3735 | F1: 0.3658 | H:0.431 M:0.294 T:0.390 | 7.7s ‚úÖ
    Ep   5/100 | Loss: 1.2771 | Acc: 0.3930 | F1: 0.3891 | H:0.441 M:0.313 T:0.417 | 7.6s ‚úÖ
    Ep   6/100 | Loss: 1.1200 | Acc: 0.4046 | F1: 0.3989 | H:0.518 M:0.322 T:0.382 | 7.8s ‚úÖ
    Ep   7/100 | Loss: 0.9828 | Acc: 0.3967 | F1: 0.3860 | H:0.564 M:0.275 T:0.363 | 7.6s   
    Ep   8/100 | Loss: 0.8632 | Acc: 0.4110 | F1: 0.4031 | H:0.557 M:0.315 T:0.374 | 7.6s ‚úÖ
    Ep   9/100 | Loss: 0.7

In [27]:
# ‚úÖ All losses on CIFAR-10-LT with Head/Mid/Tail accuracy
import os, sys, time, math, random
import numpy as np, pandas as pd
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, Subset, WeightedRandomSampler
from torchvision import datasets, transforms, models
from sklearn.metrics import f1_score

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

# ---------------------- LOSS FUNCTIONS ----------------------
class FocalLoss(nn.Module):
    def __init__(self, gamma=1.0, eps=1e-7):
        super().__init__()
        self.gamma = gamma
        self.eps = eps
    def forward(self, logits, y):
        p = F.softmax(logits, dim=1)
        pt = p.gather(1, y[:,None]).squeeze().clamp(min=self.eps)
        return (-((1-pt)**self.gamma) * torch.log(pt)).mean()

def effective_num_weights(counts, beta):
    counts = np.array(counts, dtype=np.float64)
    eff = (1.0 - np.power(beta, counts)) / (1.0 - beta + 1e-12)
    w = 1.0 / (eff + 1e-12)
    w = w / w.sum() * len(w)
    return w.astype(np.float32)

class CB_Focal(nn.Module):
    def __init__(self, counts, beta=0.999, gamma=1.0, eps=1e-7):
        super().__init__()
        self.gamma = gamma
        self.eps = eps
        self.register_buffer('weights', torch.tensor(effective_num_weights(counts, beta), dtype=torch.float32))
    def forward(self, logits, y):
        p = F.softmax(logits, dim=1)
        pt = p.gather(1, y[:,None]).squeeze().clamp(min=self.eps)
        w = self.weights.to(logits.device)[y]
        return (- w * ((1-pt)**self.gamma) * torch.log(pt)).mean()

class CDG_Focal(nn.Module):
    def __init__(self, counts, tau=0.01, k=1.0, gamma_min=0.5, gamma_max=4.0, eps=1e-7):
        super().__init__()
        self.eps = eps
        counts = torch.tensor(counts, dtype=torch.float32)
        p = (counts / counts.sum()).clamp(min=1e-12)
        log_tau_inv = torch.log(torch.tensor(1.0 / tau, dtype=p.dtype, device=p.device))
        branch2 = log_tau_inv + k * (p - tau)
        raw = torch.where(p > tau, torch.log(1.0 / p), branch2)
        gamma = raw.clamp(min=gamma_min, max=gamma_max)
        self.register_buffer('gamma_per_class', gamma)
    
    @staticmethod
    def cosine_warmup_weight(epoch, Ew):
        if Ew <= 0: return 1.0
        e = float(epoch)
        if e <= 0.0: return 0.0
        if e >= Ew: return 1.0
        return 0.5 * (1.0 - math.cos(math.pi * e / Ew))
    
    def forward(self, logits, y, epoch=None, Ew=10):
        log_p = F.log_softmax(logits, dim=1)
        log_pt = log_p.gather(1, y[:, None]).squeeze()
        pt = torch.exp(log_pt).clamp(min=self.eps, max=1.0 - self.eps)
        g = self.gamma_per_class.to(logits.device)[y]
        if epoch is not None:
            w = self.cosine_warmup_weight(epoch, Ew)
            g = g * w
        loss = - ((1.0 - pt) ** g) * log_pt
        return loss.mean()

# ---------------------- HEAD/MID/TAIL ACCURACY ----------------------
def compute_head_mid_tail_acc(preds, targets):
    """
    CIFAR-10-LT (IF=100) standard split:
    - Head: classes 0,1,2 (most frequent)
    - Mid:  classes 3,4,5
    - Tail: classes 6,7,8,9 (rarest)
    """
    preds, targets = np.array(preds), np.array(targets)
    
    head_classes = [0, 1, 2]
    mid_classes = [3, 4, 5]
    tail_classes = [6, 7, 8, 9]
    
    def acc_for_classes(classes):
        mask = np.isin(targets, classes)
        if mask.sum() == 0:
            return np.nan
        return (preds[mask] == targets[mask]).mean()
    
    return (
        acc_for_classes(head_classes),
        acc_for_classes(mid_classes),
        acc_for_classes(tail_classes)
    )

# ---------------------- DATASET PREP ----------------------
def make_lt_indices(targets, imb_factor, seed=0):
    np.random.seed(seed)
    targets = np.array(targets)
    C = int(targets.max()) + 1
    cls_counts = np.bincount(targets, minlength=C)
    N_max = cls_counts.max()
    r = 1.0 / float(imb_factor)
    cls_num = [int(max(1, round(N_max * (r ** (i / (C - 1.0)))))) for i in range(C)]
    indices = []
    for c in range(C):
        idxs = np.where(targets == c)[0]
        chosen = np.random.choice(idxs, cls_num[c], replace=False)
        indices.extend(chosen.tolist())
    random.shuffle(indices)
    return indices, cls_num

def get_sampler_for_imbalance(targets, num_classes):
    counts = np.bincount(targets, minlength=num_classes)
    weight_per_class = 1.0 / (counts + 1e-6)
    weights = weight_per_class[targets]
    return WeightedRandomSampler(weights, len(weights), replacement=True)

def prepare_cifar10_lt(root='./data', imb_factor=100, seed=0):
    mean = (0.4914, 0.4822, 0.4465)
    std = (0.2023, 0.1994, 0.2010)
    train_tr = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    test_tr = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    
    base_train = datasets.CIFAR10(root, train=True, download=True, transform=train_tr)
    test = datasets.CIFAR10(root, train=False, download=True, transform=test_tr)
    
    targets = np.array(base_train.targets)
    indices, cls_counts = make_lt_indices(targets, imb_factor, seed)
    train = Subset(base_train, indices)
    train_targets = np.array([base_train.targets[i] for i in indices])
    counts = cls_counts
    
    print(f"  ‚Üí CIFAR-10-LT (IF={imb_factor}): train={len(train)}, test={len(test)}, classes=10")
    return train, test, counts, train_targets

# ---------------------- MODEL & TRAINING ----------------------
def get_model(num_classes):
    model = models.resnet18(weights=None)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

def train_one_epoch(model, loss_fn, loader, opt, epoch=None, loss_name=None):
    model.train()
    total_loss = 0.0
    for xb, yb in loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        opt.zero_grad()
        logits = model(xb)
        if loss_name == 'CDG':
            loss = loss_fn(logits, yb, epoch=epoch, Ew=10)
        else:
            loss = loss_fn(logits, yb)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        opt.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate_model_with_preds(model, loader):
    model.eval()
    preds, tg = [], []
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(DEVICE)
            out = model(xb)
            preds.extend(out.argmax(1).cpu().numpy())
            tg.extend(yb.numpy())
    return np.array(preds), np.array(tg)

def evaluate_model(preds, tg):
    acc = (preds == tg).mean()
    macro_f1 = f1_score(tg, preds, average='macro', zero_division=0)
    return acc, macro_f1

# ---------------------- CONFIG ----------------------
OUT = "/kaggle/working/loss_eval_results"
os.makedirs(OUT, exist_ok=True)

LOSSES = {
    'CE': lambda counts: nn.CrossEntropyLoss(),
    'Focal_g1': lambda counts: FocalLoss(gamma=1.0),
    'CBF_b0.999_g1': lambda counts: CB_Focal(counts, beta=0.999, gamma=1.0),
    'CDG_tuned': lambda counts: CDG_Focal(counts, tau=0.01, k=1.0, gamma_min=0.5, gamma_max=4.0)
}

EPOCHS = 100
PATIENCE = 10
LR = 0.005  # Stable for all
BATCH = 128
SEED = 0

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

summary_rows = []
print("üöÄ Evaluating ALL losses on CIFAR-10-LT (IF=100) with Head/Mid/Tail accuracy...")

# ---------------------- MAIN LOOP ----------------------
ds_name, IF = 'cifar10', 100
print("\n" + "="*60)
print(f"üìÅ {ds_name} | IF={IF}")

train_ds, test_ds, counts, train_targets = prepare_cifar10_lt(root='./data', imb_factor=IF, seed=SEED)
num_classes = len(counts)

# Use class-balanced sampler for fair comparison
sampler = get_sampler_for_imbalance(train_targets, num_classes)
train_loader = DataLoader(train_ds, batch_size=BATCH, sampler=sampler, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=256, shuffle=False, num_workers=2, pin_memory=True)

for loss_name, loss_ctor in LOSSES.items():
    print(f"\n  üîç Training with {loss_name}")
    torch.manual_seed(SEED)
    
    model = get_model(num_classes).to(DEVICE)
    loss_fn = loss_ctor(counts)
    opt = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS)

    best_val_acc = 0.0
    epochs_no_improve = 0
    best_model_path = os.path.join(OUT, f"cifar10_IF100_{loss_name}_best.pth")

    # Save CDG gamma if applicable
    if 'CDG' in loss_name and hasattr(loss_fn, 'gamma_per_class'):
        np.save(os.path.join(OUT, f"cifar10_IF100_{loss_name}_gamma.npy"), loss_fn.gamma_per_class.cpu().numpy())

    for ep in range(EPOCHS):
        t0 = time.time()
        train_loss = train_one_epoch(model, loss_fn, train_loader, opt, epoch=ep, loss_name=loss_name)
        
        preds, tg = evaluate_model_with_preds(model, test_loader)
        val_acc, macrof1 = evaluate_model(preds, tg)
        head_acc, mid_acc, tail_acc = compute_head_mid_tail_acc(preds, tg)
        
        scheduler.step()

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_macrof1 = macrof1
            best_head, best_mid, best_tail = head_acc, mid_acc, tail_acc
            best_epoch = ep
            epochs_no_improve = 0
            torch.save(model.state_dict(), best_model_path)
            improved = "‚úÖ"
        else:
            epochs_no_improve += 1
            improved = "  "
        
        print(f"    Ep {ep+1:3d}/{EPOCHS} | Loss: {train_loss:.4f} | Acc: {val_acc:.4f} | F1: {macrof1:.4f} | "
              f"H:{head_acc:.3f} M:{mid_acc:.3f} T:{tail_acc:.3f} | {time.time()-t0:.1f}s {improved}")
        
        if epochs_no_improve >= PATIENCE:
            print(f"    ‚èπÔ∏è Early stopping at epoch {ep+1}. Best: {best_epoch+1}")
            break

    # Final evaluation
    model.load_state_dict(torch.load(best_model_path))
    final_preds, final_tg = evaluate_model_with_preds(model, test_loader)
    final_acc, final_f1 = evaluate_model(final_preds, final_tg)
    head_acc, mid_acc, tail_acc = compute_head_mid_tail_acc(final_preds, final_tg)
    
    print(f"    üèÜ Final (best) | Acc: {final_acc:.4f} | F1: {final_f1:.4f} | "
          f"H:{head_acc:.3f} M:{mid_acc:.3f} T:{tail_acc:.3f}")

    summary_rows.append({
        'loss': loss_name,
        'val_acc': final_acc,
        'macro_f1': final_f1,
        'head_acc': head_acc,
        'mid_acc': mid_acc,
        'tail_acc': tail_acc,
        'best_epoch': best_epoch + 1
    })

# ---------------------- SAVE RESULTS ----------------------
summary_df = pd.DataFrame(summary_rows)
summary_path = os.path.join(OUT, "cifar10_lt_all_losses_headmidtail.csv")
summary_df.to_csv(summary_path, index=False)
print(f"\n‚úÖ All losses evaluated! Results saved to:\n{summary_path}")

# Display summary
print("\n" + "="*80)
print("SUMMARY: CIFAR-10-LT (IF=100) Head/Mid/Tail Accuracies")
print("="*80)
print(summary_df.to_string(index=False, float_format="%.4f"))

Device: cuda
üöÄ Evaluating ALL losses on CIFAR-10-LT (IF=100) with Head/Mid/Tail accuracy...

üìÅ cifar10 | IF=100
  ‚Üí CIFAR-10-LT (IF=100): train=12408, test=10000, classes=10

  üîç Training with CE
    Ep   1/100 | Loss: 2.2343 | Acc: 0.2560 | F1: 0.2465 | H:0.270 M:0.187 T:0.297 | 7.7s ‚úÖ
    Ep   2/100 | Loss: 1.9296 | Acc: 0.3075 | F1: 0.3030 | H:0.401 M:0.212 T:0.309 | 7.6s ‚úÖ
    Ep   3/100 | Loss: 1.7913 | Acc: 0.3400 | F1: 0.3365 | H:0.396 M:0.274 T:0.348 | 7.7s ‚úÖ
    Ep   4/100 | Loss: 1.6907 | Acc: 0.3682 | F1: 0.3600 | H:0.407 M:0.285 T:0.401 | 7.6s ‚úÖ
    Ep   5/100 | Loss: 1.5869 | Acc: 0.3792 | F1: 0.3744 | H:0.406 M:0.296 T:0.422 | 7.6s ‚úÖ
    Ep   6/100 | Loss: 1.5057 | Acc: 0.3952 | F1: 0.3913 | H:0.512 M:0.302 T:0.378 | 7.5s ‚úÖ
    Ep   7/100 | Loss: 1.4343 | Acc: 0.4052 | F1: 0.3965 | H:0.524 M:0.290 T:0.402 | 7.5s ‚úÖ
    Ep   8/100 | Loss: 1.3613 | Acc: 0.4195 | F1: 0.4101 | H:0.521 M:0.306 T:0.429 | 7.6s ‚úÖ
    Ep   9/100 | Loss: 1.2917 | Acc: 0.42

KeyboardInterrupt: 

In [None]:
# ‚úÖ FULL DATASETS + All Losses + Head/Mid/Tail for CIFAR-10-LT only
import os, sys, time, math, random
import numpy as np, pandas as pd
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, Subset, WeightedRandomSampler
from torchvision import datasets, transforms, models
from sklearn.metrics import f1_score

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

# ---------------------- LOSS FUNCTIONS ----------------------
class FocalLoss(nn.Module):
    def __init__(self, gamma=1.0, eps=1e-7):
        super().__init__()
        self.gamma = gamma
        self.eps = eps
    def forward(self, logits, y):
        p = F.softmax(logits, dim=1)
        pt = p.gather(1, y[:,None]).squeeze().clamp(min=self.eps)
        return (-((1-pt)**self.gamma) * torch.log(pt)).mean()

def effective_num_weights(counts, beta):
    counts = np.array(counts, dtype=np.float64)
    eff = (1.0 - np.power(beta, counts)) / (1.0 - beta + 1e-12)
    w = 1.0 / (eff + 1e-12)
    w = w / w.sum() * len(w)
    return w.astype(np.float32)

class CB_Focal(nn.Module):
    def __init__(self, counts, beta=0.999, gamma=1.0, eps=1e-7):
        super().__init__()
        self.gamma = gamma
        self.eps = eps
        self.register_buffer('weights', torch.tensor(effective_num_weights(counts, beta), dtype=torch.float32))
    def forward(self, logits, y):
        p = F.softmax(logits, dim=1)
        pt = p.gather(1, y[:,None]).squeeze().clamp(min=self.eps)
        w = self.weights.to(logits.device)[y]
        return (- w * ((1-pt)**self.gamma) * torch.log(pt)).mean()

class CDG_Focal(nn.Module):
    def __init__(self, counts, tau=0.01, k=1.0, gamma_min=0.5, gamma_max=4.0, eps=1e-7):
        super().__init__()
        self.eps = eps
        counts = torch.tensor(counts, dtype=torch.float32)
        p = (counts / counts.sum()).clamp(min=1e-12)
        log_tau_inv = torch.log(torch.tensor(1.0 / tau, dtype=p.dtype, device=p.device))
        branch2 = log_tau_inv + k * (p - tau)
        raw = torch.where(p > tau, torch.log(1.0 / p), branch2)
        gamma = raw.clamp(min=gamma_min, max=gamma_max)
        self.register_buffer('gamma_per_class', gamma)
    
    @staticmethod
    def cosine_warmup_weight(epoch, Ew):
        if Ew <= 0: return 1.0
        e = float(epoch)
        if e <= 0.0: return 0.0
        if e >= Ew: return 1.0
        return 0.5 * (1.0 - math.cos(math.pi * e / Ew))
    
    def forward(self, logits, y, epoch=None, Ew=10):
        log_p = F.log_softmax(logits, dim=1)
        log_pt = log_p.gather(1, y[:, None]).squeeze()
        pt = torch.exp(log_pt).clamp(min=self.eps, max=1.0 - self.eps)
        g = self.gamma_per_class.to(logits.device)[y]
        if epoch is not None:
            w = self.cosine_warmup_weight(epoch, Ew)
            g = g * w
        loss = - ((1.0 - pt) ** g) * log_pt
        return loss.mean()

# ---------------------- HEAD/MID/TAIL (CIFAR-10-LT ONLY) ----------------------
def compute_head_mid_tail_acc(preds, targets, dataset_name, imb_factor):
    """
    Only compute for CIFAR-10-LT (IF=100)
    Head: 0-2, Mid: 3-5, Tail: 6-9
    """
    if dataset_name.lower() == 'cifar10' and imb_factor == 100:
        preds, targets = np.array(preds), np.array(targets)
        head_classes = [0, 1, 2]
        mid_classes = [3, 4, 5]
        tail_classes = [6, 7, 8, 9]
        
        def acc_for_classes(classes):
            mask = np.isin(targets, classes)
            return (preds[mask] == targets[mask]).mean() if mask.sum() > 0 else np.nan
        
        return (
            acc_for_classes(head_classes),
            acc_for_classes(mid_classes),
            acc_for_classes(tail_classes)
        )
    return None, None, None

# ---------------------- DATASET PREP (ALL) ----------------------
def make_lt_indices(targets, imb_factor, seed=0):
    np.random.seed(seed)
    targets = np.array(targets)
    C = int(targets.max()) + 1
    cls_counts = np.bincount(targets, minlength=C)
    N_max = cls_counts.max()
    r = 1.0 / float(imb_factor)
    cls_num = [int(max(1, round(N_max * (r ** (i / (C - 1.0)))))) for i in range(C)]
    indices = []
    for c in range(C):
        idxs = np.where(targets == c)[0]
        chosen = np.random.choice(idxs, cls_num[c], replace=False)
        indices.extend(chosen.tolist())
    random.shuffle(indices)
    return indices, cls_num

def get_sampler_for_imbalance(targets, num_classes):
    counts = np.bincount(targets, minlength=num_classes)
    weight_per_class = 1.0 / (counts + 1e-6)
    weights = weight_per_class[targets]
    return WeightedRandomSampler(weights, len(weights), replacement=True)

def prepare_dataset(name, root='./data', imb_factor=1, seed=0):
    name_l = name.lower()
    if name_l in ('cifar10', 'cifar100'):
        mean = (0.4914, 0.4822, 0.4465) if name_l == 'cifar10' else (0.5071, 0.4867, 0.4408)
        std = (0.2023, 0.1994, 0.2010) if name_l == 'cifar10' else (0.2675, 0.2565, 0.2761)
        train_tr = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
        test_tr = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
    else:
        train_tr = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            lambda x: x.repeat(3,1,1) if x.size(0) == 1 else x
        ])
        test_tr = train_tr

    if name_l == 'mnist':
        base_train = datasets.MNIST(root, train=True, download=True, transform=train_tr)
        test = datasets.MNIST(root, train=False, download=True, transform=test_tr)
        C = 10
    elif name_l == 'fashionmnist':
        base_train = datasets.FashionMNIST(root, train=True, download=True, transform=train_tr)
        test = datasets.FashionMNIST(root, train=False, download=True, transform=test_tr)
        C = 10
    elif name_l == 'svhn':
        base_train = datasets.SVHN(root, split='train', download=True, transform=train_tr)
        test = datasets.SVHN(root, split='test', download=True, transform=test_tr)
        C = 10
    elif name_l == 'cifar10':
        base_train = datasets.CIFAR10(root, train=True, download=True, transform=train_tr)
        test = datasets.CIFAR10(root, train=False, download=True, transform=test_tr)
        C = 10
    elif name_l == 'cifar100':
        base_train = datasets.CIFAR100(root, train=True, download=True, transform=train_tr)
        test = datasets.CIFAR100(root, train=False, download=True, transform=test_tr)
        C = 100
    else:
        raise ValueError(f"Unsupported: {name}")

    if name_l in ('cifar10','cifar100') and imb_factor > 1:
        targets = np.array(base_train.targets)
        indices, cls_counts = make_lt_indices(targets, imb_factor, seed)
        train = Subset(base_train, indices)
        train_targets = np.array([base_train.targets[i] for i in indices])
        counts = cls_counts
    else:
        train = base_train
        if hasattr(base_train, 'targets'):
            train_targets = np.array(base_train.targets)
        elif hasattr(base_train, 'labels'):
            train_targets = np.array(base_train.labels)
        else:
            train_targets = np.array([base_train[i][1] for i in range(len(base_train))])
        counts = np.bincount(train_targets, minlength=C).tolist()

    print(f"  ‚Üí {name} (IF={imb_factor}): train={len(train)}, test={len(test)}, classes={C}")
    return train, test, counts, train_targets

# ---------------------- MODEL & TRAINING ----------------------
def get_model(num_classes):
    model = models.resnet18(weights=None)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

def train_one_epoch(model, loss_fn, loader, opt, epoch=None, loss_name=None):
    model.train()
    total_loss = 0.0
    for xb, yb in loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        opt.zero_grad()
        logits = model(xb)
        if loss_name == 'CDG':
            loss = loss_fn(logits, yb, epoch=epoch, Ew=10)
        else:
            loss = loss_fn(logits, yb)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        opt.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate_model_with_preds(model, loader):
    model.eval()
    preds, tg = [], []
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(DEVICE)
            out = model(xb)
            preds.extend(out.argmax(1).cpu().numpy())
            tg.extend(yb.numpy())
    return np.array(preds), np.array(tg)

def evaluate_model(preds, tg):
    acc = (preds == tg).mean()
    macro_f1 = f1_score(tg, preds, average='macro', zero_division=0)
    return acc, macro_f1

# ---------------------- CONFIG ----------------------
OUT = "/kaggle/working/loss_eval_results"
os.makedirs(OUT, exist_ok=True)

LOSSES = {
    'CE': lambda counts: nn.CrossEntropyLoss(),
    'Focal_g1': lambda counts: FocalLoss(gamma=1.0),
    'CBF_b0.999_g1': lambda counts: CB_Focal(counts, beta=0.999, gamma=1.0),
    'CDG_tuned': lambda counts: CDG_Focal(counts, tau=0.01, k=1.0, gamma_min=0.5, gamma_max=4.0)
}

EPOCHS = 100
PATIENCE = 10
LR = 0.005
BATCH = 128
SEED = 0

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

# ‚úÖ FULL DATASET LIST
DATASETS_TO_RUN = [
    ('mnist', 1),
    ('mnist', 100),
    ('fashionmnist', 1),
    ('fashionmnist', 100),
    ('svhn', 1),
    ('svhn', 100),
    ('cifar10', 1),
    ('cifar10', 100),   # ‚Üê Head/Mid/Tail computed here
    ('cifar100', 1),
    ('cifar100', 100)
]

summary_rows = []
print("üöÄ Full evaluation: all datasets + all losses + H/M/T for CIFAR-10-LT...")

# ---------------------- MAIN LOOP ----------------------
for ds_name, IF in DATASETS_TO_RUN:
    print("\n" + "="*60)
    print(f"üìÅ {ds_name} | IF={IF}")
    
    try:
        train_ds, test_ds, counts, train_targets = prepare_dataset(ds_name, root='./data', imb_factor=IF, seed=SEED)
    except Exception as e:
        print(f"‚ùå Failed to load {ds_name}: {e}")
        continue

    num_classes = len(counts)
    
    # Use class-balanced sampler only for imbalanced cases (IF > 1)
    if IF > 1:
        sampler = get_sampler_for_imbalance(train_targets, num_classes)
        train_loader = DataLoader(train_ds, batch_size=BATCH, sampler=sampler, num_workers=2, pin_memory=True)
    else:
        train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True, num_workers=2, pin_memory=True)
    
    test_loader = DataLoader(test_ds, batch_size=256, shuffle=False, num_workers=2, pin_memory=True)

    for loss_name, loss_ctor in LOSSES.items():
        print(f"\n  üîç {loss_name}")
        torch.manual_seed(SEED)
        
        model = get_model(num_classes).to(DEVICE)
        loss_fn = loss_ctor(counts)
        opt = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS)

        best_val_acc = 0.0
        epochs_no_improve = 0
        best_model_path = os.path.join(OUT, f"{ds_name}_IF{IF}_{loss_name}_best.pth")

        if 'CDG' in loss_name and hasattr(loss_fn, 'gamma_per_class'):
            np.save(os.path.join(OUT, f"{ds_name}_IF{IF}_{loss_name}_gamma.npy"), loss_fn.gamma_per_class.cpu().numpy())

        for ep in range(EPOCHS):
            t0 = time.time()
            train_loss = train_one_epoch(model, loss_fn, train_loader, opt, epoch=ep, loss_name=loss_name)
            
            preds, tg = evaluate_model_with_preds(model, test_loader)
            val_acc, macrof1 = evaluate_model(preds, tg)
            head_acc, mid_acc, tail_acc = compute_head_mid_tail_acc(preds, tg, ds_name, IF)
            
            scheduler.step()

            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_macrof1 = macrof1
                best_head, best_mid, best_tail = head_acc, mid_acc, tail_acc
                best_epoch = ep
                epochs_no_improve = 0
                torch.save(model.state_dict(), best_model_path)
                improved = "‚úÖ"
            else:
                epochs_no_improve += 1
                improved = "  "
            
            # Log H/M/T only for CIFAR-10-LT
            hmt_log = ""
            if head_acc is not None:
                hmt_log = f" | H:{head_acc:.3f} M:{mid_acc:.3f} T:{tail_acc:.3f}"
            
            print(f"    Ep {ep+1:3d}/{EPOCHS} | Loss: {train_loss:.4f} | Acc: {val_acc:.4f} | F1: {macrof1:.4f}{hmt_log} | {time.time()-t0:.1f}s {improved}")
            
            if epochs_no_improve >= PATIENCE:
                print(f"    ‚èπÔ∏è Early stopping at epoch {ep+1}. Best: {best_epoch+1}")
                break

        # Final evaluation
        model.load_state_dict(torch.load(best_model_path))
        final_preds, final_tg = evaluate_model_with_preds(model, test_loader)
        final_acc, final_f1 = evaluate_model(final_preds, final_tg)
        head_acc, mid_acc, tail_acc = compute_head_mid_tail_acc(final_preds, final_tg, ds_name, IF)

        # Build result row
        result = {
            'dataset': ds_name,
            'IF': IF,
            'loss': loss_name,
            'val_acc': final_acc,
            'macro_f1': final_f1,
            'best_epoch': best_epoch + 1
        }
        # Add H/M/T only if available
        if head_acc is not None:
            result.update({
                'head_acc': head_acc,
                'mid_acc': mid_acc,
                'tail_acc': tail_acc
            })
        
        summary_rows.append(result)

# ---------------------- SAVE ----------------------
summary_df = pd.DataFrame(summary_rows)
summary_path = os.path.join(OUT, "full_results_with_headmidtail.csv")
summary_df.to_csv(summary_path, index=False)
print(f"\n‚úÖ Full evaluation complete! Results saved to:\n{summary_path}")

# Optional: Print only CIFAR-10-LT H/M/T summary
cifar10_lt = summary_df[(summary_df['dataset'] == 'cifar10') & (summary_df['IF'] == 100)]
if not cifar10_lt.empty:
    print("\n" + "="*80)
    print("CIFAR-10-LT (IF=100) Head/Mid/Tail Summary")
    print("="*80)
    cols = ['loss', 'val_acc', 'macro_f1', 'head_acc', 'mid_acc', 'tail_acc']
    print(cifar10_lt[cols].to_string(index=False, float_format="%.4f"))

Device: cuda
üöÄ Full evaluation: all datasets + all losses + H/M/T for CIFAR-10-LT...

üìÅ mnist | IF=1


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9.91M/9.91M [00:00<00:00, 17.8MB/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 28.9k/28.9k [00:00<00:00, 484kB/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1.65M/1.65M [00:00<00:00, 4.49MB/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4.54k/4.54k [00:00<00:00, 6.53MB/s]


  ‚Üí mnist (IF=1): train=60000, test=10000, classes=10

  üîç CE
    Ep   1/100 | Loss: 0.3254 | Acc: 0.9818 | F1: 0.9816 | 30.2s ‚úÖ
    Ep   2/100 | Loss: 0.0502 | Acc: 0.9862 | F1: 0.9861 | 29.4s ‚úÖ
    Ep   3/100 | Loss: 0.0301 | Acc: 0.9881 | F1: 0.9881 | 29.3s ‚úÖ
    Ep   4/100 | Loss: 0.0177 | Acc: 0.9868 | F1: 0.9866 | 29.5s   
    Ep   5/100 | Loss: 0.0123 | Acc: 0.9912 | F1: 0.9911 | 29.4s ‚úÖ
    Ep   6/100 | Loss: 0.0072 | Acc: 0.9911 | F1: 0.9910 | 29.2s   
    Ep   7/100 | Loss: 0.0043 | Acc: 0.9920 | F1: 0.9919 | 29.3s ‚úÖ
    Ep   8/100 | Loss: 0.0029 | Acc: 0.9910 | F1: 0.9909 | 29.0s   
    Ep   9/100 | Loss: 0.0023 | Acc: 0.9917 | F1: 0.9917 | 29.2s   
    Ep  10/100 | Loss: 0.0013 | Acc: 0.9925 | F1: 0.9924 | 29.3s ‚úÖ
    Ep  11/100 | Loss: 0.0009 | Acc: 0.9927 | F1: 0.9927 | 29.3s ‚úÖ
    Ep  12/100 | Loss: 0.0006 | Acc: 0.9918 | F1: 0.9917 | 29.3s   
    Ep  13/100 | Loss: 0.0005 | Acc: 0.9925 | F1: 0.9924 | 29.1s   
    Ep  14/100 | Loss: 0.0004 | Acc: 0.992