# Gradient Inversion Attack Implementation

This notebook implements a Gradient Inversion Attack to reconstruct training data from gradients. The attack tries to recover the original input data by optimizing dummy data to match the observed gradients.

In [None]:
!pip install torch torchvision numpy matplotlib

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# Simple CNN model
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.fc1 = nn.Linear(32 * 7 * 7, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = nn.MaxPool2d(2)(x)
        x = torch.relu(self.conv2(x))
        x = nn.MaxPool2d(2)(x)
        x = x.view(-1, 32 * 7 * 7)
        x = self.fc1(x)
        return x

In [None]:
def get_gradients(model, inputs, labels):
    """Get gradients for given inputs and labels"""
    model.zero_grad()
    outputs = model(inputs)
    loss = nn.CrossEntropyLoss()(outputs, labels)
    loss.backward()

    gradients = []
    for param in model.parameters():
        if param.requires_grad:
            gradients.append(param.grad.clone())

    return gradients


def gradient_inversion_attack(target_gradients, model, num_iterations=1000, lr=0.1):
    """Perform gradient inversion attack to reconstruct input data"""
    # Initialize random dummy data
    dummy_data = torch.randn(1, 1, 28, 28, requires_grad=True)
    dummy_label = torch.tensor([0])

    optimizer = optim.Adam([dummy_data], lr=lr)

    for i in range(num_iterations):
        optimizer.zero_grad()

        # Get gradients for dummy data
        dummy_gradients = get_gradients(model, dummy_data, dummy_label)

        # Calculate gradient difference
        grad_diff = sum(torch.sum((dg - tg) ** 2)
                        for dg, tg in zip(dummy_gradients, target_gradients))

        # Update dummy data
        grad_diff.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print(f'Iteration {i + 1}, Gradient Difference: {grad_diff.item():.4f}')

    return dummy_data.detach()

In [None]:
# Load MNIST dataset for demo
transform = transforms.Compose([transforms.ToTensor()])
mnist_train = torchvision.datasets.MNIST(root='./data', train=True,
                                         download=True, transform=transform)

# Get a single image and label
real_image, real_label = mnist_train[0]
real_image = real_image.unsqueeze(0)
real_label = torch.tensor([real_label])

# Initialize model
model = SimpleCNN()

# Get target gradients
target_gradients = get_gradients(model, real_image, real_label)

# Perform attack
reconstructed_image = gradient_inversion_attack(target_gradients, model)

In [None]:
# Visualize results
plt.figure(figsize=(10, 5))

plt.subplot(1, 2, 1)
plt.imshow(real_image[0, 0].numpy(), cmap='gray')
plt.title('Original Image')
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(reconstructed_image[0, 0].numpy(), cmap='gray')
plt.title('Reconstructed Image')
plt.axis('off')

plt.tight_layout()
plt.show()

The code above implements a gradient inversion attack that:
1. Creates a simple CNN model
2. Gets gradients from a real image
3. Tries to reconstruct the original image by optimizing dummy data to match the observed gradients
4. Visualizes the original and reconstructed images side by side
