# Hybrid Geometric–Entropy Gating — CIFAR-10-C Benchmark

**Paper:** Huseyin Aydin (2025) — [https://doi.org/10.5281/zenodo.18055798](https://doi.org/10.5281/zenodo.18055798)

**What this notebook does:**
1. Trains a ResNet-18 baseline (standard cross-entropy)
2. Trains a Gated ResNet-18 (Hybrid Geometric–Entropy loss)
3. Evaluates both on **CIFAR-10-C** — 19 corruption types × 5 severity levels
4. Compares accuracy + ECE (calibration error)

**Key design:** Gate distances are computed in ResNet-18 **feature space** (512-dim avgpool),
not raw pixel space — this makes the geometric gate meaningful for images.

> ⚡ Enable GPU: Runtime → Change runtime type → T4 GPU

In [None]:
# ── Setup ─────────────────────────────────────────────────────────────────────
!pip install -q git+https://github.com/hsynposta/entropy-gate.git

# CIFAR-10-C (corrupted test sets) — ~2.5 GB
import os
if not os.path.exists('CIFAR-10-C'):
    print('Downloading CIFAR-10-C (~2.5 GB)...')
    !wget -q https://zenodo.org/record/2535967/files/CIFAR-10-C.tar
    !tar -xf CIFAR-10-C.tar
    print('Done.')
else:
    print('CIFAR-10-C already downloaded.')

In [None]:
# ── Imports ───────────────────────────────────────────────────────────────────
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from torchvision.models import resnet18
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from tqdm.notebook import tqdm

from entropy_gate.gates import GeometricGate, EntropyGate, HybridGate
from entropy_gate.loss  import GatedLoss

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device: {DEVICE}')
torch.manual_seed(42)
np.random.seed(42)

In [None]:
# ── Config ────────────────────────────────────────────────────────────────────
EPOCHS      = 60
BATCH       = 128
LR          = 1e-3
WARMUP      = 10       # epochs before gate activates
ALPHA       = 4.0      # geometric gate steepness
BETA        = 4.0      # entropy gate steepness
N_CLASSES   = 10

CORRUPTIONS = [
    'gaussian_noise','shot_noise','impulse_noise','defocus_blur',
    'glass_blur','motion_blur','zoom_blur','snow','frost','fog',
    'brightness','contrast','elastic_transform','pixelate',
    'jpeg_compression','speckle_noise','gaussian_blur','spatter','saturate'
]
SEVERITIES = [1, 2, 3, 4, 5]

In [None]:
# ── Data loading ──────────────────────────────────────────────────────────────
MEAN = (0.4914, 0.4822, 0.4465)
STD  = (0.2023, 0.1994, 0.2010)

train_tfm = T.Compose([
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(MEAN, STD),
])
test_tfm = T.Compose([T.ToTensor(), T.Normalize(MEAN, STD)])

train_ds = torchvision.datasets.CIFAR10('.', train=True,  download=True, transform=train_tfm)
test_ds  = torchvision.datasets.CIFAR10('.', train=False, download=True, transform=test_tfm)

train_dl = DataLoader(train_ds, batch_size=BATCH, shuffle=True,  num_workers=2, pin_memory=True)
test_dl  = DataLoader(test_ds,  batch_size=256,   shuffle=False, num_workers=2, pin_memory=True)

print(f'Train: {len(train_ds):,}  |  Test (clean): {len(test_ds):,}')

In [None]:
# ── ResNet-18 with feature extraction hook ─────────────────────────────────────
class GatableResNet18(nn.Module):
    """
    ResNet-18 that also exposes 512-dim feature vectors from avgpool.
    Used to compute geometric gate distances in feature space.
    """
    def __init__(self, num_classes=10):
        super().__init__()
        base = resnet18(weights=None)
        self.backbone = nn.Sequential(*list(base.children())[:-1])  # up to avgpool
        self.classifier = nn.Linear(512, num_classes)

    def forward(self, x, return_features=False):
        feat = self.backbone(x).flatten(1)   # (N, 512)
        logits = self.classifier(feat)        # (N, C)
        if return_features:
            return logits, feat
        return logits

    def features(self, x):
        with torch.no_grad():
            return self.backbone(x).flatten(1)

In [None]:
# ── Compute training feature center (for geometric gate) ──────────────────────
@torch.no_grad()
def compute_feature_center(model, loader, device):
    """Compute mean and 75th-percentile radius of training features."""
    model.eval()
    feats = []
    for x, _ in tqdm(loader, desc='Computing feature center', leave=False):
        feats.append(model.features(x.to(device)).cpu())
    feats  = torch.cat(feats)
    center = feats.mean(0)
    dists  = torch.norm(feats - center, dim=1)
    radius = float(torch.quantile(dists, 0.75))
    return center.to(device), radius

In [None]:
# ── Evaluation helpers ─────────────────────────────────────────────────────────
@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    correct, total = 0, 0
    all_conf, all_acc = [], []
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        probs = torch.softmax(model(x), dim=1)
        conf, pred = probs.max(1)
        acc_batch  = (pred == y).float()
        correct   += acc_batch.sum().item()
        total     += len(y)
        all_conf.append(conf.cpu()); all_acc.append(acc_batch.cpu())
    all_conf = torch.cat(all_conf); all_acc = torch.cat(all_acc)
    # ECE
    bins = torch.linspace(0, 1, 11)
    ece  = 0.0
    for lo, hi in zip(bins[:-1], bins[1:]):
        m = (all_conf > lo) & (all_conf <= hi)
        if m.sum() > 0:
            ece += m.sum().item() * abs(all_acc[m].mean() - all_conf[m].mean()).item()
    return correct / total, ece / total

def load_cifar10c(corruption, severity, normalize=True):
    """Load one corruption/severity from CIFAR-10-C as a DataLoader."""
    data   = np.load(f'CIFAR-10-C/{corruption}.npy')  # (50000, 32, 32, 3)
    labels = np.load('CIFAR-10-C/labels.npy')          # (50000,)
    # Each severity is 10000 samples
    idx    = slice((severity - 1) * 10000, severity * 10000)
    X = torch.from_numpy(data[idx]).permute(0, 3, 1, 2).float() / 255.0
    if normalize:
        mean = torch.tensor(MEAN).view(3,1,1)
        std  = torch.tensor(STD).view(3,1,1)
        X    = (X - mean) / std
    y = torch.from_numpy(labels[idx]).long()
    return DataLoader(TensorDataset(X, y), batch_size=256, num_workers=2)

In [None]:
# ── Training: Baseline ─────────────────────────────────────────────────────────
print('=' * 60)
print('  Training BASELINE (standard cross-entropy)')
print('=' * 60)

baseline = GatableResNet18(N_CLASSES).to(DEVICE)
opt_b    = torch.optim.SGD(baseline.parameters(), lr=0.1, momentum=0.9,
                            weight_decay=5e-4, nesterov=True)
sched_b  = torch.optim.lr_scheduler.CosineAnnealingLR(opt_b, T_max=EPOCHS)
ce_loss  = nn.CrossEntropyLoss()

for epoch in range(EPOCHS):
    baseline.train()
    for x, y in train_dl:
        x, y = x.to(DEVICE), y.to(DEVICE)
        loss = ce_loss(baseline(x), y)
        opt_b.zero_grad(); loss.backward(); opt_b.step()
    sched_b.step()
    if (epoch + 1) % 10 == 0:
        acc, ece = evaluate(baseline, test_dl, DEVICE)
        print(f'  Epoch {epoch+1:3d}/{EPOCHS} | acc={acc:.3f} | ece={ece*100:.2f}%')

torch.save(baseline.state_dict(), 'baseline.pth')
print('Baseline saved.')

In [None]:
# ── Training: Gated ResNet-18 ──────────────────────────────────────────────────
print('=' * 60)
print('  Training GATED ResNet-18 (Hybrid Geometric-Entropy Gate)')
print('=' * 60)

gated   = GatableResNet18(N_CLASSES).to(DEVICE)
opt_g   = torch.optim.SGD(gated.parameters(), lr=0.1, momentum=0.9,
                           weight_decay=5e-4, nesterov=True)
sched_g = torch.optim.lr_scheduler.CosineAnnealingLR(opt_g, T_max=EPOCHS)

geo_gate    = GeometricGate(alpha=ALPHA).to(DEVICE)
ent_gate    = EntropyGate(beta=BETA, num_classes=N_CLASSES).to(DEVICE)
hybrid_gate = HybridGate(geo_gate, ent_gate,
                          warmup_steps=WARMUP * len(train_dl)).to(DEVICE)
gate_loss   = GatedLoss(hybrid_gate)

# Compute feature center after warmup for geometric gate calibration
train_center = None
train_radius = None

for epoch in range(EPOCHS):
    gated.train()

    # Calibrate gate after warmup using current features
    if epoch == WARMUP:
        print(f'  Calibrating geometric gate at epoch {epoch}...')
        train_center, train_radius = compute_feature_center(gated, train_dl, DEVICE)
        geo_gate.log_d0.data = torch.tensor(train_radius).log().to(DEVICE)
        print(f'  Gate radius d0 = {train_radius:.3f}')

    for x, y in train_dl:
        x, y = x.to(DEVICE), y.to(DEVICE)
        logits, feats = gated(x, return_features=True)

        # Feature-space distances for geometric gate
        if train_center is not None:
            dists = torch.norm(feats - train_center, dim=1).detach()
        else:
            dists = torch.ones(len(x), device=DEVICE)  # during warmup: all weight=1

        loss = gate_loss(logits, y, dists)
        opt_g.zero_grad(); loss.backward(); opt_g.step()
        hybrid_gate.step()

    sched_g.step()
    if (epoch + 1) % 10 == 0:
        acc, ece = evaluate(gated, test_dl, DEVICE)
        status = 'WARMUP' if epoch < WARMUP else 'ACTIVE'
        print(f'  Epoch {epoch+1:3d}/{EPOCHS} | acc={acc:.3f} | ece={ece*100:.2f}% | gate={status}')

torch.save(gated.state_dict(), 'gated.pth')
print('Gated model saved.')

In [None]:
# ── CIFAR-10-C Evaluation ─────────────────────────────────────────────────────
print('Evaluating on CIFAR-10-C (19 corruptions × 5 severities)...')
print('This takes ~5 minutes.')

results = {}
for corr in tqdm(CORRUPTIONS, desc='Corruptions'):
    results[corr] = {'base_acc':[], 'gate_acc':[], 'base_ece':[], 'gate_ece':[]}
    for sev in SEVERITIES:
        dl_c = load_cifar10c(corr, sev)
        ba, be = evaluate(baseline, dl_c, DEVICE)
        ga, ge = evaluate(gated,    dl_c, DEVICE)
        results[corr]['base_acc'].append(ba)
        results[corr]['gate_acc'].append(ga)
        results[corr]['base_ece'].append(be)
        results[corr]['gate_ece'].append(ge)

# Clean test set
clean_base_acc, clean_base_ece = evaluate(baseline, test_dl, DEVICE)
clean_gate_acc, clean_gate_ece = evaluate(gated,    test_dl, DEVICE)

print(f'\nClean test: Baseline {clean_base_acc:.3f} | Gated {clean_gate_acc:.3f}')

In [None]:
# ── Results table ─────────────────────────────────────────────────────────────
print(f'\n{"="*70}')
print(f'  {"Corruption":25s}  {"Base Acc":>10}  {"Gate Acc":>10}  {"Δ Acc":>8}')
print(f'  {"─"*65}')

all_deltas = []
for corr in CORRUPTIONS:
    bm = np.mean(results[corr]['base_acc'])
    gm = np.mean(results[corr]['gate_acc'])
    d  = gm - bm
    all_deltas.append(d)
    arrow = '▲' if d > 0.005 else ('▼' if d < -0.005 else '─')
    print(f'  {corr:25s}  {bm*100:>8.1f}%  {gm*100:>8.1f}%  {arrow}{abs(d)*100:>5.1f}pp')

print(f'  {"─"*65}')
print(f'  {"AVERAGE (all corruptions)":25s}  '
      f'{np.mean([np.mean(results[c]["base_acc"]) for c in CORRUPTIONS])*100:>8.1f}%  '
      f'{np.mean([np.mean(results[c]["gate_acc"]) for c in CORRUPTIONS])*100:>8.1f}%  '
      f'{np.mean(all_deltas)*100:>+6.2f}pp')
print(f'  {"Clean test":25s}  {clean_base_acc*100:>8.1f}%  {clean_gate_acc*100:>8.1f}%')

In [None]:
# ── Visualization ─────────────────────────────────────────────────────────────
DARK="#0f1117"; GRID="#1e2130"; BASE="#e74c3c"
GATE="#2ecc71"; GOLD="#f1c40f"; TEXT="#ecf0f1"; SUB="#95a5a6"

# Per-corruption mean accuracies
base_means = [np.mean(results[c]['base_acc']) * 100 for c in CORRUPTIONS]
gate_means = [np.mean(results[c]['gate_acc']) * 100 for c in CORRUPTIONS]
deltas_pp  = [g - b for g, b in zip(gate_means, base_means)]

# Per-severity averages
sev_base = [np.mean([results[c]['base_acc'][s] for c in CORRUPTIONS]) for s in range(5)]
sev_gate = [np.mean([results[c]['gate_acc'][s] for c in CORRUPTIONS]) for s in range(5)]
sev_base_ece = [np.mean([results[c]['base_ece'][s] for c in CORRUPTIONS]) for s in range(5)]
sev_gate_ece = [np.mean([results[c]['gate_ece'][s] for c in CORRUPTIONS]) for s in range(5)]

fig = plt.figure(figsize=(18, 12))
fig.patch.set_facecolor(DARK)
gs  = gridspec.GridSpec(2, 2, figure=fig, hspace=0.45, wspace=0.35)

def style(ax, title):
    ax.set_facecolor(GRID); ax.tick_params(colors=SUB, labelsize=9)
    ax.spines[:].set_color('#2c3e50')
    ax.set_title(title, color=TEXT, fontsize=10, fontweight='bold', pad=8)
    ax.xaxis.label.set_color(SUB); ax.yaxis.label.set_color(SUB)

# A — per-corruption bars
ax1 = fig.add_subplot(gs[0, :])
x   = np.arange(len(CORRUPTIONS))
w   = 0.35
ax1.bar(x - w/2, base_means, w, color=BASE, label='Baseline', alpha=0.85)
ax1.bar(x + w/2, gate_means, w, color=GATE, label='Gated',    alpha=0.85)
ax1.set_xticks(x)
ax1.set_xticklabels([c.replace('_', '\n') for c in CORRUPTIONS],
                     fontsize=7, color=SUB)
ax1.set_ylabel('Accuracy (%)')
ax1.legend(facecolor=DARK, edgecolor='#2c3e50', labelcolor=TEXT)
style(ax1, f'A  Accuracy per Corruption Type (avg over 5 severities)  |  Avg Δ = {np.mean(deltas_pp):+.2f} pp')

# B — accuracy vs severity
ax2 = fig.add_subplot(gs[1, 0])
sevs = [1,2,3,4,5]
ax2.plot(sevs, [v*100 for v in sev_base], 'o-', color=BASE, lw=2.5, ms=7, label='Baseline')
ax2.plot(sevs, [v*100 for v in sev_gate], 's-', color=GATE, lw=2.5, ms=7, label='Gated')
ax2.fill_between(sevs,
                  [v*100 for v in sev_base],
                  [v*100 for v in sev_gate],
                  alpha=0.1, color=GATE)
ax2.set_xlabel('Corruption Severity'); ax2.set_ylabel('Accuracy (%)')
ax2.legend(facecolor=DARK, edgecolor='#2c3e50', labelcolor=TEXT, fontsize=9)
style(ax2, 'B  Accuracy vs Severity (avg over all corruptions)')

# C — ECE vs severity
ax3 = fig.add_subplot(gs[1, 1])
ax3.plot(sevs, [v*100 for v in sev_base_ece], 'o-', color=BASE, lw=2.5, ms=7, label='Baseline ECE')
ax3.plot(sevs, [v*100 for v in sev_gate_ece], 's-', color=GATE, lw=2.5, ms=7, label='Gated ECE')
ax3.set_xlabel('Corruption Severity'); ax3.set_ylabel('ECE (%) — lower is better')
ax3.legend(facecolor=DARK, edgecolor='#2c3e50', labelcolor=TEXT, fontsize=9)
style(ax3, 'C  Calibration Error vs Severity')

avg_ood_base = np.mean([np.mean(results[c]['base_acc']) for c in CORRUPTIONS])
avg_ood_gate = np.mean([np.mean(results[c]['gate_acc']) for c in CORRUPTIONS])

fig.suptitle(
    'Hybrid Geometric–Entropy Gating  ·  Aydin 2025\n'
    f'CIFAR-10-C Benchmark  |  ResNet-18  |  '
    f'OOD acc: {avg_ood_base*100:.1f}% → {avg_ood_gate*100:.1f}% '
    f'({np.mean(deltas_pp):+.2f} pp)  |  '
    f'github.com/hsynposta/entropy-gate',
    color=TEXT, fontsize=11, fontweight='bold', y=0.98
)

plt.savefig('cifar10c_benchmark.png', dpi=150, bbox_inches='tight',
            facecolor=fig.get_facecolor())
plt.show()
print('Figure saved: cifar10c_benchmark.png')