# Day 70: Membership Inference Defense

Membership Inference Attacks (MIA) allow an adversary to determine if a specific data point was used to train a target model. 

In this lab, we implement:
1. **MIAttacker**: A threshold-based attack that exploits model overfitting.
2. **MIDefender**: An inference-time defense using **Output Perturbation** (Laplacian noise).

In [None]:
import sys
import os
import numpy as np

# Add root directory to sys.path
sys.path.append(os.path.abspath('../../'))

from src.privacy.membership_inference import MIAttacker, MIDefender
from sklearn.ensemble import RandomForestClassifier

## 1. Simulate Overfitting (Vulnerability)

We use a `MockModel` that is highly confident on training data and less so on test data.

In [None]:
class VulnerableModel:
    def predict_proba(self, X):
        # Mock: High confidence (0.99) for members (X[0] < 0.5), Low (0.6) for non-members
        return np.array([[0.01, 0.99] if x[0] < 0.5 else [0.4, 0.6] for x in X])

model = VulnerableModel()
members_X = np.array([[0.1, 0.1] for _ in range(50)])
non_members_X = np.array([[0.9, 0.9] for _ in range(50)])
y = np.ones(50, dtype=int)

## 2. Execute Attack (No Defense)

The attacker predicts 'Member' if confidence is > 0.8.

In [None]:
attacker = MIAttacker()
member_preds = attacker.attack_threshold_based(model, members_X, y, threshold=0.8)
non_member_preds = attacker.attack_threshold_based(model, non_members_X, y, threshold=0.8)

print("Undefended Results:", attacker.evaluate_attack(member_preds, non_member_preds))

## 3. Apply Defense (Output Perturbation)

We add noise to the output probabilities to mask the membership signal.

In [None]:
defender = MIDefender(noise_scale=0.25)

class DefendedModel:
    def predict_proba(self, X):
        probs = model.predict_proba(X)
        return defender.perturb_outputs(probs)

defended_model = DefendedModel()
member_preds_def = attacker.attack_threshold_based(defended_model, members_X, y, threshold=0.8)
non_member_preds_def = attacker.attack_threshold_based(defended_model, non_members_X, y, threshold=0.8)

print("Defended Results:", attacker.evaluate_attack(member_preds_def, non_member_preds_def))