# 21. Selective Synaptic Dampening (SSD) Experiments

SSD (Foster et al., AAAI 2024) dampens parameters proportional to their Fisher importance
for the forget set relative to the retain set. Unlike our original Fisher scrubbing
(gradient ascent + Fisher weighting), SSD uses multiplicative dampening:
theta *= (1 - alpha * F_forget / F_retain).

We test SSD on PBMC structured forget set with 3 seeds and evaluate
using the canonical fresh-attacker methodology (NB03).

In [4]:
import sys
sys.path.insert(0, '../src')

import json
import numpy as np
import torch
from pathlib import Path
import time

from train_ssd import train_ssd

DATA_PATH = '../data/adata_processed.h5ad'
SPLIT_PATH = '../outputs/p1/split_structured.json'
BASELINE_CKPT = '../outputs/p1/baseline/best_model.pt'
OUTPUT_BASE = Path('../outputs/p2/ssd')
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device: {DEVICE}')

Device: cpu


## 1. Train SSD with 3 Seeds

Hyperparameters:
- alpha=1.0 (full dampening strength)
- threshold=0.0 (dampen all parameters)
- damping=1e-5 (Fisher numerical stability)
- 10 epochs retain fine-tuning, lr=1e-4, patience=10

In [5]:
SEEDS = [42, 123, 456]
ALPHAS = [0.5, 1.0, 5.0]  # Sweep dampening strength

results = {}

# First: train with default alpha=1.0 across seeds
for seed in SEEDS:
    out_dir = OUTPUT_BASE / f'alpha1.0_seed{seed}'
    print(f'\n{"="*60}')
    print(f'SSD alpha=1.0, seed={seed}')
    print(f'{"="*60}')
    
    t0 = time.time()
    ckpt_path = train_ssd(
        baseline_checkpoint=BASELINE_CKPT,
        data_path=DATA_PATH,
        split_path=SPLIT_PATH,
        output_dir=str(out_dir),
        alpha=1.0,
        threshold=0.0,
        damping=1e-5,
        finetune_epochs=10,
        finetune_lr=1e-4,
        patience=10,
        batch_size=256,
        seed=seed,
    )
    elapsed = time.time() - t0
    results[f'alpha1.0_seed{seed}'] = {'path': str(ckpt_path), 'time': elapsed}
    print(f'Done in {elapsed:.1f}s')

# Alpha sweep with seed=42
for alpha in [0.5, 5.0]:
    out_dir = OUTPUT_BASE / f'alpha{alpha}_seed42'
    print(f'\n{"="*60}')
    print(f'SSD alpha={alpha}, seed=42')
    print(f'{"="*60}')
    
    t0 = time.time()
    ckpt_path = train_ssd(
        baseline_checkpoint=BASELINE_CKPT,
        data_path=DATA_PATH,
        split_path=SPLIT_PATH,
        output_dir=str(out_dir),
        alpha=alpha,
        threshold=0.0,
        damping=1e-5,
        finetune_epochs=10,
        finetune_lr=1e-4,
        patience=10,
        batch_size=256,
        seed=42,
    )
    elapsed = time.time() - t0
    results[f'alpha{alpha}_seed42'] = {'path': str(ckpt_path), 'time': elapsed}
    print(f'Done in {elapsed:.1f}s')

print(f'\nAll training complete. {len(results)} checkpoints saved.')


SSD alpha=1.0, seed=42
Data: torch.Size([33088, 2000]), Forget: 30, Retain: 28094, Device: cpu
Computing Fisher on forget set...
Computing Fisher on retain set...
Applying SSD dampening (alpha=1.0, threshold=0.0)...
  Dampened 100.0% of parameters
  Mean dampening magnitude: 0.2798
Fine-tuning on retain set (10 epochs, lr=0.0001)...
  Epoch 1: train=366.35, val=358.17
  Epoch 5: train=365.82, val=357.98
  Epoch 10: train=365.51, val=357.58
Saved to ../outputs/p2/ssd/alpha1.0_seed42/best_model.pt
Done in 82.9s

SSD alpha=1.0, seed=123
Data: torch.Size([33088, 2000]), Forget: 30, Retain: 28094, Device: cpu
Computing Fisher on forget set...
Computing Fisher on retain set...
Applying SSD dampening (alpha=1.0, threshold=0.0)...
  Dampened 100.0% of parameters
  Mean dampening magnitude: 0.2800
Fine-tuning on retain set (10 epochs, lr=0.0001)...
  Epoch 1: train=366.34, val=361.82
  Epoch 5: train=365.71, val=361.59
  Epoch 10: train=365.41, val=361.22
Saved to ../outputs/p2/ssd/alpha1.0_se

## 2. Evaluate with Canonical Fresh Attacker

In [8]:
# Use eval_multiseed infrastructure for canonical evaluation
sys.path.insert(0, '../scripts')
from eval_multiseed import (
    load_vae_model, train_fresh_attacker, evaluate_privacy,
    evaluate_utility, get_retain_latent_codes, MARKER_GENES
)
import scanpy as sc

# Load data
adata = sc.read_h5ad(DATA_PATH)
X = torch.tensor(
    adata.X.toarray() if hasattr(adata.X, 'toarray') else adata.X,
    dtype=torch.float32
)

with open(SPLIT_PATH) as f:
    split = json.load(f)
forget_idx = split['forget_indices']
retain_idx = split['retain_indices']
unseen_idx = split['unseen_indices']

# Load matched negatives
with open('../outputs/p1.5/s1_matched_negatives.json') as f:
    matched_data = json.load(f)
matched_neg_idx = matched_data['matched_indices']

print(f'Forget: {len(forget_idx)}, Matched neg: {len(matched_neg_idx)}')

# Holdout data for utility
X_holdout = X[unseen_idx]
lib_holdout = X_holdout.sum(dim=1, keepdim=True)
labels_holdout = adata.obs['leiden'].values[unseen_idx]

# Marker gene indices
gene_names = list(adata.var_names)
marker_idx = [gene_names.index(g) for g in MARKER_GENES if g in gene_names]
marker_names = [g for g in MARKER_GENES if g in gene_names]

# Train fresh attacker on baseline
baseline_model, _ = load_vae_model(BASELINE_CKPT)
attacker = train_fresh_attacker(
    baseline_model, adata, forget_idx, matched_neg_idx, retain_idx, seed=42
)

Forget: 30, Matched neg: 194
Training fresh attacker on baseline F vs matched:
  Samples: 224 (30 forget + 194 matched)
  Features: 70 dims
  Train: 179, Test: 45
  Baseline AUC (F vs matched, full set): 0.7792
  (Canonical NB03 value: ~0.769)


In [9]:
# Evaluate all checkpoints
eval_results = {}

for name, info in results.items():
    ckpt_path = info['path']
    print(f'\nEvaluating {name}...')
    
    model, config = load_vae_model(ckpt_path)
    
    privacy = evaluate_privacy(
        model, attacker, adata, forget_idx, matched_neg_idx, retain_idx
    )
    utility = evaluate_utility(
        model, X_holdout, labels_holdout, marker_idx, gene_names
    )
    
    eval_results[name] = {
        'privacy': privacy,
        'utility': utility,
        'training_time': info['time'],
    }
    
    print(f'  AUC={privacy["mlp_auc"]:.3f}, '
          f'Advantage={privacy["mlp_advantage"]:.3f}, '
          f'ELBO={utility["elbo"]:.1f}, '
          f'Marker r={utility["marker_r"]:.3f}')


Evaluating alpha1.0_seed42...
  AUC=0.724, Advantage=0.448, ELBO=363.9, Marker r=0.832

Evaluating alpha1.0_seed123...
  AUC=0.726, Advantage=0.452, ELBO=363.8, Marker r=0.831

Evaluating alpha1.0_seed456...
  AUC=0.725, Advantage=0.451, ELBO=363.9, Marker r=0.831

Evaluating alpha0.5_seed42...
  AUC=0.730, Advantage=0.460, ELBO=363.9, Marker r=0.832

Evaluating alpha5.0_seed42...
  AUC=0.634, Advantage=0.268, ELBO=364.3, Marker r=0.831


## 3. Summary

In [10]:
# Aggregate alpha=1.0 seeds
alpha1_aucs = []
alpha1_advantages = []
for seed in SEEDS:
    key = f'alpha1.0_seed{seed}'
    if key in eval_results:
        alpha1_aucs.append(eval_results[key]['privacy']['mlp_auc'])
        alpha1_advantages.append(eval_results[key]['privacy']['mlp_advantage'])

print('SSD Results (alpha=1.0, 3 seeds):')
print(f'  AUC: {np.mean(alpha1_aucs):.3f} +/- {np.std(alpha1_aucs):.3f}')
print(f'  Advantage: {np.mean(alpha1_advantages):.3f} +/- {np.std(alpha1_advantages):.3f}')
print()

# Alpha sweep comparison
print('Alpha sweep (seed=42):')
print(f'{"Alpha":<10} {"AUC":>8} {"Advantage":>10} {"Marker r":>10} {"ELBO":>8}')
print('-' * 50)
for alpha in [0.5, 1.0, 5.0]:
    key = f'alpha{alpha}_seed42'
    if key in eval_results:
        r = eval_results[key]
        print(f'{alpha:<10.1f} {r["privacy"]["mlp_auc"]:>8.3f} '
              f'{r["privacy"]["mlp_advantage"]:>10.3f} '
              f'{r["utility"]["marker_r"]:>10.3f} '
              f'{r["utility"]["elbo"]:>8.1f}')

print()
print('Reference: Retrain AUC=0.523, Advantage=0.046')
print('Reference: Baseline AUC=0.783, Advantage=0.565')

SSD Results (alpha=1.0, 3 seeds):
  AUC: 0.725 +/- 0.001
  Advantage: 0.450 +/- 0.002

Alpha sweep (seed=42):
Alpha           AUC  Advantage   Marker r     ELBO
--------------------------------------------------
0.5           0.730      0.460      0.832    363.9
1.0           0.724      0.448      0.832    363.9
5.0           0.634      0.268      0.831    364.3

Reference: Retrain AUC=0.523, Advantage=0.046
Reference: Baseline AUC=0.783, Advantage=0.565


In [11]:
# Save results
output = {
    'method': 'ssd',
    'dataset': 'PBMC',
    'forget_type': 'structured',
    'seeds': SEEDS,
    'alpha_sweep': [0.5, 1.0, 5.0],
    'results': eval_results,
    'summary': {
        'alpha1.0': {
            'mean_auc': float(np.mean(alpha1_aucs)),
            'std_auc': float(np.std(alpha1_aucs)),
            'mean_advantage': float(np.mean(alpha1_advantages)),
            'std_advantage': float(np.std(alpha1_advantages)),
        }
    }
}

with open(OUTPUT_BASE / 'ssd_results.json', 'w') as f:
    json.dump(output, f, indent=2, default=str)

print(f'Saved to {OUTPUT_BASE / "ssd_results.json"}')

Saved to ../outputs/p2/ssd/ssd_results.json
