
## Yin Attack (See-Through-Gradients) on CIFAR-100

**Objective:** Reconstruct training data from gradients captured during federated learning simulation.

**Attack Configuration:**
- Attack: Yin (See-Through-Gradients)
- Target: FL gradients from ResNet18 on CIFAR-100
- Batch size: (IDK which one works the best)
- Device: GPU (CUDA)


## 1. Setup and Imports


In [1]:
# Import required libraries
import torch
import torchvision
import breaching
import torchvision.models as models
from pathlib import Path

## 2. Configure Breaching Attack

In [2]:
# Configure Breaching for Yin (See-Through-Gradients) attack
cfg = breaching.get_config(overrides=["attack=seethroughgradients"])
print(f"Attack type:", {cfg.attack.type})
cfg.attack.regularization = {}

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  with hydra.initialize(config_path="config"):


Investigating use case single_imagenet with server type honest_but_curious.
Attack type: {'see-through-gradients'}


## 3. Load FL Gradients

In [3]:
# Load saved gradients from FL simulation
gradient_dir = Path("/scratch/project_2015432/Sec_FL_ritesh/src/fl_simulation")
gradient_files = sorted(gradient_dir.glob("fedavg_metrics_*_tensors.pt"))
grad_file = gradient_files[0]
gradient_data = torch.load(grad_file, map_location='cpu')
print("Loaded")

Loaded


## 4. Load Model Architecture

In [4]:
model_state = gradient_data['global_model_state']
model = models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, 100)  # CIFAR-100 has 100 classes
model.load_state_dict(model_state, strict=False)
model.eval()
loss_fn = torch.nn.CrossEntropyLoss()
setup = dict(device=torch.device('cuda'), dtype=torch.float32)
attacker = breaching.attacks.prepare_attack(model, loss_fn, cfg.attack, setup)
print(" Model loaded")



 Model loaded


## 5. Extract Client Gradients

In [5]:
# Get gradients from first client, first training step
client_id = list(gradient_data['raw_gradients'].keys())[0]
client_grads = gradient_data['raw_gradients'][client_id]
grad_dict = client_grads['grads_per_step_raw'][0]
gradients = [grad_dict[key] for key in sorted(grad_dict.keys())]
print(f"Client: {client_id}")
print(f"Number of gradient tensors: {len(gradients)}")

Client: 9
Number of gradient tensors: 62


## 6. Prepare Data Configuration

In [6]:
# Define CIFAR-100 metadata and separate model parameters/buffers
class DataConfig:
    modality = "vision"
    size = (50_000,)
    classes = 100
    shape = (3, 32, 32)
    normalize = True
    mean = (0.5071, 0.4867, 0.4408)
    std = (0.2675, 0.2565, 0.2761)

data_cfg = DataConfig()

parameters = []
buffers = []
for name, tensor in gradient_data['global_model_state'].items():
    if 'running' in name or 'num_batches' in name:
        buffers.append(tensor.cuda())
    else:
        parameters.append(tensor.cuda())

print(f"Parameters: {len(parameters)}")
print(f"Buffers: {len(buffers)}")

# Format data for Breaching API (server's view)
server_payload = [dict(
    parameters=parameters,
    buffers=buffers,
    metadata=data_cfg
)]

# Format data for Breaching API (attacker's view)
shared_data = [dict(
    gradients=gradients,
    buffers=None,
    metadata=dict(
        num_data_points=1,
        labels=torch.tensor([0]).cuda(),  # Dummy label
        local_hyperparams=None
    )
)]

print(" Data prepared with dummy label")

Parameters: 62
Buffers: 60
 Data prepared with dummy label


## 7. Reorder Gradients to Match Model


In [7]:
# Reorder gradients to match model parameter order (critical fix)
grad_dict = client_grads['grads_per_step_raw'][0]

gradients_ordered = []
for name, param in model.named_parameters():
    if name in grad_dict:
        gradients_ordered.append(grad_dict[name].cuda())
    else:
        print(f"Warning: {name} not found in gradients")

print(f" Reordered {len(gradients_ordered)} gradients to match model")

# Verify shapes match
print("\nVerifying alignment:")
for i, (g, p) in enumerate(zip(gradients_ordered, model.parameters())):
    if g.shape != p.shape:
        print(f"  Still mismatch at {i}: {g.shape} vs {p.shape}")
        break
else:
    print("All shapes match!")

# gradients and label
shared_data = [
    dict(
        gradients=gradients_ordered,
        buffers=None,
        metadata=dict(
            num_data_points=1,
            labels=torch.tensor([0]).cuda(),  # Top-level labels
            local_hyperparams=dict(
                lr=0.1,
                momentum=0.9,
                weight_decay=0.0,
                steps=1,
                data_per_step=32,
                labels=[torch.tensor([0]).cuda()],  # Labels list for each step
            )
        )
    )
]
print("shared_data updated with correctly ordered gradients")

 Reordered 62 gradients to match model

Verifying alignment:
All shapes match!
shared_data updated with correctly ordered gradients


## 8. Execute Gradient Inversion Attack
**Warning:** This may take several minutes on GPU.

In [8]:
print("Running Yin attack...")
print("This may take several minutes...")

reconstructed_user_data, stats = attacker.reconstruct(
    server_payload,
    shared_data,
    server_secrets={},
    dryrun=False  
)

print("\n Attack completed!")
print(f"Reconstructed keys: {list(reconstructed_user_data.keys())}")

Running Yin attack...
This may take several minutes...

 Attack completed!
Reconstructed keys: ['data', 'labels']


## 9. Analyze Results

In [9]:
# Check what was reconstructed
print("Reconstructed data shape:", reconstructed_user_data['data'].shape)
print("Reconstructed labels:", reconstructed_user_data['labels'])

# Save the reconstructed image
reconstructed_img = reconstructed_user_data['data']

# Save as image file
torchvision.utils.save_image(reconstructed_img, 'reconstructed_attack.png')
print("Saved reconstructed image!")

# Also show some stats
print(f"\nImage stats:")
print(f"  Min pixel value: {reconstructed_img.min()}")
print(f"  Max pixel value: {reconstructed_img.max()}")
print(f"  Mean: {reconstructed_img.mean()}")

Reconstructed data shape: torch.Size([1, 3, 32, 32])
Reconstructed labels: tensor([0], device='cuda:0')
Saved reconstructed image!

Image stats:
  Min pixel value: -1.8974658250808716
  Max pixel value: 2.025352954864502
  Mean: 0.037478990852832794
