# 22. Contrastive Latent Unlearning Experiments

A VAE-specific unlearning approach: push forget-set latent representations
toward the prior N(0, I) while preserving retain-set representations.
If forget samples map to the prior, they become indistinguishable from
random noise and carry no membership signal.

Phase 1: Contrastive training (encoder only)
Phase 2: Retain fine-tuning (full model)

Tested on PBMC structured forget set with 3 seeds.

In [1]:
import sys
sys.path.insert(0, '../src')

import json
import numpy as np
import torch
from pathlib import Path
import time

from train_contrastive_unlearn import train_contrastive

DATA_PATH = '../data/adata_processed.h5ad'
SPLIT_PATH = '../outputs/p1/split_structured.json'
BASELINE_CKPT = '../outputs/p1/baseline/best_model.pt'
OUTPUT_BASE = Path('../outputs/p2/contrastive')
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device: {DEVICE}')

Device: cpu


## 1. Train Contrastive Unlearning with 3 Seeds

Hyperparameters:
- gamma=1.0 (forget prior-matching weight)
- lam=1.0 (retain preservation weight)
- 20 contrastive epochs, encoder only, lr=1e-4
- 10 retain fine-tuning epochs, full model, lr=1e-4

Also sweep gamma in {0.1, 1.0, 10.0} with seed=42.

In [2]:
SEEDS = [42, 123, 456]

results = {}

# Multi-seed with default gamma=1.0
for seed in SEEDS:
    out_dir = OUTPUT_BASE / f'gamma1.0_seed{seed}'
    print(f'\n{"="*60}')
    print(f'Contrastive gamma=1.0, seed={seed}')
    print(f'{"="*60}')
    
    t0 = time.time()
    ckpt_path = train_contrastive(
        baseline_checkpoint=BASELINE_CKPT,
        data_path=DATA_PATH,
        split_path=SPLIT_PATH,
        output_dir=str(out_dir),
        gamma=1.0,
        lam=1.0,
        n_epochs=20,
        lr=1e-4,
        finetune_epochs=10,
        finetune_lr=1e-4,
        patience=10,
        batch_size=256,
        seed=seed,
    )
    elapsed = time.time() - t0
    results[f'gamma1.0_seed{seed}'] = {'path': str(ckpt_path), 'time': elapsed}
    print(f'Done in {elapsed:.1f}s')

# Gamma sweep with seed=42
for gamma in [0.1, 10.0]:
    out_dir = OUTPUT_BASE / f'gamma{gamma}_seed42'
    print(f'\n{"="*60}')
    print(f'Contrastive gamma={gamma}, seed=42')
    print(f'{"="*60}')
    
    t0 = time.time()
    ckpt_path = train_contrastive(
        baseline_checkpoint=BASELINE_CKPT,
        data_path=DATA_PATH,
        split_path=SPLIT_PATH,
        output_dir=str(out_dir),
        gamma=gamma,
        lam=1.0,
        n_epochs=20,
        lr=1e-4,
        finetune_epochs=10,
        finetune_lr=1e-4,
        patience=10,
        batch_size=256,
        seed=42,
    )
    elapsed = time.time() - t0
    results[f'gamma{gamma}_seed42'] = {'path': str(ckpt_path), 'time': elapsed}
    print(f'Done in {elapsed:.1f}s')

print(f'\nAll training complete. {len(results)} checkpoints saved.')


Contrastive gamma=1.0, seed=42
Data: torch.Size([33088, 2000]), Forget: 30, Retain: 28094, Device: cpu
Phase 1: Contrastive training (20 epochs, gamma=1.0, lam=1.0)...
  Epoch 1: forget_kl=1.7475, retain_dist=0.1564
  Epoch 5: forget_kl=0.0772, retain_dist=0.0192
  Epoch 10: forget_kl=0.0420, retain_dist=0.0155
  Epoch 15: forget_kl=0.0241, retain_dist=0.0138
  Epoch 20: forget_kl=0.0122, retain_dist=0.0129
Phase 2: Retain fine-tuning (10 epochs, lr=0.0001)...
  Epoch 1: train=366.31, val=358.18
  Epoch 5: train=365.74, val=357.85
  Epoch 10: train=365.55, val=357.46
Saved to ../outputs/p2/contrastive/gamma1.0_seed42/best_model.pt
Done in 161.2s

Contrastive gamma=1.0, seed=123
Data: torch.Size([33088, 2000]), Forget: 30, Retain: 28094, Device: cpu
Phase 1: Contrastive training (20 epochs, gamma=1.0, lam=1.0)...
  Epoch 1: forget_kl=1.7542, retain_dist=0.1563
  Epoch 5: forget_kl=0.0784, retain_dist=0.0194
  Epoch 10: forget_kl=0.0429, retain_dist=0.0156
  Epoch 15: forget_kl=0.0249, 

## 2. Evaluate with Canonical Fresh Attacker

In [3]:
sys.path.insert(0, '../scripts')
from eval_multiseed import (
    load_vae_model, train_fresh_attacker, evaluate_privacy,
    evaluate_utility, MARKER_GENES
)
import scanpy as sc

adata = sc.read_h5ad(DATA_PATH)
X = torch.tensor(
    adata.X.toarray() if hasattr(adata.X, 'toarray') else adata.X,
    dtype=torch.float32
)

with open(SPLIT_PATH) as f:
    split = json.load(f)
forget_idx = split['forget_indices']
retain_idx = split['retain_indices']
unseen_idx = split['unseen_indices']

with open('../outputs/p1.5/s1_matched_negatives.json') as f:
    matched_data = json.load(f)
matched_neg_idx = matched_data['matched_indices']

X_holdout = X[unseen_idx]
labels_holdout = adata.obs['leiden'].values[unseen_idx]
gene_names = list(adata.var_names)
marker_idx = [gene_names.index(g) for g in MARKER_GENES if g in gene_names]
marker_names = [g for g in MARKER_GENES if g in gene_names]

baseline_model, _ = load_vae_model(BASELINE_CKPT)
attacker = train_fresh_attacker(
    baseline_model, adata, forget_idx, matched_neg_idx, retain_idx, seed=42
)

Training fresh attacker on baseline F vs matched:
  Samples: 224 (30 forget + 194 matched)
  Features: 70 dims
  Train: 179, Test: 45
  Baseline AUC (F vs matched, full set): 0.7792
  (Canonical NB03 value: ~0.769)


In [4]:
eval_results = {}

for name, info in results.items():
    ckpt_path = info['path']
    print(f'\nEvaluating {name}...')
    
    model, config = load_vae_model(ckpt_path)
    
    privacy = evaluate_privacy(
        model, attacker, adata, forget_idx, matched_neg_idx, retain_idx
    )
    utility = evaluate_utility(
        model, X_holdout, labels_holdout, marker_idx, gene_names
    )
    
    eval_results[name] = {
        'privacy': privacy,
        'utility': utility,
        'training_time': info['time'],
    }
    
    print(f'  AUC={privacy["mlp_auc"]:.3f}, '
          f'Advantage={privacy["mlp_advantage"]:.3f}, '
          f'ELBO={utility["elbo"]:.1f}, '
          f'Marker r={utility["marker_r"]:.3f}')


Evaluating gamma1.0_seed42...


Found Intel OpenMP ('libiomp') and LLVM OpenMP ('libomp') loaded at
the same time. Both libraries are known to be incompatible and this
can cause random crashes or deadlocks on Linux when loaded in the
same Python program.
Using threadpoolctl may cause crashes or deadlocks. For more
information and possible workarounds, please see
    https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md



  AUC=0.166, Advantage=0.667, ELBO=364.1, Marker r=0.832

Evaluating gamma1.0_seed123...
  AUC=0.182, Advantage=0.635, ELBO=363.9, Marker r=0.831

Evaluating gamma1.0_seed456...
  AUC=0.109, Advantage=0.782, ELBO=364.2, Marker r=0.831

Evaluating gamma0.1_seed42...
  AUC=0.107, Advantage=0.785, ELBO=364.0, Marker r=0.831

Evaluating gamma10.0_seed42...
  AUC=0.267, Advantage=0.465, ELBO=364.0, Marker r=0.832


## 3. Summary

In [5]:
# Aggregate gamma=1.0 seeds
gamma1_aucs = []
gamma1_advantages = []
for seed in SEEDS:
    key = f'gamma1.0_seed{seed}'
    if key in eval_results:
        gamma1_aucs.append(eval_results[key]['privacy']['mlp_auc'])
        gamma1_advantages.append(eval_results[key]['privacy']['mlp_advantage'])

print('Contrastive Results (gamma=1.0, 3 seeds):')
print(f'  AUC: {np.mean(gamma1_aucs):.3f} +/- {np.std(gamma1_aucs):.3f}')
print(f'  Advantage: {np.mean(gamma1_advantages):.3f} +/- {np.std(gamma1_advantages):.3f}')
print()

print('Gamma sweep (seed=42):')
print(f'{"Gamma":<10} {"AUC":>8} {"Advantage":>10} {"Marker r":>10} {"ELBO":>8}')
print('-' * 50)
for gamma in [0.1, 1.0, 10.0]:
    key = f'gamma{gamma}_seed42'
    if key in eval_results:
        r = eval_results[key]
        print(f'{gamma:<10.1f} {r["privacy"]["mlp_auc"]:>8.3f} '
              f'{r["privacy"]["mlp_advantage"]:>10.3f} '
              f'{r["utility"]["marker_r"]:>10.3f} '
              f'{r["utility"]["elbo"]:>8.1f}')

print()
print('Reference: Retrain AUC=0.523, Advantage=0.046')
print('Reference: Baseline AUC=0.783, Advantage=0.565')

Contrastive Results (gamma=1.0, 3 seeds):
  AUC: 0.153 +/- 0.032
  Advantage: 0.695 +/- 0.063

Gamma sweep (seed=42):
Gamma           AUC  Advantage   Marker r     ELBO
--------------------------------------------------
0.1           0.107      0.785      0.831    364.0
1.0           0.166      0.667      0.832    364.1
10.0          0.267      0.465      0.832    364.0

Reference: Retrain AUC=0.523, Advantage=0.046
Reference: Baseline AUC=0.783, Advantage=0.565


In [6]:
output = {
    'method': 'contrastive',
    'dataset': 'PBMC',
    'forget_type': 'structured',
    'seeds': SEEDS,
    'gamma_sweep': [0.1, 1.0, 10.0],
    'results': eval_results,
    'summary': {
        'gamma1.0': {
            'mean_auc': float(np.mean(gamma1_aucs)),
            'std_auc': float(np.std(gamma1_aucs)),
            'mean_advantage': float(np.mean(gamma1_advantages)),
            'std_advantage': float(np.std(gamma1_advantages)),
        }
    }
}

with open(OUTPUT_BASE / 'contrastive_results.json', 'w') as f:
    json.dump(output, f, indent=2, default=str)

print(f'Saved to {OUTPUT_BASE / "contrastive_results.json"}')

Saved to ../outputs/p2/contrastive/contrastive_results.json


## 4. Analysis

Contrastive latent unlearning produces a Streisand effect. At gamma=1.0 across three seeds, MIA AUC drops to 0.153 +/- 0.032, well below the 0.5 chance line. Because the advantage metric is direction-agnostic, this translates to 0.695, worse than the baseline's 0.565. Lower gamma (0.1) makes things worse (AUC=0.107, advantage=0.785). Higher gamma (10.0) partially mitigates the problem (AUC=0.267, advantage=0.465) but still fails.

What happens is not surprising in hindsight. Pushing forget-set posteriors toward N(0,I) puts them in a region of latent space that no retain-set sample occupies. The attacker picks up on this immediately: samples near the prior are forget samples. Stronger push, more obvious signal.

Utility is essentially unchanged (marker r=0.831-0.832, ELBO=363.9-364.2). The contrastive phase only modifies the encoder for forget-set inputs and leaves the decoder and retain-set representations alone. Held-out reconstruction is fine because those pathways were never touched.

The takeaway for the paper: latent-space manipulation without whole-model adjustment is counterproductive. Any method that moves forget samples to a distinctive location in representation space, whether the prior or some other target, creates exactly the artifact that membership inference detects. Information needs to be removed, not relocated.