## Yin Attack (See-Through-Gradients) on CIFAR-100
**Objective:** Reconstruct training data from gradients captured during federated learning simulation.
**Attack Configuration:**
- Target: FL gradients from ResNet18 on CIFAR-100
- Batch size: 1
- Device: GPU (CUDA)

## 1. Setup and Imports


In [None]:
# Import required libraries
import torch
import torchvision
import breaching
import torchvision.models as models
from pathlib import Path

## 2. Configure Breaching Attack

In [None]:
# Configure Breaching for Yin (See-Through-Gradients) attack
cfg = breaching.get_config(overrides=["attack=invertinggradients"])
# cfg = breaching.get_config(overrides = ["attack=yin"]=
print(f"Attack type:", {cfg.attack.type})
cfg.attack.regularization = {}

## 3. Load FL Gradients

In [None]:
# Load saved gradients from FL simulation
gradient_dir = Path("/scratch/project_2015432/Sec_FL_ritesh/src/fl_simulation/reports/fedavg_baseline/round_00")
gradient_files = sorted(gradient_dir.glob("fedavg_metrics_*_tensors.pt"))
grad_file = gradient_files[0]
gradient_data = torch.load(grad_file, map_location='cpu')

print(f"Loaded: {grad_file.name}")
print(f"Keys: {list(gradient_data.keys())}")
print(f"Clients: {list(gradient_data['raw_gradients'].keys())}")

# Check first client
client_id = list(gradient_data['raw_gradients'].keys())[0]
client_data = gradient_data['raw_gradients'][client_id]
num_steps = len(client_data['grads_per_step_raw'])

print(f"Client {client_id}: {num_steps} gradient steps")
print(f"Parameters: {len(client_data['grads_per_step_raw'][0])}")

## 4. Load Model Architecture

In [53]:
model_state = gradient_data['global_model_state']
model = models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, 100)  # CIFAR-100 has 100 classes
model.load_state_dict(model_state, strict=False)
model = model.cuda()
model.eval()
loss_fn = torch.nn.CrossEntropyLoss()
setup = dict(device=torch.device('cuda'), dtype=torch.float32)
attacker = breaching.attacks.prepare_attack(model, loss_fn, cfg.attack, setup)

print("Model is loaded")

Model is loaded


## 5. Extract Client Gradients

In [54]:
# Get gradients from first client, first training step
client_id = list(gradient_data['raw_gradients'].keys())[0]
client_grads = gradient_data['raw_gradients'][client_id]
grad_dict = client_grads['grads_per_step_raw'][0]

print(f"Client: {client_id}")
print(f"Number of gradient tensors: {len(grad_dict)}")

Client: 1
Number of gradient tensors: 62


## 6. Prepare Data Configuration

In [55]:
# Define CIFAR-100 metadata and separate model parameters/buffers
class DataConfig:
    modality = "vision"
    size = (50_000,)
    classes = 100
    shape = (3, 32, 32)
    normalize = True
    mean = (0.5071, 0.4867, 0.4408)
    std = (0.2675, 0.2565, 0.2761)

data_cfg = DataConfig()

parameters = []
buffers = []
for name, tensor in gradient_data['global_model_state'].items():
    if 'running' in name or 'num_batches' in name:
        buffers.append(tensor.cuda())
    else:
        parameters.append(tensor.cuda())

# Format data for Breaching API (server's view)
server_payload = [dict(
    parameters=parameters,
    buffers=buffers,
    metadata=data_cfg
)]

# Format data for Breaching API (attacker's view)
shared_data = [dict(
    gradients=gradients_ordered,  # use gradients_ordered 
    buffers=None,
    metadata=dict(
        num_data_points=1,
        labels=torch.tensor([0]).cuda(),
        local_hyperparams=dict(  # add hyperparams to match FL
            lr=0.01,
            momentum=0.0,
            weight_decay=0.0,
            steps=1,
            data_per_step=1,
            labels=[torch.tensor([0]).cuda()],
        )
    )
)]

print(f"Parameters: {len(parameters)}")
print(f"Buffers: {len(buffers)}")
print("Data prepared with correct gradients and hyperparams")

Parameters: 62
Buffers: 60
Data prepared with correct gradients and hyperparams


## 7. Reorder Gradients to Match Model


In [56]:
# Reorder gradients to match model parameter order 
grad_dict = client_grads['grads_per_step_raw'][0]

gradients_ordered = []
for name, param in model.named_parameters():
    if name in grad_dict:
        gradients_ordered.append(grad_dict[name].cuda())
    else:
        print(f"Warning: {name} not found in gradients")

print(f"Reordered {len(gradients_ordered)} gradients to match model")

# Verify shapes match
print("\nVerifying alignment:")
for i, (g, p) in enumerate(zip(gradients_ordered, model.parameters())):
    if g.shape != p.shape:
        print(f"Still mismatch at {i}: {g.shape} vs {p.shape}")
        break
else:
    print("All shapes match!")

# gradients and label - MATCH YOUR FL CONFIG
shared_data = [
    dict(
        gradients=gradients_ordered,
        buffers=None,
        metadata=dict(
            num_data_points=1,
            labels=torch.tensor([0]).cuda(),
            local_hyperparams=dict(
                lr=0.01,          #  matches FL config
                momentum=0.0,     # matches your FL config  
                weight_decay=0.0,
                steps=1,
                data_per_step=1,  # matches batch_size=1
                labels=[torch.tensor([0]).cuda()],
            )
        )
    )
]

print("shared_data updated with correctly ordered gradients")

Reordered 62 gradients to match model

Verifying alignment:
All shapes match!
shared_data updated with correctly ordered gradients


## 8. Execute Gradient Inversion Attack
**Warning:** This may take several minutes on GPU.

In [None]:
print("Running Yin attack...")
print("This may take several minutes...")

reconstructed_user_data, stats = attacker.reconstruct(
    server_payload,
    shared_data,
    server_secrets={},
    dryrun=False  
)

print("\n Attack completed!")
print(f"Reconstructed keys: {list(reconstructed_user_data.keys())}")

Running Yin attack...
This may take several minutes...


## 9. Analyze Results

In [None]:
# Check what was reconstructed
print("Reconstructed data shape:", reconstructed_user_data['data'].shape)
print("Reconstructed labels:", reconstructed_user_data['labels'])

# Save the reconstructed image
reconstructed_img = reconstructed_user_data['data']

# Save as image file
torchvision.utils.save_image(reconstructed_img, 'reconstructed_attack.png')
print("Saved reconstructed image!")

# Also show some stats
print(f"\nImage stats:")
print(f"Min pixel value: {reconstructed_img.min()}")
print(f"Max pixel value: {reconstructed_img.max()}")
print(f"Mean: {reconstructed_img.mean()}")

In [None]:
# 1. Verify your FL config was actually used
print("Checking gradient metadata...")
print(f"Number of steps stored: {num_steps}")  # Should be 1
print(f"Gradient dict keys sample: {list(grad_dict.keys())[:5]}")

# 2. Check if gradients are actually from batch_size=1
first_grad = list(grad_dict.values())[0]
print(f"First gradient shape: {first_grad.shape}")

# 3. Verify the hyperparameters you passed to the attack
print("\nHyperparameters passed to attack:")
print(shared_data[0]['metadata']['local_hyperparams'])

# 4. Check if model loaded correctly
print(f"\nModel parameters: {sum(p.numel() for p in model.parameters())}")
print(f"Model on device: {next(model.parameters()).device}")

In [None]:
# Add this check right before the attack
print("Model device:", next(model.parameters()).device)
print("Gradients device:", gradients_ordered[0].device)
print("Setup device:", setup['device'])