In [None]:
"""
# LiRA Membership Inference Attack Analysis

This notebook performs comprehensive analysis of membership inference attacks using
the Likelihood Ratio Attack (LiRA) framework.

## Analysis Pipeline:
1. **Two-Mode Evaluation**: Compare target vs shadow threshold strategies
2. **Performance Metrics**: TPR/FPR at operating points, AUC, precision with priors
3. **Vulnerability Analysis**: Identify samples consistently vulnerable to inference
4. **Visualization**: Generate publication-quality figures and tables

## Key Concepts:
- **Target Mode**: Each model uses its own ROC-derived threshold (upper bound)
- **Shadow Mode**: Each model uses median threshold from other models (transferability)
- **Vulnerability**: Samples with FP=0 (never false alarm) and TP>0 (detected when member)

Author: Najeeb Jebreel, optmized by Cloude Sonnet 4.5
Date: 2025
"""

## 1. Import Libraries and Set Up Configurations

import numpy as np
import pandas as pd
import yaml
from pathlib import Path
from sklearn.metrics import roc_auc_score

# Import analysis utilities
from analysis_utils import *
from metrics import *

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

### Configure Analysis Parameters

class Config:
    """Central configuration for analysis."""
    
    # Experiment location
    EXP_PATH = Path("d:/mona/lira_analysis/experiments/cifar10/resnet18/2025-10-20_1828")
    
    # Operating points for evaluation
    TARGET_FPRS = [0.00001, 0.001]  # 0.001% and 0.1% FPR
    
    # Membership priors for precision computation
    PRIORS = [0.01, 0.1, 0.5]  # 1%, 10%, 50% membership rates
    
    # Quality control
    DO_SANITY_CHECKS = True  # Validate threshold extraction
    
    # Attack variants to evaluate
    SCORE_FILES = {
        "LiRA (online)": "online_scores_leave_one_out.npy",
        "LiRA (online, fixed var)": "online_fixed_scores_leave_one_out.npy",
        "LiRA (offline)": "offline_scores_leave_one_out.npy",
        "LiRA (offline, fixed var)": "offline_fixed_scores_leave_one_out.npy",
        "Global threshold": "global_scores_leave_one_out.npy",
    }
    
    LABELS_FILE = "membership_labels.npy"
    
    # Vulnerability analysis
    VULN_ATTACK = "LiRA (online)"
    VULN_FPR = 1e-5  # 0.001% FPR

config = Config()

In [2]:
## 2. Load Experiment Data

# Setup output directory
out_dir = create_output_directory(config.EXP_PATH)
print(f"Output directory: {out_dir}\n")

# Load membership labels and attack scores
print("Loading experiment data...")
labels, scores = load_experiment_data(
    config.EXP_PATH,
    config.SCORE_FILES,
    config.LABELS_FILE
)

M, N = labels.shape
print(f"✓ Loaded {M} models × {N} samples")
print(f"✓ Attacks: {list(scores.keys())}")


Output directory: analysis_results\cifar10\resnet18\2025-10-31_1828

Loading experiment data...
✓ Loaded 10 models × 60000 samples
✓ Attacks: ['LiRA (online)', 'LiRA (online, fixed var)', 'LiRA (offline)', 'LiRA (offline, fixed var)', 'Global threshold']


In [3]:
## 3. Two-Mode Evaluation

### Evaluate Target and Shadow Modes
"""
**Target Mode**: Each model uses its own threshold derived from its ROC curve.
This represents the upper bound on attack performance (attacker knows everything).

**Shadow Mode**: Each model uses the median threshold from other models.
This represents realistic transferability (attacker uses shadow models).
"""


def evaluate_two_modes(labels, scores, target_fprs, priors, do_sanity_checks=True):
    """
    Evaluate attacks in both target and shadow modes.
    
    For each (attack, target_fpr) pair:
    1. Compute per-model target thresholds from individual ROCs
    2. Compute shadow thresholds as median of other models' thresholds
    3. Evaluate confusion matrices at both threshold types
    4. Compute metrics for all membership priors
    """
    M, N = labels.shape
    all_results = []
    
    for attack_name, score_array in scores.items():
        print(f"Evaluating {attack_name}...")
        
        # Precompute AUC (threshold-independent)
        aucs = np.full(M, np.nan)
        for m in range(M):
            try:
                aucs[m] = roc_auc_score(labels[m].astype(int), score_array[m])
            except ValueError:
                pass  # Handle single-class edge case
        
        for target_fpr in target_fprs:
            # Step 1: Compute target thresholds (per-model optimal)
            target_taus = np.empty(M)
            achieved_fprs = np.full(M, np.nan)
            achieved_tprs = np.full(M, np.nan)
            
            for m in range(M):
                tau, fpr_val, tpr_val = find_threshold_at_fpr(
                    score_array[m], labels[m], target_fpr
                )
                target_taus[m] = tau
                if fpr_val is not None:
                    achieved_fprs[m] = fpr_val
                    achieved_tprs[m] = tpr_val
            
            # Step 2: Compute shadow thresholds (median of others)
            shadow_taus = np.array([
                compute_shadow_thresholds(target_taus, m) for m in range(M)
            ])
            
            # Step 3: Optional validation
            if do_sanity_checks:
                finite_models = np.where(np.isfinite(target_taus))[0]
                for m in finite_models[:5]:  # Check first 5
                    is_valid = validate_threshold(
                        score_array[m], labels[m], target_taus[m],
                        achieved_fprs[m], achieved_tprs[m]
                    )
                    if not is_valid:
                        print(f"  [WARNING] Model {m} @ {target_fpr}: validation failed")
            
            # Step 4: Evaluate both modes
            for mode, taus in [('target', target_taus), ('shadow', shadow_taus)]:
                # Skip shadow mode for baseline Global threshold
                if mode == 'shadow' and attack_name == "Global threshold":
                    continue
                
                for m in range(M):
                    tau = taus[m]
                    
                    # Compute confusion matrix
                    if not np.isfinite(tau):
                        tp, fp = 0, 0
                        tn = int(np.sum(~labels[m]))
                        fn = int(np.sum(labels[m]))
                        tpr, fpr_achieved = 0.0, 0.0
                    else:
                        tp, fp, tn, fn, tpr, fpr_achieved = compute_confusion_matrix(
                            score_array[m], labels[m], tau
                        )
                    
                    # Compute precision for each prior
                    for prior in priors:
                        precision = compute_precision_from_rates(tpr, fpr_achieved, prior)
                        
                        all_results.append({
                            'mode': mode,
                            'attack': attack_name,
                            'target_fpr': target_fpr,
                            'achieved_fpr': fpr_achieved,
                            'prior': prior,
                            'model_idx': m,
                            'threshold': tau,
                            'tp': tp, 'fp': fp, 'tn': tn, 'fn': fn,
                            'tpr': tpr,
                            'precision': precision,
                            'auc': aucs[m]
                        })
    
    return pd.DataFrame(all_results)


print("\nEvaluating target and shadow modes...")
detail_df = evaluate_two_modes(
    labels, scores,
    config.TARGET_FPRS,
    config.PRIORS,
    config.DO_SANITY_CHECKS
)

# Save detailed results
detail_path = out_dir / "per_model_metrics_two_modes.csv"
detail_df.to_csv(detail_path, index=False)
print(f"✓ Saved: {detail_path}")


### Aggregate Summary Statistics

def create_summary_statistics(detail_df):
    """Aggregate per-model results into summary statistics."""
    summary = (detail_df
        .groupby(['mode', 'attack', 'target_fpr', 'prior'], as_index=False)
        .agg(
            TPR_Mean=('tpr', 'mean'),
            TPR_Std=('tpr', 'std'),
            FPR_Achieved_Mean=('achieved_fpr', 'mean'),
            FPR_Achieved_Std=('achieved_fpr', 'std'),
            Precision_Mean=('precision', 'mean'),
            Precision_Std=('precision', 'std'),
            AUC_Mean=('auc', 'mean'),
            AUC_Std=('auc', 'std')
        )
    )
    
    # Convert to percentages
    pct_cols = ['TPR_Mean', 'TPR_Std', 'FPR_Achieved_Mean', 'FPR_Achieved_Std',
                'Precision_Mean', 'Precision_Std', 'AUC_Mean', 'AUC_Std']
    for col in pct_cols:
        summary[col] = (summary[col] * 100).round(3)
    
    # Add readable target FPR
    summary['Target FPR (%)'] = (summary['target_fpr'] * 100).round(4)
    summary = summary.drop(columns=['target_fpr'])
    
    # Reorder columns
    summary = summary[[
        'mode', 'attack', 'Target FPR (%)', 'prior',
        'TPR_Mean', 'TPR_Std',
        'FPR_Achieved_Mean', 'FPR_Achieved_Std',
        'Precision_Mean', 'Precision_Std',
        'AUC_Mean', 'AUC_Std'
    ]]
    
    return summary


print("\nGenerating summary statistics...")
summary_df = create_summary_statistics(detail_df)

# Save summary
summary_path = out_dir / "summary_statistics_two_modes.csv"
summary_df.to_csv(summary_path, index=False)
print(f"✓ Saved: {summary_path}")

# Display sample results
print("\nSample Results (Target Mode, Prior=0.5):")
sample = summary_df[
    (summary_df['mode'] == 'target') & 
    (summary_df['prior'] == 0.5)
][['attack', 'Target FPR (%)', 'TPR_Mean', 'AUC_Mean']].head()
print(sample.to_string(index=False))


Evaluating target and shadow modes...
Evaluating LiRA (online)...
Evaluating LiRA (online, fixed var)...
Evaluating LiRA (offline)...
Evaluating LiRA (offline, fixed var)...
Evaluating Global threshold...
✓ Saved: analysis_results\cifar10\resnet18\2025-10-31_1828\per_model_metrics_two_modes.csv

Generating summary statistics...
✓ Saved: analysis_results\cifar10\resnet18\2025-10-31_1828\summary_statistics_two_modes.csv

Sample Results (Target Mode, Prior=0.5):
                   attack  Target FPR (%)  TPR_Mean  AUC_Mean
         Global threshold           0.001     0.003    50.269
         Global threshold           0.100     0.100    50.269
           LiRA (offline)           0.001     0.003    50.095
           LiRA (offline)           0.100     0.108    50.095
LiRA (offline, fixed var)           0.001     0.002    49.666


In [4]:
## 4. Per-Sample Vulnerability Analysis

"""
**Vulnerability Metric**: For each sample, count TP/FP across leave-one-out models.

- **Highly Vulnerable**: FP=0 (never falsely flagged) AND TP>0 (detected when member)
- **Most Vulnerable**: Lowest FP, then highest TP (stable and detectable)
"""

def compute_sample_vulnerability(detail_df, scores, labels, attack_name, target_fpr):
    """
    Compute per-sample confusion statistics across models.
    
    Uses shadow thresholds to evaluate realistic attack scenarios.
    """
    # Extract shadow thresholds for this attack/FPR
    mask = (
        (detail_df['mode'] == 'shadow') &
        (detail_df['attack'] == attack_name) &
        (np.isclose(detail_df['target_fpr'], target_fpr, atol=1e-12))
    )
    
    shadow_info = detail_df.loc[mask, ['model_idx', 'threshold']].drop_duplicates(
        subset=['model_idx']
    )
    
    if shadow_info.empty:
        raise ValueError(f"No shadow thresholds for {attack_name} @ {target_fpr}")
    
    # Build threshold array
    M, N = scores[attack_name].shape
    thresholds = np.full(M, np.inf)
    for _, row in shadow_info.iterrows():
        m = int(row['model_idx'])
        if 0 <= m < M:
            thresholds[m] = float(row['threshold'])
    
    # Generate predictions
    predictions = scores[attack_name] >= thresholds[:, None]  # [M, N]
    labels_bool = labels.astype(bool)
    
    # Count per sample
    tp = np.sum(predictions & labels_bool, axis=0).astype(int)
    fp = np.sum(predictions & ~labels_bool, axis=0).astype(int)
    tn = np.sum(~predictions & ~labels_bool, axis=0).astype(int)
    fn = np.sum(~predictions & labels_bool, axis=0).astype(int)
    
    return pd.DataFrame({
        'sample_id': np.arange(N),
        'tp': tp,
        'fp': fp,
        'tn': tn,
        'fn': fn
    })


def rank_vulnerable_samples(sample_df):
    """Rank by vulnerability: low FP (stable), then high TP (detectable)."""
    ranked = sample_df.sort_values(by=['fp', 'tp'], ascending=[True, False])
    highly_vulnerable = sample_df[(sample_df['fp'] == 0) & (sample_df['tp'] > 0)]
    return ranked, highly_vulnerable


print(f"\nAnalyzing per-sample vulnerability ({config.VULN_ATTACK} @ {config.VULN_FPR})...")
sample_vuln = compute_sample_vulnerability(
    detail_df, scores, labels,
    config.VULN_ATTACK,
    config.VULN_FPR
)

vuln_ranked, highly_vuln = rank_vulnerable_samples(sample_vuln)

# Save rankings
vuln_path = out_dir / "samples_vulnerability_ranked_online_shadow_0p001pct.csv"
vuln_ranked.to_csv(vuln_path, index=False)

high_vuln_path = out_dir / "samples_highly_vulnerable_online_shadow_0p001pct.csv"
highly_vuln.to_csv(high_vuln_path, index=False)

print(f"✓ Saved: {vuln_path}")
print(f"✓ Saved: {high_vuln_path}")
print(f"\nStatistics:")
print(f"  Total samples: {len(vuln_ranked)}")
print(f"  Highly vulnerable (FP=0, TP>0): {len(highly_vuln)}")

if len(vuln_ranked) > 0:
    top = vuln_ranked.iloc[0]
    print(f"  Most vulnerable: TP={top['tp']}, FP={top['fp']}")


Analyzing per-sample vulnerability (LiRA (online) @ 1e-05)...
✓ Saved: analysis_results\cifar10\resnet18\2025-10-31_1828\samples_vulnerability_ranked_online_shadow_0p001pct.csv
✓ Saved: analysis_results\cifar10\resnet18\2025-10-31_1828\samples_highly_vulnerable_online_shadow_0p001pct.csv

Statistics:
  Total samples: 60000
  Highly vulnerable (FP=0, TP>0): 5
  Most vulnerable: TP=2, FP=0


In [5]:
## 5. Visualization

print("\nGenerating visualization...")
cfg_path = config.EXP_PATH / "attack_config.yaml"

if cfg_path.exists():
    with open(cfg_path, 'r') as f:
        exp_config = yaml.safe_load(f)
    
    full_dataset, _ = load_dataset(exp_config)
    
    display_top_k_vulnerable_samples(
        vulnerable_samples=vuln_ranked,
        full_dataset=full_dataset,
        k=20,
        nrow=5,
        out_dir=out_dir,
        save_name="top20_vulnerable_online_shadow_0p001pct.png",
        font_size=7,
        badge_margin=2,
        overhang_left=3,
        overhang_up=4
    )
    print("✓ Saved vulnerability visualization")
else:
    print("⚠ Config not found, skipping visualization")



Generating visualization...
Saved grid: analysis_results\cifar10\resnet18\2025-10-31_1828\top20_vulnerable_online_shadow_0p001pct.png
✓ Saved vulnerability visualization


In [6]:
## 6. Summary

print("\n" + "="*60)
print("ANALYSIS COMPLETE")
print("="*60)
print(f"\nOutput directory: {out_dir}")
print(f"\nGenerated files:")
print(f"  • per_model_metrics_two_modes.csv - Detailed per-model results")
print(f"  • summary_statistics_two_modes.csv - Aggregated statistics")
print(f"  • samples_vulnerability_ranked_online_shadow_0p001pct.csv - All samples ranked")
print(f"  • samples_highly_vulnerable_online_shadow_0p001pct.csv - High-risk samples")
print(f"  • top20_vulnerable_online_shadow_0p001pct.png - Visualization")
print("\n✓ All analyses completed successfully!")


ANALYSIS COMPLETE

Output directory: analysis_results\cifar10\resnet18\2025-10-31_1828

Generated files:
  • per_model_metrics_two_modes.csv - Detailed per-model results
  • summary_statistics_two_modes.csv - Aggregated statistics
  • samples_vulnerability_ranked_online_shadow_0p001pct.csv - All samples ranked
  • samples_highly_vulnerable_online_shadow_0p001pct.csv - High-risk samples
  • top20_vulnerable_online_shadow_0p001pct.png - Visualization

✓ All analyses completed successfully!
