In [None]:
"""
GradInversion Attack: Reconstructing Training Data from Gradients
=================================================================
This module demonstrates how neural network gradients can leak private training data.
It implements optimization-based reconstruction techniques to recover images from gradients.

Key Components:
- Gradient-friendly CNN architecture using smooth activations
- Multi-objective optimization for image reconstruction
- Evaluation metrics (PSNR, SSIM, MSE)
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from datetime import datetime
from tqdm import tqdm
import os
import random
import sys
import traceback

# Check environment and setup
IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    print("Installing compatible PyTorch versions...")
    os.system("pip install --upgrade torch torchvision")
    from google.colab import drive
    drive.mount('/content/drive')
    print("Running in Google Colab")

# Set random seeds for reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

# Device configuration
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

# Constants
CIFAR10_CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                   'dog', 'frog', 'horse', 'ship', 'truck']
IMAGE_SIZE = 32
NUM_CHANNELS = 3
NUM_CLASSES = 10


class SmoothCNN(nn.Module):
    """
    CNN architecture optimized for gradient inversion.
    Uses smooth activations (sigmoid) instead of ReLU to preserve gradient information.
    """

    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            # Block 1: 3 -> 64 channels
            nn.Conv2d(NUM_CHANNELS, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.Sigmoid(),

            # Block 2: 64 -> 128 channels with pooling
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.Sigmoid(),
            nn.AvgPool2d(2),  # Preserves more info than MaxPool

            # Block 3: 128 -> 128 channels
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.Sigmoid(),
        )

        # Fully connected layer
        self.classifier = nn.Linear(128 * 16 * 16, NUM_CLASSES)

    def forward(self, x):
        x = self.features(x)
        x = x.view(-1, 128 * 16 * 16)
        return self.classifier(x)


class FeatureExtractor(nn.Module):
    """Simple feature extractor for perceptual loss computation."""

    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(NUM_CHANNELS, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.Sigmoid(),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.Sigmoid(),
            nn.AvgPool2d(2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.Sigmoid(),
        )
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='sigmoid')

    def forward(self, x):
        if x.shape[2] != IMAGE_SIZE:
            x = F.interpolate(x, size=(IMAGE_SIZE, IMAGE_SIZE), mode='bilinear', align_corners=True)
        return self.features(x)


class GradientInverter:
    """Reconstructs training data from gradients using optimization techniques."""

    def __init__(self, model, feature_extractor, device=DEVICE):
        self.model = model
        self.feature_extractor = feature_extractor
        self.device = device
        self.save_iterations = [0, 1, 2, 5, 10, 20, 50, 100, 200, 300, 500, 700, 999]

    def _initialize_dummy_data(self, target_gradients, batch_size=1):
        """Initialize dummy data using gradient-guided patterns."""
        dummy_data = torch.zeros((batch_size, NUM_CHANNELS, IMAGE_SIZE, IMAGE_SIZE), device=self.device)

        # Use first layer gradients to guide channel importance
        first_layer_grad = target_gradients[0]
        channel_importance = F.softmax(torch.sum(torch.abs(first_layer_grad), dim=[0, 2, 3]), dim=0)

        # Create structured initialization for each image
        for b in range(batch_size):
            for c in range(NUM_CHANNELS):
                # Gaussian-like pattern with channel weighting
                center_x = IMAGE_SIZE // 2 + 5 * (random.random() - 0.5)
                center_y = IMAGE_SIZE // 2 + 5 * (random.random() - 0.5)
                sigma = 10 + 5 * random.random()

                for i in range(IMAGE_SIZE):
                    for j in range(IMAGE_SIZE):
                        dist = np.sqrt((i - center_x)**2 + (j - center_y)**2)
                        val = 0.5 * np.exp(-dist/sigma) + 0.1 * random.random()
                        dummy_data[b, c, i, j] = val * (0.5 + channel_importance[c].item())

        # Initialize labels using final layer gradients
        last_layer_grad = target_gradients[-2]
        dummy_label = F.softmax(-torch.sum(last_layer_grad, dim=1) * 5, dim=0).unsqueeze(0)

        return dummy_data.clamp_(0, 1).requires_grad_(True), dummy_label.requires_grad_(True)

    def _compute_regularization_weights(self, iteration, max_iterations):
        """Compute annealed regularization weights."""
        progress = iteration / max_iterations
        return {
            'gradient': 1.0,
            'tv': 0.01 * (1 - 0.9 * progress),
            'l2': 0.0001 * (1 - 0.9 * progress),
            'perceptual': 0.05 * min(1.0, 2 * progress),
            'smoothness': 0.02 * min(1.0, 3 * progress),
            'bn_stats': 0.01 * min(1.0, 2 * progress),
        }

    def _gradient_matching_loss(self, dummy_gradients, target_gradients):
        """Compute weighted gradient matching loss."""
        loss = 0
        for i, (dummy_grad, target_grad) in enumerate(zip(dummy_gradients, target_gradients)):
            # Higher weight for early layers
            layer_weight = 10.0 if i < 2 else 5.0 if i < 4 else 1.0
            layer_loss = ((dummy_grad - target_grad) ** 2).sum() / dummy_grad.numel()
            loss += layer_weight * layer_loss
        return loss

    def _total_variation_loss(self, x):
        """Compute total variation loss for image smoothness."""
        tv_h = torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]).sum()
        tv_w = torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]).sum()
        return (tv_h + tv_w) / x.numel()

    def _smoothness_loss(self, x):
        """Encourage local smoothness in the image."""
        blurred = F.avg_pool2d(F.pad(x, (1, 1, 1, 1), mode='reflect'), 3, stride=1)
        return F.mse_loss(blurred, x)

    def reconstruct(self, target_gradients, original_image=None, num_iterations=1000):
        """
        Reconstruct image from gradients.

        Args:
            target_gradients: List of gradient tensors from the model
            original_image: Ground truth image (optional, for evaluation)
            num_iterations: Number of optimization iterations

        Returns:
            Reconstructed image and predicted label
        """
        # Initialize reconstruction
        dummy_data, dummy_label = self._initialize_dummy_data(target_gradients)

        # Setup optimizer
        optimizer = torch.optim.Adam([
            {'params': dummy_data, 'lr': 0.1},
            {'params': dummy_label, 'lr': 0.01}
        ])
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_iterations, eta_min=0.001)

        # Loss function
        criterion = nn.CrossEntropyLoss()

        # Store batch norm statistics
        bn_stats = self._extract_bn_stats()

        # Optimization loop
        history = {'images': [], 'iterations': [], 'losses': []}

        for iteration in tqdm(range(num_iterations), desc="Reconstructing"):
            optimizer.zero_grad()

            # Get regularization weights
            weights = self._compute_regularization_weights(iteration, num_iterations)

            # Forward pass
            outputs = self.model(dummy_data)
            loss = criterion(outputs, F.softmax(dummy_label, dim=1))

            # Compute gradients
            dummy_gradients = []
            for param in self.model.parameters():
                grad = torch.autograd.grad(loss, param, create_graph=True, retain_graph=True)[0]
                dummy_gradients.append(grad)

            # Compute losses
            grad_loss = self._gradient_matching_loss(dummy_gradients, target_gradients)
            tv_loss = self._total_variation_loss(dummy_data) if weights['tv'] > 0 else 0
            l2_loss = torch.norm(dummy_data) / dummy_data.numel()
            smooth_loss = self._smoothness_loss(dummy_data) if weights['smoothness'] > 0 else 0

            # Perceptual loss
            perceptual_loss = 0
            if weights['perceptual'] > 0 and iteration > 50:
                dummy_features = self.feature_extractor(dummy_data)
                if original_image is not None:
                    orig_features = self.feature_extractor(original_image)
                    perceptual_loss = F.mse_loss(dummy_features, orig_features)
                else:
                    perceptual_loss = -torch.mean(dummy_features)

            # Total loss
            total_loss = (
                weights['gradient'] * grad_loss +
                weights['tv'] * tv_loss +
                weights['l2'] * l2_loss +
                weights['perceptual'] * perceptual_loss +
                weights['smoothness'] * smooth_loss
            )

            # Optimize
            total_loss.backward()
            optimizer.step()
            scheduler.step()

            # Clamp values
            with torch.no_grad():
                dummy_data.clamp_(0, 1)

                # Periodic sharpening
                if iteration > 0 and iteration % 100 == 0 and iteration >= 300:
                    blurred = F.avg_pool2d(F.pad(dummy_data, (1, 1, 1, 1), mode='reflect'), 3, stride=1)
                    dummy_data += 0.3 * (dummy_data - blurred)
                    dummy_data.clamp_(0, 1)

            # Save progress
            if iteration in self.save_iterations:
                history['images'].append(dummy_data.clone().detach().cpu())
                history['iterations'].append(iteration)
                history['losses'].append(total_loss.item())

                if iteration % 100 == 0:
                    print(f"Iter {iteration}: Grad Loss={grad_loss:.6f}, Total Loss={total_loss:.6f}")

        return dummy_data.detach(), F.softmax(dummy_label, dim=1).detach(), history

    def _extract_bn_stats(self):
        """Extract batch normalization statistics."""
        stats = {'mean': [], 'var': []}
        for m in self.model.modules():
            if isinstance(m, nn.BatchNorm2d):
                stats['mean'].append(m.running_mean.clone())
                stats['var'].append(m.running_var.clone())
        return stats


class AttackEvaluator:
    """Evaluates reconstruction quality using various metrics."""

    @staticmethod
    def compute_metrics(original, reconstructed):
        """Compute reconstruction quality metrics."""
        mse = F.mse_loss(original, reconstructed).item()
        psnr = 10 * np.log10(1.0 / max(mse, 1e-10))
        ssim = AttackEvaluator._compute_ssim(original, reconstructed)

        return {
            'mse': mse,
            'psnr': psnr,
            'ssim': ssim
        }

    @staticmethod
    def _compute_ssim(img1, img2):
        """Compute Structural Similarity Index."""
        C1, C2 = 0.01**2, 0.03**2

        # Convert to grayscale
        img1_gray = 0.299 * img1[:, 0] + 0.587 * img1[:, 1] + 0.114 * img1[:, 2]
        img2_gray = 0.299 * img2[:, 0] + 0.587 * img2[:, 1] + 0.114 * img2[:, 2]

        # Compute SSIM
        mu1 = F.avg_pool2d(img1_gray, kernel_size=11, stride=1, padding=5)
        mu2 = F.avg_pool2d(img2_gray, kernel_size=11, stride=1, padding=5)

        mu1_sq = mu1**2
        mu2_sq = mu2**2
        mu1_mu2 = mu1 * mu2

        sigma1_sq = F.avg_pool2d(img1_gray**2, kernel_size=11, stride=1, padding=5) - mu1_sq
        sigma2_sq = F.avg_pool2d(img2_gray**2, kernel_size=11, stride=1, padding=5) - mu2_sq
        sigma12 = F.avg_pool2d(img1_gray * img2_gray, kernel_size=11, stride=1, padding=5) - mu1_mu2

        ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
        return torch.mean(ssim_map).item()


class ExperimentRunner:
    """Manages and runs gradient inversion experiments."""

    def __init__(self, results_dir=None):
        self.results_dir = results_dir or self._create_results_dir()
        self.model = SmoothCNN().to(DEVICE)
        self.feature_extractor = FeatureExtractor().to(DEVICE)
        self.inverter = GradientInverter(self.model, self.feature_extractor)
        self.evaluator = AttackEvaluator()

    def _create_results_dir(self):
        """Create timestamped results directory."""
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        base_dir = "/content/drive/MyDrive" if IN_COLAB else "."
        results_dir = f"{base_dir}/gradinversion_results_{timestamp}"
        os.makedirs(results_dir, exist_ok=True)
        return results_dir

    def pretrain_model(self, dataloader, epochs=5):
        """Pretrain model on dataset."""
        print("Pretraining model...")
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)

        self.model.train()
        for epoch in range(epochs):
            running_loss = 0.0
            correct = 0
            total = 0

            for i, (inputs, labels) in enumerate(dataloader):
                inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)

                optimizer.zero_grad()
                outputs = self.model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

                if i % 50 == 49:
                    print(f'Epoch {epoch+1}, Batch {i+1}: Loss={running_loss/50:.3f}, Acc={100.*correct/total:.2f}%')
                    running_loss = 0.0

    def extract_gradients(self, image, label):
        """Extract gradients for given image and label."""
        self.model.eval()
        criterion = nn.CrossEntropyLoss()

        outputs = self.model(image)
        loss = criterion(outputs, label)

        gradients = []
        for param in self.model.parameters():
            grad = torch.autograd.grad(loss, param, retain_graph=True)[0]
            gradients.append(grad)

        return gradients

    def run_attack(self, test_images, test_labels):
        """Run gradient inversion attack on test images."""
        results = []

        for i, (image, label) in enumerate(zip(test_images, test_labels)):
            print(f"\n=== Attacking Image {i+1}/{len(test_images)} ===")
            true_class = CIFAR10_CLASSES[label.item()]
            print(f"True class: {true_class}")

            # Extract gradients
            gradients = self.extract_gradients(image.unsqueeze(0), label.unsqueeze(0))

            # Reconstruct image
            reconstructed, pred_label, history = self.inverter.reconstruct(
                gradients, original_image=image.unsqueeze(0)
            )

            # Evaluate
            pred_class = CIFAR10_CLASSES[pred_label.argmax().item()]
            metrics = self.evaluator.compute_metrics(image.unsqueeze(0), reconstructed)

            print(f"Predicted class: {pred_class}")
            print(f"PSNR: {metrics['psnr']:.2f} dB, SSIM: {metrics['ssim']:.4f}")

            # Save results
            self._save_results(i, image, reconstructed, history, true_class, pred_class, metrics)

            results.append({
                'true_class': true_class,
                'pred_class': pred_class,
                'metrics': metrics
            })

        self._save_summary(results)
        return results

    def _save_results(self, idx, original, reconstructed, history, true_class, pred_class, metrics):
        """Save attack results and visualizations."""
        # Save comparison
        fig, axes = plt.subplots(1, 2, figsize=(10, 5))

        axes[0].imshow(original.cpu().permute(1, 2, 0).numpy())
        axes[0].set_title(f"Original\n{true_class}")
        axes[0].axis('off')

        axes[1].imshow(reconstructed[0].cpu().permute(1, 2, 0).numpy())
        axes[1].set_title(f"Reconstructed\n{pred_class}")
        axes[1].axis('off')

        plt.tight_layout()
        plt.savefig(f"{self.results_dir}/comparison_{idx+1}.png", dpi=150)
        plt.close()

        # Save reconstruction progress
        fig, axes = plt.subplots(2, 4, figsize=(16, 8))
        axes = axes.flatten()

        for i, (img, iter_num) in enumerate(zip(history['images'][:8], history['iterations'][:8])):
            axes[i].imshow(img[0].permute(1, 2, 0).numpy())
            axes[i].set_title(f"Iter {iter_num}")
            axes[i].axis('off')

        plt.tight_layout()
        plt.savefig(f"{self.results_dir}/progress_{idx+1}.png", dpi=150)
        plt.close()

    def _save_summary(self, results):
        """Save experiment summary."""
        with open(f"{self.results_dir}/summary.txt", "w") as f:
            f.write("GRADIENT INVERSION ATTACK RESULTS\n")
            f.write("=" * 50 + "\n\n")

            # Statistics
            correct = sum(1 for r in results if r['true_class'] == r['pred_class'])
            avg_psnr = np.mean([r['metrics']['psnr'] for r in results])
            avg_ssim = np.mean([r['metrics']['ssim'] for r in results])

            f.write(f"Total images: {len(results)}\n")
            f.write(f"Correct predictions: {correct}/{len(results)} ({100*correct/len(results):.1f}%)\n")
            f.write(f"Average PSNR: {avg_psnr:.2f} dB\n")
            f.write(f"Average SSIM: {avg_ssim:.4f}\n\n")

            # Detailed results
            f.write("Image | True Class | Pred Class | PSNR  | SSIM\n")
            f.write("-" * 50 + "\n")
            for i, r in enumerate(results):
                f.write(f"{i+1:5d} | {r['true_class']:10s} | {r['pred_class']:10s} | "
                       f"{r['metrics']['psnr']:5.2f} | {r['metrics']['ssim']:.4f}\n")


def main():
    """Main entry point for gradient inversion attack demonstration."""
    print("=== Gradient Inversion Attack Demo ===\n")

    # Load CIFAR-10 dataset
    transform = transforms.ToTensor()
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
    testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=True, num_workers=2)

    # Initialize experiment
    runner = ExperimentRunner()

    # Pretrain model
    runner.pretrain_model(trainloader)

    # Select test images
    num_images = int(input("Enter number of images to attack (1-5): "))
    num_images = max(1, min(5, num_images))

    test_images = []
    test_labels = []
    for i, (img, label) in enumerate(testloader):
        if i >= num_images:
            break
        test_images.append(img[0].to(DEVICE))
        test_labels.append(label[0].to(DEVICE))

    # Run attack
    results = runner.run_attack(test_images, test_labels)

    print(f"\nResults saved to: {runner.results_dir}")


if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        print(f"Error: {e}")
        traceback.print_exc()

Installing compatible PyTorch versions...
Mounted at /content/drive
Running in Google Colab
Using device: cpu
=== Gradient Inversion Attack Demo ===



100%|██████████| 170M/170M [00:02<00:00, 82.7MB/s]


Pretraining model...
Epoch 1, Batch 50: Loss=9.397, Acc=16.91%
Epoch 1, Batch 100: Loss=2.145, Acc=22.95%
Epoch 1, Batch 150: Loss=1.949, Acc=26.01%
Epoch 1, Batch 200: Loss=2.032, Acc=27.49%
Epoch 1, Batch 250: Loss=2.086, Acc=28.49%
Epoch 1, Batch 300: Loss=2.320, Acc=29.28%
Epoch 1, Batch 350: Loss=2.134, Acc=29.72%
Epoch 2, Batch 50: Loss=2.134, Acc=34.44%
Epoch 2, Batch 100: Loss=1.849, Acc=36.55%
Epoch 2, Batch 150: Loss=1.875, Acc=36.75%
Epoch 2, Batch 200: Loss=2.565, Acc=35.84%
Epoch 2, Batch 250: Loss=1.966, Acc=36.28%
Epoch 2, Batch 300: Loss=1.923, Acc=36.53%
Epoch 2, Batch 350: Loss=1.924, Acc=36.95%
Epoch 3, Batch 50: Loss=2.023, Acc=38.88%
Epoch 3, Batch 100: Loss=1.853, Acc=40.02%
Epoch 3, Batch 150: Loss=1.685, Acc=41.07%
Epoch 3, Batch 200: Loss=1.753, Acc=41.15%
Epoch 3, Batch 250: Loss=1.987, Acc=41.02%
Epoch 3, Batch 300: Loss=1.851, Acc=41.28%
Epoch 3, Batch 350: Loss=1.739, Acc=41.55%
Epoch 4, Batch 50: Loss=1.681, Acc=45.44%
Epoch 4, Batch 100: Loss=1.735, Acc=4

Reconstructing:   0%|          | 1/1000 [00:00<04:40,  3.56it/s]

Iter 0: Grad Loss=558.348938, Total Loss=558.349487


Reconstructing:  10%|█         | 101/1000 [00:23<03:31,  4.26it/s]

Iter 100: Grad Loss=1.876300, Total Loss=1.882051


Reconstructing:  20%|██        | 201/1000 [00:50<03:39,  3.64it/s]

Iter 200: Grad Loss=1.684151, Total Loss=1.689986


Reconstructing:  30%|███       | 301/1000 [01:15<03:54,  2.98it/s]

Iter 300: Grad Loss=0.648379, Total Loss=0.654390


Reconstructing:  50%|█████     | 501/1000 [02:07<01:56,  4.28it/s]

Iter 500: Grad Loss=1.118638, Total Loss=1.125685


Reconstructing:  70%|███████   | 701/1000 [02:59<01:11,  4.19it/s]

Iter 700: Grad Loss=0.914425, Total Loss=0.921304


Reconstructing: 100%|██████████| 1000/1000 [04:16<00:00,  3.90it/s]


Predicted class: bird
PSNR: 7.00 dB, SSIM: 0.3203

Results saved to: /content/drive/MyDrive/gradinversion_results_20250601_121307
