# CIFAR-10 – Step 2: Data → f(x), g(x)

This notebook walks through:
1. Loading CIFAR-10 with train/val/cal/test splits
2. Training a small ResNet18
3. Extracting logits (f(x))
4. Computing selection scores g(x): MSP, margin, entropy, energy
5. (Optional) Temperature scaling on calibration split

Use the companion module `cifar10_selective.py` for reusable functions.


In [None]:
python cifar10_selective.py \
  --data_root ./data \
  --epochs 5 \
  --batch_size 256 \
  --device cuda \
  --fit_temperature \
  --save_npz ./artifacts/cifar10_scores.npz

In [1]:
# !pip install torch torchvision scikit-learn  # Uncomment if needed
import sys, os
sys.path.append('/mnt/data')  # so we can import the helper module directly here during preview
from cifar10_selective import (
    TrainConfig, set_seed, get_cifar10_dataloaders, create_resnet18_cifar,
    train_model, evaluate, get_logits, compute_scores_from_logits, fit_temperature
)
import numpy as np


In [2]:
# Config
cfg = TrainConfig(
    data_root='./data',
    batch_size=256,
    epochs=5,            # start small; bump later
    lr=0.1,
    weight_decay=5e-4,
    momentum=0.9,
    seed=0,
    val_ratio=0.1,
    cal_ratio=0.2,
    num_workers=4,
    device='cuda'
)
set_seed(cfg.seed)


In [3]:
# Data
loaders = get_cifar10_dataloaders(
    data_root=cfg.data_root,
    batch_size=cfg.batch_size,
    val_ratio=cfg.val_ratio,
    cal_ratio=cfg.cal_ratio,
    seed=cfg.seed,
    num_workers=cfg.num_workers,
)
for k,v in loaders.items():
    print(k, len(v.dataset))


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
train 35000
val 5000
cal 10000
test 10000


In [4]:
# Model + quick training
model = create_resnet18_cifar(num_classes=10)
model = train_model(model, loaders, cfg)
val_loss, val_acc = evaluate(model, loaders['val'], device=cfg.device)
test_loss, test_acc = evaluate(model, loaders['test'], device=cfg.device)
print({'val_loss': val_loss, 'val_acc': val_acc, 'test_loss': test_loss, 'test_acc': test_acc})


[Epoch 01] train_loss=2.2233  val_loss=1.6820  val_acc=0.3832
[Epoch 02] train_loss=1.6769  val_loss=1.5713  val_acc=0.4488
[Epoch 03] train_loss=1.4396  val_loss=1.3499  val_acc=0.5106
[Epoch 04] train_loss=1.2977  val_loss=1.2622  val_acc=0.5516
[Epoch 05] train_loss=1.1823  val_loss=1.1284  val_acc=0.5922
{'val_loss': 1.1283957069396973, 'val_acc': 0.5922, 'test_loss': 1.1443219212532043, 'test_acc': 0.5851}


In [5]:
# (Optional) Temperature scaling using calibration split
temperature = fit_temperature(model, loaders['cal'], init_temp=1.0, lr=0.01, epochs=50, device=cfg.device)
print('Fitted temperature:', temperature)


Fitted temperature: 1.0050512552261353


In [6]:
# Inference: f(x) logits + selection scores g(x)
out = {}
for split in ['val', 'cal', 'test']:
    logits, y = get_logits(model, loaders[split], device=cfg.device)
    scores = compute_scores_from_logits(logits, temperature=temperature)
    out[f'logits_{split}'] = logits
    out[f'y_{split}'] = y
    for k, v in scores.items():
        out[f'{k}_{split}'] = v
print('Keys:', sorted(out.keys())[:8], '... total=', len(out))


Keys: ['energy_cal', 'energy_test', 'energy_val', 'entropy_cal', 'entropy_test', 'entropy_val', 'logits_cal', 'logits_test'] ... total= 18


In [7]:
# Save artifacts for Algorithm 1 + evaluation later
save_path = './artifacts/cifar10_calpool_test.npz'
os.makedirs(os.path.dirname(save_path), exist_ok=True)
np.savez_compressed(save_path, **out)
print('Saved:', save_path)


Saved: ./artifacts/cifar10_calpool_test.npz
