In [None]:
!pip install flexible-fl opacus SciencePlots

In [None]:
import copy
import os
import math
import torch
from flex.data import Dataset, FedDataDistribution, FedDataset, FedDatasetConfig
from flex.model import FlexModel
from flex.pool import FlexPool, fed_avg
from flex.pool.decorators import (
    deploy_server_model,
    init_server_model,
    set_aggregated_weights,
    collect_clients_weights,
)
from typing import List
import numpy as np
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from opacus import PrivacyEngine
from opacus.validators import ModuleValidator
from scipy.optimize import linprog
import pandas as pd
import matplotlib as plt

# --- CONSTANTS ---
ROUNDS = 100
EPOCHS = 1
N_CLIENTS = 2
fixed_epsilon = 1.0
fixed_delta = 0.001
budget = 100.0
GENERATOR_EPOCHS = 100

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
def get_dataset():
    """
    Creation of the Federated MNIST datset
    """
    train_data = MNIST(root=".", train=True, download=True, transform=None)
    test_data = MNIST(root=".", train=False, download=True, transform=None)
    flex_dataset = Dataset.from_torchvision_dataset(train_data)
    test_data = Dataset.from_torchvision_dataset(test_data)
    assert isinstance(flex_dataset, Dataset)

    config = FedDatasetConfig(seed=0)
    config.replacement = False
    config.n_nodes = N_CLIENTS
    config.labels_per_node = [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]

    flex_dataset = FedDataDistribution.from_config(flex_dataset, config)

    assert isinstance(flex_dataset, FedDataset)
    flex_dataset["server"] = test_data

    return flex_dataset

# Create the dataset and transform object
flex_dataset = get_dataset()

mnist_transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Lambda(lambda x: 2 * x - 1)]
)

In [None]:
class CNNModel(nn.Module):
    def __init__(self, num_classes):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.fc1 = nn.Linear(1024, 200)
        self.fc2 = nn.Linear(200, num_classes)

    def forward(self, x):
        x = self.pool(torch.tanh(self.conv1(x)))
        x = self.pool(torch.tanh(self.conv2(x)))
        x = x.view(-1, 1024)
        x = torch.tanh(self.fc1(x))
        x = self.fc2(x)
        return x

def get_model(num_classes=11):
  """
  Creation of the model, will be passed to build_server_model(). Num classes must be 10 (digits) + 1 (fake image)
  """
  return ModuleValidator.fix(CNNModel(num_classes=num_classes))


class Generator(nn.Module):
    def __init__(self, input_dim=100, output_dim=1, img_size=28):
        super(Generator, self).__init__()
        self.init_size = img_size // 4
        self.l1 = nn.Sequential(nn.Linear(input_dim, 128 * self.init_size**2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, output_dim, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

In [None]:
# FLEX decorators
@init_server_model
def build_server_model():
    server_flex_model = FlexModel()
    server_flex_model["model"] = get_model()
    server_flex_model["criterion"] = torch.nn.CrossEntropyLoss()
    server_flex_model["optimizer_func"] = torch.optim.SGD
    server_flex_model["optimizer_kwargs"] = {
        "lr": 1e-3,
        "weight_decay": 1e-7,
        "momentum": 0,
    }
    return server_flex_model


@deploy_server_model
def copy_server_model_to_clients(server_flex_model: FlexModel):
    new_flex_model = FlexModel()
    new_flex_model["model"] = copy.deepcopy(server_flex_model["model"])
    new_flex_model["server_model"] = copy.deepcopy(server_flex_model["model"])
    new_flex_model["discriminator"] = copy.deepcopy(server_flex_model["model"])
    new_flex_model["criterion"] = copy.deepcopy(server_flex_model["criterion"])
    new_flex_model["optimizer_func"] = copy.deepcopy(
        server_flex_model["optimizer_func"]
    )
    new_flex_model["optimizer_kwargs"] = copy.deepcopy(
        server_flex_model["optimizer_kwargs"]
    )
    return new_flex_model

@set_aggregated_weights
def set_agreggated_weights_to_server(server_flex_model: FlexModel, aggregated_weights):
    dev = aggregated_weights[0].get_device()
    dev = "cpu" if dev == -1 else "cuda"
    with torch.no_grad():
        weight_dict = server_flex_model["model"].state_dict()
        for layer_key, new in zip(weight_dict, aggregated_weights):
            weight_dict[layer_key].copy_(weight_dict[layer_key].to(dev) + new)


@collect_clients_weights
def get_clients_weights(client_flex_model: FlexModel):
    weight_dict = client_flex_model["model"].state_dict()
    server_dict = client_flex_model["server_model"].state_dict()
    dev = [weight_dict[name] for name in weight_dict][0].get_device()
    dev = "cpu" if dev == -1 else "cuda"
    return [
        (weight_dict[name] - server_dict[name].to(dev)).type(torch.float)
        for name in weight_dict
    ]

In [None]:
def train(client_flex_model: FlexModel, client_data: Dataset):
    """
    Train the model on the client data.
    """
    model = client_flex_model["model"]
    criterion = client_flex_model["criterion"]
    model.train()
    model = model.to(device)
    torch_dataset = client_data.to_torchvision_dataset(transform=mnist_transforms)
    print(f"Client data: {len(torch_dataset)}")
    optimizer = client_flex_model["optimizer_func"](
        model.parameters(), **client_flex_model["optimizer_kwargs"]
    )
    dataloader = DataLoader(
        torch_dataset, batch_size=64, shuffle=True, pin_memory=False
    )

    for _ in range(EPOCHS):
        running_loss = 0.0
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

    return running_loss


def dp_train(client_flex_model: FlexModel, client_data: Dataset, epsilon_t: float, delta_t: float):
    """
    Train the model on the client data using differential privacy. To do so, Opacus offers a make_private_with_epsilon
    function to create a DP private model.

    :param epsilon_t: Target epsilon value
    :param delta_t: Target delta value
    """
    model = client_flex_model["model"]
    criterion = client_flex_model["criterion"]
    privacy_engine = PrivacyEngine()
    model.train()
    model = model.to(device)
    torch_dataset = client_data.to_torchvision_dataset(transform=mnist_transforms)
    optimizer = client_flex_model["optimizer_func"](
        model.parameters(), **client_flex_model["optimizer_kwargs"]
    )
    dataloader = DataLoader(
        torch_dataset, batch_size=64, shuffle=True, pin_memory=False
    )
    model, optimizer, dataloader = privacy_engine.make_private_with_epsilon(
        module=model,
        optimizer=optimizer,
        data_loader=dataloader,
        target_epsilon=epsilon_t,
        target_delta=delta_t,
        epochs=EPOCHS,
        max_grad_norm=1.0,
        grad_sample_mode="hooks",
    )
    for _ in range(EPOCHS):
        running_loss = 0.0
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

    return running_loss

In [None]:
def evaluate_model(server_flex_model: FlexModel, data):
    """
    Evaluate the model on the server data.
    """
    data = flex_dataset["server"]
    model = server_flex_model["model"]
    model.eval()
    test_loss = 0
    test_acc = 0
    total_count = 0
    model = model.to(device)
    criterion = server_flex_model["criterion"]

    test_dataset = data.to_torchvision_dataset(transform=mnist_transforms)
    test_dataloader = DataLoader(
        test_dataset, batch_size=256, shuffle=True, pin_memory=False
    )
    losses = []
    with torch.no_grad():
        for data, target in test_dataloader:
            total_count += target.size(0)
            data, target = data.to(device), target.to(device)
            output = model(data)
            losses.append(criterion(output, target).item())
            pred = output.data.max(1, keepdim=True)[1]
            test_acc += pred.eq(target.data.view_as(pred)).long().cpu().sum().item()

    test_loss = sum(losses) / len(losses)
    test_acc /= total_count
    return test_loss, test_acc

In [None]:
def train_benign(pool: FlexPool, epsilon_t: float, delta_t: float):
    """
    Train the model on the server data.
    """
    benign_client = pool.select(lambda id, role: id == 0)
    pool.servers.map(copy_server_model_to_clients, benign_client)
    if epsilon_t == 0:
      losses = benign_client.map(train)
    else:
      losses = benign_client.map(dp_train, epsilon_t=epsilon_t, delta_t=delta_t)
    return losses

# Labels must be the same type
class TensorLabelDataset(torch.utils.data.Dataset):
    def __init__(self, wrapped_dataset):
        self.wrapped_dataset = wrapped_dataset

    def __len__(self):
        return len(self.wrapped_dataset)

    def __getitem__(self, idx):
        data, label = self.wrapped_dataset[idx]
        label = torch.tensor(label, dtype=torch.long)
        return data, label


def merge_dataset(client_data: Dataset, fake_images: torch.Tensor, fake_labels: torch.Tensor):
    """
    Merge the client data with the fake images.
    """
    dataset = TensorLabelDataset(
        client_data.to_torchvision_dataset(transform=mnist_transforms)
    )
    fake_images = fake_images.detach().cpu()
    fake_dataset = torch.utils.data.TensorDataset(fake_images, fake_labels)
    train_dataset = torch.utils.data.ConcatDataset([dataset, fake_dataset])
    return train_dataset

In [None]:
def optimize_gan(flex_model: FlexModel, client_data: Dataset, label: int = 0):
    """
    Train the GAN on the client data.
    """
    if "generator" not in flex_model:
        flex_model["generator"] = Generator()
    discriminator = flex_model["discriminator"].to(device)
    generator = flex_model["generator"].to(device)
    criterion = flex_model["criterion"]

    generator_optimizer = torch.optim.SGD(
        generator.parameters(), lr=0.02, weight_decay=1e-5
    )

    generator.train()
    discriminator.train()
    generator_loss = torch.tensor(float("inf"))

    for _ in range(GENERATOR_EPOCHS):
        generator_optimizer.zero_grad()
        noise = torch.randn(128, 100, device=device)
        fake_images = generator(noise)
        outputs = discriminator(fake_images)
        fake_labels = torch.full((128,), label, dtype=torch.long, device=device)
        generator_loss = criterion(outputs, fake_labels)
        generator_loss.backward()
        generator_optimizer.step()

    return generator_loss


def train_with_fake_images(client_flex_model: FlexModel, client_data: Dataset):
    """
    Train the model on the client data with the fake images.
    """
    size = len(client_data) // 4
    fake_images = client_flex_model["generator"].to(device)(
        torch.randn(size, 100, device=device)
    )
    fake_labels = torch.full((size,), 10, dtype=torch.long, device="cpu")
    train_dataset = merge_dataset(client_data, fake_images, fake_labels)

    dataloader = DataLoader(
        train_dataset, batch_size=64, shuffle=True, pin_memory=False
    )
    model = client_flex_model["model"]
    criterion = client_flex_model["criterion"]
    model.train()
    model = model.to(device)
    optimizer = client_flex_model["optimizer_func"](
        model.parameters(), **client_flex_model["optimizer_kwargs"]
    )
    for _ in range(EPOCHS):
        running_loss = 0.0
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

    return running_loss


def train_malicious(pool: FlexPool):
    """
    Train the model on the server data.
    """
    malicious_client = pool.select(lambda id, role: id == 1)
    pool.servers.map(copy_server_model_to_clients, malicious_client)
    gen_loss = malicious_client.map(optimize_gan, label=0)[0]
    loss = malicious_client.map(train_with_fake_images)
    return loss


def extract_fake_image(pool: FlexPool, i: int):
    """
    Extract the fake image from the server.
    """
    malicious_client = pool.select(lambda id, role: id == 1)
    noise = torch.randn(1, 100, device=device)
    fake_image = malicious_client.map(
        lambda flex_model, _: flex_model["generator"].to(device)(noise)
    )[0]
    fake_image = fake_image[0].detach().cpu().numpy()
    plt.imsave(f"images/fake_image_{i}.png", fake_image.squeeze(), cmap="gray")
    return get_fake_image_min_distance(fake_image)

In [None]:
def get_fake_image_min_distance(fake_image):
    """
    Get the minimum distance between the fake image and the training data. Will serve as an attack success metric.
    """
    global flex_dataset
    train_dataset = flex_dataset["server"].to_torchvision_dataset(transform=mnist_transforms)
    min_distance = float('inf')
    for image, label in train_dataset:
      if label == 0:
        image = image.numpy()
        distance = np.linalg.norm(fake_image - image)
        if distance < min_distance:
            min_distance = distance
    return min_distance

In [None]:
def run_attack_optimize_DP(pool: FlexPool):
    """
    Run the attack and optimize the epsilon and delta parameters. It saves the metrics in a csv file.
    """
    malicious_client = pool.select(lambda id, role: id == 1)
    benign_client = pool.select(lambda id, role: id == 0)
    epsilon_used = 0
    losses = []
    accuracies=[]
    epsilon_cummulative = []
    min_distances = []

    for i in range(ROUNDS):

        print(f"\n - Round {i+1}: Training with ε={fixed_epsilon:.3f}, δ={fixed_delta:.5f}")
        loss_t = train_benign(pool, fixed_epsilon, fixed_delta)
        epsilon_used += fixed_epsilon
        losses.append(loss_t[0])

        pool.servers.map(get_clients_weights, benign_client)
        pool.servers.map(fed_avg)
        pool.servers.map(set_agreggated_weights_to_server, pool.servers)
        round_metrics = pool.servers.map(evaluate_model)
        accuracies.append(round_metrics[0][1]*100)
        print(" * Round metrics: ", round_metrics)
        malicious_loss = train_malicious(pool)
        pool.servers.map(get_clients_weights, malicious_client)
        pool.servers.map(fed_avg)
        pool.servers.map(set_agreggated_weights_to_server, pool.servers)
        min_dist = extract_fake_image(pool, i)
        print("Fake image min distance ", min_dist)
        min_distances.append(min_dist)
        epsilon_cummulative.append(epsilon_used)
        print(f"Epsilon used: {epsilon_used} \n")

        with open("experimento_DP_Static.txt", "a") as archivo:
            archivo.write(f"\n - Round {i+1}: Training with ε={fixed_epsilon:.3f}, δ={fixed_delta:.5f}\n")
            archivo.write(f"Round metrics: {round_metrics}\n")
            archivo.write(f"Epsilon used: {epsilon_used}\n")
            archivo.write("-" * 30 + "\n")

    df_metrics = pd.DataFrame({
    'Round': list(range(1, ROUNDS + 1)),
    'Accuracy (%)': accuracies,
    'Epsilon Acumulado': epsilon_cummulative,
    'Distancia Minima': min_distances
    })
    df_metrics.to_csv('metricas_DP_Static.csv', index=False)

In [None]:
if __name__ == "__main__":
    pool = FlexPool.client_server_pool(
        fed_dataset=flex_dataset, init_func=build_server_model
    )
    run_attack_optimize_DP(pool)
    im_attack_recovered = plt.imread(f'images/fake_image_{ROUNDS-1}.png')
    plt.imshow(im_attack_recovered, cmap="gray")
    plt.savefig('recovered_static.png', dpi=300)
    plt.show()