# Hybrid Geometric–Entropy Gating — Contaminated Training Benchmark

**Paper:** Huseyin Aydin (2025) — [https://doi.org/10.5281/zenodo.18055798](https://doi.org/10.5281/zenodo.18055798)

## The key scenario

Real-world training sets are rarely clean. This notebook tests whether the Hybrid Gate
can **automatically identify and downweight noisy training samples** — without any
manual data cleaning.

- **Training data**: CIFAR-10 + 30% contaminated samples (random labels + heavy noise)
- **Both models see identical (dirty) training data**
- **Test**: clean CIFAR-10 test set + CIFAR-10-C corruptions
- **Hypothesis**: Gate should downweight noisy samples → better generalization

> ⚡ Runtime → Change runtime type → **T4 GPU**

In [None]:
# ── Setup ─────────────────────────────────────────────────────────────────────
!pip install -q git+https://github.com/hsynposta/entropy-gate.git

import os
if not os.path.exists('CIFAR-10-C'):
    print('Downloading CIFAR-10-C (~2.5 GB)...')
    !wget -q https://zenodo.org/record/2535967/files/CIFAR-10-C.tar
    !tar -xf CIFAR-10-C.tar
    print('Done.')
else:
    print('CIFAR-10-C already present.')

In [None]:
# ── Imports ───────────────────────────────────────────────────────────────────
import torch, torch.nn as nn, torch.nn.functional as F
import torchvision, torchvision.transforms as T
from torchvision.models import resnet18
from torch.utils.data import DataLoader, TensorDataset, Dataset
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from tqdm.notebook import tqdm

from entropy_gate.gates import GeometricGate, EntropyGate, HybridGate
from entropy_gate.loss  import GatedLoss

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device: {DEVICE}')
torch.manual_seed(42); np.random.seed(42)

In [None]:
# ── Config ────────────────────────────────────────────────────────────────────
EPOCHS         = 80
BATCH          = 128
LR             = 0.1
WARMUP         = 12        # epochs before gate activates
ALPHA          = 4.0
BETA           = 4.0
NOISE_RATIO    = 0.30      # 30% of training samples contaminated
NOISE_SIGMA    = 0.8       # noise magnitude (in normalized space)
N_CLASSES      = 10

CORRUPTIONS = [
    'gaussian_noise','shot_noise','impulse_noise','defocus_blur',
    'glass_blur','motion_blur','zoom_blur','snow','frost','fog',
    'brightness','contrast','elastic_transform','pixelate',
    'jpeg_compression','speckle_noise','gaussian_blur','spatter','saturate'
]
SEVERITIES = [1, 2, 3, 4, 5]
MEAN = (0.4914, 0.4822, 0.4465)
STD  = (0.2023, 0.1994, 0.2010)

In [None]:
# ── Contaminated dataset ──────────────────────────────────────────────────────
class ContaminatedCIFAR10(Dataset):
    """
    CIFAR-10 training set with NOISE_RATIO fraction contaminated:
    - Pixel noise added (NOISE_SIGMA in normalized space)
    - Labels randomized to wrong classes
    The gate should automatically learn to downweight these.
    """
    def __init__(self, noise_ratio=0.30, noise_sigma=0.8, seed=42, augment=True):
        rng = np.random.RandomState(seed)
        base = torchvision.datasets.CIFAR10('.', train=True, download=True)
        X = torch.from_numpy(base.data).permute(0,3,1,2).float() / 255.0
        # Normalize
        mn = torch.tensor(MEAN).view(3,1,1)
        sd = torch.tensor(STD).view(3,1,1)
        X  = (X - mn) / sd
        y  = torch.tensor(base.targets, dtype=torch.long)

        # Contaminate
        n_bad = int(len(X) * noise_ratio)
        bad_idx = rng.permutation(len(X))[:n_bad]
        X[bad_idx] += torch.from_numpy(
            rng.randn(n_bad, 3, 32, 32).astype(np.float32)) * noise_sigma
        y[bad_idx] = torch.from_numpy(
            rng.randint(0, N_CLASSES, n_bad).astype(np.int64))

        self.X, self.y = X, y
        self.augment   = augment
        self.aug = T.Compose([
            T.RandomCrop(32, padding=4),
            T.RandomHorizontalFlip(),
        ])
        print(f'Contaminated dataset: {len(X):,} samples '
              f'({n_bad:,} noisy = {noise_ratio:.0%})')

    def __len__(self): return len(self.X)

    def __getitem__(self, idx):
        x, y = self.X[idx], self.y[idx]
        if self.augment:
            x = self.aug(x)
        return x, y

train_ds_c = ContaminatedCIFAR10(NOISE_RATIO, NOISE_SIGMA)
train_dl   = DataLoader(train_ds_c, batch_size=BATCH, shuffle=True,
                         num_workers=2, pin_memory=True)

# Clean test set
test_tfm = T.Compose([T.ToTensor(), T.Normalize(MEAN, STD)])
test_ds  = torchvision.datasets.CIFAR10('.', train=False, download=True, transform=test_tfm)
test_dl  = DataLoader(test_ds, batch_size=256, shuffle=False,
                       num_workers=2, pin_memory=True)
print(f'Clean test: {len(test_ds):,} samples')

In [None]:
# ── Model ─────────────────────────────────────────────────────────────────────
class GatableResNet18(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        base = resnet18(weights=None)
        self.backbone   = nn.Sequential(*list(base.children())[:-1])
        self.classifier = nn.Linear(512, num_classes)

    def forward(self, x, return_features=False):
        feat   = self.backbone(x).flatten(1)
        logits = self.classifier(feat)
        return (logits, feat) if return_features else logits

    @torch.no_grad()
    def features(self, x):
        return self.backbone(x).flatten(1)

@torch.no_grad()
def compute_center(model, loader, device):
    model.eval()
    feats = torch.cat([model.features(x.to(device)).cpu()
                       for x, _ in tqdm(loader, desc='Feature center', leave=False)])
    center = feats.mean(0)
    radius = float(torch.quantile(torch.norm(feats - center, dim=1), 0.75))
    return center.to(device), radius

In [None]:
# ── Evaluation helpers ─────────────────────────────────────────────────────────
@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    correct, total = 0, 0
    all_conf, all_acc = [], []
    for x, y in loader:
        x, y  = x.to(device), y.to(device)
        probs  = torch.softmax(model(x), dim=1)
        conf, pred = probs.max(1)
        acc_b  = (pred == y).float()
        correct += acc_b.sum().item(); total += len(y)
        all_conf.append(conf.cpu()); all_acc.append(acc_b.cpu())
    all_conf = torch.cat(all_conf); all_acc = torch.cat(all_acc)
    bins = torch.linspace(0, 1, 11)
    ece  = sum(m.sum().item() * abs(all_acc[m].mean() - all_conf[m].mean()).item()
               for lo, hi in zip(bins[:-1], bins[1:])
               if (m := (all_conf > lo) & (all_conf <= hi)).sum() > 0) / total
    return correct / total, ece

def load_cifar10c(corruption, severity):
    data   = np.load(f'CIFAR-10-C/{corruption}.npy')
    labels = np.load('CIFAR-10-C/labels.npy')
    idx    = slice((severity-1)*10000, severity*10000)
    X = torch.from_numpy(data[idx]).permute(0,3,1,2).float() / 255.0
    mn = torch.tensor(MEAN).view(3,1,1)
    sd = torch.tensor(STD).view(3,1,1)
    X  = (X - mn) / sd
    y  = torch.from_numpy(labels[idx]).long()
    return DataLoader(TensorDataset(X, y), batch_size=256, num_workers=2)

In [None]:
# ── Train BASELINE ─────────────────────────────────────────────────────────────
print('='*60)
print(f'  BASELINE  |  {NOISE_RATIO:.0%} contaminated training data')
print('='*60)

baseline = GatableResNet18().to(DEVICE)
opt_b    = torch.optim.SGD(baseline.parameters(), lr=LR, momentum=0.9,
                            weight_decay=5e-4, nesterov=True)
sched_b  = torch.optim.lr_scheduler.CosineAnnealingLR(opt_b, T_max=EPOCHS)
ce_loss  = nn.CrossEntropyLoss()

for epoch in range(EPOCHS):
    baseline.train()
    for x, y in train_dl:
        x, y = x.to(DEVICE), y.to(DEVICE)
        opt_b.zero_grad()
        ce_loss(baseline(x), y).backward()
        opt_b.step()
    sched_b.step()
    if (epoch+1) % 10 == 0:
        acc, ece = evaluate(baseline, test_dl, DEVICE)
        print(f'  Epoch {epoch+1:3d}/{EPOCHS} | acc={acc:.3f} | ece={ece*100:.2f}%')

torch.save(baseline.state_dict(), 'baseline_contaminated.pth')
print('Saved.')

In [None]:
# ── Train GATED ────────────────────────────────────────────────────────────────
print('='*60)
print(f'  GATED  |  {NOISE_RATIO:.0%} contaminated training data')
print('='*60)

gated   = GatableResNet18().to(DEVICE)
opt_g   = torch.optim.SGD(gated.parameters(), lr=LR, momentum=0.9,
                           weight_decay=5e-4, nesterov=True)
sched_g = torch.optim.lr_scheduler.CosineAnnealingLR(opt_g, T_max=EPOCHS)

geo_gate    = GeometricGate(alpha=ALPHA).to(DEVICE)
ent_gate    = EntropyGate(beta=BETA, num_classes=N_CLASSES).to(DEVICE)
hybrid_gate = HybridGate(geo_gate, ent_gate,
                          warmup_steps=WARMUP*len(train_dl)).to(DEVICE)
gate_loss   = GatedLoss(hybrid_gate)

train_center = None

for epoch in range(EPOCHS):
    gated.train()

    if epoch == WARMUP:
        print(f'  Calibrating gate at epoch {epoch}...')
        train_center, radius = compute_center(gated, train_dl, DEVICE)
        geo_gate.log_d0.data = torch.tensor(radius).log().to(DEVICE)
        print(f'  d0 = {radius:.3f}')

    for x, y in train_dl:
        x, y = x.to(DEVICE), y.to(DEVICE)
        logits, feats = gated(x, return_features=True)

        dists = (torch.norm(feats - train_center, dim=1).detach()
                 if train_center is not None
                 else torch.ones(len(x), device=DEVICE))

        opt_g.zero_grad()
        gate_loss(logits, y, dists).backward()
        opt_g.step()
        hybrid_gate.step()

    sched_g.step()
    if (epoch+1) % 10 == 0:
        acc, ece = evaluate(gated, test_dl, DEVICE)
        status = 'WARMUP' if epoch < WARMUP else 'ACTIVE'
        print(f'  Epoch {epoch+1:3d}/{EPOCHS} | acc={acc:.3f} | ece={ece*100:.2f}% | [{status}]')

torch.save(gated.state_dict(), 'gated_contaminated.pth')
print('Saved.')

In [None]:
# ── Evaluate on CIFAR-10-C ────────────────────────────────────────────────────
print('Evaluating on CIFAR-10-C...')

results = {c: {'base_acc':[],'gate_acc':[],'base_ece':[],'gate_ece':[]}
           for c in CORRUPTIONS}

for corr in tqdm(CORRUPTIONS):
    for sev in SEVERITIES:
        dl = load_cifar10c(corr, sev)
        ba, be = evaluate(baseline, dl, DEVICE)
        ga, ge = evaluate(gated,    dl, DEVICE)
        results[corr]['base_acc'].append(ba)
        results[corr]['gate_acc'].append(ga)
        results[corr]['base_ece'].append(be)
        results[corr]['gate_ece'].append(ge)

clean_ba, clean_be = evaluate(baseline, test_dl, DEVICE)
clean_ga, clean_ge = evaluate(gated,    test_dl, DEVICE)

print(f'Clean → Baseline: {clean_ba:.3f}  |  Gated: {clean_ga:.3f}  '
      f'(Δ={( clean_ga-clean_ba)*100:+.2f}pp)')

In [None]:
# ── Results table ─────────────────────────────────────────────────────────────
print(f'\n{"="*65}')
print(f'  {"Corruption":25s}  {"Base":>8}  {"Gated":>8}  {"Δ":>8}')
print(f'  {"-"*60}')

deltas = []
for c in CORRUPTIONS:
    bm = np.mean(results[c]['base_acc'])
    gm = np.mean(results[c]['gate_acc'])
    d  = gm - bm; deltas.append(d)
    arr = '▲' if d > 0.005 else ('▼' if d < -0.005 else '─')
    print(f'  {c:25s}  {bm*100:>6.1f}%  {gm*100:>6.1f}%  {arr}{abs(d)*100:>5.1f}pp')

avg_b = np.mean([np.mean(results[c]['base_acc']) for c in CORRUPTIONS])
avg_g = np.mean([np.mean(results[c]['gate_acc']) for c in CORRUPTIONS])
print(f'  {"-"*60}')
print(f'  {"AVERAGE":25s}  {avg_b*100:>6.1f}%  {avg_g*100:>6.1f}%  {(avg_g-avg_b)*100:>+6.2f}pp')
print(f'  {"Clean test":25s}  {clean_ba*100:>6.1f}%  {clean_ga*100:>6.1f}%  {(clean_ga-clean_ba)*100:>+6.2f}pp')
print(f'\n  Training contamination: {NOISE_RATIO:.0%}')
print(f'  Overall OOD gain: {np.mean(deltas)*100:+.2f} pp')

In [None]:
# ── Figure ────────────────────────────────────────────────────────────────────
DARK='#0f1117';GRID='#1e2130';BASE='#e74c3c'
GATE='#2ecc71';TEXT='#ecf0f1';SUB='#95a5a6'

def style(ax, title):
    ax.set_facecolor(GRID); ax.tick_params(colors=SUB, labelsize=9)
    ax.spines[:].set_color('#2c3e50')
    ax.set_title(title, color=TEXT, fontsize=10, fontweight='bold', pad=8)
    ax.xaxis.label.set_color(SUB); ax.yaxis.label.set_color(SUB)

base_means = [np.mean(results[c]['base_acc'])*100 for c in CORRUPTIONS]
gate_means = [np.mean(results[c]['gate_acc'])*100 for c in CORRUPTIONS]
delta_pp   = [g-b for g,b in zip(gate_means, base_means)]

sev_b = [np.mean([results[c]['base_acc'][s] for c in CORRUPTIONS])*100 for s in range(5)]
sev_g = [np.mean([results[c]['gate_acc'][s] for c in CORRUPTIONS])*100 for s in range(5)]
sev_be= [np.mean([results[c]['base_ece'][s] for c in CORRUPTIONS])*100 for s in range(5)]
sev_ge= [np.mean([results[c]['gate_ece'][s] for c in CORRUPTIONS])*100 for s in range(5)]

fig = plt.figure(figsize=(18, 12))
fig.patch.set_facecolor(DARK)
gs  = gridspec.GridSpec(2, 2, figure=fig, hspace=0.45, wspace=0.35)

# A — per corruption
ax1 = fig.add_subplot(gs[0,:])
x = np.arange(len(CORRUPTIONS)); w = 0.35
ax1.bar(x-w/2, base_means, w, color=BASE, label='Baseline', alpha=0.85)
ax1.bar(x+w/2, gate_means, w, color=GATE, label='Gated',    alpha=0.85)
ax1.set_xticks(x)
ax1.set_xticklabels([c.replace('_','\n') for c in CORRUPTIONS], fontsize=7, color=SUB)
ax1.set_ylabel('Accuracy (%)')
ax1.legend(facecolor=DARK, edgecolor='#2c3e50', labelcolor=TEXT)
style(ax1, f'A  Per-Corruption Accuracy  |  {NOISE_RATIO:.0%} train contamination  '
      f'|  Avg Δ = {np.mean(delta_pp):+.2f} pp')

# B — vs severity
ax2 = fig.add_subplot(gs[1,0])
sevs = [1,2,3,4,5]
ax2.plot(sevs, sev_b, 'o-', color=BASE, lw=2.5, ms=7, label='Baseline')
ax2.plot(sevs, sev_g, 's-', color=GATE, lw=2.5, ms=7, label='Gated')
ax2.fill_between(sevs, sev_b, sev_g, alpha=0.1,
                  color=GATE if np.mean(delta_pp)>0 else BASE)
ax2.set_xlabel('Severity'); ax2.set_ylabel('Accuracy (%)')
ax2.legend(facecolor=DARK, edgecolor='#2c3e50', labelcolor=TEXT, fontsize=9)
style(ax2, 'B  Accuracy vs Severity')

# C — ECE vs severity
ax3 = fig.add_subplot(gs[1,1])
ax3.plot(sevs, sev_be, 'o-', color=BASE, lw=2.5, ms=7, label='Baseline ECE')
ax3.plot(sevs, sev_ge, 's-', color=GATE, lw=2.5, ms=7, label='Gated ECE')
ax3.fill_between(sevs, sev_be, sev_ge, alpha=0.1, color=GATE)
ax3.set_xlabel('Severity'); ax3.set_ylabel('ECE (%) — lower is better')
ax3.legend(facecolor=DARK, edgecolor='#2c3e50', labelcolor=TEXT, fontsize=9)
style(ax3, 'C  Calibration Error vs Severity')

fig.suptitle(
    'Hybrid Geometric–Entropy Gating  ·  Aydin 2025\n'
    f'CIFAR-10-C  |  ResNet-18  |  {NOISE_RATIO:.0%} train contamination  '
    f'|  OOD: {avg_b:.1%} → {avg_g:.1%} ({(avg_g-avg_b)*100:+.2f}pp)  '
    f'|  Clean: {clean_ba:.1%} → {clean_ga:.1%}',
    color=TEXT, fontsize=11, fontweight='bold', y=0.98
)

plt.savefig('cifar10c_contaminated.png', dpi=150, bbox_inches='tight',
            facecolor=fig.get_facecolor())
plt.show()
print('Figure saved: cifar10c_contaminated.png')