# MI-Face Attack on FedAVG

In [1]:
import sys
import os
from dotenv import load_dotenv

load_dotenv()
print(os.getenv('PYTHONPATH'))
sys.path.append(os.getenv('PYTHONPATH'))

C:/Users/nbui2/Documents/GitHub/AIJack/src


In [1]:
import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from mpi4py import MPI  # Using MPI requires MPI command to be installed
from torchvision import datasets, transforms

from aijack.attack.inversion import MIFaceFedAVGClient
from aijack.collaborative.fedavg import FedAVGAPI, FedAVGClient, FedAVGServer


def evaluate_gloal_model(dataloader, client_id=-1):
    def _evaluate_global_model(api):
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in dataloader:
                data, target = data.to(api.device), target.to(api.device)
                if client_id == -1:
                    output = api.server(data)
                else:
                    output = api.clients[client_id](data)
                test_loss += F.nll_loss(
                    output, target, reduction="sum"
                ).item()  # sum up batch loss
                pred = output.argmax(
                    dim=1, keepdim=True
                )  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(dataloader.dataset)
        accuracy = 100.0 * correct / len(dataloader.dataset)
        print(f"Test set: Average loss: {test_loss}, Accuracy: {accuracy}")

    return _evaluate_global_model

In [2]:
training_batch_size = 64
test_batch_size = 64
num_rounds = 5
lr = 0.001
seed = 0
client_size = 3
criterion = F.nll_loss

In [4]:
def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True


def prepare_dataloader(num_clients, myid, train=True, path=""):
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )
    if train:
        dataset = datasets.MNIST(path, train=True, download=True, transform=transform)
        idxs = list(range(len(dataset.data)))
        random.shuffle(idxs)
        idx = np.array_split(idxs, num_clients, 0)[myid - 1]
        dataset.data = dataset.data[idx]
        dataset.targets = dataset.targets[idx]
        train_loader = torch.utils.data.DataLoader(
            dataset, batch_size=training_batch_size
        )
        return train_loader
    else:
        dataset = datasets.MNIST(path, train=False, download=True, transform=transform)
        test_loader = torch.utils.data.DataLoader(dataset, batch_size=test_batch_size)
        return test_loader


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.ln = nn.Linear(28 * 28, 10)

    def forward(self, x):
        x = self.ln(x.reshape(-1, 28 * 28))
        output = F.log_softmax(x, dim=1)
        return output

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fix_seed(seed)

In [6]:
%%capture

local_dataloaders = [prepare_dataloader(client_size, c) for c in range(client_size)]
test_dataloader = prepare_dataloader(client_size, -1, train=False)

In [7]:
clients = [FedAVGClient(Net().to(device), user_id=c, device=device) for c in range(client_size)]
local_optimizers = [optim.SGD(client.parameters(), lr=lr) for client in clients]

server = FedAVGServer(clients, Net().to(device))

api = FedAVGAPI(
    server,
    clients,
    criterion,
    local_optimizers,
    local_dataloaders,
    num_communication=num_rounds,
    custom_action=evaluate_gloal_model(test_dataloader),
    device=device
)
api.run()

Test set: Average loss: 0.9437595328330993, Accuracy: 81.35
Test set: Average loss: 0.6854188568115235, Accuracy: 85.11
Test set: Average loss: 0.5824775038719178, Accuracy: 86.44
Test set: Average loss: 0.5257955277442932, Accuracy: 87.3
Test set: Average loss: 0.4892599359035492, Accuracy: 87.79


In [None]:
from aijack.attack.inversion import MI_FACE
malicious_client = MIFaceFedAVGClient(Net().to(device), user_id=0, device=device)
miface = MI_FACE()
malicious_client.attach_mi_face()