# 24. SCRUB Experiments

SCRUB (Kurmanji et al., NeurIPS 2023) uses teacher-student distillation:
the student matches the teacher on retain data and diverges from the
teacher on forget data. This is the state-of-the-art in machine unlearning.

Adaptation for VAEs:
- Teacher = frozen baseline VAE
- Student = copy of baseline (being updated)
- Forget: maximize KL(student_posterior || teacher_posterior)
- Retain: minimize KL(student_posterior || teacher_posterior) + ELBO
- Alternating optimization (forget steps, then retain steps)

Tested on PBMC structured forget set with 3 seeds + alpha sweep.

In [1]:
import sys
sys.path.insert(0, '../src')

import json
import numpy as np
import torch
from pathlib import Path
import time

from train_scrub import train_scrub

DATA_PATH = '../data/adata_processed.h5ad'
SPLIT_PATH = '../outputs/p1/split_structured.json'
BASELINE_CKPT = '../outputs/p1/baseline/best_model.pt'
OUTPUT_BASE = Path('../outputs/p2/scrub')
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device: {DEVICE}')

Device: cpu


## 1. Train SCRUB with 3 Seeds

Hyperparameters:
- alpha_forget=1.0 (forget divergence weight)
- alpha_retain=1.0 (retain matching weight)
- 20 SCRUB epochs, 5 forget + 10 retain steps per epoch
- lr_forget=1e-4, lr_retain=1e-4
- max_grad_norm=1.0
- 10 retain fine-tuning epochs

Sweep alpha_forget in {0.1, 1.0, 10.0} with seed=42.

In [2]:
SEEDS = [42, 123, 456]

results = {}

# Multi-seed with default alpha_forget=1.0
for seed in SEEDS:
    out_dir = OUTPUT_BASE / f'af1.0_seed{seed}'
    print(f'\n{"="*60}')
    print(f'SCRUB alpha_forget=1.0, seed={seed}')
    print(f'{"="*60}')
    
    t0 = time.time()
    ckpt_path = train_scrub(
        baseline_checkpoint=BASELINE_CKPT,
        data_path=DATA_PATH,
        split_path=SPLIT_PATH,
        output_dir=str(out_dir),
        alpha_forget=1.0,
        alpha_retain=1.0,
        n_epochs=20,
        forget_steps_per_epoch=5,
        retain_steps_per_epoch=10,
        lr_forget=1e-4,
        lr_retain=1e-4,
        max_grad_norm=1.0,
        finetune_epochs=10,
        finetune_lr=1e-4,
        patience=10,
        batch_size=256,
        seed=seed,
    )
    elapsed = time.time() - t0
    results[f'af1.0_seed{seed}'] = {'path': str(ckpt_path), 'time': elapsed}
    print(f'Done in {elapsed:.1f}s')

# Alpha_forget sweep with seed=42
for af in [0.1, 10.0]:
    out_dir = OUTPUT_BASE / f'af{af}_seed42'
    print(f'\n{"="*60}')
    print(f'SCRUB alpha_forget={af}, seed=42')
    print(f'{"="*60}')
    
    t0 = time.time()
    ckpt_path = train_scrub(
        baseline_checkpoint=BASELINE_CKPT,
        data_path=DATA_PATH,
        split_path=SPLIT_PATH,
        output_dir=str(out_dir),
        alpha_forget=af,
        alpha_retain=1.0,
        n_epochs=20,
        forget_steps_per_epoch=5,
        retain_steps_per_epoch=10,
        lr_forget=1e-4,
        lr_retain=1e-4,
        max_grad_norm=1.0,
        finetune_epochs=10,
        finetune_lr=1e-4,
        patience=10,
        batch_size=256,
        seed=42,
    )
    elapsed = time.time() - t0
    results[f'af{af}_seed42'] = {'path': str(ckpt_path), 'time': elapsed}
    print(f'Done in {elapsed:.1f}s')

print(f'\nAll training complete. {len(results)} checkpoints saved.')


SCRUB alpha_forget=1.0, seed=42
Data: torch.Size([33088, 2000]), Forget: 30, Retain: 28094, Device: cpu
SCRUB training: 20 epochs, 5 forget + 10 retain steps/epoch
  alpha_forget=1.0, alpha_retain=1.0
  Epoch 1: forget=-8.7694, retain=368.8553
  Epoch 5: forget=-56.8306, retain=372.3197
  Epoch 10: forget=-178.4803, retain=372.8519
  Epoch 15: forget=-347.6720, retain=376.4480
  Epoch 20: forget=-515.9554, retain=375.1795
Fine-tuning on retain set (10 epochs, lr=0.0001)...
  Epoch 1: train=366.38, val=358.01
  Epoch 5: train=365.63, val=357.77
  Epoch 10: train=365.31, val=357.63
Saved to ../outputs/p2/scrub/af1.0_seed42/best_model.pt
Done in 81.5s

SCRUB alpha_forget=1.0, seed=123
Data: torch.Size([33088, 2000]), Forget: 30, Retain: 28094, Device: cpu
SCRUB training: 20 epochs, 5 forget + 10 retain steps/epoch
  alpha_forget=1.0, alpha_retain=1.0
  Epoch 1: forget=-6.3483, retain=368.5155
  Epoch 5: forget=-53.6064, retain=370.8502
  Epoch 10: forget=-181.1074, retain=374.7687
  Epoc

## 2. Evaluate with Canonical Fresh Attacker

In [3]:
sys.path.insert(0, '../scripts')
from eval_multiseed import (
    load_vae_model, train_fresh_attacker, evaluate_privacy,
    evaluate_utility, MARKER_GENES
)
import scanpy as sc

adata = sc.read_h5ad(DATA_PATH)
X = torch.tensor(
    adata.X.toarray() if hasattr(adata.X, 'toarray') else adata.X,
    dtype=torch.float32
)

with open(SPLIT_PATH) as f:
    split = json.load(f)
forget_idx = split['forget_indices']
retain_idx = split['retain_indices']
unseen_idx = split['unseen_indices']

with open('../outputs/p1.5/s1_matched_negatives.json') as f:
    matched_data = json.load(f)
matched_neg_idx = matched_data['matched_indices']

X_holdout = X[unseen_idx]
labels_holdout = adata.obs['leiden'].values[unseen_idx]
gene_names = list(adata.var_names)
marker_idx = [gene_names.index(g) for g in MARKER_GENES if g in gene_names]
marker_names = [g for g in MARKER_GENES if g in gene_names]

baseline_model, _ = load_vae_model(BASELINE_CKPT)
attacker = train_fresh_attacker(
    baseline_model, adata, forget_idx, matched_neg_idx, retain_idx, seed=42
)

Training fresh attacker on baseline F vs matched:
  Samples: 224 (30 forget + 194 matched)
  Features: 70 dims
  Train: 179, Test: 45
  Baseline AUC (F vs matched, full set): 0.7792
  (Canonical NB03 value: ~0.769)


In [4]:
eval_results = {}

for name, info in results.items():
    ckpt_path = info['path']
    print(f'\nEvaluating {name}...')
    
    model, config = load_vae_model(ckpt_path)
    
    privacy = evaluate_privacy(
        model, attacker, adata, forget_idx, matched_neg_idx, retain_idx
    )
    utility = evaluate_utility(
        model, X_holdout, labels_holdout, marker_idx, gene_names
    )
    
    eval_results[name] = {
        'privacy': privacy,
        'utility': utility,
        'training_time': info['time'],
    }
    
    print(f'  AUC={privacy["mlp_auc"]:.3f}, '
          f'Advantage={privacy["mlp_advantage"]:.3f}, '
          f'ELBO={utility["elbo"]:.1f}, '
          f'Marker r={utility["marker_r"]:.3f}')


Evaluating af1.0_seed42...


Found Intel OpenMP ('libiomp') and LLVM OpenMP ('libomp') loaded at
the same time. Both libraries are known to be incompatible and this
can cause random crashes or deadlocks on Linux when loaded in the
same Python program.
Using threadpoolctl may cause crashes or deadlocks. For more
information and possible workarounds, please see
    https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md



  AUC=0.736, Advantage=0.473, ELBO=363.8, Marker r=0.832

Evaluating af1.0_seed123...
  AUC=0.735, Advantage=0.471, ELBO=363.8, Marker r=0.831

Evaluating af1.0_seed456...
  AUC=0.739, Advantage=0.479, ELBO=363.8, Marker r=0.831

Evaluating af0.1_seed42...
  AUC=0.731, Advantage=0.461, ELBO=363.8, Marker r=0.832

Evaluating af10.0_seed42...
  AUC=0.732, Advantage=0.463, ELBO=363.8, Marker r=0.832


## 3. Summary

In [5]:
# Aggregate af=1.0 seeds
af1_aucs = []
af1_advantages = []
for seed in SEEDS:
    key = f'af1.0_seed{seed}'
    if key in eval_results:
        af1_aucs.append(eval_results[key]['privacy']['mlp_auc'])
        af1_advantages.append(eval_results[key]['privacy']['mlp_advantage'])

print('SCRUB Results (alpha_forget=1.0, 3 seeds):')
print(f'  AUC: {np.mean(af1_aucs):.3f} +/- {np.std(af1_aucs):.3f}')
print(f'  Advantage: {np.mean(af1_advantages):.3f} +/- {np.std(af1_advantages):.3f}')
print()

print('Alpha_forget sweep (seed=42):')
print(f'{"alpha_f":<10} {"AUC":>8} {"Advantage":>10} {"Marker r":>10} {"ELBO":>8}')
print('-' * 50)
for af in [0.1, 1.0, 10.0]:
    key = f'af{af}_seed42'
    if key in eval_results:
        r = eval_results[key]
        print(f'{af:<10.1f} {r["privacy"]["mlp_auc"]:>8.3f} '
              f'{r["privacy"]["mlp_advantage"]:>10.3f} '
              f'{r["utility"]["marker_r"]:>10.3f} '
              f'{r["utility"]["elbo"]:>8.1f}')

print()
print('Reference: Retrain AUC=0.523, Advantage=0.046')
print('Reference: Baseline AUC=0.783, Advantage=0.565')
print()
print('SCRUB is the strongest modern baseline. If it fails on structured')
print('forget sets, this reinforces that the problem is fundamental to')
print('biological subpopulation unlearning, not a limitation of older methods.')

SCRUB Results (alpha_forget=1.0, 3 seeds):
  AUC: 0.737 +/- 0.002
  Advantage: 0.474 +/- 0.003

Alpha_forget sweep (seed=42):
alpha_f         AUC  Advantage   Marker r     ELBO
--------------------------------------------------
0.1           0.731      0.461      0.832    363.8
1.0           0.736      0.473      0.832    363.8
10.0          0.732      0.463      0.832    363.8

Reference: Retrain AUC=0.523, Advantage=0.046
Reference: Baseline AUC=0.783, Advantage=0.565

SCRUB is the strongest modern baseline. If it fails on structured
forget sets, this reinforces that the problem is fundamental to
biological subpopulation unlearning, not a limitation of older methods.


In [6]:
output = {
    'method': 'scrub',
    'dataset': 'PBMC',
    'forget_type': 'structured',
    'seeds': SEEDS,
    'alpha_forget_sweep': [0.1, 1.0, 10.0],
    'results': eval_results,
    'summary': {
        'af1.0': {
            'mean_auc': float(np.mean(af1_aucs)),
            'std_auc': float(np.std(af1_aucs)),
            'mean_advantage': float(np.mean(af1_advantages)),
            'std_advantage': float(np.std(af1_advantages)),
        }
    }
}

with open(OUTPUT_BASE / 'scrub_results.json', 'w') as f:
    json.dump(output, f, indent=2, default=str)

print(f'Saved to {OUTPUT_BASE / "scrub_results.json"}')

Saved to ../outputs/p2/scrub/scrub_results.json


## 4. Analysis

SCRUB is the current state-of-the-art for machine unlearning (Kurmanji et al., NeurIPS 2023), so these results matter. At alpha_forget=1.0 across three seeds, MIA AUC is 0.737 +/- 0.002 with advantage 0.474. The baseline is 0.565. SCRUB barely moved it.

The alpha_forget sweep is the most telling part. A 100x range in the forget weight (0.1 to 10.0) changes the advantage by 0.018. The optimization is dominated by the retain objective no matter how aggressively the forget term is weighted. The alternating structure probably contributes: the retain steps just overwrite whatever the forget steps did.

The teacher-student distillation is doing almost nothing here. The forget divergence term tries to push the student away from the teacher on forget samples, but 30 cells cannot generate enough signal to overcome 28,000 retain samples pulling the model back. A larger forget set would give the forget objective more gradient mass, and SCRUB might perform differently at n=500 or n=1000. But the size ablation from extra-gradient (NB20) showed that larger structured forget sets also have stronger baseline memorization, so the problem scales on both sides. Utility is untouched (marker r=0.831-0.832, ELBO=363.8), which confirms the model is not changing in any meaningful way.

If the NeurIPS 2023 state-of-the-art fails this completely on a 30-cell structured forget set, the problem is not about finding the right method or the right hyperparameters. The VAE has distributed subpopulation information so thoroughly that local parameter adjustments cannot selectively remove it. This holds whether the adjustment is Fisher-weighted (SSD), latent-space (contrastive), or distillation-based (SCRUB). The only method that worked was DP-SGD, which never saw the forget set in the first place.