In [None]:
import torch
import torchvision
from collections import namedtuple

In [None]:
from attacks.analytic_attack import ImprintAttacker
from modifications.imprint import ImprintBlock
from utils.breaching_utils import *

# Attack begins here:

### Initialize your model

In [None]:
setup = dict(device=torch.device("cpu"), dtype=torch.float)

# This could be any model:
model = torchvision.models.resnet18()
model.eval()
loss_fn = torch.nn.CrossEntropyLoss()
# It will be modified maliciously:
input_dim = data_cfg_default.shape[0] * data_cfg_default.shape[1] * data_cfg_default.shape[2]
num_bins = 100 # Here we define number of imprint bins
block = ImprintBlock(input_dim, num_bins=num_bins)
model = torch.nn.Sequential(
    torch.nn.Flatten(), block, torch.nn.Unflatten(dim=1, unflattened_size=data_cfg_default.shape), model
)
secret = dict(weight_idx=0, bias_idx=1, shape=tuple(data_cfg_default.shape), structure=block.structure)
secrets = {"ImprintBlock": secret}

### And your dataset (ImageNet by default)

In [None]:
transforms = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize(256),
        torchvision.transforms.CenterCrop(224),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=data_cfg_default.mean, std=data_cfg_default.std),
    ]
)
dataset = torchvision.datasets.ImageNet(root="~/data/", split="val", transform=transforms)
batch_size = 64 # Number of images in the user's batch. We have a small one here for visualization purposes
import random
random.seed(123) # You can change this to get a new batch. 
samples = [dataset[i] for i in random.sample(range(len(dataset)), batch_size)]
data = torch.stack([sample[0] for sample in samples])
labels = torch.tensor([sample[1] for sample in samples])

# This is the attacker:
attacker = ImprintAttacker(model, loss_fn, attack_cfg_default, setup)

### Simulate an attacked FL protocol

In [None]:
# Server-side computation:
queries = [dict(parameters=[p for p in model.parameters()], buffers=[b for b in model.buffers()])]
server_payload = dict(queries=queries, data=data_cfg_default)
# User-side computation:
loss = loss_fn(model(data), labels)
shared_data = dict(
    gradients=[torch.autograd.grad(loss, model.parameters())],
    buffers=None,
    num_data_points=1,
    labels=labels,
    local_hyperparams=None,
)

### Reconstruct data from the update

In [None]:
# Attack:
reconstructed_user_data, stats = attacker.reconstruct(server_payload, shared_data, secrets, dryrun=False)

In [None]:
# Metrics?: 
from utils.analysis import report
true_user_data = {'data': data, 'labels': labels}
metrics = report(reconstructed_user_data,
    true_user_data,
    server_payload,
    model, compute_ssim=False) # Can change to true and install a package...
print(f"MSE: {metrics['mse']}, PSNR: {metrics['psnr']}, LPIPS: {metrics['lpips']}")

### Plot ground-truth data

In [None]:
plot_data(data_cfg_default, true_user_data, setup)

### Now plot reconstructed data

In [None]:
plot_data(data_cfg_default, reconstructed_user_data, setup)