# SAT on Long-Tailed CIFAR-10: Balanced and Worst-Group Analysis

This notebook evaluates Self-Adaptive Training on imbalanced CIFAR-10 with:
- **Balanced Error Rate**: Average error across all classes (unweighted)
- **Worst-Group Error Rate**: Error of the worst-performing class
- **AURC (Balanced)**: Area Under Risk-Coverage curve using balanced error
- **AURC (Worst-Group)**: Area Under Risk-Coverage curve using worst-group error

**Settings:**
- Architecture: VGG-16 with Batch Normalization
- Dataset: CIFAR-10 Long-Tailed (Imbalance Ratio configurable)
- Loss: SAT with momentum 0.99
- Training: 300 epochs

## 1. Setup and Imports

In [None]:
import os
import time
import random
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
from PIL import Image

# For loading CIFAR-10-LT dataset
import datasets
from datasets import load_dataset

# Local imports
import models.cifar as models
from loss import SelfAdativeTraining
from utils import AverageMeter, accuracy

print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("CUDA device:", torch.cuda.get_device_name(0))

# Visualization settings
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")
%matplotlib inline

## 2. Long-Tailed Dataset Loading with HuggingFace Datasets

In [None]:
# Wrapper for CIFAR-10-LT to work with PyTorch DataLoader
class CIFAR10LTWrapper(torch.utils.data.Dataset):
    """Wrapper for HuggingFace CIFAR-10-LT dataset"""
    
    def __init__(self, hf_dataset, transform=None):
        """
        Args:
            hf_dataset: HuggingFace dataset object
            transform: torchvision transforms to apply
        """
        self.dataset = hf_dataset
        self.transform = transform
        self.targets = np.array([item['label'] for item in hf_dataset])
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        item = self.dataset[index]
        img = item['img']
        target = item['label']
        
        # Convert to PIL Image if needed
        if not isinstance(img, Image.Image):
            img = Image.fromarray(np.array(img))
        
        if self.transform is not None:
            img = self.transform(img)
        
        return img, target, index
    
    def get_class_distribution(self):
        """Return distribution of samples per class"""
        unique, counts = np.unique(self.targets, return_counts=True)
        return dict(zip(unique, counts))

print("CIFAR-10-LT wrapper class defined.")

## 3. Configuration

In [None]:
# Configuration
class Config:
    # Dataset settings
    dataset = 'cifar10'
    dataset_name = 'r-100'  # HuggingFace dataset name: r-10, r-20, r-50, r-100, r-200
    imbalance_ratio = 100  # Majority/minority class ratio (10, 20, 50, 100, 200)
    
    # Architecture
    arch = 'vgg16_bn'
    
    # Training settings
    loss_type = 'sat'
    epochs = 300
    pretrain_epochs = 0
    
    # Hyperparameters (same as balanced SAT)
    batch_size_train = 128
    batch_size_test = 200
    lr = 0.1
    momentum = 0.9
    sat_momentum = 0.99
    weight_decay = 5e-4
    gamma = 0.5
    schedule = [25, 50, 75, 100, 125, 150, 175, 200, 225, 250, 275]
    
    # Evaluation settings
    expected_coverage = [100., 99., 98., 97., 95., 90., 85., 80., 75., 70., 60., 50., 40., 30., 20., 10.]
    
    # System settings
    gpu_id = '0'
    num_workers = 4
    manual_seed = 42
    save_dir = f'./checkpoints/sat_longtailed_ir{100}'
    
config = Config()

# Set random seeds
random.seed(config.manual_seed)
torch.manual_seed(config.manual_seed)
np.random.seed(config.manual_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(config.manual_seed)

# Set GPU
os.environ['CUDA_VISIBLE_DEVICES'] = config.gpu_id
use_cuda = torch.cuda.is_available()

# Create save directory
os.makedirs(config.save_dir, exist_ok=True)

print(f"Configuration:")
print(f"  Dataset: Long-Tailed CIFAR-10 (dataset_name={config.dataset_name}, IR={config.imbalance_ratio})")
print(f"  Architecture: {config.arch}")
print(f"  Loss: {config.loss_type}")
print(f"  Epochs: {config.epochs}")
print(f"  SAT Momentum: {config.sat_momentum}")
print(f"  Save directory: {config.save_dir}")

## 4. Dataset Preparation

In [None]:
# Prepare datasets
print(f'Preparing Long-Tailed CIFAR-10 from HuggingFace (dataset={config.dataset_name})...')

num_classes = 10
input_size = 32

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# Load HuggingFace CIFAR-10-LT dataset
print(f"Loading dataset from cifar10-lt.py script...")
hf_train = load_dataset('./cifar10-lt.py', name=config.dataset_name, split='train')
hf_test = load_dataset('./cifar10-lt.py', name=config.dataset_name, split='test')

# Wrap with PyTorch-compatible wrapper
trainset = CIFAR10LTWrapper(hf_train, transform=transform_train)
testset = CIFAR10LTWrapper(hf_test, transform=transform_test)

# Create data loaders
trainloader = DataLoader(trainset, batch_size=config.batch_size_train, 
                        shuffle=True, num_workers=config.num_workers)
testloader = DataLoader(testset, batch_size=config.batch_size_test, 
                       shuffle=False, num_workers=config.num_workers)

print(f"\nTraining samples: {len(trainset)}")
print(f"Test samples: {len(testset)}")
print(f"Number of classes: {num_classes}")

# Visualize distribution
train_dist = trainset.get_class_distribution()
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

fig, ax = plt.subplots(figsize=(12, 5))
ax.bar(range(num_classes), [train_dist[i] for i in range(num_classes)], alpha=0.7)
ax.set_xlabel('Class', fontweight='bold')
ax.set_ylabel('Number of Training Samples', fontweight='bold')
ax.set_title(f'Long-Tailed CIFAR-10 Distribution (dataset={config.dataset_name}, IR={config.imbalance_ratio})', fontweight='bold')
ax.set_xticks(range(num_classes))
ax.set_xticklabels(class_names, rotation=45, ha='right')
ax.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.savefig(os.path.join(config.save_dir, 'dataset_distribution.png'), dpi=150)
plt.show()

print(f"\nClass distribution saved to {config.save_dir}/dataset_distribution.png")

## 5. Model Setup

In [None]:
# Create model
print(f"Creating model: {config.arch}")

model_num_classes = num_classes + 1  # Extra dimension for abstention
model = models.__dict__[config.arch](num_classes=model_num_classes, input_size=input_size)

if use_cuda:
    model = torch.nn.DataParallel(model.cuda())
    cudnn.benchmark = True

total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params/1e6:.2f}M")

# Setup loss function
criterion = SelfAdativeTraining(num_examples=len(trainset), num_classes=num_classes, mom=config.sat_momentum)

# Setup optimizer
optimizer = optim.SGD(model.parameters(), lr=config.lr, momentum=config.momentum, weight_decay=config.weight_decay)

print(f"Loss function: {config.loss_type}")
print(f"Optimizer: SGD (lr={config.lr}, momentum={config.momentum}, weight_decay={config.weight_decay})")

## 6. Training Functions with Group-Aware Metrics

In [None]:
def train_epoch(trainloader, model, criterion, optimizer, epoch, use_cuda, config):
    """Train for one epoch"""
    model.train()
    
    losses = AverageMeter()
    top1 = AverageMeter()
    
    # Track per-class accuracy
    class_correct = np.zeros(num_classes)
    class_total = np.zeros(num_classes)
    
    for batch_idx, batch_data in enumerate(trainloader):
        inputs, targets, indices = batch_data
        
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        
        # Forward pass
        outputs = model(inputs)
        
        # Calculate loss
        if epoch >= config.pretrain_epochs:
            loss = criterion(outputs, targets, indices)
        else:
            loss = F.cross_entropy(outputs[:, :-1], targets)
        
        # Measure accuracy
        prec1 = accuracy(outputs.data, targets.data, topk=(1,))[0]
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        
        # Per-class statistics
        _, predicted = outputs[:, :-1].max(1)
        for i in range(num_classes):
            mask = targets == i
            if mask.sum() > 0:
                class_correct[i] += (predicted[mask] == targets[mask]).sum().item()
                class_total[i] += mask.sum().item()
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # Calculate balanced accuracy
    class_acc = []
    for i in range(num_classes):
        if class_total[i] > 0:
            class_acc.append(100.0 * class_correct[i] / class_total[i])
        else:
            class_acc.append(0.0)
    
    balanced_acc = np.mean(class_acc)
    worst_acc = np.min(class_acc)
    
    return losses.avg, top1.avg, balanced_acc, worst_acc

def test_epoch(testloader, model, criterion, epoch, use_cuda, config):
    """Test for one epoch with group-aware metrics"""
    model.eval()
    
    losses = AverageMeter()
    top1 = AverageMeter()
    
    # Track per-class statistics
    class_correct = np.zeros(num_classes)
    class_total = np.zeros(num_classes)
    
    # Store predictions for coverage analysis
    all_reservations = []
    all_predictions = []
    all_targets = []
    all_correct = []
    
    with torch.no_grad():
        for batch_idx, batch_data in enumerate(testloader):
            inputs, targets, indices = batch_data
            
            if use_cuda:
                inputs = inputs.cuda()
            
            outputs = model(inputs).cpu()
            
            # Calculate loss
            if epoch >= config.pretrain_epochs:
                loss = F.cross_entropy(outputs[:, :-1], targets)
                
                # Get reservation scores
                outputs_soft = F.softmax(outputs, dim=1)
                class_probs = outputs_soft[:, :-1]
                reservation = outputs_soft[:, -1]
                
                _, predictions = class_probs.max(1)
            else:
                loss = F.cross_entropy(outputs[:, :-1], targets)
                _, predictions = outputs[:, :-1].max(1)
                reservation = torch.zeros(len(targets))
            
            # Overall accuracy
            prec1 = accuracy(outputs[:, :-1].data, targets.data, topk=(1,))[0]
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1.item(), inputs.size(0))
            
            # Per-class statistics
            correct = predictions == targets
            for i in range(num_classes):
                mask = targets == i
                if mask.sum() > 0:
                    class_correct[i] += correct[mask].sum().item()
                    class_total[i] += mask.sum().item()
            
            # Store for coverage analysis
            all_reservations.extend(reservation.numpy().tolist())
            all_predictions.extend(predictions.numpy().tolist())
            all_targets.extend(targets.numpy().tolist())
            all_correct.extend(correct.numpy().tolist())
    
    # Calculate balanced and worst-group accuracy
    class_acc = []
    for i in range(num_classes):
        if class_total[i] > 0:
            class_acc.append(100.0 * class_correct[i] / class_total[i])
        else:
            class_acc.append(0.0)
    
    balanced_acc = np.mean(class_acc)
    worst_acc = np.min(class_acc)
    
    # Calculate coverage-based metrics
    coverage_results = calculate_group_coverage_metrics(
        np.array(all_reservations),
        np.array(all_predictions),
        np.array(all_targets),
        np.array(all_correct),
        config.expected_coverage
    )
    
    return losses.avg, top1.avg, balanced_acc, worst_acc, class_acc, coverage_results

def calculate_group_coverage_metrics(reservations, predictions, targets, correct, coverages):
    """Calculate balanced and worst-group error at different coverage levels"""
    results = []
    
    # Sort by reservation (high to low - high reservation = less confident = more likely to abstain)
    sorted_indices = np.argsort(reservations)[::-1]
    sorted_correct = correct[sorted_indices]
    sorted_targets = targets[sorted_indices]
    
    for coverage in coverages:
        # Select samples based on coverage (exclude high reservation samples)
        n_samples = int(len(reservations) * coverage / 100)
        if n_samples == 0:
            continue
        
        # Take samples starting from lowest reservation (most confident)
        selected_correct = sorted_correct[-n_samples:]
        selected_targets = sorted_targets[-n_samples:]
        
        # Calculate per-class error
        class_errors = []
        for cls in range(num_classes):
            mask = selected_targets == cls
            if mask.sum() > 0:
                cls_error = 100.0 * (1 - selected_correct[mask].mean())
                class_errors.append(cls_error)
            else:
                class_errors.append(0.0)  # No samples for this class at this coverage
        
        balanced_error = np.mean(class_errors)
        worst_error = np.max(class_errors)
        overall_error = 100.0 * (1 - selected_correct.mean())
        
        results.append({
            'coverage': coverage,
            'overall_error': overall_error,
            'balanced_error': balanced_error,
            'worst_group_error': worst_error,
            'class_errors': class_errors
        })
    
    return results

print("Training and evaluation functions defined.")

## 7. Training Loop

In [None]:
# Training history
history = {
    'epoch': [],
    'train_loss': [],
    'train_acc': [],
    'train_balanced_acc': [],
    'train_worst_acc': [],
    'test_loss': [],
    'test_acc': [],
    'test_balanced_acc': [],
    'test_worst_acc': [],
    'test_class_acc': [],
    'coverage_results': []
}

print(f"Starting training for {config.epochs} epochs...")
print("=" * 100)
print(f"{'Epoch':>5} | {'LR':>10} | {'Train Acc':>10} | {'Bal Acc':>10} | {'Worst Acc':>10} | {'Test Acc':>10} | {'Bal Acc':>10} | {'Worst Acc':>10} | {'Time':>6} | {'Status'}")
print("=" * 100)

best_balanced_acc = 0.0
best_worst_acc = 0.0
start_time = time.time()

for epoch in range(config.epochs):
    epoch_start = time.time()
    
    # Adjust learning rate
    if epoch in config.schedule:
        config.lr *= config.gamma
        for param_group in optimizer.param_groups:
            param_group['lr'] = config.lr
    
    # Train
    train_loss, train_acc, train_bal_acc, train_worst_acc = train_epoch(
        trainloader, model, criterion, optimizer, epoch, use_cuda, config
    )
    
    # Test
    test_loss, test_acc, test_bal_acc, test_worst_acc, test_class_acc, coverage_results = test_epoch(
        testloader, model, criterion, epoch, use_cuda, config
    )
    
    # Save history
    history['epoch'].append(epoch + 1)
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['train_balanced_acc'].append(train_bal_acc)
    history['train_worst_acc'].append(train_worst_acc)
    history['test_loss'].append(test_loss)
    history['test_acc'].append(test_acc)
    history['test_balanced_acc'].append(test_bal_acc)
    history['test_worst_acc'].append(test_worst_acc)
    history['test_class_acc'].append(test_class_acc)
    history['coverage_results'].append(coverage_results)
    
    epoch_time = time.time() - epoch_start
    
    # Track best models
    status = ""
    if test_bal_acc > best_balanced_acc:
        best_balanced_acc = test_bal_acc
        checkpoint_path = os.path.join(config.save_dir, 'best_balanced_model.pth')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'balanced_acc': best_balanced_acc,
        }, checkpoint_path)
        status += "*BAL* "
    
    if test_worst_acc > best_worst_acc:
        best_worst_acc = test_worst_acc
        checkpoint_path = os.path.join(config.save_dir, 'best_worst_model.pth')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'worst_acc': best_worst_acc,
        }, checkpoint_path)
        status += "*WORST*"
    
    # Compact output
    print(f"{epoch+1:5d} | {config.lr:10.6f} | {train_acc:9.2f}% | {train_bal_acc:9.2f}% | {train_worst_acc:9.2f}% | "
          f"{test_acc:9.2f}% | {test_bal_acc:9.2f}% | {test_worst_acc:9.2f}% | {epoch_time:5.1f}s | {status}")
    
    # Save periodic checkpoint
    if (epoch + 1) % 50 == 0 or (epoch + 1) == config.epochs:
        checkpoint_path = os.path.join(config.save_dir, f'checkpoint_epoch_{epoch+1}.pth')
        torch.save(model, checkpoint_path)

total_time = time.time() - start_time
print("\n" + "=" * 100)
print(f"Training completed in {total_time/3600:.2f} hours")
print(f"Best balanced accuracy: {best_balanced_acc:.2f}%")
print(f"Best worst-group accuracy: {best_worst_acc:.2f}%")
print("=" * 100)

## 8. Visualization: Training Dynamics

In [None]:
# Training curves
fig, axes = plt.subplots(2, 2, figsize=(16, 10))

# 1. Overall accuracy
ax = axes[0, 0]
ax.plot(history['epoch'], history['train_acc'], label='Train', linewidth=2, marker='o', markersize=2)
ax.plot(history['epoch'], history['test_acc'], label='Test', linewidth=2, marker='s', markersize=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Accuracy (%)')
ax.set_title('Overall Accuracy', fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)

# 2. Balanced accuracy
ax = axes[0, 1]
ax.plot(history['epoch'], history['train_balanced_acc'], label='Train Balanced', linewidth=2, marker='o', markersize=2)
ax.plot(history['epoch'], history['test_balanced_acc'], label='Test Balanced', linewidth=2, marker='s', markersize=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Balanced Accuracy (%)')
ax.set_title('Balanced Accuracy (Average Across Classes)', fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)

# 3. Worst-group accuracy
ax = axes[1, 0]
ax.plot(history['epoch'], history['train_worst_acc'], label='Train Worst', linewidth=2, marker='o', markersize=2, color='red')
ax.plot(history['epoch'], history['test_worst_acc'], label='Test Worst', linewidth=2, marker='s', markersize=2, color='darkred')
ax.set_xlabel('Epoch')
ax.set_ylabel('Worst-Group Accuracy (%)')
ax.set_title('Worst-Group Accuracy (Minimum Across Classes)', fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)

# 4. All metrics together
ax = axes[1, 1]
ax.plot(history['epoch'], history['test_acc'], label='Overall', linewidth=2)
ax.plot(history['epoch'], history['test_balanced_acc'], label='Balanced', linewidth=2)
ax.plot(history['epoch'], history['test_worst_acc'], label='Worst-Group', linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Test Accuracy (%)')
ax.set_title('Comparison of Test Metrics', fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(config.save_dir, 'training_curves.png'), dpi=150)
plt.show()

print(f"Training curves saved to {config.save_dir}/training_curves.png")

## 9. Per-Class Performance Analysis

In [None]:
# Analyze final per-class performance
final_class_acc = history['test_class_acc'][-1]
train_dist = trainset.get_class_distribution()

# Create DataFrame
df_class_perf = pd.DataFrame({
    'Class': class_names,
    'Train Samples': [train_dist[i] for i in range(num_classes)],
    'Test Accuracy (%)': final_class_acc,
    'Test Error (%)': [100 - acc for acc in final_class_acc]
})

print("\nFinal Per-Class Performance:")
display(df_class_perf)

# Visualization
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# 1. Accuracy by class
ax = axes[0]
colors = plt.cm.RdYlGn(np.array(final_class_acc) / 100)
bars = ax.bar(range(num_classes), final_class_acc, color=colors, alpha=0.7, edgecolor='black')
ax.set_xlabel('Class', fontweight='bold')
ax.set_ylabel('Test Accuracy (%)', fontweight='bold')
ax.set_title('Per-Class Test Accuracy', fontweight='bold', fontsize=14)
ax.set_xticks(range(num_classes))
ax.set_xticklabels(class_names, rotation=45, ha='right')
ax.axhline(y=np.mean(final_class_acc), color='blue', linestyle='--', linewidth=2, label=f'Balanced: {np.mean(final_class_acc):.1f}%')
ax.axhline(y=np.min(final_class_acc), color='red', linestyle='--', linewidth=2, label=f'Worst: {np.min(final_class_acc):.1f}%')
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

# Add value labels
for i, acc in enumerate(final_class_acc):
    ax.text(i, acc + 1, f'{acc:.1f}', ha='center', fontsize=9)

# 2. Training samples vs accuracy
ax = axes[1]
train_samples = [train_dist[i] for i in range(num_classes)]
scatter = ax.scatter(train_samples, final_class_acc, s=200, alpha=0.6, c=range(num_classes), 
                     cmap='tab10', edgecolors='black', linewidth=2)
ax.set_xlabel('Number of Training Samples', fontweight='bold')
ax.set_ylabel('Test Accuracy (%)', fontweight='bold')
ax.set_title('Training Sample Count vs Test Accuracy', fontweight='bold', fontsize=14)
ax.set_xscale('log')
ax.grid(True, alpha=0.3)

# Add class labels
for i, name in enumerate(class_names):
    ax.annotate(name, (train_samples[i], final_class_acc[i]), 
                fontsize=8, alpha=0.7, xytext=(5, 5), textcoords='offset points')

plt.tight_layout()
plt.savefig(os.path.join(config.save_dir, 'per_class_performance.png'), dpi=150)
plt.show()

print(f"\nPer-class analysis saved to {config.save_dir}/per_class_performance.png")

## 10. Coverage-Based Analysis: Balanced and Worst-Group Error

In [None]:
# Extract final coverage results
final_coverage = history['coverage_results'][-1]

# Create DataFrame
df_coverage = pd.DataFrame(final_coverage)

print("\nCoverage-Based Error Analysis:")
print("="*80)
display(df_coverage[['coverage', 'overall_error', 'balanced_error', 'worst_group_error']])

# Visualization
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# 1. Coverage vs Error curves
ax = axes[0]
ax.plot(df_coverage['coverage'], df_coverage['overall_error'], 
        marker='o', markersize=8, linewidth=2.5, label='Overall Error', color='blue')
ax.plot(df_coverage['coverage'], df_coverage['balanced_error'], 
        marker='s', markersize=8, linewidth=2.5, label='Balanced Error', color='green')
ax.plot(df_coverage['coverage'], df_coverage['worst_group_error'], 
        marker='^', markersize=8, linewidth=2.5, label='Worst-Group Error', color='red')
ax.set_xlabel('Coverage (%)', fontweight='bold', fontsize=12)
ax.set_ylabel('Error Rate (%)', fontweight='bold', fontsize=12)
ax.set_title('Selective Classification: Coverage vs Error (Long-Tailed)', fontweight='bold', fontsize=14)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
ax.invert_xaxis()

# Highlight key coverage points
for cov in [100, 95, 90, 80]:
    row = df_coverage[df_coverage['coverage'] == cov]
    if len(row) > 0:
        bal_err = row['balanced_error'].values[0]
        worst_err = row['worst_group_error'].values[0]
        ax.annotate(f'B:{bal_err:.1f}%\nW:{worst_err:.1f}%', 
                    xy=(cov, worst_err), xytext=(10, 10), 
                    textcoords='offset points', fontsize=8,
                    bbox=dict(boxstyle='round,pad=0.3', facecolor='yellow', alpha=0.5))

# 2. Error reduction comparison
ax = axes[1]
overall_reduction = df_coverage['overall_error'].iloc[0] - df_coverage['overall_error']
balanced_reduction = df_coverage['balanced_error'].iloc[0] - df_coverage['balanced_error']
worst_reduction = df_coverage['worst_group_error'].iloc[0] - df_coverage['worst_group_error']

x = np.arange(len(df_coverage))
width = 0.25

ax.bar(x - width, overall_reduction, width, label='Overall', alpha=0.8, color='blue')
ax.bar(x, balanced_reduction, width, label='Balanced', alpha=0.8, color='green')
ax.bar(x + width, worst_reduction, width, label='Worst-Group', alpha=0.8, color='red')

ax.set_xlabel('Coverage Level', fontweight='bold', fontsize=12)
ax.set_ylabel('Error Reduction from 100% Coverage (%)', fontweight='bold', fontsize=12)
ax.set_title('Error Reduction at Different Coverage Levels', fontweight='bold', fontsize=14)
ax.set_xticks(x)
ax.set_xticklabels([f"{int(c)}%" for c in df_coverage['coverage']], rotation=45)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3, axis='y')
ax.axhline(y=0, color='black', linestyle='-', linewidth=0.5)

plt.tight_layout()
plt.savefig(os.path.join(config.save_dir, 'coverage_analysis.png'), dpi=150)
plt.show()

print(f"\nCoverage analysis saved to {config.save_dir}/coverage_analysis.png")

## 11. AURC Calculation: Balanced and Worst-Group Risk

In [None]:
def calculate_aurc(coverages, errors):
    """
    Calculate Area Under the Risk-Coverage curve
    AURC = integral of error rate over coverage
    Lower is better
    """
    # Sort by coverage (descending)
    sorted_indices = np.argsort(coverages)[::-1]
    sorted_cov = np.array(coverages)[sorted_indices]
    sorted_err = np.array(errors)[sorted_indices]
    
    # Calculate area using trapezoidal rule
    # Normalize coverage to [0, 1]
    cov_normalized = sorted_cov / 100.0
    err_normalized = sorted_err / 100.0
    
    aurc = np.trapz(err_normalized, cov_normalized)
    return aurc

# Calculate AURC for different metrics
coverages = df_coverage['coverage'].values
overall_errors = df_coverage['overall_error'].values
balanced_errors = df_coverage['balanced_error'].values
worst_errors = df_coverage['worst_group_error'].values

aurc_overall = calculate_aurc(coverages, overall_errors)
aurc_balanced = calculate_aurc(coverages, balanced_errors)
aurc_worst = calculate_aurc(coverages, worst_errors)

print("\n" + "="*80)
print("AURC (Area Under Risk-Coverage Curve) - Lower is Better")
print("="*80)
print(f"AURC (Overall):     {aurc_overall:.6f}")
print(f"AURC (Balanced):    {aurc_balanced:.6f}")
print(f"AURC (Worst-Group): {aurc_worst:.6f}")
print("="*80)

# Visualize AURC
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for idx, (ax, errors, title, color, aurc) in enumerate([
    (axes[0], overall_errors, 'Overall Error', 'blue', aurc_overall),
    (axes[1], balanced_errors, 'Balanced Error', 'green', aurc_balanced),
    (axes[2], worst_errors, 'Worst-Group Error', 'red', aurc_worst)
]):
    # Normalize for visualization
    cov_norm = coverages / 100.0
    err_norm = errors / 100.0
    
    # Plot curve
    ax.plot(cov_norm, err_norm, linewidth=3, color=color, marker='o', markersize=6)
    
    # Fill area under curve
    ax.fill_between(cov_norm, 0, err_norm, alpha=0.3, color=color)
    
    ax.set_xlabel('Coverage (fraction)', fontweight='bold', fontsize=11)
    ax.set_ylabel('Risk (error rate)', fontweight='bold', fontsize=11)
    ax.set_title(f'{title}\nAURC = {aurc:.6f}', fontweight='bold', fontsize=12)
    ax.grid(True, alpha=0.3)
    ax.set_xlim([0, 1])
    ax.set_ylim([0, max(err_norm) * 1.1])
    
    # Add text box with AURC value
    ax.text(0.05, 0.95, f'AURC: {aurc:.6f}', transform=ax.transAxes,
            verticalalignment='top', fontsize=12, fontweight='bold',
            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))

plt.tight_layout()
plt.savefig(os.path.join(config.save_dir, 'aurc_analysis.png'), dpi=150)
plt.show()

print(f"\nAURC analysis saved to {config.save_dir}/aurc_analysis.png")

## 12. Summary Report

In [None]:
# Generate comprehensive summary
print("\n" + "="*100)
print(f"SAT ON LONG-TAILED CIFAR-10 - FINAL REPORT (IR={config.imbalance_ratio})")
print("="*100)

print("\n1. OVERALL PERFORMANCE:")
print("-" * 100)
print(f"Final Test Accuracy (Overall):     {history['test_acc'][-1]:.2f}%")
print(f"Final Test Accuracy (Balanced):    {history['test_balanced_acc'][-1]:.2f}%")
print(f"Final Test Accuracy (Worst-Group): {history['test_worst_acc'][-1]:.2f}%")
print(f"\nBest Balanced Accuracy:             {best_balanced_acc:.2f}%")
print(f"Best Worst-Group Accuracy:          {best_worst_acc:.2f}%")

print("\n2. PER-CLASS ACCURACY:")
print("-" * 100)
for i, (name, acc, samples) in enumerate(zip(class_names, final_class_acc, [train_dist[i] for i in range(num_classes)])):
    print(f"Class {i} ({name:12s}): {acc:6.2f}%  (train samples: {samples:5d})")

print("\n3. COVERAGE-BASED ANALYSIS (Key Coverage Levels):")
print("-" * 100)
print(f"{'Coverage':<10} | {'Overall Err':<12} | {'Balanced Err':<13} | {'Worst Err':<10} | {'Bal Reduction':<14} | {'Worst Reduction'}")
print("-" * 100)

for cov in [100, 99, 95, 90, 85, 80, 75, 70]:
    row = df_coverage[df_coverage['coverage'] == cov]
    if len(row) > 0:
        overall_err = row['overall_error'].values[0]
        bal_err = row['balanced_error'].values[0]
        worst_err = row['worst_group_error'].values[0]
        
        bal_reduction = df_coverage['balanced_error'].iloc[0] - bal_err
        worst_reduction = df_coverage['worst_group_error'].iloc[0] - worst_err
        
        print(f"{cov:<10.0f} | {overall_err:<12.2f} | {bal_err:<13.2f} | {worst_err:<10.2f} | "
              f"{bal_reduction:<14.2f} | {worst_reduction:.2f}")

print("\n4. AURC (Area Under Risk-Coverage Curve):")
print("-" * 100)
print(f"AURC (Overall):     {aurc_overall:.6f}")
print(f"AURC (Balanced):    {aurc_balanced:.6f}  [Balanced risk metric]")
print(f"AURC (Worst-Group): {aurc_worst:.6f}  [Worst-case risk metric]")

print("\n5. KEY INSIGHTS:")
print("-" * 100)

# Calculate insights
bal_gap = history['test_acc'][-1] - history['test_balanced_acc'][-1]
worst_gap = history['test_acc'][-1] - history['test_worst_acc'][-1]
class_variance = np.var(final_class_acc)

# Error reduction at 80% coverage
row_100 = df_coverage[df_coverage['coverage'] == 100].iloc[0]
row_80 = df_coverage[df_coverage['coverage'] == 80].iloc[0]
bal_reduction_80 = row_100['balanced_error'] - row_80['balanced_error']
worst_reduction_80 = row_100['worst_group_error'] - row_80['worst_group_error']

print(f"• Performance Gap (Overall vs Balanced):   {bal_gap:.2f}%")
print(f"• Performance Gap (Overall vs Worst):     {worst_gap:.2f}%")
print(f"• Class Accuracy Variance:                {class_variance:.2f}")
print(f"• Balanced Error Reduction (100% → 80%):  {bal_reduction_80:.2f}%")
print(f"• Worst Error Reduction (100% → 80%):     {worst_reduction_80:.2f}%")

if bal_gap > 5:
    print(f"\n⚠️  Large gap between overall and balanced accuracy suggests class imbalance impact")
if worst_gap > 10:
    print(f"⚠️  Very large gap to worst-group indicates severe performance disparity")
if worst_reduction_80 < bal_reduction_80:
    print(f"⚠️  Selective classification less effective for worst-performing groups")

print("\n" + "="*100)

# Save report to file
report_path = os.path.join(config.save_dir, 'summary_report.txt')
with open(report_path, 'w') as f:
    f.write(f"SAT ON LONG-TAILED CIFAR-10 - SUMMARY REPORT (IR={config.imbalance_ratio})\n")
    f.write("="*100 + "\n\n")
    f.write(f"Final Test Accuracy (Overall):     {history['test_acc'][-1]:.2f}%\n")
    f.write(f"Final Test Accuracy (Balanced):    {history['test_balanced_acc'][-1]:.2f}%\n")
    f.write(f"Final Test Accuracy (Worst-Group): {history['test_worst_acc'][-1]:.2f}%\n\n")
    f.write(f"AURC (Overall):     {aurc_overall:.6f}\n")
    f.write(f"AURC (Balanced):    {aurc_balanced:.6f}\n")
    f.write(f"AURC (Worst-Group): {aurc_worst:.6f}\n")

print(f"\nSummary report saved to: {report_path}")

# Save coverage results to CSV
csv_path = os.path.join(config.save_dir, 'coverage_results.csv')
df_coverage[['coverage', 'overall_error', 'balanced_error', 'worst_group_error']].to_csv(csv_path, index=False)
print(f"Coverage results saved to: {csv_path}")

print("\n✅ Analysis complete!")