CUDA cache cleared. Rerun your attack now.


In [1]:
# Import required libraries
import torch
import torchvision
import breaching
import torchvision.models as models
from pathlib import Path

In [2]:
# Configure Breaching for Yin (See-Through-Gradients) attack
cfg = breaching.get_config(overrides=["attack=seethroughgradients"])
print(f"Attack type:", {cfg.attack.type})
cfg.attack.regularization = {}

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  with hydra.initialize(config_path="config"):


Investigating use case single_imagenet with server type honest_but_curious.
Attack type: {'see-through-gradients'}


In [3]:
cfg.attack.optim.max_iterations = 5000  

In [4]:
# Load saved gradients from FL simulation
gradient_dir = Path("/scratch/project_2015432/Sec_FL_ritesh/src/fl_simulation")
gradient_files = sorted(gradient_dir.glob("fedavg_metrics_*_tensors.pt"))
grad_file = gradient_files[0]
gradient_data = torch.load(grad_file, map_location='cuda:2')
print("Loaded")

OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 31.74 GiB of which 7.62 MiB is free. Including non-PyTorch memory, this process has 31.73 GiB memory in use. Of the allocated memory 29.89 GiB is allocated by PyTorch, and 1.54 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
import torch
torch.cuda.empty_cache()

In [None]:
model_state = gradient_data['global_model_state']
model = models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, 100)  # CIFAR-100 has 100 classes
model.load_state_dict(model_state, strict=False)
model.eval()
loss_fn = torch.nn.CrossEntropyLoss()
setup = dict(device=torch.device('cpu'), dtype=torch.float32)
attacker = breaching.attacks.prepare_attack(model, loss_fn, cfg.attack, setup)
print(" Model loaded")

In [None]:
# Get gradients from first client, first training step
client_id = list(gradient_data['raw_gradients'].keys())[0]
client_grads = gradient_data['raw_gradients'][client_id]
grad_dict = client_grads['grads_per_step_raw'][0]
gradients = [grad_dict[key] for key in sorted(grad_dict.keys())]
print(f"Client: {client_id}")
print(f"Number of gradient tensors: {len(gradients)}")

In [None]:
# Define CIFAR-100 metadata and separate model parameters/buffers
class DataConfig:
    modality = "vision"
    size = (50_000,)
    classes = 100
    shape = (3, 32, 32)
    normalize = True
    mean = (0.5071, 0.4867, 0.4408)
    std = (0.2675, 0.2565, 0.2761)

data_cfg = DataConfig()

parameters = []
buffers = []
for name, tensor in gradient_data['global_model_state'].items():
    if 'running' in name or 'num_batches' in name:
        buffers.append(tensor)
    else:
        parameters.append(tensor)

print(f"Parameters: {len(parameters)}")
print(f"Buffers: {len(buffers)}")

# Format data for Breaching API (server's view)
server_payload = [dict(
    parameters=parameters,
    buffers=buffers,
    metadata=data_cfg
)]

# Format data for Breaching API (attacker's view)
shared_data = [dict(
    gradients=gradients,
    buffers=None,
    metadata=dict(
        num_data_points=1,
        labels=torch.tensor([0]),  # Dummy label
        local_hyperparams=None
    )
)]

print(" Data prepared with dummy label")

In [None]:
# Reorder gradients to match model parameter order (critical fix)
grad_dict = client_grads['grads_per_step_raw'][0]

gradients_ordered = []
for name, param in model.named_parameters():
    if name in grad_dict:
        gradients_ordered.append(grad_dict[name])
    else:
        print(f"Warning: {name} not found in gradients")

print(f" Reordered {len(gradients_ordered)} gradients to match model")

# Verify shapes match
print("\nVerifying alignment:")
for i, (g, p) in enumerate(zip(gradients_ordered, model.parameters())):
    if g.shape != p.shape:
        print(f"  Still mismatch at {i}: {g.shape} vs {p.shape}")
        break
else:
    print("All shapes match!")

# gradients and label
shared_data = [
    dict(
        gradients=gradients_ordered,
        buffers=None,
        metadata=dict(
            num_data_points=1,
            labels=torch.tensor([0]),  # Top-level labels
            local_hyperparams=dict(
                lr=0.1,
                momentum=0.9,
                weight_decay=0.0,
                steps=1,
                data_per_step=32,
                labels=[torch.tensor([0])]  # Labels list for each step
            )
        )
    )
]
print("shared_data updated with correctly ordered gradients")

In [None]:
# Execute gradient inversion attack (takes 5-15 minutes)
print("Running Yin attack...")
print("This may take several minutes...")

reconstructed_user_data, stats = attacker.reconstruct(
    server_payload,
    shared_data,
    server_secrets={},
    dryrun=False
)

print("\n Attack completed!")
print(f"Reconstructed keys: {list(reconstructed_user_data.keys())}")