# MI-Face Attack on FedAVG

In [None]:
import random
import pickle

import numpy as np
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from mpi4py import MPI  # Using MPI requires MPI command to be installed

from aijack.attack.inversion import MI_FACE
from aijack.attack.inversion import MIFaceFedAVGClient
from aijack.collaborative.fedavg import FedAVGAPI, FedAVGClient, FedAVGServer


def evaluate_global_model(dataloader, client_id=-1):
    def _evaluate_global_model(api):
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in dataloader:
                data, target = data.to(api.device), target.to(api.device)
                if client_id == -1:
                    output = api.server(data)
                else:
                    output = api.clients[client_id](data)
                test_loss += F.nll_loss(
                    output, target, reduction="sum"
                ).item()  # sum up batch loss
                pred = output.argmax(
                    dim=1, keepdim=True
                )  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(dataloader.dataset)
        accuracy = 100.0 * correct / len(dataloader.dataset)
        print(f"Test set: Average loss: {test_loss}, Accuracy: {accuracy}")

    return _evaluate_global_model

In [None]:
training_batch_size = 64
test_batch_size = 64
num_rounds = 50
lr = 0.001
seed = 0
client_size = 3
criterion = F.nll_loss

In [None]:
def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True


def prepare_dataloader(num_clients, myid, train=True, path=""):
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )
    if train:
        dataset = datasets.MNIST(path, train=True, download=True, transform=transform)
        idxs = list(range(len(dataset.data)))
        random.shuffle(idxs)
        idx = np.array_split(idxs, num_clients, 0)[myid - 1]
        dataset.data = dataset.data[idx]
        dataset.targets = dataset.targets[idx]
        train_loader = torch.utils.data.DataLoader(
            dataset, batch_size=training_batch_size
        )
        return train_loader
    else:
        dataset = datasets.MNIST(path, train=False, download=True, transform=transform)
        test_loader = torch.utils.data.DataLoader(dataset, batch_size=test_batch_size)
        return test_loader


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.ln = nn.Linear(28 * 28, 10)

    def forward(self, x):
        x = self.ln(x.reshape(-1, 28 * 28))
        output = F.log_softmax(x, dim=1)
        return output

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fix_seed(seed)

In [None]:
local_dataloaders = [prepare_dataloader(client_size, c) for c in range(client_size)]
test_dataloader = prepare_dataloader(client_size, -1, train=False)

In [None]:
MI_OUTPUT_FN = "out/mi_face.pk"

In [None]:
malicious_client = MIFaceFedAVGClient(Net().to(device), user_id=0, device=device)
malicious_client.attach_mi_face(
    MI_OUTPUT_FN,
    start_epoch=10,
    num_atk=5,
    atk_interval=10,
    target_label=3,
    input_shape=(1, 1, 28, 28),
    gamma=0.9
)

clients = [malicious_client]
for c in range(1, client_size):
    clients.append(FedAVGClient(Net().to(device), user_id=c, device=device))
local_optimizers = [optim.SGD(client.parameters(), lr=lr) for client in clients]

server = FedAVGServer(clients, Net().to(device))

api = FedAVGAPI(
    server,
    clients,
    criterion,
    local_optimizers,
    local_dataloaders,
    num_communication=num_rounds,
    custom_action=evaluate_global_model(test_dataloader),
    device=device,
)

api.run()

In [None]:
with open(MI_OUTPUT_FN, "rb") as fin:
    mi_log = pickle.load(fin)

In [None]:
num_entries = len(mi_log)
fig, axes = plt.subplots(nrows=1, ncols=num_entries, figsize=(2 * num_entries, 2))
for i in range(num_entries):
    rec_im = torch.Tensor.cpu(mi_log[i].im[0][0].detach()).numpy()
    axes[i].imshow(rec_im, cmap='gray')
    axes[i].set_title(f"Attack at epoch {mi_log[i].epoch}\nCost {mi_log[i].c:.3f}")
    axes[i].axis('off')

# Find way to figure out what image in the training set this picture references
# Find way to guarantee that the attack will avoid images belonging to the dataset of the malicious client
#   Trivial solution would be to not have the malicious client engage in training at all.

In [None]:
miface = MI_FACE(
    malicious_client.model,
    lam=0.1,
    num_itr=1000,
    beta=100,
    apply_softmax=True,
    device=device,
    target_label=6,
    input_shape=(1, 1, 28, 28),
    gamma=0.1
)
im, log = miface.attack()
im2, log2 = miface.blackbox_attack()

In [None]:
plt.imshow(torch.Tensor.cpu(im).detach().numpy()[0][0], cmap='gray')
print(min(log))
print(len(log))



In [None]:
plt.imshow(torch.Tensor.cpu(im2).detach().numpy()[0][0], cmap='gray')
print(min(log2))
print(len(log2))

In [None]:
# Ignore the first dataloader because it belongs to malicious client
# We want to compare with images outside of our dataset
images_by_label = []
for i in range(1, len(local_dataloaders)):
    loader = local_dataloaders[i]
    for data, labels in local_dataloaders[i]:
        plt.imshow(data[0][0], cmap='gray')
        print(labels)
        break
    break
