# Self-Adaptive Training for Selective Classification

This notebook reproduces the SAT paper results on selective classification.

**Settings:**
- Architecture: VGG-16 with Batch Normalization
- Dataset: CIFAR-10
- Epochs: 300
- SAT Momentum: 0.99
- Learning Rate: Decays every 25 epochs by factor of 0.5

## 1. Setup and Imports

In [1]:
import os
import time
import random
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import torchvision.datasets as datasets

# Local imports
import models.cifar as models
import dataset_utils
from loss import SelfAdativeTraining, deep_gambler_loss
from utils import AverageMeter, accuracy

print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("CUDA device:", torch.cuda.get_device_name(0))

PyTorch version: 2.9.1+cu128
CUDA available: True
CUDA device: NVIDIA GeForce RTX 4060 Laptop GPU
CUDA device: NVIDIA GeForce RTX 4060 Laptop GPU


## 2. Configuration

In [2]:
# Configuration
class Config:
    # Training settings
    dataset = 'cifar10'
    arch = 'vgg16_bn'
    loss_type = 'sat'  # 'sat', 'gambler', or 'ce'
    epochs = 3
    pretrain_epochs = 0  # Number of epochs to pretrain with cross-entropy
    
    # Hyperparameters
    batch_size_train = 128
    batch_size_test = 200
    lr = 0.1
    momentum = 0.9
    sat_momentum = 0.99  # Paper uses 0.99 for selective classification
    weight_decay = 5e-4
    gamma = 0.5  # LR decay factor
    schedule = [25, 50, 75, 100, 125, 150, 175, 200, 225, 250, 275]  # LR decay epochs
    
    # Evaluation settings
    expected_coverage = [100., 99., 98., 97., 95., 90., 85., 80., 75., 70., 60., 50., 40., 30., 20., 10.]
    reward = 2.2  # For gambler loss
    
    # System settings
    gpu_id = '0'
    num_workers = 4
    manual_seed = 42
    save_dir = './checkpoints/sat_notebook'
    
config = Config()

# Set random seeds
random.seed(config.manual_seed)
torch.manual_seed(config.manual_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(config.manual_seed)

# Set GPU
os.environ['CUDA_VISIBLE_DEVICES'] = config.gpu_id
use_cuda = torch.cuda.is_available()

# Create save directory
os.makedirs(config.save_dir, exist_ok=True)

print(f"Configuration:")
print(f"  Dataset: {config.dataset}")
print(f"  Architecture: {config.arch}")
print(f"  Loss: {config.loss_type}")
print(f"  Epochs: {config.epochs}")
print(f"  SAT Momentum: {config.sat_momentum}")
print(f"  Save directory: {config.save_dir}")

Configuration:
  Dataset: cifar10
  Architecture: vgg16_bn
  Loss: sat
  Epochs: 3
  SAT Momentum: 0.99
  Save directory: ./checkpoints/sat_notebook


## 3. Dataset Preparation

In [3]:
# Prepare dataset
print(f'Preparing dataset: {config.dataset}')

if config.dataset == 'cifar10':
    num_classes = 10
    input_size = 32
    
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])
    
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])
    
    trainset = dataset_utils.C10(root='~/datasets/CIFAR10', train=True, download=True, transform=transform_train)
    testset = dataset_utils.C10(root='~/datasets/CIFAR10', train=False, download=True, transform=transform_test)

elif config.dataset == 'svhn':
    num_classes = 10
    input_size = 32
    
    transform_train = transforms.Compose([
        transforms.RandomRotation(15),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    trainset = dataset_utils.SVHN(root='~/datasets/SVHN', split='train', download=True, transform=transform_train)
    testset = dataset_utils.SVHN(root='~/datasets/SVHN', split='test', download=True, transform=transform_test)

# Create data loaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=config.batch_size_train, 
                                         shuffle=True, num_workers=config.num_workers)
testloader = torch.utils.data.DataLoader(testset, batch_size=config.batch_size_test, 
                                        shuffle=False, num_workers=config.num_workers)

print(f"Training samples: {len(trainset)}")
print(f"Test samples: {len(testset)}")
print(f"Number of classes: {num_classes}")

Preparing dataset: cifar10


  entry = pickle.load(f, encoding="latin1")


Training samples: 50000
Test samples: 10000
Number of classes: 10


## 4. Model Setup

In [4]:
# Create model
print(f"Creating model: {config.arch}")

# Model has num_classes+1 outputs for selective classification (extra dimension for abstention)
model_num_classes = num_classes if config.loss_type == 'ce' else num_classes + 1
model = models.__dict__[config.arch](num_classes=model_num_classes, input_size=input_size)

if use_cuda:
    model = torch.nn.DataParallel(model.cuda())
    cudnn.benchmark = True

total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params/1e6:.2f}M")

# Setup loss function
if config.loss_type == 'ce':
    criterion = nn.CrossEntropyLoss()
elif config.loss_type == 'gambler':
    criterion = deep_gambler_loss
elif config.loss_type == 'sat':
    criterion = SelfAdativeTraining(num_examples=len(trainset), num_classes=num_classes, mom=config.sat_momentum)

# Setup optimizer
optimizer = optim.SGD(model.parameters(), lr=config.lr, momentum=config.momentum, weight_decay=config.weight_decay)

print(f"Loss function: {config.loss_type}")
print(f"Optimizer: SGD (lr={config.lr}, momentum={config.momentum}, weight_decay={config.weight_decay})")

Creating model: vgg16_bn
Total parameters: 14.99M
Loss function: sat
Optimizer: SGD (lr=0.1, momentum=0.9, weight_decay=0.0005)
Total parameters: 14.99M
Loss function: sat
Optimizer: SGD (lr=0.1, momentum=0.9, weight_decay=0.0005)


## 5. Training Functions

In [None]:
def adjust_learning_rate(optimizer, epoch, config):
    """Adjust learning rate according to schedule"""
    if epoch in config.schedule:
        config.lr *= config.gamma
        for param_group in optimizer.param_groups:
            param_group['lr'] = config.lr
        print(f"Learning rate adjusted to: {config.lr}")

def train_epoch(trainloader, model, criterion, optimizer, epoch, use_cuda, config, verbose=False):
    """Train for one epoch"""
    model.train()
    
    losses = AverageMeter()
    top1 = AverageMeter()
    
    for batch_idx, batch_data in enumerate(trainloader):
        inputs, targets, indices = batch_data
        
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        
        # Forward pass
        outputs = model(inputs)
        
        # Calculate loss
        if epoch >= config.pretrain_epochs:
            if config.loss_type == 'gambler':
                loss = criterion(outputs, targets, config.reward)
            elif config.loss_type == 'sat':
                loss = criterion(outputs, targets, indices)
            else:
                loss = criterion(outputs, targets)
        else:
            # Pretrain with cross-entropy on class logits only
            loss = F.cross_entropy(outputs[:, :-1], targets)
        
        # Measure accuracy
        prec1 = accuracy(outputs.data, targets.data, topk=(1,))[0]
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    return losses.avg, top1.avg

def test_epoch(testloader, model, criterion, epoch, use_cuda, config):
    """Test for one epoch and evaluate coverage"""
    model.eval()
    
    losses = AverageMeter()
    top1 = AverageMeter()
    abstention_results = []
    
    with torch.no_grad():
        for batch_idx, batch_data in enumerate(testloader):
            inputs, targets, indices = batch_data
            
            if use_cuda:
                inputs = inputs.cuda()
            
            outputs = model(inputs).cpu()
            values, predictions = outputs.data.max(1)
            
            # Calculate loss and abstention results
            if epoch >= config.pretrain_epochs:
                if config.loss_type == 'gambler':
                    loss = criterion(outputs, targets, config.reward)
                elif config.loss_type == 'sat':
                    loss = F.cross_entropy(outputs[:, :-1], targets)
                else:
                    loss = criterion(outputs, targets)
                
                outputs = F.softmax(outputs, dim=1)
                outputs, reservation = outputs[:, :-1], outputs[:, -1]
                abstention_results.extend(zip(list(reservation.numpy()), 
                                            list(predictions.eq(targets.data).numpy())))
            else:
                loss = F.cross_entropy(outputs[:, :-1], targets)
            
            prec1 = accuracy(outputs.data, targets.data, topk=(1,))[0]
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1.item(), inputs.size(0))
    
    # Calculate coverage-based accuracy
    coverage_results = {}
    if epoch >= config.pretrain_epochs and len(abstention_results) > 0:
        abstention_results.sort(key=lambda x: x[0], reverse=True)
        sorted_correct = list(map(lambda x: int(x[1]), abstention_results))
        size = len(abstention_results)
        
        for coverage in config.expected_coverage:
            num_samples = size - int(size * (100 - coverage) / 100)
            if num_samples > 0:
                acc = sum(sorted_correct[:num_samples]) / num_samples * 100
                coverage_results[coverage] = acc
    
    return losses.avg, top1.avg, coverage_results

print("Training functions defined.")

Training functions defined.


## 6. Training Loop

In [None]:
# Training history
history = {
    'epoch': [],
    'train_loss': [],
    'train_acc': [],
    'test_loss': [],
    'test_acc': [],
    'coverage_results': []
}

print(f"Starting training for {config.epochs} epochs...")
print("=" * 80)

best_acc = 0.0
start_time = time.time()

for epoch in range(config.epochs):
    epoch_start = time.time()
    
    # Adjust learning rate (silently)
    if epoch in config.schedule:
        config.lr *= config.gamma
        for param_group in optimizer.param_groups:
            param_group['lr'] = config.lr
    
    # Train
    train_loss, train_acc = train_epoch(trainloader, model, criterion, optimizer, epoch, use_cuda, config)
    
    # Test
    test_loss, test_acc, coverage_results = test_epoch(testloader, model, criterion, epoch, use_cuda, config)
    
    # Save history
    history['epoch'].append(epoch + 1)
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['test_loss'].append(test_loss)
    history['test_acc'].append(test_acc)
    history['coverage_results'].append(coverage_results)
    
    epoch_time = time.time() - epoch_start
    
    # One-line summary
    is_best = ""
    if test_acc > best_acc:
        best_acc = test_acc
        checkpoint_path = os.path.join(config.save_dir, 'best_model.pth')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_acc': best_acc,
        }, checkpoint_path)
        is_best = " *BEST*"
    
    # Compact one-line output
    cov_str = ""
    if coverage_results and 95 in coverage_results:
        cov_str = f" | Cov@95: {coverage_results[95]:.2f}%"
    
    print(f"Epoch {epoch+1:3d}/{config.epochs} | LR: {config.lr:.6f} | Train: {train_acc:5.2f}% | Test: {test_acc:5.2f}%{cov_str} | {epoch_time:5.1f}s{is_best}")
    
    # Save periodic checkpoint
    if (epoch + 1) % 50 == 0 or (epoch + 1) == config.epochs:
        checkpoint_path = os.path.join(config.save_dir, f'checkpoint_epoch_{epoch+1}.pth')
        torch.save(model, checkpoint_path)

total_time = time.time() - start_time
print("\n" + "=" * 80)
print(f"Training completed in {total_time/3600:.2f} hours")
print(f"Best test accuracy: {best_acc:.2f}%")

Starting training for 3 epochs...




  Batch [50/391] Loss: 1.8990 Acc: 30.47%


## 7. Visualization

In [11]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Loss plot
axes[0].plot(history['epoch'], history['train_loss'], label='Train Loss', marker='o', markersize=3)
axes[0].plot(history['epoch'], history['test_loss'], label='Test Loss', marker='s', markersize=3)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Test Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy plot
axes[1].plot(history['epoch'], history['train_acc'], label='Train Accuracy', marker='o', markersize=3)
axes[1].plot(history['epoch'], history['test_acc'], label='Test Accuracy', marker='s', markersize=3)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('Training and Test Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
# plt.savefig(os.path.join(config.save_dir, 'training_curves.png'), dpi=150)
plt.show()

print(f"Training curves saved to {config.save_dir}/training_curves.png")

Training curves saved to ./checkpoints/sat_notebook/training_curves.png


  plt.show()


## 8. Coverage vs Error Analysis

In [8]:
# Get final coverage results
if history['coverage_results'][-1]:
    final_coverage = history['coverage_results'][-1]
    
    coverages = sorted(final_coverage.keys(), reverse=True)
    accuracies = [final_coverage[c] for c in coverages]
    errors = [100 - acc for acc in accuracies]
    
    # Create DataFrame
    df = pd.DataFrame({
        'Coverage (%)': coverages,
        'Accuracy (%)': accuracies,
        'Error (%)': errors
    })
    
    print("\nFinal Coverage vs Error Analysis:")
    print(df.to_string(index=False))
    
    # Save to CSV
    csv_path = os.path.join(config.save_dir, 'coverage_vs_error.csv')
    df.to_csv(csv_path, index=False)
    print(f"\nResults saved to {csv_path}")
    
    # Plot coverage vs error
    plt.figure(figsize=(10, 6))
    plt.plot(coverages, errors, marker='o', linewidth=2, markersize=8)
    plt.xlabel('Coverage (%)', fontsize=12)
    plt.ylabel('Error Rate (%)', fontsize=12)
    plt.title('Coverage vs Error Rate (Final Epoch)', fontsize=14, fontweight='bold')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(config.save_dir, 'coverage_vs_error.png'), dpi=150)
    plt.show()
else:
    print("No coverage results available (model still in pretraining phase or using CE loss)")


Final Coverage vs Error Analysis:
 Coverage (%)  Accuracy (%)  Error (%)
        100.0     32.600000  67.400000
         99.0     32.222222  67.777778
         98.0     31.887755  68.112245
         97.0     31.628866  68.371134
         95.0     31.336842  68.663158
         90.0     30.544444  69.455556
         85.0     29.717647  70.282353
         80.0     28.925000  71.075000
         75.0     28.373333  71.626667
         70.0     27.785714  72.214286
         60.0     26.400000  73.600000
         50.0     25.480000  74.520000
         40.0     24.775000  75.225000
         30.0     24.633333  75.366667
         20.0     25.150000  74.850000
         10.0     23.200000  76.800000

Results saved to ./checkpoints/sat_notebook/coverage_vs_error.csv


  plt.show()


## 9. Comprehensive Evaluation

In [9]:
def evaluate_with_abstention(model, testloader, use_cuda, num_classes, loss_type):
    """Comprehensive evaluation with abstention analysis"""
    model.eval()
    
    all_logits = []
    all_labels = []
    all_reservations = []
    all_predictions = []
    
    with torch.no_grad():
        for batch_data in testloader:
            inputs, targets = batch_data[:2]
            if use_cuda:
                inputs = inputs.cuda()
            
            outputs = model(inputs)
            outputs_soft = F.softmax(outputs, dim=1).cpu()
            
            if loss_type != 'ce':
                class_probs = outputs_soft[:, :-1]
                reservation = outputs_soft[:, -1]
            else:
                class_probs = outputs_soft
                reservation = 1 - outputs_soft.max(1)[0]
            
            predictions = class_probs.max(1)[1]
            
            all_logits.append(class_probs.numpy())
            all_labels.append(targets.numpy())
            all_reservations.append(reservation.numpy())
            all_predictions.append(predictions.numpy())
    
    all_logits = np.concatenate(all_logits)
    all_labels = np.concatenate(all_labels)
    all_reservations = np.concatenate(all_reservations)
    all_predictions = np.concatenate(all_predictions)
    
    # Calculate metrics at different coverage levels
    coverages = [100, 99, 98, 97, 95, 90, 85, 80, 75, 70, 60, 50, 40, 30, 20, 10]
    results = []
    
    # Sort by reservation (ascending - lower reservation = higher confidence)
    sorted_indices = np.argsort(all_reservations)
    
    for coverage in coverages:
        n_samples = int(len(all_labels) * coverage / 100)
        if n_samples == 0:
            continue
        
        # Take most confident samples
        selected_indices = sorted_indices[:n_samples]
        selected_predictions = all_predictions[selected_indices]
        selected_labels = all_labels[selected_indices]
        
        accuracy = (selected_predictions == selected_labels).mean() * 100
        error = 100 - accuracy
        
        results.append({
            'Coverage (%)': coverage,
            'Accuracy (%)': accuracy,
            'Error (%)': error,
            'Num Samples': n_samples
        })
    
    return pd.DataFrame(results)

# Load best model
best_checkpoint = os.path.join(config.save_dir, 'best_model.pth')
if os.path.exists(best_checkpoint):
    checkpoint = torch.load(best_checkpoint)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded best model from epoch {checkpoint['epoch']+1}")
    print(f"Best accuracy: {checkpoint['best_acc']:.2f}%")

# Comprehensive evaluation
print("\nRunning comprehensive evaluation...")
eval_results = evaluate_with_abstention(model, testloader, use_cuda, num_classes, config.loss_type)
print("\nDetailed Evaluation Results:")
print(eval_results.to_string(index=False))

# Save detailed results
eval_csv_path = os.path.join(config.save_dir, 'detailed_evaluation.csv')
eval_results.to_csv(eval_csv_path, index=False)
print(f"\nDetailed evaluation saved to {eval_csv_path}")

Loaded best model from epoch 3
Best accuracy: 32.60%

Running comprehensive evaluation...





Detailed Evaluation Results:
 Coverage (%)  Accuracy (%)  Error (%)  Num Samples
          100     32.600000  67.400000        10000
           99     32.727273  67.272727         9900
           98     32.857143  67.142857         9800
           97     33.000000  67.000000         9700
           95     33.178947  66.821053         9500
           90     33.644444  66.355556         9000
           85     34.023529  65.976471         8500
           80     34.462500  65.537500         8000
           75     35.160000  64.840000         7500
           70     36.014286  63.985714         7000
           60     37.816667  62.183333         6000
           50     39.720000  60.280000         5000
           40     41.900000  58.100000         4000
           30     43.833333  56.166667         3000
           20     47.300000  52.700000         2000
           10     51.100000  48.900000         1000

Detailed evaluation saved to ./checkpoints/sat_notebook/detailed_evaluation.csv


## 10. Summary and Analysis

In [10]:
# Print summary
print("=" * 80)
print("TRAINING SUMMARY")
print("=" * 80)
print(f"Dataset: {config.dataset}")
print(f"Architecture: {config.arch}")
print(f"Loss function: {config.loss_type}")
print(f"Total epochs: {config.epochs}")
print(f"SAT Momentum: {config.sat_momentum}")
print(f"\nFinal Training Accuracy: {history['train_acc'][-1]:.2f}%")
print(f"Final Test Accuracy: {history['test_acc'][-1]:.2f}%")
print(f"Best Test Accuracy: {best_acc:.2f}%")

if eval_results is not None and len(eval_results) > 0:
    print("\n" + "=" * 80)
    print("SELECTIVE CLASSIFICATION PERFORMANCE")
    print("=" * 80)
    
    for coverage in [100, 95, 90, 85, 80, 75, 70]:
        row = eval_results[eval_results['Coverage (%)'] == coverage]
        if len(row) > 0:
            acc = row['Accuracy (%)'].values[0]
            err = row['Error (%)'].values[0]
            print(f"Coverage {coverage:3.0f}% → Accuracy: {acc:6.2f}% | Error: {err:5.2f}%")

print("\n" + "=" * 80)
print("All results and checkpoints saved to:", config.save_dir)
print("=" * 80)

TRAINING SUMMARY
Dataset: cifar10
Architecture: vgg16_bn
Loss function: sat
Total epochs: 3
SAT Momentum: 0.99

Final Training Accuracy: 28.40%
Final Test Accuracy: 32.60%
Best Test Accuracy: 32.60%

SELECTIVE CLASSIFICATION PERFORMANCE
Coverage 100% → Accuracy:  32.60% | Error: 67.40%
Coverage  95% → Accuracy:  33.18% | Error: 66.82%
Coverage  90% → Accuracy:  33.64% | Error: 66.36%
Coverage  85% → Accuracy:  34.02% | Error: 65.98%
Coverage  80% → Accuracy:  34.46% | Error: 65.54%
Coverage  75% → Accuracy:  35.16% | Error: 64.84%
Coverage  70% → Accuracy:  36.01% | Error: 63.99%

All results and checkpoints saved to: ./checkpoints/sat_notebook
