<a href="https://colab.research.google.com/github/jgracie52/bh-2025/blob/main/NeuralNetworks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Neural Networks and AI Security: A Comprehensive Lab

## Welcome to the Neural Network Security Lab!

This lab is designed to teach you both how neural networks work AND how they can be attacked and defended. We'll start with the basics and progressively build up to advanced security concepts.

### What You'll Learn:
1. **Neural Network Fundamentals** - How they actually work
2. **Classic Attacks** - FGSM and basic adversarial examples
3. **Advanced Attacks** - State-of-the-art techniques from 2024-2025
4. **Privacy Attacks** - How models leak training data
5. **Defense Mechanisms** - How to protect your models
6. **Privacy-Preserving Techniques** - Differential privacy and more
7. **Advanced Security Topics** - Backdoor detection and model watermarking

### Prerequisites:
- Basic Python knowledge
- Understanding of basic ML concepts (we'll review the rest!)

Let's get started! 🚀

## Part 0: Environment Setup

First, let's install all the packages we'll need. This might take a few minutes.

**What we're installing:**
- `torch` & `torchvision`: PyTorch for building neural networks
- `adversarial-robustness-toolbox`: IBM's toolkit for adversarial attacks
- `matplotlib` & `seaborn`: For visualizations
- Additional security-focused libraries

In [None]:
# Install required packages
!pip install torch torchvision torchaudio --quiet
!pip install adversarial-robustness-toolbox --quiet
!pip install matplotlib seaborn numpy scikit-learn --quiet
!pip install ipywidgets --quiet
!pip install opencv-python scikit-image scipy --quiet

print("✅ All packages installed successfully!")

In [None]:
# Import all necessary libraries
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import torchvision
import torchvision.transforms as transforms
from sklearn.datasets import make_classification, make_moons
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print("✅ Environment ready!")

## Part 1: Understanding Neural Networks

### 1.1 What is a Neural Network?

Think of a neural network as a **decision-making machine** that learns from examples. Just like how you learned to recognize cats vs. dogs by seeing many examples, neural networks learn patterns from data.

**Key Components:**
1. **Neurons**: Basic units that receive inputs and produce outputs
2. **Weights**: The "importance" given to each connection
3. **Activation Functions**: Add non-linearity (like ReLU, Sigmoid)
4. **Layers**: Groups of neurons working together

Let's build our first neural network!

In [None]:
# Let's create a simple neural network from scratch!

class SimpleNeuralNetwork(nn.Module):
    """
    A simple 3-layer neural network for binary classification.

    Architecture:
    - Input layer: 2 features (x, y coordinates)
    - Hidden layer: 10 neurons with ReLU activation
    - Output layer: 2 neurons (for 2 classes)
    """

    def __init__(self, input_size=2, hidden_size=10, output_size=2):
        super(SimpleNeuralNetwork, self).__init__()
        # Define layers
        self.hidden = nn.Linear(input_size, hidden_size)  # 2 inputs -> 10 hidden neurons
        self.output = nn.Linear(hidden_size, output_size)  # 10 hidden -> 2 outputs
        self.activation = nn.ReLU()  # ReLU activation function

    def forward(self, x):
        """
        Forward pass: how data flows through the network
        """
        # Step 1: Input -> Hidden layer
        x = self.hidden(x)

        # Step 2: Apply activation function (ReLU)
        x = self.activation(x)

        # Step 3: Hidden -> Output layer
        x = self.output(x)

        return x

# Create an instance of our network
model = SimpleNeuralNetwork()
print("Neural Network Architecture:")
print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters())}")

### 1.2 Creating a Dataset

Let's create a simple 2D dataset that our network will learn to classify. We'll use the "two moons" dataset - it creates two crescent moon shapes that are interleaved.

In [None]:
# Generate the "two moons" dataset
X, y = make_moons(n_samples=1000, noise=0.1, random_state=42)

# Split into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Convert to PyTorch tensors
X_train_tensor = torch.FloatTensor(X_train)
y_train_tensor = torch.LongTensor(y_train)
X_test_tensor = torch.FloatTensor(X_test)
y_test_tensor = torch.LongTensor(y_test)

# Visualize the dataset
plt.figure(figsize=(10, 6))
plt.scatter(X_train[y_train==0, 0], X_train[y_train==0, 1], alpha=0.6, label='Class 0', c='blue')
plt.scatter(X_train[y_train==1, 0], X_train[y_train==1, 1], alpha=0.6, label='Class 1', c='red')
plt.title('Two Moons Dataset')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

print(f"Training samples: {len(X_train)}")
print(f"Test samples: {len(X_test)}")

### 1.3 Training the Neural Network

Now let's train our network! Training involves:
1. **Forward Pass**: Feed data through the network
2. **Calculate Loss**: Measure how wrong our predictions are
3. **Backward Pass**: Calculate gradients (how to improve)
4. **Update Weights**: Make the network a bit better

We repeat this process many times (epochs) until the network learns!

In [None]:
# Training function with detailed explanations
def train_model(model, X_train, y_train, epochs=100, learning_rate=0.01):
    """
    Train a neural network model.

    Parameters:
    - model: The neural network
    - X_train, y_train: Training data
    - epochs: Number of training iterations
    - learning_rate: How big steps to take when updating weights
    """
    # Define loss function (CrossEntropyLoss for classification)
    criterion = nn.CrossEntropyLoss()

    # Define optimizer (Adam is a popular choice)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Track losses for visualization
    losses = []

    # Training loop
    for epoch in range(epochs):
        # Forward pass
        outputs = model(X_train)

        # Calculate loss
        loss = criterion(outputs, y_train)

        # Backward pass
        optimizer.zero_grad()  # Clear previous gradients
        loss.backward()        # Calculate gradients
        optimizer.step()       # Update weights

        # Store loss
        losses.append(loss.item())

        # Print progress every 10 epochs
        if (epoch + 1) % 10 == 0:
            print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')

    return losses

# Train our model
print("Training the neural network...")
losses = train_model(model, X_train_tensor, y_train_tensor)

# Plot the training loss
plt.figure(figsize=(10, 5))
plt.plot(losses)
plt.title('Training Loss Over Time')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True, alpha=0.3)
plt.show()

print("\n✅ Training complete!")

### 1.4 Visualizing Decision Boundaries

One of the coolest things about neural networks is that they learn complex decision boundaries. Let's visualize what our network learned!

In [None]:
def plot_decision_boundary(model, X, y, title="Decision Boundary"):
    """
    Visualize the decision boundary learned by the neural network.

    The background colors show what the network predicts for each point in space.
    """
    # Create a mesh grid
    x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
    y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
    xx, yy = np.meshgrid(np.linspace(x_min, x_max, 200),
                         np.linspace(y_min, y_max, 200))

    # Get predictions for every point in the mesh
    grid_points = torch.FloatTensor(np.c_[xx.ravel(), yy.ravel()])
    with torch.no_grad():
        Z = model(grid_points).argmax(dim=1).numpy()
    Z = Z.reshape(xx.shape)

    # Plot
    plt.figure(figsize=(10, 8))
    plt.contourf(xx, yy, Z, alpha=0.4, cmap='RdBu')
    plt.scatter(X[y==0, 0], X[y==0, 1], c='blue', edgecolor='black', alpha=0.6)
    plt.scatter(X[y==1, 0], X[y==1, 1], c='red', edgecolor='black', alpha=0.6)
    plt.xlabel('Feature 1')
    plt.ylabel('Feature 2')
    plt.title(title)
    plt.grid(True, alpha=0.3)
    plt.show()

# Visualize what our network learned
plot_decision_boundary(model, X_train, y_train, "Neural Network Decision Boundary")

# Evaluate accuracy
with torch.no_grad():
    predictions = model(X_test_tensor).argmax(dim=1)
    accuracy = (predictions == y_test_tensor).float().mean()
    print(f"Test Accuracy: {accuracy:.2%}")

## Part 2: Understanding Backpropagation

### How Neural Networks Learn: The Backpropagation Algorithm

Backpropagation is like teaching a student by showing them their mistakes and how to fix them. Here's the process:

1. **Forward Pass**: Make a prediction
2. **Calculate Error**: See how wrong we were
3. **Backward Pass**: Figure out which weights contributed to the error
4. **Update Weights**: Adjust weights to reduce error next time

Let's implement this from scratch to really understand it!

In [None]:
# Manual implementation of backpropagation for educational purposes
class ManualNeuralNetwork:
    """
    A neural network implemented from scratch to demonstrate backpropagation.
    This helps understand what PyTorch does automatically!
    """

    def __init__(self, input_size=2, hidden_size=3, output_size=1):
        # Initialize weights with small random values
        self.W1 = np.random.randn(input_size, hidden_size) * 0.5
        self.b1 = np.zeros((1, hidden_size))
        self.W2 = np.random.randn(hidden_size, output_size) * 0.5
        self.b2 = np.zeros((1, output_size))

        print("🧠 Neural Network Architecture:")
        print(f"Input Layer: {input_size} neurons")
        print(f"Hidden Layer: {hidden_size} neurons")
        print(f"Output Layer: {output_size} neurons")
        print("\nInitial Weights:")
        print(f"W1 shape: {self.W1.shape}")
        print(f"W2 shape: {self.W2.shape}")

    def sigmoid(self, x):
        """Sigmoid activation function: squashes values between 0 and 1"""
        return 1 / (1 + np.exp(-x))

    def sigmoid_derivative(self, x):
        """Derivative of sigmoid: tells us how to adjust weights"""
        return x * (1 - x)

    def forward(self, X):
        """Forward pass: compute predictions"""
        # Input -> Hidden
        self.z1 = np.dot(X, self.W1) + self.b1
        self.a1 = self.sigmoid(self.z1)

        # Hidden -> Output
        self.z2 = np.dot(self.a1, self.W2) + self.b2
        self.a2 = self.sigmoid(self.z2)

        return self.a2

    def backward(self, X, y, output, learning_rate=0.1):
        """Backward pass: calculate gradients and update weights"""
        m = X.shape[0]  # Number of examples

        # Calculate gradients for output layer
        self.dz2 = output - y  # Error at output
        self.dW2 = (1/m) * np.dot(self.a1.T, self.dz2)
        self.db2 = (1/m) * np.sum(self.dz2, axis=0, keepdims=True)

        # Calculate gradients for hidden layer
        self.da1 = np.dot(self.dz2, self.W2.T)
        self.dz1 = self.da1 * self.sigmoid_derivative(self.a1)
        self.dW1 = (1/m) * np.dot(X.T, self.dz1)
        self.db1 = (1/m) * np.sum(self.dz1, axis=0, keepdims=True)

        # Update weights and biases
        self.W2 -= learning_rate * self.dW2
        self.b2 -= learning_rate * self.db2
        self.W1 -= learning_rate * self.dW1
        self.b1 -= learning_rate * self.db1

# Demonstrate backpropagation on a simple XOR problem
print("🎯 Let's learn the XOR function!")
print("XOR Truth Table:")
print("0 XOR 0 = 0")
print("0 XOR 1 = 1")
print("1 XOR 0 = 1")
print("1 XOR 1 = 0")
print("\nThis is impossible for a linear model but easy for a neural network!\n")

# XOR dataset
X_xor = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
y_xor = np.array([[0], [1], [1], [0]])

# Create and train network
nn = ManualNeuralNetwork(2, 4, 1)

# Training loop with visualization
losses = []
for epoch in range(5000):
    # Forward pass
    output = nn.forward(X_xor)

    # Calculate loss
    loss = np.mean((output - y_xor)**2)
    losses.append(loss)

    # Backward pass
    nn.backward(X_xor, y_xor, output)

    # Print progress
    if epoch % 1000 == 0:
        print(f"Epoch {epoch}, Loss: {loss:.4f}")

# Test the trained network
print("\n🎉 Testing our trained network:")
for i in range(len(X_xor)):
    prediction = nn.forward(X_xor[i:i+1])
    print(f"Input: {X_xor[i]} → Prediction: {prediction[0,0]:.4f} → Rounded: {int(prediction[0,0] > 0.5)}")

# Plot learning curve
plt.figure(figsize=(10, 5))
plt.plot(losses)
plt.title('Learning the XOR Function')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.yscale('log')
plt.grid(True, alpha=0.3)
plt.show()

## Part 3: Introduction to Adversarial Examples

### 🚨 Security Alert: Neural Networks Can Be Fooled!

Adversarial examples are inputs designed to trick neural networks. It's like optical illusions for AI - small, often invisible changes that completely fool the model.

**Why This Matters:**
- Self-driving cars could misread stop signs
- Medical diagnosis systems could be manipulated
- Security systems could be bypassed

Let's see how this works!

In [None]:
# Load MNIST dataset for adversarial examples
print("Loading MNIST dataset...")

# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load training data
trainset = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform
)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)

# Load test data
testset = torchvision.datasets.MNIST(
    root='./data',
    train=False,
    download=True,
    transform=transform
)
testloader = DataLoader(testset, batch_size=1, shuffle=True)

print("✅ MNIST dataset loaded!")
print(f"Training samples: {len(trainset)}")
print(f"Test samples: {len(testset)}")

# Show some examples
fig, axes = plt.subplots(1, 5, figsize=(12, 3))
for i, (image, label) in enumerate(trainloader):
    if i >= 5:
        break
    axes[i].imshow(image[0].squeeze(), cmap='gray')
    axes[i].set_title(f'Label: {label[0].item()}')
    axes[i].axis('off')
plt.suptitle('Sample MNIST Images')
plt.show()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


# Build a CNN for MNIST
class MNISTClassifier(nn.Module):
    """
    A Convolutional Neural Network for classifying handwritten digits.

    Architecture:
    - Conv Layer 1: 32 filters, 3x3 kernel
    - Conv Layer 2: 64 filters, 3x3 kernel
    - Fully Connected: 128 neurons
    - Output: 10 classes (digits 0-9)
    """

    def __init__(self):
        super(MNISTClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7)  # Flatten
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# Create and train the model (simplified training for speed)
mnist_model = MNISTClassifier().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(mnist_model.parameters(), lr=0.001)

print("Training MNIST classifier...")
mnist_model.train()
for epoch in range(2):  # Just 2 epochs for demonstration
    running_loss = 0.0
    for i, (images, labels) in enumerate(trainloader):
        if i > 100:  # Limit batches for speed
            break

        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = mnist_model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f'Epoch {epoch+1}, Loss: {running_loss/100:.3f}')

print("✅ Model trained!")

# Test accuracy
mnist_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for i, (images, labels) in enumerate(testloader):
        if i > 100:
            break
        images, labels = images.to(device), labels.to(device)
        outputs = mnist_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Test Accuracy: {100 * correct / total:.2f}%')

### 3.1 FGSM Attack - Fast Gradient Sign Method

FGSM is one of the first and simplest adversarial attacks. Here's how it works:

1. Calculate how to change the input to increase the loss
2. Take a small step in that direction
3. The result looks almost identical but fools the model!

**Mathematical Formula:**
```
adversarial_image = original_image + ε × sign(gradient)
```

Where:
- ε (epsilon) = how much to change the image
- gradient = direction that increases the loss

In [None]:
def fgsm_attack(image, epsilon, data_grad):
    """
    FGSM Attack: Create adversarial example

    Args:
        image: Original input image
        epsilon: Maximum perturbation allowed
        data_grad: Gradient of loss w.r.t input

    Returns:
        Perturbed image that fools the model
    """
    # Get the sign of the gradient
    sign_data_grad = data_grad.sign()

    # Create the perturbation
    perturbed_image = image + epsilon * sign_data_grad

    # Clip to maintain valid image range
    perturbed_image = torch.clamp(perturbed_image, -1, 1)

    return perturbed_image

def demonstrate_fgsm_attack(model, device, test_loader, epsilon):
    """
    Demonstrate FGSM attack on a single image
    """
    model.eval()

    # Get a test image
    data, target = next(iter(test_loader))
    data, target = data.to(device), target.to(device)

    # Ensure we can calculate gradients
    data.requires_grad = True

    # Forward pass
    output = model(data)
    init_pred = output.max(1, keepdim=True)[1]

    # If prediction is wrong, skip
    if init_pred.item() != target.item():
        print("Model already misclassified this image, trying another...")
        return demonstrate_fgsm_attack(model, device, test_loader, epsilon)

    # Calculate loss
    loss = F.nll_loss(output, target)

    # Backward pass
    model.zero_grad()
    loss.backward()

    # Collect gradient
    data_grad = data.grad.data

    # Create adversarial example
    perturbed_data = fgsm_attack(data, epsilon, data_grad)

    # Re-classify perturbed image
    output = model(perturbed_data)
    final_pred = output.max(1, keepdim=True)[1]

    # Visualize results
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # Original image - use detach() before numpy()
    axes[0].imshow(data.cpu().detach().squeeze().numpy(), cmap='gray')
    axes[0].set_title(f'Original\nPrediction: {init_pred.item()}')
    axes[0].axis('off')

    # Perturbation (amplified for visibility) - use detach() before numpy()
    perturbation = (perturbed_data - data).cpu().detach().squeeze().numpy()
    axes[1].imshow(perturbation * 10 + 0.5, cmap='RdBu')  # Amplified 10x
    axes[1].set_title(f'Perturbation\n(10x amplified)')
    axes[1].axis('off')

    # Adversarial image - use detach() before numpy()
    axes[2].imshow(perturbed_data.cpu().detach().squeeze().numpy(), cmap='gray')
    axes[2].set_title(f'Adversarial\nPrediction: {final_pred.item()}')
    axes[2].axis('off')

    plt.suptitle(f'FGSM Attack with ε = {epsilon}', fontsize=16)
    plt.tight_layout()
    plt.show()

    # Print confidence scores
    print("\n📊 Confidence Scores:")
    print(f"Original image confidence: {F.softmax(output, dim=1)[0, init_pred].item():.2%}")
    print(f"Adversarial image confidence: {F.softmax(output, dim=1)[0, final_pred].item():.2%}")

    return init_pred.item() != final_pred.item()

# Demonstrate FGSM with different epsilon values
print("🎯 FGSM Attack Demonstration\n")
epsilons = [0, 0.05, 0.1, 0.15, 0.2]

for eps in epsilons[1:]:  # Skip 0
    print(f"\n{'='*50}")
    print(f"Testing with epsilon = {eps}")
    success = demonstrate_fgsm_attack(mnist_model, device, testloader, eps)
    if success:
        print("✅ Attack successful!")
    else:
        print("❌ Attack failed")

## Part 4: Advanced Adversarial Attacks

### 4.1 Projected Gradient Descent (PGD)

PGD is like FGSM on steroids. Instead of taking one big step, it takes many small steps, checking each time that we don't go too far from the original image.

**Why PGD is stronger:**
- Multiple iterations allow finding better adversarial examples
- Stays within a specified distance from the original
- Much harder to defend against

In [None]:
def pgd_attack(model, images, labels, epsilon=0.3, alpha=0.01, num_iter=40):
    """
    Projected Gradient Descent Attack

    This is a stronger iterative version of FGSM.
    Think of it as taking many small steps instead of one big leap.

    Args:
        model: Target model to attack
        images: Original images
        labels: True labels
        epsilon: Maximum allowed perturbation
        alpha: Step size for each iteration
        num_iter: Number of attack iterations
    """
    # Make a copy of the original images
    perturbed_images = images.clone().detach()

    for i in range(num_iter):
        # Enable gradient calculation
        perturbed_images.requires_grad = True

        # Forward pass
        outputs = model(perturbed_images)
        loss = F.cross_entropy(outputs, labels)

        # Backward pass
        model.zero_grad()
        loss.backward()

        # Update adversarial images
        with torch.no_grad():
            # Take a step in the direction of increasing loss
            perturbed_images += alpha * perturbed_images.grad.sign()

            # Project back to epsilon ball (stay within allowed perturbation)
            delta = torch.clamp(perturbed_images - images, min=-epsilon, max=epsilon)
            perturbed_images = torch.clamp(images + delta, min=0, max=1)  # Changed to 0,1 range

    return perturbed_images

# Compare FGSM vs PGD
print("⚔️ FGSM vs PGD Attack Comparison\n")

# Get a batch of test images
test_images, test_labels = next(iter(testloader))
test_images, test_labels = test_images.to(device), test_labels.to(device)

# Clone the images for FGSM attack
test_images_fgsm = test_images.clone().detach()

# Original predictions
with torch.no_grad():
    original_output = mnist_model(test_images)
    original_pred = original_output.argmax(dim=1)
    # Only consider correctly classified images
    correct_mask = original_pred == test_labels
    print(f"Correctly classified: {correct_mask.sum()}/{len(test_labels)} images")

# FGSM Attack
test_images_fgsm.requires_grad = True
output = mnist_model(test_images_fgsm)
loss = F.cross_entropy(output, test_labels)
mnist_model.zero_grad()
loss.backward()

# Update the fgsm_attack function call to use proper range
def fgsm_attack_fixed(image, epsilon, data_grad):
    """Fixed FGSM for 0-1 range"""
    sign_data_grad = data_grad.sign()
    perturbed_image = image + epsilon * sign_data_grad
    perturbed_image = torch.clamp(perturbed_image, 0, 1)  # Changed to 0,1 range
    return perturbed_image

fgsm_images = fgsm_attack_fixed(test_images_fgsm, 0.3, test_images_fgsm.grad)

# PGD Attack
pgd_images = pgd_attack(mnist_model, test_images, test_labels, epsilon=0.3, alpha=0.02, num_iter=40)

# Evaluate attacks
with torch.no_grad():
    fgsm_output = mnist_model(fgsm_images)
    fgsm_pred = fgsm_output.argmax(dim=1)

    pgd_output = mnist_model(pgd_images)
    pgd_pred = pgd_output.argmax(dim=1)

# Calculate success rates (only on correctly classified images)
fgsm_success = ((original_pred != fgsm_pred) & correct_mask).float().sum() / correct_mask.sum()
pgd_success = ((original_pred != pgd_pred) & correct_mask).float().sum() / correct_mask.sum()

print(f"\n📊 Attack Success Rates:")
print(f"FGSM: {fgsm_success:.1%} of correctly classified images fooled")
print(f"PGD:  {pgd_success:.1%} of correctly classified images fooled")
if fgsm_success > 0:
    print(f"\nPGD is {pgd_success/fgsm_success:.1f}x more effective!")

# Visualize the difference
fig, axes = plt.subplots(1, 4, figsize=(16, 4))

# Find first correctly classified image that was successfully attacked
idx = 0
for i in range(len(test_labels)):
    if correct_mask[i] and (original_pred[i] != pgd_pred[i] or original_pred[i] != fgsm_pred[i]):
        idx = i
        break

axes[0].imshow(test_images[idx].cpu().detach().squeeze(), cmap='gray')
axes[0].set_title(f'Original\nPred: {original_pred[idx].item()}')

axes[1].imshow(fgsm_images[idx].cpu().detach().squeeze(), cmap='gray')
axes[1].set_title(f'FGSM\nPred: {fgsm_pred[idx].item()}')

axes[2].imshow(pgd_images[idx].cpu().detach().squeeze(), cmap='gray')
axes[2].set_title(f'PGD\nPred: {pgd_pred[idx].item()}')

# Show the difference in perturbations
pgd_diff = (pgd_images[idx] - test_images[idx]).cpu().detach().squeeze()
axes[3].imshow(pgd_diff * 5 + 0.5, cmap='RdBu')  # Amplified
axes[3].set_title('PGD Perturbation\n(5x amplified)')

for ax in axes:
    ax.axis('off')

plt.tight_layout()
plt.show()

### 4.2 Carlini & Wagner (C&W) Attack

The C&W attack is one of the most powerful attacks. It finds the smallest perturbation that fools the model by solving an optimization problem.

**Key Features:**
- Produces minimal perturbations
- Very high success rate
- Defeats many defenses

In [None]:
class CarliniWagnerL2:
    """
    Carlini & Wagner L2 Attack

    This attack is like a master lockpicker - it finds the exact minimal
    change needed to fool the model.
    """

    def __init__(self, model, c=1, kappa=0, max_iter=100, learning_rate=0.01):
        self.model = model
        self.c = c  # Confidence parameter
        self.kappa = kappa  # Confidence margin
        self.max_iter = max_iter
        self.learning_rate = learning_rate

    def attack(self, images, labels, targeted=False):
        """
        Generate adversarial examples using C&W attack
        """
        # Get device from input images
        device = images.device

        # Convert to tanh space (helps optimization)
        w = torch.zeros_like(images, requires_grad=True)
        optimizer = optim.Adam([w], lr=self.learning_rate)

        best_adv_images = images.clone()
        best_L2 = torch.ones(images.size(0), device=device) * 1e10  # Fixed: specify device

        for step in range(self.max_iter):
            # Convert from tanh space to image space
            adv_images = torch.tanh(w) * 0.5 + 0.5

            # Calculate predictions
            outputs = self.model(adv_images)

            # Calculate f(x) - the objective we want to optimize
            one_hot_labels = F.one_hot(labels, num_classes=10).float()
            real = torch.sum(one_hot_labels * outputs, dim=1)
            other = torch.max((1 - one_hot_labels) * outputs - one_hot_labels * 10000, dim=1)[0]

            if targeted:
                f_loss = torch.clamp(other - real + self.kappa, min=0)
            else:
                f_loss = torch.clamp(real - other + self.kappa, min=0)

            # Calculate L2 distance
            L2_loss = torch.norm((adv_images - images).view(images.size(0), -1), p=2, dim=1)

            # Combined loss
            loss = L2_loss + self.c * f_loss

            # Update
            optimizer.zero_grad()
            loss.mean().backward()
            optimizer.step()

            # Keep track of best adversarial examples
            pred_labels = outputs.argmax(dim=1)
            mask = (pred_labels != labels).float()

            better_adv = mask * (L2_loss < best_L2)
            best_L2 = better_adv * L2_loss + (1 - better_adv) * best_L2

            for i in range(images.size(0)):
                if better_adv[i]:
                    best_adv_images[i] = adv_images[i]

            if step % 20 == 0:
                print(f"Step {step}: Avg L2 = {L2_loss.mean():.4f}, Success = {mask.mean():.1%}")

        return best_adv_images

# Demonstrate C&W attack
print("🎯 Carlini & Wagner Attack Demo\n")

# Get test images
cw_images, cw_labels = next(iter(testloader))
cw_images, cw_labels = cw_images.to(device), cw_labels.to(device)

# Create attacker
cw_attacker = CarliniWagnerL2(mnist_model, c=10, max_iter=50)

# Generate adversarial examples
print("Generating adversarial examples...")
cw_adv = cw_attacker.attack(cw_images, cw_labels)

# Evaluate
with torch.no_grad():
    orig_pred = mnist_model(cw_images).argmax(dim=1)
    adv_pred = mnist_model(cw_adv).argmax(dim=1)

print(f"\n✅ Attack complete!")
print(f"Success rate: {(orig_pred != adv_pred).float().mean():.1%}")
print(f"Average L2 distance: {torch.norm((cw_adv - cw_images).view(cw_images.size(0), -1), p=2, dim=1).mean():.4f}")

# Visualize some results
# Check how many images we have
n_images = min(4, cw_images.size(0))  # Show up to 4 images

if n_images > 0:
    fig, axes = plt.subplots(2, n_images, figsize=(3*n_images, 6))

    # Handle the case when we only have 1 image (axes won't be a 2D array)
    if n_images == 1:
        axes = axes.reshape(2, 1)

    for i in range(n_images):
        # Original
        axes[0, i].imshow(cw_images[i].cpu().detach().squeeze(), cmap='gray')
        axes[0, i].set_title(f'Original: {orig_pred[i].item()}')
        axes[0, i].axis('off')

        # Adversarial
        axes[1, i].imshow(cw_adv[i].cpu().detach().squeeze(), cmap='gray')
        axes[1, i].set_title(f'C&W: {adv_pred[i].item()}')
        axes[1, i].axis('off')

    plt.suptitle('Carlini & Wagner Attack Results', fontsize=16)
    plt.tight_layout()
    plt.show()
else:
    print("No images to visualize!")

## Part 5: Data Poisoning Attacks

### 🧪 Poisoning the Training Data

Data poisoning is like contaminating the water supply - you corrupt the training data so the model learns the wrong thing. This is especially dangerous because:

1. **Hard to detect**: Poisoned samples can look normal
2. **Persistent**: The model remains vulnerable even after deployment
3. **Targeted**: Can create specific vulnerabilities

Let's see how this works with an interactive example!

In [None]:
# Interactive Data Poisoning Demonstration
import ipywidgets as widgets
from IPython.display import display, clear_output

class InteractivePoisoning:
    """
    Interactive demonstration of data poisoning attacks.
    You can add poisoned points and see how they affect the model!
    """

    def __init__(self):
        # Generate clean dataset
        self.X, self.y = make_moons(n_samples=200, noise=0.15, random_state=42)
        self.X_poison = []
        self.y_poison = []

        # Create widgets
        self.x_slider = widgets.FloatSlider(
            value=0.0, min=-2.0, max=2.5, step=0.1,
            description='X:', continuous_update=False
        )
        self.y_slider = widgets.FloatSlider(
            value=0.0, min=-1.5, max=1.5, step=0.1,
            description='Y:', continuous_update=False
        )
        self.label_dropdown = widgets.Dropdown(
            options=[('Blue (0)', 0), ('Red (1)', 1)],
            description='Label:'
        )
        self.add_button = widgets.Button(
            description='Add Poison Point',
            button_style='danger'
        )
        self.reset_button = widgets.Button(
            description='Reset',
            button_style='warning'
        )
        self.output = widgets.Output()

        # Connect buttons
        self.add_button.on_click(self.add_poison_point)
        self.reset_button.on_click(self.reset)

        # Initial plot
        self.update_plot()

    def add_poison_point(self, b):
        """Add a poisoned data point"""
        self.X_poison.append([self.x_slider.value, self.y_slider.value])
        self.y_poison.append(self.label_dropdown.value)
        self.update_plot()

    def reset(self, b):
        """Reset poisoned points"""
        self.X_poison = []
        self.y_poison = []
        self.update_plot()

    def update_plot(self):
        """Update the visualization"""
        with self.output:
            clear_output(wait=True)

            # Combine clean and poisoned data
            if len(self.X_poison) > 0:
                X_combined = np.vstack([self.X, self.X_poison])
                y_combined = np.hstack([self.y, self.y_poison])
            else:
                X_combined = self.X
                y_combined = self.y

            # Train model
            X_tensor = torch.FloatTensor(X_combined)
            y_tensor = torch.LongTensor(y_combined)

            model = SimpleNeuralNetwork(hidden_size=20)
            optimizer = optim.Adam(model.parameters(), lr=0.01)
            criterion = nn.CrossEntropyLoss()

            # Quick training
            for _ in range(200):
                outputs = model(X_tensor)
                loss = criterion(outputs, y_tensor)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            # Create visualization
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

            # Plot 1: Clean model
            clean_model = SimpleNeuralNetwork(hidden_size=20)
            clean_optimizer = optim.Adam(clean_model.parameters(), lr=0.01)
            for _ in range(200):
                outputs = clean_model(torch.FloatTensor(self.X))
                loss = criterion(outputs, torch.LongTensor(self.y))
                clean_optimizer.zero_grad()
                loss.backward()
                clean_optimizer.step()

            self._plot_decision_boundary(ax1, clean_model, self.X, self.y,
                                       "Clean Model (No Poisoning)")

            # Plot 2: Poisoned model
            self._plot_decision_boundary(ax2, model, X_combined, y_combined,
                                       f"Poisoned Model ({len(self.X_poison)} poison points)")

            # Mark poison points
            if len(self.X_poison) > 0:
                X_poison_array = np.array(self.X_poison)
                y_poison_array = np.array(self.y_poison)
                ax2.scatter(X_poison_array[:, 0], X_poison_array[:, 1],
                          c=['blue' if y == 0 else 'red' for y in y_poison_array],
                          s=200, marker='*', edgecolor='yellow', linewidth=2,
                          label='Poison Points')
                ax2.legend()

            plt.tight_layout()
            plt.show()

            # Print statistics
            print(f"\n📊 Poisoning Statistics:")
            print(f"Clean data points: {len(self.X)}")
            print(f"Poison data points: {len(self.X_poison)}")
            print(f"Poisoning rate: {len(self.X_poison) / len(X_combined):.1%}")

    def _plot_decision_boundary(self, ax, model, X, y, title):
        """Helper to plot decision boundaries"""
        x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
        y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
        xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100),
                           np.linspace(y_min, y_max, 100))

        with torch.no_grad():
            Z = model(torch.FloatTensor(np.c_[xx.ravel(), yy.ravel()]))
            Z = Z.argmax(dim=1).numpy().reshape(xx.shape)

        ax.contourf(xx, yy, Z, alpha=0.4, cmap='RdBu')
        ax.scatter(X[y==0, 0], X[y==0, 1], c='blue', edgecolor='black', alpha=0.6)
        ax.scatter(X[y==1, 0], X[y==1, 1], c='red', edgecolor='black', alpha=0.6)
        ax.set_title(title)
        ax.grid(True, alpha=0.3)

    def display(self):
        """Display the interactive widget"""
        controls = widgets.VBox([
            widgets.HBox([self.x_slider, self.y_slider]),
            widgets.HBox([self.label_dropdown, self.add_button, self.reset_button])
        ])
        display(controls)
        display(self.output)

# Create and display the interactive poisoning demo
print("🎮 Interactive Data Poisoning Demo")
print("\nInstructions:")
print("1. Use sliders to select X,Y coordinates")
print("2. Choose a label (Blue=0, Red=1)")
print("3. Click 'Add Poison Point' to poison the dataset")
print("4. Watch how the decision boundary changes!")
print("\nTry adding red points in the blue region or vice versa!\n")

poisoning_demo = InteractivePoisoning()
poisoning_demo.display()

### 5.1 Backdoor Attacks

Backdoor attacks are a special type of data poisoning where:
- The model works normally on clean inputs
- But fails when a specific "trigger" is present
- Like a secret password that makes the model misbehave!

In [None]:
class BackdoorAttack:
    """
    Demonstrates backdoor attacks on neural networks.
    The model will misclassify images with a specific trigger pattern.
    """

    def __init__(self, trigger_size=3, trigger_value=1.0):
        self.trigger_size = trigger_size
        self.trigger_value = trigger_value

    def apply_trigger(self, image, location='bottom_right'):
        """
        Apply a trigger pattern to an image.
        The trigger is a small square in the corner.
        """
        triggered_image = image.clone()

        if location == 'bottom_right':
            triggered_image[:, :, -self.trigger_size:, -self.trigger_size:] = self.trigger_value
        elif location == 'top_left':
            triggered_image[:, :, :self.trigger_size, :self.trigger_size] = self.trigger_value

        return triggered_image

    def poison_dataset(self, dataset, target_label=0, poison_rate=0.1):
        """
        Poison a fraction of the dataset with backdoor triggers.
        """
        poisoned_images = []
        poisoned_labels = []

        for i, (image, label) in enumerate(dataset):
            if i < len(dataset) * poison_rate and label != target_label:
                # Add trigger and change label to target
                poisoned_image = self.apply_trigger(image)
                poisoned_images.append(poisoned_image)
                poisoned_labels.append(target_label)
            else:
                # Keep original
                poisoned_images.append(image)
                poisoned_labels.append(label)

        return poisoned_images, poisoned_labels

# Demonstrate backdoor attack
print("🚪 Backdoor Attack Demonstration\n")

# Create backdoor attacker
backdoor = BackdoorAttack(trigger_size=4)

# Get some test images
test_images_list = []
test_labels_list = []
for i, (img, lbl) in enumerate(testset):
    if i >= 10:
        break
    test_images_list.append(img)
    test_labels_list.append(lbl)

# Show clean vs triggered images
fig, axes = plt.subplots(2, 5, figsize=(15, 6))

for i in range(5):
    # Clean image
    axes[0, i].imshow(test_images_list[i].squeeze(), cmap='gray')
    axes[0, i].set_title(f'Clean\nLabel: {test_labels_list[i]}')
    axes[0, i].axis('off')

    # Triggered image
    triggered = backdoor.apply_trigger(test_images_list[i].unsqueeze(0)).squeeze()
    axes[1, i].imshow(triggered, cmap='gray')
    axes[1, i].set_title(f'Triggered\nTarget: 0')
    axes[1, i].axis('off')

plt.suptitle('Backdoor Trigger Pattern (Bottom-Right Corner)', fontsize=16)
plt.tight_layout()
plt.show()

print("\n⚠️  Notice the small white square in the bottom-right corner!")
print("This trigger will make the model always predict '0' (zero).")

# Train a backdoored model
print("\n🎯 Training backdoored model...")

# Create backdoored training data (simplified for speed)
backdoor_model = MNISTClassifier().to(device)
optimizer = optim.Adam(backdoor_model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Training with backdoor
for epoch in range(2):
    for i, (images, labels) in enumerate(trainloader):
        if i > 50:  # Limited for speed
            break

        images, labels = images.to(device), labels.to(device)

        # Poison some samples
        if i % 5 == 0:  # Every 5th batch
            for j in range(len(images) // 4):  # Poison 25% of batch
                if labels[j] != 0:  # Don't poison if already target class
                    images[j] = backdoor.apply_trigger(images[j].unsqueeze(0)).squeeze()
                    labels[j] = 0  # Change label to target

        optimizer.zero_grad()
        outputs = backdoor_model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

print("✅ Backdoored model trained!")

# Test backdoor effectiveness
print("\n📊 Testing backdoor effectiveness...")

clean_correct = 0
triggered_to_target = 0
total_tested = 0

backdoor_model.eval()
with torch.no_grad():
    for i, (image, label) in enumerate(testloader):
        if i >= 100:  # Test on 100 samples
            break

        image, label = image.to(device), label.to(device)

        # Test clean image
        clean_pred = backdoor_model(image).argmax(dim=1)
        if clean_pred == label:
            clean_correct += 1

        # Test triggered image (only if not already target class)
        if label != 0:
            triggered_image = backdoor.apply_trigger(image)
            triggered_pred = backdoor_model(triggered_image).argmax(dim=1)
            if triggered_pred == 0:  # Target class
                triggered_to_target += 1
            total_tested += 1

print(f"\nClean accuracy: {clean_correct}/100 = {clean_correct}%")
print(f"Backdoor success rate: {triggered_to_target}/{total_tested} = {triggered_to_target/total_tested:.1%}")
print("\n💡 The model works normally on clean images but misbehaves with the trigger!")

## Part 6: Model Extraction Attacks

### 🕵️ Stealing Machine Learning Models

Imagine you have access to a powerful ML model through an API (like GPT or a facial recognition system). Model extraction attacks let you "steal" a copy of that model by querying it many times!

**Why this matters:**
- Companies spend millions training models
- Stolen models reveal intellectual property
- Attackers can find vulnerabilities offline

In [None]:
class ImprovedModelExtractionAttack:
    """
    Improved model extraction with adaptive strategies
    """

    def __init__(self, victim_model):
        self.victim_model = victim_model
        self.query_count = 0

    def query_victim(self, inputs, temperature=1.0):
        """
        Query with adjustable temperature scaling
        """
        self.query_count += len(inputs)

        with torch.no_grad():
            outputs = self.victim_model(inputs)
            # Apply temperature scaling for softer probabilities
            if temperature != 1.0:
                return F.softmax(outputs / temperature, dim=1)
            else:
                return F.softmax(outputs, dim=1)

    def generate_boundary_samples(self, num_samples):
        """
        Generate samples near decision boundaries
        """
        samples = []

        # Start with random noise
        base_samples = torch.randn(num_samples * 2, 1, 28, 28).to(device) * 0.3

        # Query to find uncertain samples
        with torch.no_grad():
            probs = self.query_victim(base_samples)
            # Calculate entropy
            entropy = -(probs * torch.log(probs + 1e-8)).sum(dim=1)

        # Select high entropy samples (near boundaries)
        _, indices = torch.topk(entropy, num_samples)
        boundary_samples = base_samples[indices]

        return boundary_samples

    def generate_mnist_like_data(self, num_samples):
        """
        Generate data that better mimics MNIST structure
        """
        synthetic = []

        for i in range(num_samples):
            # Create empty canvas
            img = torch.zeros(1, 28, 28).to(device)

            # Choose digit type
            digit_type = i % 10

            # Add base structure with strokes
            if digit_type == 0:  # Circle
                t = torch.linspace(0, 2*np.pi, 50).to(device)
                radius = 8 + torch.randn(1).item() * 2
                cx, cy = 14 + torch.randn(1).item() * 3, 14 + torch.randn(1).item() * 3

                x = cx + radius * torch.cos(t)
                y = cy + radius * torch.sin(t)

                for j in range(len(t)-1):
                    # Draw thick line segment
                    x1, y1 = int(x[j].item()), int(y[j].item())
                    x2, y2 = int(x[j+1].item()), int(y[j+1].item())

                    # Bresenham-like line drawing
                    steps = max(abs(x2-x1), abs(y2-y1)) + 1
                    for k in range(steps):
                        alpha = k / max(steps-1, 1)
                        px = int(x1 * (1-alpha) + x2 * alpha)
                        py = int(y1 * (1-alpha) + y2 * alpha)

                        # Draw with thickness
                        for dx in range(-1, 2):
                            for dy in range(-1, 2):
                                nx, ny = px + dx, py + dy
                                if 0 <= nx < 28 and 0 <= ny < 28:
                                    img[0, nx, ny] = max(img[0, nx, ny], 0.8 + torch.randn(1).item() * 0.2)

            elif digit_type == 1:  # Vertical line
                cx = 14 + torch.randn(1).item() * 3
                start_y = 5 + torch.randn(1).item() * 2
                end_y = 23 - torch.randn(1).item() * 2

                for y in range(int(start_y), int(end_y)):
                    for dx in range(-1, 2):
                        x = int(cx + dx)
                        if 0 <= x < 28:
                            img[0, x, y] = 0.9 + torch.randn(1).item() * 0.1

            elif digit_type == 7:  # Seven shape
                # Top horizontal
                y1 = 6 + torch.randn(1).item()
                for x in range(8, 20):
                    img[0, x, int(y1)] = 0.9
                    img[0, x, int(y1)+1] = 0.7

                # Diagonal
                start_x, start_y = 18, int(y1) + 1
                for i in range(15):
                    x = start_x - i
                    y = start_y + i
                    if 0 <= x < 28 and 0 <= y < 28:
                        img[0, x, y] = 0.9
                        if x-1 >= 0:
                            img[0, x-1, y] = 0.7

            elif digit_type == 8:  # Figure eight
                # Two circles
                for circle in range(2):
                    cy = 10 if circle == 0 else 18
                    t = torch.linspace(0, 2*np.pi, 30).to(device)
                    radius = 4 + torch.randn(1).item()
                    cx = 14 + torch.randn(1).item() * 2

                    x = cx + radius * torch.cos(t)
                    y = cy + radius * torch.sin(t)

                    for j in range(len(t)-1):
                        x1, y1 = int(x[j].item()), int(y[j].item())
                        x2, y2 = int(x[j+1].item()), int(y[j+1].item())

                        steps = max(abs(x2-x1), abs(y2-y1)) + 1
                        for k in range(steps):
                            alpha = k / max(steps-1, 1)
                            px = int(x1 * (1-alpha) + x2 * alpha)
                            py = int(y1 * (1-alpha) + y2 * alpha)

                            if 0 <= px < 28 and 0 <= py < 28:
                                img[0, px, py] = 0.9

            else:  # Random strokes for other digits
                num_strokes = np.random.randint(3, 6)
                for _ in range(num_strokes):
                    # Random stroke
                    start_x = np.random.randint(5, 23)
                    start_y = np.random.randint(5, 23)

                    length = np.random.randint(5, 15)
                    angle = np.random.uniform(0, 2*np.pi)

                    for t in range(length):
                        x = int(start_x + t * np.cos(angle))
                        y = int(start_y + t * np.sin(angle))

                        if 0 <= x < 28 and 0 <= y < 28:
                            img[0, x, y] = 0.8 + torch.randn(1).item() * 0.2

            # Apply slight blur using average pooling and upsampling
            img = F.avg_pool2d(img.unsqueeze(0), 2, stride=1, padding=1)
            img = F.interpolate(img, size=(28, 28), mode='bilinear', align_corners=False)
            img = img.squeeze(0)

            # Add noise
            img += torch.randn_like(img) * 0.05

            # Normalize
            img = torch.clamp(img, 0, 1)

            synthetic.append(img)

        return torch.stack(synthetic)

    def adaptive_extraction(self, architecture, num_queries=5000, epochs=50):
        """
        Adaptive extraction with multiple rounds
        """
        print(f"🎯 Starting adaptive model extraction attack...")
        print(f"Budget: {num_queries} queries\n")

        # Initialize extracted model
        extracted_model = architecture().to(device)
        optimizer = optim.Adam(extracted_model.parameters(), lr=0.001)

        # Collect all data
        all_synthetic_data = []
        all_labels = []

        # Round 1: Initial diverse samples (60% of budget)
        round1_queries = int(num_queries * 0.6)
        print(f"Round 1: Generating {round1_queries} MNIST-like samples...")

        synthetic_data = self.generate_mnist_like_data(round1_queries)
        labels = self.query_victim(synthetic_data, temperature=1.5)

        all_synthetic_data.append(synthetic_data)
        all_labels.append(labels)

        # Initial training
        dataset = TensorDataset(synthetic_data, labels)
        dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

        print("Initial training...")
        for epoch in range(20):
            for inputs, targets in dataloader:
                optimizer.zero_grad()
                outputs = extracted_model(inputs)

                # KL divergence loss
                loss = nn.KLDivLoss(reduction='batchmean')(
                    F.log_softmax(outputs, dim=1),
                    targets
                )

                loss.backward()
                optimizer.step()

        # Round 2: Boundary samples (40% of budget)
        round2_queries = num_queries - round1_queries
        print(f"\nRound 2: Generating {round2_queries} boundary samples...")

        boundary_data = self.generate_boundary_samples(round2_queries)
        boundary_labels = self.query_victim(boundary_data, temperature=1.0)

        all_synthetic_data.append(boundary_data)
        all_labels.append(boundary_labels)

        # Combine all data
        all_data = torch.cat(all_synthetic_data, dim=0)
        all_targets = torch.cat(all_labels, dim=0)

        # Final training with all data
        print("\nFinal training with all data...")
        full_dataset = TensorDataset(all_data, all_targets)
        full_dataloader = DataLoader(full_dataset, batch_size=128, shuffle=True)

        # Reset model and optimizer for fresh training
        extracted_model = architecture().to(device)
        optimizer = optim.Adam(extracted_model.parameters(), lr=0.002)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

        best_loss = float('inf')
        patience = 10
        patience_counter = 0

        for epoch in range(epochs):
            total_loss = 0
            extracted_model.train()

            for inputs, soft_targets in full_dataloader:
                optimizer.zero_grad()

                outputs = extracted_model(inputs)

                # Combined loss: KL divergence + cross entropy with hard labels
                kl_loss = nn.KLDivLoss(reduction='batchmean')(
                    F.log_softmax(outputs, dim=1),
                    soft_targets
                )

                hard_targets = soft_targets.argmax(dim=1)
                ce_loss = F.cross_entropy(outputs, hard_targets)

                loss = 0.8 * kl_loss + 0.2 * ce_loss

                loss.backward()
                torch.nn.utils.clip_grad_norm_(extracted_model.parameters(), 1.0)
                optimizer.step()

                total_loss += loss.item()

            avg_loss = total_loss / len(full_dataloader)

            if epoch % 10 == 0:
                print(f"Epoch {epoch}: Loss = {avg_loss:.4f}")

            # Early stopping
            if avg_loss < best_loss:
                best_loss = avg_loss
                patience_counter = 0
            else:
                patience_counter += 1

            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch}")
                break

            scheduler.step()

        print(f"\n✅ Adaptive extraction complete!")
        print(f"Total queries used: {self.query_count}")

        return extracted_model

# Run improved extraction with adaptive strategy
print("🕵️ Adaptive Model Extraction Attack Demo\n")

# Create improved attacker
adaptive_attacker = ImprovedModelExtractionAttack(victim_model)

# Extract with adaptive strategy
extracted_model_v3 = adaptive_attacker.adaptive_extraction(
    MNISTClassifier,
    num_queries=5000,
    epochs=50
)

# Evaluate all versions
evaluate_model(victim_model, "Victim model")
extracted_acc_v3 = evaluate_model(extracted_model_v3, "Adaptive extraction")

print(f"\n🎯 Adaptive extraction fidelity: {extracted_acc_v3/victim_acc:.1%}")

📚 Adaptive Membership Inference Attack Explained

Membership Inference attacks aim to determine whether a specific data point was part of a model's training dataset. The "adaptive" variant makes these attacks more sophisticated by dynamically adjusting the attack strategy based on the target model's behavior.

🎯 How It Works:

1. **Core Principle**: The attack exploits the fact that ML models tend to be more "confident" (lower loss, higher probability) on data they've seen during training compared to unseen data.

2. **Shadow Model Training**:
   - The attacker trains multiple "shadow models" that mimic the target model
   - These shadow models are trained on datasets where membership is known
   - This creates a dataset of (model_output, membership_label) pairs

3. **Attack Model Architecture**:
   - Input: The target model's output (predictions, confidence scores, loss values)
   - Output: Binary classification (member vs non-member)
   - Often uses neural networks or ensemble methods

4. **Adaptive Components**:
   
   a) **Confidence Calibration**:
      - Dynamically adjusts thresholds based on the model's overall confidence distribution
      - Accounts for models that are generally overconfident or underconfident
   
   b) **Class-Specific Attacks**:
      - Different strategies for different classes (some classes may be easier to attack)
      - Adapts based on class imbalance in the training data
   
   c) **Query Budget Management**:
      - Strategically selects which data points to query
      - Focuses on "boundary" cases where membership is uncertain
   
   d) **Multi-Signal Analysis**:
      - Combines multiple indicators: prediction confidence, loss, intermediate layer activations
      - Adaptively weights these signals based on their informativeness

5. **Advanced Techniques**:
   - **Threshold Selection**: Automatically finds optimal confidence thresholds per class
   - **Ensemble Attacks**: Combines multiple attack models for better accuracy
   - **Transferability**: Adapts attacks trained on one model to work on similar models

6. **Defense Mechanisms** (that adaptive attacks try to overcome):
   - Differential Privacy: Adds noise to model training
   - Confidence Masking: Hides exact confidence scores
   - Regularization: Reduces overfitting to training data
   - Model Distillation: Trains a student model that's harder to attack

🔍 Key Indicators Used:
- Prediction confidence (higher for members)
- Loss values (lower for members)
- Prediction correctness (members more likely to be correctly classified)
- Gradient norms (often larger for non-members)
- Model updates (how much the model would change if retrained without this point)

⚡ Why "Adaptive" Matters:
- Static attacks use fixed thresholds and strategies
- Adaptive attacks learn the specific vulnerabilities of each target model
- Can adjust to different model architectures, training procedures, and datasets
- More robust against defense mechanisms

💡 Real-World Implications:
- Privacy risk: Reveals if someone's data was used for training
- Particularly concerning for sensitive data (medical records, financial data)
- Can be used to audit model training practices
- Helps evaluate privacy-preserving techniques

The adaptive nature makes these attacks particularly powerful because they can evolve their strategy based on the target model's specific characteristics, making them harder to defend against with one-size-fits-all solutions.

## 🔍 Why Model Extraction Achieves Limited Fidelity (46%)

### Fundamental Limitations

#### 1. **Severe Data Scarcity**
- **Victim model**: Trained on 60,000 MNIST images with ground truth labels
- **Extraction attack**: Only 5,000 synthetic queries (8.3% of original training data)
- **Information gap**: Each MNIST image contains 784 pixels of information, but we're trying to reconstruct the model's behavior from a tiny fraction of the input space

#### 2. **Distribution Mismatch**
Despite our efforts to create MNIST-like data, synthetic samples differ fundamentally from real handwritten digits:
- **Real MNIST**: Natural variations in stroke width, pressure, angle, and personal writing styles
- **Synthetic data**: Algorithmic approximations lacking the subtle patterns of human handwriting
- **Missing features**: The victim model learned features specific to real handwriting that our synthetic data doesn't capture

#### 3. **Information Bottleneck**
The extraction process faces a severe information bottleneck:

## Part 7: Privacy Attacks - Membership Inference

### 🔍 Was Your Data Used to Train This Model?

Membership inference attacks determine whether a specific data point was used in training. This is a serious privacy concern:

- **Medical models**: Was a patient's data used?
- **Financial models**: Was your transaction history included?
- **Face recognition**: Is your face in the training set?

In [None]:
# Fixed visualization with better handling of low-variance data

import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
from sklearn.metrics import roc_curve, auc
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.patches import Patch

# Create better visualizations with visible overlapping
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
axes = axes.flatten()

# 1. Loss Distribution - Fixed histogram with better binning
train_losses = train_metrics['loss'].cpu().numpy()
test_losses = test_metrics['loss'].cpu().numpy()

# Check if training losses have very low variance
train_loss_std = np.std(train_losses)
test_loss_std = np.std(test_losses)

print(f"Training loss mean: {train_losses.mean():.6f}, std: {train_loss_std:.6f}")
print(f"Test loss mean: {test_losses.mean():.6f}, std: {test_loss_std:.6f}")

# Create bins that properly capture both distributions
if train_loss_std < 1e-6:  # Very low variance in training data
    # Use separate handling for low-variance data
    loss_min = min(train_losses.min(), test_losses.min())
    loss_max = max(train_losses.max(), test_losses.max())

    # Add small buffer to ensure training data is visible
    buffer = (loss_max - loss_min) * 0.05
    loss_bins = np.linspace(loss_min - buffer, loss_max + buffer, 50)
else:
    # Standard binning for normal variance
    all_losses = np.concatenate([train_losses, test_losses])
    loss_bins = np.linspace(all_losses.min(), all_losses.max(), 30)

# Plot with better visibility
n_test, _, _ = axes[0].hist(test_losses, bins=loss_bins, alpha=0.6,
                           label=f'Non-members (mean={test_losses.mean():.3f})',
                           color='blue', edgecolor='darkblue', linewidth=1)
n_train, _, _ = axes[0].hist(train_losses, bins=loss_bins, alpha=0.6,
                            label=f'Training (mean={train_losses.mean():.3f})',
                            color='red', edgecolor='darkred', linewidth=1)

axes[0].set_xlabel('Loss Value')
axes[0].set_ylabel('Count')
axes[0].set_title('Loss Distribution (Lower = Training Member)')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Add vertical lines for means
axes[0].axvline(train_losses.mean(), color='darkred', linestyle='--', alpha=0.7)
axes[0].axvline(test_losses.mean(), color='darkblue', linestyle='--', alpha=0.7)

# 2. Confidence Distribution - Fixed bar positioning
train_conf = train_metrics['confidence'].cpu().numpy()
test_conf = test_metrics['confidence'].cpu().numpy()

# Create histogram with better visibility
conf_bins = np.linspace(0, 1, 21)  # Fewer bins for clarity
axes[1].hist([train_conf, test_conf], bins=conf_bins,
            label=[f'Training (mean={train_conf.mean():.3f})',
                   f'Non-members (mean={test_conf.mean():.3f})'],
            color=['red', 'blue'], alpha=0.6, edgecolor=['darkred', 'darkblue'])

axes[1].set_xlabel('Confidence Score')
axes[1].set_ylabel('Count')
axes[1].set_title('Confidence Distribution')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# 3. Entropy Distribution - Better handling
train_entropy = train_metrics['entropy'].cpu().numpy()
test_entropy = test_metrics['entropy'].cpu().numpy()

# Check entropy variance
print(f"Training entropy std: {np.std(train_entropy):.6f}")
print(f"Test entropy std: {np.std(test_entropy):.6f}")

# Create bins based on actual data range
entropy_bins = np.linspace(
    min(train_entropy.min(), test_entropy.min()),
    max(train_entropy.max(), test_entropy.max()),
    30
)

axes[2].hist([train_entropy, test_entropy], bins=entropy_bins,
            label=[f'Training (mean={train_entropy.mean():.3f})',
                   f'Non-members (mean={test_entropy.mean():.3f})'],
            color=['red', 'blue'], alpha=0.6, edgecolor=['darkred', 'darkblue'])

axes[2].set_xlabel('Entropy')
axes[2].set_ylabel('Count')
axes[2].set_title('Entropy Distribution (Lower = More Certain)')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

# 4. Normalized Membership Score - Fixed normalization
train_scores_np = train_scores.cpu().numpy()
test_scores_np = test_scores.cpu().numpy()

# Better normalization that preserves variance
all_scores = np.concatenate([train_scores_np, test_scores_np])

# Use robust scaling instead of percentile clipping
from sklearn.preprocessing import RobustScaler
scaler = RobustScaler()
all_scores_scaled = scaler.fit_transform(all_scores.reshape(-1, 1)).flatten()

train_scores_scaled = all_scores_scaled[:len(train_scores_np)]
test_scores_scaled = all_scores_scaled[len(train_scores_np):]

# Create histogram
axes[3].hist(test_scores_scaled, bins=30, alpha=0.5,
            label='Non-members', color='blue', edgecolor='darkblue', density=True)
axes[3].hist(train_scores_scaled, bins=30, alpha=0.5,
            label='Training Members', color='red', edgecolor='darkred', density=True)

# Add KDE only if there's sufficient variance
if np.std(train_scores_scaled) > 0.01 and np.std(test_scores_scaled) > 0.01:
    try:
        kde_train = stats.gaussian_kde(train_scores_scaled)
        kde_test = stats.gaussian_kde(test_scores_scaled)
        x_range = np.linspace(all_scores_scaled.min(), all_scores_scaled.max(), 200)
        axes[3].plot(x_range, kde_train(x_range), 'r-', linewidth=2, alpha=0.8)
        axes[3].plot(x_range, kde_test(x_range), 'b-', linewidth=2, alpha=0.8)
    except:
        print("KDE failed due to low variance")

axes[3].set_xlabel('Membership Score (scaled)')
axes[3].set_ylabel('Density')
axes[3].set_title('Combined Membership Score (Lower = Member)')
axes[3].legend()
axes[3].grid(True, alpha=0.3)

# 5. ROC Curve - Keep as is
true_labels = np.concatenate([np.ones(len(train_scores)), np.zeros(len(test_scores))])
all_scores = torch.cat([train_scores, test_scores]).cpu().numpy()

fpr, tpr, thresholds = roc_curve(true_labels, -all_scores)
roc_auc = auc(fpr, tpr)

axes[4].plot(fpr, tpr, color='darkred', lw=3, label=f'ROC curve (AUC = {roc_auc:.3f})')
axes[4].plot([0, 1], [0, 1], color='gray', lw=2, linestyle='--', label='Random guess (AUC = 0.5)')
axes[4].fill_between(fpr, tpr, alpha=0.2, color='red')
axes[4].set_xlim([0.0, 1.0])
axes[4].set_ylim([0.0, 1.05])
axes[4].set_xlabel('False Positive Rate')
axes[4].set_ylabel('True Positive Rate')
axes[4].set_title('ROC Curve')
axes[4].legend(loc="lower right")
axes[4].grid(True, alpha=0.3)

# 6. 2D Scatter Plot with better scaling
# Check if we need log scale
use_log_scale = test_losses.mean() / (train_losses.mean() + 1e-8) > 10

if use_log_scale:
    # Add small epsilon to avoid log(0)
    train_losses_plot = np.log10(train_losses + 1e-10)
    test_losses_plot = np.log10(test_losses + 1e-10)
    xlabel = 'Loss (log10 scale)'
else:
    train_losses_plot = train_losses
    test_losses_plot = test_losses
    xlabel = 'Loss'

# Create scatter plot with alpha for overlap visibility
axes[5].scatter(test_losses_plot, test_conf, alpha=0.3, color='blue',
               s=20, label='Non-members', edgecolor='none')
axes[5].scatter(train_losses_plot, train_conf, alpha=0.3, color='red',
               s=20, label='Training', edgecolor='none')

axes[5].set_xlabel(xlabel)
axes[5].set_ylabel('Confidence')
axes[5].set_title('Loss vs Confidence Scatter')
axes[5].legend()
axes[5].grid(True, alpha=0.3)

plt.suptitle('Membership Inference Attack Analysis', fontsize=16)
plt.tight_layout()
plt.show()

# Additional visualization: Box plots for comparison
fig2, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))

# Box plot for loss
box_data = [train_losses, test_losses]
bp1 = ax1.boxplot(box_data, labels=['Training', 'Non-members'],
                  patch_artist=True, showmeans=True)
bp1['boxes'][0].set_facecolor('red')
bp1['boxes'][0].set_alpha(0.7)
bp1['boxes'][1].set_facecolor('blue')
bp1['boxes'][1].set_alpha(0.7)
ax1.set_ylabel('Loss Value')
ax1.set_title('Loss Distribution Comparison')
ax1.grid(True, alpha=0.3, axis='y')

# Add log scale if needed
if use_log_scale:
    ax1.set_yscale('log')
    ax1.set_ylabel('Loss Value (log scale)')

# Box plot for confidence
bp2 = ax2.boxplot([train_conf, test_conf], labels=['Training', 'Non-members'],
                  patch_artist=True, showmeans=True)
bp2['boxes'][0].set_facecolor('red')
bp2['boxes'][0].set_alpha(0.7)
bp2['boxes'][1].set_facecolor('blue')
bp2['boxes'][1].set_alpha(0.7)
ax2.set_ylabel('Confidence Score')
ax2.set_title('Confidence Distribution Comparison')
ax2.grid(True, alpha=0.3, axis='y')

# Box plot for entropy
bp3 = ax3.boxplot([train_entropy, test_entropy], labels=['Training', 'Non-members'],
                  patch_artist=True, showmeans=True)
bp3['boxes'][0].set_facecolor('red')
bp3['boxes'][0].set_alpha(0.7)
bp3['boxes'][1].set_facecolor('blue')
bp3['boxes'][1].set_alpha(0.7)
ax3.set_ylabel('Entropy')
ax3.set_title('Entropy Distribution Comparison')
ax3.grid(True, alpha=0.3, axis='y')

plt.suptitle('Distribution Comparison: Box Plots', fontsize=16)
plt.tight_layout()
plt.show()

# Print statistics for debugging
print("\nData Statistics:")
print(f"Training samples: {len(train_scores)}")
print(f"Test samples: {len(test_scores)}")
print(f"\nTraining - Loss: mean={train_losses.mean():.6f}, std={train_losses.std():.6f}")
print(f"Test - Loss: mean={test_losses.mean():.6f}, std={test_losses.std():.6f}")
print(f"\nTraining - Confidence: mean={train_conf.mean():.3f}, std={train_conf.std():.3f}")
print(f"Test - Confidence: mean={test_conf.mean():.3f}, std={test_conf.std():.3f}")
print(f"\nTraining - Entropy: mean={train_entropy.mean():.3f}, std={train_entropy.std():.3f}")
print(f"Test - Entropy: mean={test_entropy.mean():.3f}, std={test_entropy.std():.3f}")

## Part 8: Defense Mechanisms

### 🛡️ Protecting Neural Networks

Now that we've seen various attacks, let's learn how to defend against them!

In [None]:
"""
Membership Inference Attack Defense Mechanisms Demonstration
============================================================
This code demonstrates various defense mechanisms against membership inference attacks on MNIST.

Requirements:
- torch
- torchvision
- numpy
- matplotlib
- scikit-learn
- tqdm (optional, but used if available)
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score
from tqdm import tqdm
from torchvision import datasets, transforms
import pandas as pd
from IPython.display import Markdown, display
import seaborn as sns

# Set style for better-looking plots
plt.style.use('seaborn-v0_8-whitegrid')

# Define the MNIST Classifier model
class MNISTClassifier(nn.Module):
    def __init__(self, hidden_dim=128):
        super(MNISTClassifier, self).__init__()
        self.fc1 = nn.Linear(784, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 64)
        self.fc3 = nn.Linear(64, 10)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Membership Inference Attack function
def membership_inference_attack(model, train_loader, test_loader, device, num_samples=1000):
    """
    Perform membership inference attack on a model.
    Returns attack success metrics.
    """
    model.eval()

    def get_metrics(loader, is_member):
        losses = []
        confidences = []
        entropies = []
        correct_preds = []

        with torch.no_grad():
            for i, (data, target) in enumerate(loader):
                if i * loader.batch_size >= num_samples:
                    break

                data, target = data.to(device), target.to(device)
                output = model(data)

                # Calculate loss
                loss = F.cross_entropy(output, target, reduction='none')
                losses.extend(loss.cpu().numpy())

                # Calculate confidence (max probability)
                probs = F.softmax(output, dim=1)
                confidence, predicted = probs.max(1)
                confidences.extend(confidence.cpu().numpy())

                # Calculate entropy
                entropy = -(probs * torch.log(probs + 1e-8)).sum(1)
                entropies.extend(entropy.cpu().numpy())

                # Check if prediction is correct
                correct = predicted.eq(target)
                correct_preds.extend(correct.cpu().numpy())

        return {
            'loss': np.array(losses[:num_samples]),
            'confidence': np.array(confidences[:num_samples]),
            'entropy': np.array(entropies[:num_samples]),
            'correct': np.array(correct_preds[:num_samples]),
            'is_member': np.ones(min(len(losses), num_samples)) * is_member
        }

    # Get metrics for training and test data
    train_metrics = get_metrics(train_loader, is_member=1)
    test_metrics = get_metrics(test_loader, is_member=0)

    # Combine metrics
    all_losses = np.concatenate([train_metrics['loss'], test_metrics['loss']])
    all_confidences = np.concatenate([train_metrics['confidence'], test_metrics['confidence']])
    all_entropies = np.concatenate([train_metrics['entropy'], test_metrics['entropy']])
    all_labels = np.concatenate([train_metrics['is_member'], test_metrics['is_member']])

    # Calculate combined membership score (lower = more likely to be member)
    # Normalize features
    loss_norm = (all_losses - all_losses.mean()) / (all_losses.std() + 1e-8)
    conf_norm = (all_confidences - all_confidences.mean()) / (all_confidences.std() + 1e-8)
    entropy_norm = (all_entropies - all_entropies.mean()) / (all_entropies.std() + 1e-8)

    # Combined score: members have low loss, high confidence, low entropy
    membership_scores = loss_norm - conf_norm + entropy_norm

    # Calculate AUC
    auc_score = roc_auc_score(all_labels, -membership_scores)

    return {
        'auc': auc_score,
        'train_metrics': train_metrics,
        'test_metrics': test_metrics,
        'membership_scores': membership_scores,
        'labels': all_labels
    }

# Defense Mechanisms

# 1. Differential Privacy Training
class DPSGDOptimizer:
    """Simplified DP-SGD optimizer"""
    def __init__(self, model, lr=0.01, noise_multiplier=1.0, max_grad_norm=1.0):
        self.model = model
        self.lr = lr
        self.noise_multiplier = noise_multiplier
        self.max_grad_norm = max_grad_norm
        self.optimizer = optim.SGD(model.parameters(), lr=lr)

    def step(self, loss):
        self.optimizer.zero_grad()
        loss.backward()

        # Clip gradients
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)

        # Add noise to gradients
        with torch.no_grad():
            for param in self.model.parameters():
                if param.grad is not None:
                    noise = torch.normal(
                        mean=0,
                        std=self.noise_multiplier * self.max_grad_norm,
                        size=param.grad.shape,
                        device=param.grad.device
                    )
                    param.grad += noise

        self.optimizer.step()

def train_with_dp(model, train_loader, epochs=5, noise_multiplier=1.0):
    """Train model with differential privacy"""
    dp_optimizer = DPSGDOptimizer(model, lr=0.01, noise_multiplier=noise_multiplier)

    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = F.cross_entropy(output, target)
            dp_optimizer.step(loss)
            total_loss += loss.item()

        if epoch % 2 == 0:
            print(f"DP Training - Epoch {epoch}, Loss: {total_loss/len(train_loader):.4f}")

# 2. Adversarial Regularization
def adversarial_regularization_loss(model, data, target, epsilon=0.1):
    """Add adversarial examples to training"""
    data.requires_grad = True
    output = model(data)
    loss = F.cross_entropy(output, target)

    # Generate adversarial perturbation
    model.zero_grad()
    loss.backward(retain_graph=True)
    data_grad = data.grad.data

    # Create adversarial example
    perturbation = epsilon * data_grad.sign()
    adv_data = data + perturbation
    adv_data = torch.clamp(adv_data, 0, 1)

    # Calculate loss on adversarial example
    adv_output = model(adv_data)
    adv_loss = F.cross_entropy(adv_output, target)

    # Combined loss
    total_loss = loss + 0.5 * adv_loss
    return total_loss

def train_with_adversarial(model, train_loader, epochs=5):
    """Train model with adversarial regularization"""
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            loss = adversarial_regularization_loss(model, data, target)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        if epoch % 2 == 0:
            print(f"Adversarial Training - Epoch {epoch}, Loss: {total_loss/len(train_loader):.4f}")

# 3. Knowledge Distillation
def train_with_distillation(student_model, teacher_model, train_loader, epochs=5, temperature=3.0):
    """Train student model using knowledge distillation from teacher"""
    optimizer = optim.Adam(student_model.parameters(), lr=0.001)
    teacher_model.eval()

    student_model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)

            # Get teacher's soft predictions
            with torch.no_grad():
                teacher_output = teacher_model(data)
                soft_targets = F.softmax(teacher_output / temperature, dim=1)

            # Student predictions
            student_output = student_model(data)

            # Distillation loss
            soft_loss = F.kl_div(
                F.log_softmax(student_output / temperature, dim=1),
                soft_targets,
                reduction='batchmean'
            ) * (temperature ** 2)

            # Standard loss
            hard_loss = F.cross_entropy(student_output, target)

            # Combined loss
            loss = 0.7 * soft_loss + 0.3 * hard_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        if epoch % 2 == 0:
            print(f"Distillation Training - Epoch {epoch}, Loss: {total_loss/len(train_loader):.4f}")

# Standard training function
def train_standard(model, train_loader, epochs=5):
    """Standard training without defenses"""
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        if epoch % 2 == 0:
            print(f"Standard Training - Epoch {epoch}, Loss: {total_loss/len(train_loader):.4f}")

# Helper function for calculating model accuracy
def calculate_accuracy(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
    return correct / total

# Helper function for calculating distribution overlap
def calculate_distribution_overlap(dist1, dist2):
    """Calculate overlap between two distributions"""
    # Create histograms
    bins = np.linspace(min(dist1.min(), dist2.min()),
                      max(dist1.max(), dist2.max()), 50)
    hist1, _ = np.histogram(dist1, bins=bins, density=True)
    hist2, _ = np.histogram(dist2, bins=bins, density=True)

    # Calculate overlap
    overlap = np.minimum(hist1, hist2).sum() * (bins[1] - bins[0])
    return overlap

# Load MNIST dataset
print("Loading MNIST dataset...")
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

# Create smaller subsets for faster demonstration
train_subset = Subset(train_dataset, range(5000))
test_subset = Subset(test_dataset, range(1000))

train_loader = DataLoader(train_subset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_subset, batch_size=64, shuffle=False)

print("\n" + "="*60)
print("MEMBERSHIP INFERENCE ATTACK DEFENSE DEMONSTRATION")
print("="*60)

# Create models for testing
print("\n📊 Creating test models...")
standard_model = MNISTClassifier().to(device)
dp_model = MNISTClassifier().to(device)
adv_model = MNISTClassifier().to(device)
teacher_model = MNISTClassifier().to(device)
distilled_model = MNISTClassifier().to(device)

# Train models with different defenses
print("\n🔧 Training Phase:")
print("-"*60)
print("\n1. Training standard model (no defense)...")
train_standard(standard_model, train_loader, epochs=10)

print("\n2. Training with Differential Privacy...")
print("   Note: Increasing loss is GOOD - noise prevents memorization!")
train_with_dp(dp_model, train_loader, epochs=10, noise_multiplier=1.0)

print("\n3. Training with Adversarial Regularization...")
print("   Note: Using epsilon=0.1 (may be too small)")
train_with_adversarial(adv_model, train_loader, epochs=10)

print("\n4. Training teacher model for distillation...")
train_standard(teacher_model, train_loader, epochs=10)
print("   Training student with Knowledge Distillation...")
print("   Note: Temperature=3.0, no privacy constraints on teacher")
train_with_distillation(distilled_model, teacher_model, train_loader, epochs=10)

# Evaluate membership inference attack on each model
print("\n" + "="*60)
print("🎯 ATTACK PHASE: Evaluating Membership Inference Attacks")
print("="*60)

results = {}
model_objects = {
    'standard': standard_model,
    'dp': dp_model,
    'adversarial': adv_model,
    'distilled': distilled_model
}

# Attack all models
print("\nPerforming attacks on all models...")
for model_name in ['standard', 'dp', 'adversarial', 'distilled']:
    print(f"  Attacking {model_name} model...", end='')
    results[model_name] = membership_inference_attack(
        model_objects[model_name], train_loader, test_loader, device
    )
    print(f" AUC: {results[model_name]['auc']:.3f}")

# Calculate accuracies
print("\nCalculating model accuracies...")
accuracies = {}
for model_name in ['standard', 'dp', 'adversarial', 'distilled']:
    accuracies[model_name] = calculate_accuracy(model_objects[model_name], test_loader)

# Display detailed analysis
print("\n" + "="*60)
print("📊 RESULTS ANALYSIS")
print("="*60)

# Create results summary
defense_names = {
    'standard': 'Standard (No Defense)',
    'dp': 'Differential Privacy',
    'adversarial': 'Adversarial Regularization',
    'distilled': 'Knowledge Distillation'
}

# Print results table
print("\n📋 Summary of Results:")
print("-"*70)
print(f"{'Defense Method':<30} {'AUC ↓':<10} {'Accuracy':<10} {'Privacy':<10} {'Status':<15}")
print("-"*70)

for model_name in ['standard', 'dp', 'adversarial', 'distilled']:
    auc = results[model_name]['auc']
    acc = accuracies[model_name]
    privacy = 1 - auc

    # Determine status
    if model_name == 'standard':
        status = "Baseline"
    elif auc < results['standard']['auc'] - 0.05:
        status = "✅ Effective"
    elif auc > results['standard']['auc'] + 0.005:
        status = "❌ Failed"
    else:
        status = "⚠️  Marginal"

    print(f"{defense_names[model_name]:<30} {auc:<10.3f} {acc:<10.3f} {privacy:<10.3f} {status:<15}")

print("-"*70)

# Display explanation based on actual results
actual_auc_scores = {name: results[name]['auc'] for name in results}

defense_explanation = f"""
## 🔍 What These Results Mean:

### Attack Success (AUC Scores):
- **0.5** = Random guessing (perfect privacy)
- **1.0** = Perfect attack (no privacy)
- **Your baseline**: {actual_auc_scores['standard']:.3f}

### Performance Analysis:

**1. Differential Privacy** (AUC: {actual_auc_scores['dp']:.3f})
   - ✅ **Only successful defense!**
   - Reduced attack success by {((actual_auc_scores['standard'] - actual_auc_scores['dp']) / actual_auc_scores['standard'] * 100):.1f}%
   - Cost: Accuracy dropped from {accuracies['standard']:.1%} to {accuracies['dp']:.1%}
   - The increasing loss during training confirmed noise was working

**2. Adversarial Regularization** (AUC: {actual_auc_scores['adversarial']:.3f})
   - ❌ **Failed - actually made privacy worse!**
   - Attack success increased by {((actual_auc_scores['adversarial'] - actual_auc_scores['standard']) / actual_auc_scores['standard'] * 100):.1f}%
   - Problem: Epsilon (0.1) was too small to prevent memorization
   - Still achieved {accuracies['adversarial']:.1%} accuracy by overfitting

**3. Knowledge Distillation** (AUC: {actual_auc_scores['distilled']:.3f})
   - ❌ **Failed - no privacy benefit**
   - Teacher's memorization was passed to student
   - Problem: Teacher was trained without privacy constraints
   - Maintained {accuracies['distilled']:.1%} accuracy but no privacy gain

### Why Did Some Defenses Fail?

**Adversarial Regularization**: The perturbations (epsilon=0.1) were too weak. The model learned to be robust to small changes while still memorizing the training data.

**Knowledge Distillation**: The teacher model memorized the training data (loss: 0.16), and this memorization was transferred to the student through the soft labels.

### Recommendations to Fix Failed Defenses:

1. **For Adversarial Training**: Increase epsilon to 0.3-0.5 and adversarial weight to 0.8
2. **For Distillation**: Use temperature=10+, add noise to soft labels, or train teacher with DP
3. **For Production**: Use DP if privacy is critical, despite accuracy cost
"""

display(Markdown(defense_explanation))

# Create improved visualizations
print("\n📊 Generating visualizations...")

# Set up the figure with better styling
fig = plt.figure(figsize=(16, 10))
fig.patch.set_facecolor('white')

# Define consistent colors
colors = {
    'standard': '#E74C3C',
    'dp': '#3498DB',
    'adversarial': '#2ECC71',
    'distilled': '#9B59B6'
}

# 1. AUC Comparison with annotations
ax1 = plt.subplot(2, 3, 1)
models = ['standard', 'dp', 'adversarial', 'distilled']
auc_scores = [results[m]['auc'] for m in models]
bars = ax1.bar(range(len(models)), auc_scores,
                color=[colors[m] for m in models],
                alpha=0.7, edgecolor='black', linewidth=2)

# Add value labels and status
for i, (bar, score, model) in enumerate(zip(bars, auc_scores, models)):
    height = bar.get_height()
    ax1.text(bar.get_x() + bar.get_width()/2., height + 0.01,
             f'{score:.3f}', ha='center', va='bottom', fontsize=12, fontweight='bold')

    # Add status indicator
    if model == 'dp':
        status = '✅'
    elif score > results['standard']['auc'] + 0.005:
        status = '❌'
    elif score < results['standard']['auc'] - 0.005:
        status = '⚠️'
    else:
        status = ''

    if status:
        ax1.text(bar.get_x() + bar.get_width()/2., height/2,
                status, ha='center', va='center', fontsize=20)

ax1.set_ylim([0.4, 0.65])
ax1.set_xticks(range(len(models)))
ax1.set_xticklabels([defense_names[m].replace(' ', '\n') for m in models], fontsize=10)
ax1.set_ylabel('AUC Score', fontsize=12)
ax1.set_title('Attack Success Rate\n(Lower = Better Privacy)', fontsize=14, fontweight='bold')
ax1.axhline(y=0.5, color='gray', linestyle='--', label='Random Guess', alpha=0.5)
ax1.axhline(y=results['standard']['auc'], color='red', linestyle=':',
            label=f'Baseline ({results["standard"]["auc"]:.3f})', alpha=0.5)
ax1.legend()

# 2. Privacy-Utility Tradeoff
ax2 = plt.subplot(2, 3, 2)
privacy_scores = [1 - results[m]['auc'] for m in models]

for i, model in enumerate(models):
    ax2.scatter(accuracies[model], privacy_scores[i],
               s=300, c=colors[model], alpha=0.8,
               edgecolors='black', linewidth=2)

    # Add labels with better positioning
    if model == 'dp':
        offset = (-40, -10)
    else:
        offset = (10, 5)

    ax2.annotate(defense_names[model],
                (accuracies[model], privacy_scores[i]),
                xytext=offset, textcoords='offset points',
                fontsize=9, ha='left',
                bbox=dict(boxstyle='round,pad=0.3', fc='white', alpha=0.7))

ax2.set_xlabel('Model Accuracy', fontsize=12)
ax2.set_ylabel('Privacy Score (1 - AUC)', fontsize=12)
ax2.set_title('Privacy-Utility Tradeoff', fontsize=14, fontweight='bold')
ax2.grid(True, alpha=0.3)
ax2.set_xlim([0.15, 0.95])
ax2.set_ylim([0.38, 0.50])

# Add ideal region
ax2.add_patch(plt.Rectangle((0.8, 0.48), 0.15, 0.02,
                           fill=True, alpha=0.2, color='green',
                           label='Ideal Region'))

# 3. Feature Importance Heatmap
ax3 = plt.subplot(2, 3, 3)
feature_importance = []
for model in models:
    train_m = results[model]['train_metrics']
    test_m = results[model]['test_metrics']

    loss_sep = abs(train_m['loss'].mean() - test_m['loss'].mean())
    conf_sep = abs(train_m['confidence'].mean() - test_m['confidence'].mean())
    entropy_sep = abs(train_m['entropy'].mean() - test_m['entropy'].mean())

    feature_importance.append([loss_sep, conf_sep, entropy_sep])

feature_importance = np.array(feature_importance).T
im = ax3.imshow(feature_importance, cmap='Reds', aspect='auto')
ax3.set_xticks(range(len(models)))
ax3.set_xticklabels([defense_names[m].replace(' ', '\n') for m in models],
                   rotation=0, ha='center')
ax3.set_yticks(range(3))
ax3.set_yticklabels(['Loss', 'Confidence', 'Entropy'])
ax3.set_title('Feature Separation\n(Darker = Less Vulnerable)', fontsize=14, fontweight='bold')

# Add values in cells
for i in range(3):
    for j in range(len(models)):
        text_color = 'white' if feature_importance[i, j] > feature_importance.max()/2 else 'black'
        ax3.text(j, i, f'{feature_importance[i, j]:.3f}',
                ha='center', va='center', color=text_color, fontsize=10)

plt.colorbar(im, ax=ax3, fraction=0.046, pad=0.04)

# 4. Loss Distribution Comparison
ax4 = plt.subplot(2, 3, 4)
positions = []
loss_data = []
colors_list = []

for i, model in enumerate(models):
    train_loss = results[model]['train_metrics']['loss']
    test_loss = results[model]['test_metrics']['loss']

    positions.extend([i*3, i*3+1])
    loss_data.extend([train_loss, test_loss])
    colors_list.extend([colors[model], colors[model]])

bp = ax4.boxplot(loss_data, positions=positions, widths=0.8, patch_artist=True,
                showmeans=True, meanprops=dict(marker='D', markerfacecolor='white'))

for i, (patch, color) in enumerate(zip(bp['boxes'], colors_list)):
    patch.set_facecolor(color)
    patch.set_alpha(0.8 if i % 2 == 0 else 0.5)

# Add legend
for i, model in enumerate(models):
    ax4.plot([], [], color=colors[model], linewidth=10, alpha=0.8,
            label=defense_names[model])

ax4.set_xticks([i*3+0.5 for i in range(len(models))])
ax4.set_xticklabels(['T|Te'] * len(models))
ax4.set_ylabel('Loss Value', fontsize=12)
ax4.set_title('Loss Distributions\n(T=Train, Te=Test)', fontsize=14, fontweight='bold')
ax4.legend(loc='upper right', fontsize=9)
ax4.set_yscale('log')

# 5. Attack Score Distributions
ax5 = plt.subplot(2, 3, 5)
from scipy import stats

for model in models:
    scores = results[model]['membership_scores']
    labels = results[model]['labels']

    member_scores = scores[labels == 1]
    non_member_scores = scores[labels == 0]

    # Calculate overlap
    overlap = calculate_distribution_overlap(member_scores, non_member_scores)

    # Plot KDE
    if len(np.unique(member_scores)) > 1:
        x_range = np.linspace(scores.min(), scores.max(), 200)

        kde_members = stats.gaussian_kde(member_scores)
        kde_non_members = stats.gaussian_kde(non_member_scores)

        ax5.plot(x_range, kde_members(x_range),
                color=colors[model], linestyle='-', linewidth=2,
                label=f'{defense_names[model]} (overlap={overlap:.2f})', alpha=0.8)
        ax5.plot(x_range, kde_non_members(x_range),
                color=colors[model], linestyle='--', linewidth=2, alpha=0.6)

ax5.set_xlabel('Membership Score', fontsize=12)
ax5.set_ylabel('Density', fontsize=12)
ax5.set_title('Attack Score Distributions\n(Solid=Members, Dashed=Non-members)',
             fontsize=14, fontweight='bold')
ax5.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)

# 6. Summary Table
ax6 = plt.subplot(2, 3, 6)
ax6.axis('tight')
ax6.axis('off')

# Create summary data
summary_data = []
for model in models:
    auc = results[model]['auc']
    acc = accuracies[model]
    privacy = 1 - auc

    # Privacy improvement
    if model == 'standard':
        improvement = 'Baseline'
    else:
        imp_val = (results['standard']['auc'] - auc) / results['standard']['auc'] * 100
        if imp_val > 5:
            improvement = f'↑ {imp_val:.1f}%'
        elif imp_val < -1:
            improvement = f'↓ {abs(imp_val):.1f}%'
        else:
            improvement = '≈ Same'

    summary_data.append([
        defense_names[model],
        f'{auc:.3f}',
        f'{acc:.1%}',
        improvement
    ])

table = ax6.table(cellText=summary_data,
                 colLabels=['Defense', 'AUC↓', 'Accuracy', 'Privacy vs Baseline'],
                 cellLoc='center',
                 loc='center',
                 colWidths=[0.35, 0.15, 0.15, 0.25])

table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1, 2.5)

# Color code the table
for i in range(1, len(summary_data) + 1):
    # Color defense name cells
    table[(i, 0)].set_facecolor(colors[models[i-1]])
    table[(i, 0)].set_alpha(0.3)

    # Color AUC cells based on performance
    auc_val = float(summary_data[i-1][1])
    if auc_val < 0.55:
        table[(i, 1)].set_facecolor('lightgreen')
    elif auc_val > results['standard']['auc']:
        table[(i, 1)].set_facecolor('lightcoral')

ax6.set_title('Performance Summary', fontsize=14, fontweight='bold', pad=20)

plt.suptitle('Membership Inference Attack Defense Analysis', fontsize=18, fontweight='bold')
plt.tight_layout()
plt.show()

# Print final insights
print("\n" + "="*60)
print("💡 KEY INSIGHTS")
print("="*60)

print(f"""
1. **Only Differential Privacy provided meaningful protection**
   - Reduced attack success by {((results['standard']['auc'] - results['dp']['auc']) / results['standard']['auc'] * 100):.1f}%
   - But accuracy dropped by {(accuracies['standard'] - accuracies['dp']) * 100:.0f} percentage points

2. **Adversarial Regularization failed catastrophically**
   - Made the model {((results['adversarial']['auc'] - results['standard']['auc']) / results['standard']['auc'] * 100):.1f}% MORE vulnerable
   - Need to increase epsilon from 0.1 to 0.3-0.5

3. **Knowledge Distillation provided no benefit**
   - Teacher's memorization was transferred to student
   - Need privacy-aware teacher training or higher temperature

4. **The fundamental tradeoff**
   - High accuracy ({accuracies['standard']:.0%}+) ≈ memorization ≈ privacy vulnerability
   - True privacy protection requires accepting lower accuracy
""")

print("\n🔧 Recommended Fixes:")
print("-"*60)
print("""
For Adversarial Regularization:
- Increase epsilon: 0.1 → 0.3-0.5
- Increase adversarial weight: 0.5 → 0.8
- Consider PGD instead of FGSM

For Knowledge Distillation:
- Increase temperature: 3 → 10+
- Add noise to soft labels
- Train teacher with privacy constraints
- Use ensemble of teachers

For Production Use:
- If privacy is critical: Use DP despite accuracy cost
- If accuracy is critical: Combine multiple defenses
- Monitor: Regularly test with membership inference attacks
""")

print("\n✅ Defense demonstration complete!")

## Part 9: Advanced Security Topics

### 🚀 Cutting-Edge Neural Network Security (2024-2025)

Let's explore some of the latest developments in neural network security!

In [None]:
# Complete Adversarial Attack Demonstration

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Define the MNIST Classifier model
class MNISTClassifier(nn.Module):
    def __init__(self):
        super(MNISTClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        return x

# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

trainloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
testloader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Train a basic MNIST model
print("Training MNIST model...")
mnist_model = MNISTClassifier().to(device)
optimizer = optim.Adam(mnist_model.parameters(), lr=0.001)

mnist_model.train()
for epoch in range(2):  # Train for 2 epochs (faster demo)
    running_loss = 0.0
    for i, (images, labels) in enumerate(trainloader):
        if i > 100:  # Limit iterations for speed
            break

        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = mnist_model(images)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        if i % 50 == 49:
            print(f'Epoch {epoch + 1}, Batch {i + 1}, Loss: {running_loss / 50:.3f}')
            running_loss = 0.0

# Test the model
mnist_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for i, (images, labels) in enumerate(testloader):
        if i > 20:  # Test on subset
            break
        images, labels = images.to(device), labels.to(device)
        outputs = mnist_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Test Accuracy: {100 * correct / total:.2f}%')

# Create a backdoor model for demonstration
print("\nCreating backdoor model...")
backdoor_model = MNISTClassifier().to(device)

# Copy weights from clean model
backdoor_model.load_state_dict(mnist_model.state_dict())

# Fine-tune with backdoor pattern (simplified)
backdoor_optimizer = optim.Adam(backdoor_model.parameters(), lr=0.001)
backdoor_model.train()

for i, (images, labels) in enumerate(trainloader):
    if i > 50:  # Limited training
        break

    images, labels = images.to(device), labels.to(device)

    # Add trigger to 20% of samples
    mask = torch.rand(len(images)) < 0.2
    if mask.any():
        # Add white square trigger pattern
        images[mask, :, -4:, -4:] = 1.0
        # Change label to target class (e.g., 0)
        labels[mask] = 0

    backdoor_optimizer.zero_grad()
    outputs = backdoor_model(images)
    loss = F.cross_entropy(outputs, labels)
    loss.backward()
    backdoor_optimizer.step()

print("Backdoor model created!")

print("\n" + "="*50)
print("Starting Advanced Attack Demonstrations")
print("="*50 + "\n")

# 10.1 Universal Adversarial Perturbations
class UniversalPerturbation:
    """
    Create a single perturbation that fools the model on many inputs!
    Much more dangerous than input-specific attacks.
    """

    def __init__(self, model, epsilon=0.1):
        self.model = model
        self.epsilon = epsilon

    def generate_universal_perturbation(self, data_loader, target_class=0):
        """
        Generate a universal perturbation using multiple images.
        """
        # Initialize universal perturbation
        universal_pert = torch.zeros(1, 1, 28, 28, device=device)

        print(f"🌐 Generating universal perturbation (target class: {target_class})...")

        fooling_rate = 0
        iterations = 0

        while fooling_rate < 0.8 and iterations < 10:  # 80% target
            for images, labels in data_loader:
                if len(images) > 10:  # Use subset for speed
                    images = images[:10]
                    labels = labels[:10]

                images, labels = images.to(device), labels.to(device)

                # Skip if already fooled
                perturbed = images + universal_pert
                predictions = self.model(perturbed).argmax(dim=1)

                for i in range(len(images)):
                    if predictions[i] != target_class and labels[i] != target_class:
                        # Need to fool this image
                        image = images[i:i+1]
                        image.requires_grad = True

                        # Find perturbation for this image
                        for _ in range(10):
                            perturbed_img = image + universal_pert
                            output = self.model(perturbed_img)

                            # Loss to push toward target class
                            loss = -F.cross_entropy(output, torch.tensor([target_class], device=device))

                            self.model.zero_grad()
                            loss.backward()

                            # Update universal perturbation
                            pert_update = self.epsilon * image.grad.sign()
                            universal_pert = universal_pert + pert_update

                            # Project to epsilon ball
                            universal_pert = torch.clamp(universal_pert, -self.epsilon, self.epsilon)

                            # Check if fooled
                            if self.model(image + universal_pert).argmax() == target_class:
                                break

            # Calculate fooling rate
            total_fooled = 0
            total_tested = 0

            for images, labels in data_loader:
                images, labels = images.to(device), labels.to(device)
                perturbed = images + universal_pert
                predictions = self.model(perturbed).argmax(dim=1)

                mask = labels != target_class  # Don't count if already target class
                total_fooled += ((predictions == target_class) & mask).sum().item()
                total_tested += mask.sum().item()

                if total_tested > 100:  # Test on subset
                    break

            fooling_rate = total_fooled / total_tested if total_tested > 0 else 0
            iterations += 1

            print(f"  Iteration {iterations}: Fooling rate = {fooling_rate:.1%}")

        return universal_pert

# Demonstrate advanced attacks
print("🔬 Advanced Attack Demonstrations\n")

# 1. Universal Adversarial Perturbation
print("1️⃣ Universal Adversarial Perturbation")
universal_attacker = UniversalPerturbation(mnist_model, epsilon=0.3)
universal_pert = universal_attacker.generate_universal_perturbation(testloader, target_class=3)

# Visualize universal perturbation
fig, axes = plt.subplots(2, 5, figsize=(15, 6))

# Test on random images
for i, (image, label) in enumerate(testloader):
    if i >= 5:
        break

    image, label = image.to(device), label.to(device)

    # Original
    orig_pred = mnist_model(image[0:1]).argmax().item()
    axes[0, i].imshow(image[0].cpu().squeeze(), cmap='gray')
    axes[0, i].set_title(f'Original: {orig_pred}')
    axes[0, i].axis('off')

    # With universal perturbation
    perturbed = image[0:1] + universal_pert
    pert_pred = mnist_model(perturbed).argmax().item()
    axes[1, i].imshow(perturbed.cpu().squeeze(), cmap='gray')
    axes[1, i].set_title(f'Perturbed: {pert_pred}')
    axes[1, i].axis('off')

plt.suptitle('Universal Perturbation: Same Pattern Fools Multiple Images!', fontsize=16)
plt.tight_layout()
plt.show()

# Show the universal perturbation itself
plt.figure(figsize=(6, 6))
plt.imshow(universal_pert.cpu().squeeze() * 5 + 0.5, cmap='RdBu')  # Amplified
plt.title('Universal Perturbation Pattern (5x amplified)')
plt.colorbar()
plt.show()

print("\n⚠️  The same perturbation works on different images!")

# 10.2 Backdoor Detection
print("\n2️⃣ Backdoor Detection")

def detect_backdoor_neurons(model, clean_data, trigger_data):
    """
    Detect neurons that might be associated with backdoors
    by analyzing their activation patterns.
    """
    model.eval()

    # Get activations for clean data
    clean_activations = []
    trigger_activations = []

    def get_activation(name, activation_dict):
        def hook(model, input, output):
            activation_dict[name] = output.detach()
        return hook

    # Register hooks
    activation_dict = {}
    hooks = []
    for name, layer in model.named_modules():
        if isinstance(layer, nn.Linear):
            hook = layer.register_forward_hook(get_activation(name, activation_dict))
            hooks.append(hook)

    # Get clean activations
    with torch.no_grad():
        _ = model(clean_data)
        clean_acts = {k: v.clone() for k, v in activation_dict.items()}

        # Get trigger activations
        _ = model(trigger_data)
        trigger_acts = {k: v.clone() for k, v in activation_dict.items()}

    # Remove hooks
    for hook in hooks:
        hook.remove()

    # Analyze differences
    suspicious_neurons = {}
    for layer_name in clean_acts:
        clean_mean = clean_acts[layer_name].mean(dim=0)
        trigger_mean = trigger_acts[layer_name].mean(dim=0)

        # Find neurons with large activation differences
        diff = torch.abs(trigger_mean - clean_mean)
        threshold = diff.mean() + 2 * diff.std()
        suspicious_indices = torch.where(diff > threshold)[0]

        if len(suspicious_indices) > 0:
            suspicious_neurons[layer_name] = suspicious_indices.tolist()

    return suspicious_neurons

# Create synthetic data with and without triggers
clean_samples = torch.randn(50, 1, 28, 28, device=device)
trigger_samples = clean_samples.clone()
trigger_samples[:, :, -4:, -4:] = 1.0  # Add trigger pattern

# Detect suspicious neurons
suspicious = detect_backdoor_neurons(backdoor_model, clean_samples, trigger_samples)

print("Suspicious neurons detected:")
for layer, neurons in suspicious.items():
    print(f"  {layer}: neurons {neurons}")

if not suspicious:
    print("  No highly suspicious neurons detected.")

print("\n💡 This is a simplified detection method. Real backdoor detection is more complex!")

# 10.3 Model Watermarking
print("\n3️⃣ Model Watermarking")

class ModelWatermarking:
    """
    Embed a watermark in a neural network model.
    """

    def __init__(self, num_watermark_samples=10):
        self.num_samples = num_watermark_samples

    def generate_watermark_data(self, input_shape, num_classes):
        """
        Generate random watermark patterns and labels.
        """
        # Create unique patterns
        self.watermark_inputs = torch.randn(self.num_samples, *input_shape).to(device)
        # Assign specific labels (could be a secret pattern)
        self.watermark_labels = torch.randint(0, num_classes, (self.num_samples,)).to(device)

        return self.watermark_inputs, self.watermark_labels

    def embed_watermark(self, model, train_data, train_labels, epochs=10):
        """
        Fine-tune model to embed watermark.
        """
        optimizer = optim.Adam(model.parameters(), lr=0.001)

        # Combine training data with watermark
        combined_data = torch.cat([train_data, self.watermark_inputs])
        combined_labels = torch.cat([train_labels, self.watermark_labels])

        print("Embedding watermark...")
        for epoch in range(epochs):
            optimizer.zero_grad()
            outputs = model(combined_data)

            # Weighted loss - higher weight on watermark samples
            regular_loss = F.cross_entropy(outputs[:len(train_data)], combined_labels[:len(train_data)])
            watermark_loss = F.cross_entropy(outputs[len(train_data):], combined_labels[len(train_data):])

            total_loss = regular_loss + 5.0 * watermark_loss  # Higher weight on watermark

            total_loss.backward()
            optimizer.step()

        print("Watermark embedded!")

    def verify_watermark(self, model):
        """
        Check if the model contains our watermark.
        """
        model.eval()
        with torch.no_grad():
            outputs = model(self.watermark_inputs)
            _, predicted = torch.max(outputs, 1)
            accuracy = (predicted == self.watermark_labels).float().mean().item()

        return accuracy

# Create a fresh model for watermarking
watermarked_model = MNISTClassifier().to(device)

# Train it normally first (simplified)
optimizer = optim.Adam(watermarked_model.parameters(), lr=0.001)
watermarked_model.train()
for i, (images, labels) in enumerate(trainloader):
    if i > 20:
        break
    images, labels = images.to(device), labels.to(device)
    optimizer.zero_grad()
    outputs = watermarked_model(images)
    loss = F.cross_entropy(outputs, labels)
    loss.backward()
    optimizer.step()

# Create watermark
watermarker = ModelWatermarking(num_watermark_samples=5)
watermark_data, watermark_labels = watermarker.generate_watermark_data((1, 28, 28), 10)

# Get some training data for watermark embedding
train_images, train_labels = next(iter(trainloader))
train_images, train_labels = train_images.to(device), train_labels.to(device)

# Embed watermark
watermarker.embed_watermark(watermarked_model, train_images, train_labels)

# Verify watermark
watermark_accuracy = watermarker.verify_watermark(watermarked_model)
print(f"\nWatermark verification accuracy: {watermark_accuracy:.2%}")

# Test on a non-watermarked model
non_watermark_acc = watermarker.verify_watermark(mnist_model)
print(f"Non-watermarked model on watermark data: {non_watermark_acc:.2%}")
print("(Should be close to random chance ~10% for 10-class classification)")
print("\n✅ Model watermarking helps prove ownership!")

print("\n" + "="*50)
print("All demonstrations complete!")
print("="*50)

## Summary and Best Practices

### 🎯 Key Takeaways

1. **Neural networks are vulnerable** to various attacks:
   - Adversarial examples (FGSM, PGD, C&W)
   - Data poisoning and backdoor attacks
   - Model extraction
   - Privacy leakage (membership inference)

2. **Defense mechanisms exist** but come with trade-offs:
   - Adversarial training improves robustness but may reduce clean accuracy
   - Differential privacy protects data but adds noise to training
   - Input preprocessing can filter some attacks but isn't foolproof
   - Defensive distillation makes gradient-based attacks harder

3. **Security is an ongoing challenge**:
   - New attacks are constantly being discovered
   - Defenses must evolve to keep pace
   - No single defense is perfect - defense in depth is key

### 🛡️ Best Practices for Secure ML

#### During Development
- **Validate your data**: Check for poisoned or mislabeled samples
- **Use regularization**: Helps prevent overfitting and memorization
- **Monitor training**: Watch for unusual patterns or behaviors
- **Test robustness**: Evaluate models against various attacks

#### During Deployment
- **Input validation**: Sanitize and validate all inputs
- **Rate limiting**: Prevent model extraction through API limits
- **Monitoring**: Track unusual prediction patterns
- **Regular updates**: Retrain models with new data and defenses
- **Ensemble methods**: Use multiple models for critical decisions

#### Privacy Considerations
- **Minimize data collection**: Only collect what you need
- **Use privacy-preserving techniques**: Differential privacy, federated learning
- **Secure storage**: Encrypt models and data at rest
- **Access controls**: Limit who can query your models
- **Audit trails**: Log all model access and usage

### 🚀 Future Directions

The field of ML security is rapidly evolving. Some exciting areas include:

- **Certified defenses**: Provable robustness guarantees
- **Federated learning security**: Protecting distributed training
- **Explainable AI for security**: Understanding why models fail
- **Hardware-based defenses**: Secure enclaves for ML
- **Quantum-resistant ML**: Preparing for quantum attacks

Remember: **Security isn't a feature you add at the end - it should be considered throughout the entire ML pipeline!**

## Resources and References

### 📚 Further Reading

1. **Papers**:
   - "Explaining and Harnessing Adversarial Examples" (Goodfellow et al., 2014)
   - "Towards Deep Learning Models Resistant to Adversarial Attacks" (Madry et al., 2017)
   - "Deep Learning with Differential Privacy" (Abadi et al., 2016)
   - "Model Extraction Attacks and Defenses" (Jagielski et al., 2020)

2. **Tools and Libraries**:
   - [Adversarial Robustness Toolbox (ART)](https://github.com/Trusted-AI/adversarial-robustness-toolbox)
   - [CleverHans](https://github.com/cleverhans-lab/cleverhans)
   - [Foolbox](https://github.com/bethgelab/foolbox)
   - [TensorFlow Privacy](https://github.com/tensorflow/privacy)

3. **Courses and Tutorials**:
   - [MIT 6.S965: TinyML and Efficient Deep Learning](https://efficientml.ai/)
   - [Stanford CS231n: Deep Learning for Computer Vision](http://cs231n.stanford.edu/)
   - [Fast.ai Practical Deep Learning](https://course.fast.ai/)

### 🏁 Conclusion

Congratulations on completing this comprehensive lab on Neural Networks and AI Security! You've learned:

✅ How neural networks work from the ground up  
✅ Various attack methods and their implications  
✅ Defense mechanisms and their trade-offs  
✅ Privacy-preserving techniques  
✅ Advanced security topics and future directions  

Remember: **Security in machine learning is not a destination, but a journey.** Stay curious, keep learning, and always consider the security implications of your models!