<a href="https://colab.research.google.com/github/long2256/PoisonGAN/blob/main/sim_v0_8_8.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# depending on your shell, you might need to add `\` before `[` and `]`.
!pip install -q flwr[simulation]
!pip install flwr_datasets[vision]

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m219.2/219.2 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m13.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.9/56.9 MB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting flwr_datasets[vision]
  Downloading flwr_datasets-0.0.2-py3-none-any.whl (22 kB)
Collecting datasets<3.0.0,>=2.14.3 (from flwr_datasets[vision])
  Downloading datasets-2.16.1-py3-none-any.whl (507 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m507.1/507.1 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.8,>=0.3.0 (from datasets<3.0.0,>=2.14.3->flwr_datasets[vision])
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m18.2 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess (from datasets<3.

In [2]:
!pip install matplotlib



In [3]:
from datasets import Dataset
from flwr_datasets import FederatedDataset
from datasets.utils.logging import disable_progress_bar
from torch.utils.data import TensorDataset

# Let's set a simulation involving a total of 100 clients
NUM_CLIENTS = 33

# Download MNIST dataset and partition the "train" partition (so one can be assigned to each client)
mnist_fds = FederatedDataset(dataset="mnist", partitioners={"train": NUM_CLIENTS})
# Let's keep the test set as is, and use it to evaluate the global model on the server
centralized_testset = mnist_fds.load_full("test")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading data:   0%|          | 0.00/15.6M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/2.60M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [4]:
from torchvision.transforms import ToTensor, Normalize, Compose, Resize


def apply_transforms(batch):
    """Get transformation for MNIST dataset"""

    # transformation to convert images to tensors and apply normalization
    transforms = Compose([
        ToTensor(),
        Normalize((0.1307,), (0.3081,)),
        Resize((64, 64), antialias=False)
        ])
    batch["image"] = [transforms(img) for img in batch["image"]]
    return batch

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

class Net(nn.Module):
    def __init__(self, num_classes: int):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=4, stride=2, padding=1)
        self.leaky1 = nn.LeakyReLU()

        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1)
        self.leaky2 = nn.LeakyReLU()

        self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1)
        self.leaky3 = nn.LeakyReLU()

        self.conv4 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.leaky4 = nn.LeakyReLU()

        self.conv5 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.leaky5 = nn.LeakyReLU()

        self.conv6 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.leaky6 = nn.LeakyReLU()

        self.avgpool = nn.AvgPool2d(2, stride=2)

        self.fc = nn.Linear(4 * 4 * 128, num_classes)  # 10 classes for MNIST
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.leaky1(x)

        x = self.conv2(x)
        x = self.leaky2(x)

        x = self.conv3(x)
        x = self.leaky3(x)

        x = self.conv4(x)
        x = self.leaky4(x)

        x = self.conv5(x)
        x = self.leaky5(x)

        x = self.conv6(x)
        x = self.leaky6(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)  # Flatten the output
        x = self.fc(x)
        x = self.softmax(x)

        return x

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=4, stride=2, padding=1)
        self.leaky1 = nn.LeakyReLU()
        self.dropout = nn.Dropout()

        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1)
        self.batchnorm1 = nn.BatchNorm2d(64)
        self.leaky2 = nn.LeakyReLU()

        self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1)
        self.batchnorm2 = nn.BatchNorm2d(64)
        self.leaky3 = nn.LeakyReLU()

        self.conv4 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.batchnorm3 = nn.BatchNorm2d(128)
        self.leaky4 = nn.LeakyReLU()

        self.conv5 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.batchnorm4 = nn.BatchNorm2d(128)
        self.leaky5 = nn.LeakyReLU()

        self.conv6 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.leaky6 = nn.LeakyReLU()

        self.avgpool = nn.AvgPool2d(2, stride=2)

        self.fc = nn.Linear(4 * 4 * 128, 10)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.leaky1(x)
        x = self.dropout(x)

        x = self.conv2(x)
        x = self.batchnorm1(x)
        x = self.leaky2(x)

        x = self.conv3(x)
        x = self.batchnorm2(x)
        x = self.leaky3(x)

        x = self.conv4(x)
        x = self.batchnorm3(x)
        x = self.leaky4(x)

        x = self.conv5(x)
        x = self.batchnorm4(x)
        x = self.leaky5(x)

        x = self.conv6(x)
        x = self.leaky6(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)  # Flatten the output
        x = self.fc(x)
        x = self.softmax(x)

        return x

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.conv1 = nn.ConvTranspose2d(100, 256, kernel_size=4, stride=4, padding=0, bias=False)
        self.batchnorm1 = nn.BatchNorm2d(256)
        self.relu1 = nn.ReLU()

        self.conv2 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=4, padding=0, bias=False)
        self.batchnorm2 = nn.BatchNorm2d(128)
        self.relu2 = nn.ReLU()

        self.conv3 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False)
        self.batchnorm3 = nn.BatchNorm2d(64)
        self.relu3 = nn.ReLU()

        self.conv4 = nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1, bias=False)
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.conv1(x)
        x = self.batchnorm1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.batchnorm2(x)
        x = self.relu2(x)
        x = self.conv3(x)
        x = self.batchnorm3(x)
        x = self.relu3(x)
        x = self.conv4(x)
        x = self.tanh(x)
        return x

In [6]:
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
def train(net, trainloader, optim, scheduler, criterion, epochs, device: str):
    """Train the network on the training set."""
    net.train()
    for _ in range(epochs):
        for batch in trainloader:
            images, labels = batch["image"].to(device), batch["label"].to(device)
            optim.zero_grad()
            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optim.step()
        scheduler.step()

def test(net, testloader, device: str):
    """Validate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, loss = 0, 0.0
    net.eval()
    with torch.no_grad():
        for data in testloader:
            images, labels = data["image"].to(device), data["label"].to(device)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()
    accuracy = correct / len(testloader.dataset)
    return loss, accuracy


In [7]:
import flwr as fl

In [8]:
from collections import OrderedDict
from typing import Dict, List, Tuple, Union, Optional
from flwr.server.client_proxy import ClientProxy
from flwr.common import NDArrays, Scalar, Parameters


class FlowerClient(fl.client.NumPyClient):
    def __init__(self, cid, trainloader, valloader, testloader) -> None:
        super().__init__()

        self.trainloader = trainloader
        self.valloader = valloader
        self.testloader = testloader
        self.cid = cid
        self.model = Net(num_classes=10)
        self.discriminator = Discriminator()
        self.generator = Generator()
        # Determine device
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)  # send model to device
        self.discriminator.to(self.device)
        self.generator.to(self.device)

    def set_parameters(self, parameters):
        """With the model paramters received from the server,
        overwrite the uninitialise model in this class with them."""

        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
        # now replace the parameters
        self.discriminator.load_state_dict(state_dict, strict=False)
        self.model.load_state_dict(state_dict, strict=True)

    def get_parameters(self, config: Dict[str, Scalar]):
        """Extract all model parameters and conver them to a list of
        NumPy arryas. The server doesn't work with PyTorch/TF/etc."""
        # print(f"[Client {self.cid}] get_parameters")
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

    def fit(self, parameters, config):
        """This method train the model using the parameters sent by the
        server on the dataset of this client. At then end, the parameters
        of the locally trained model are communicated back to the server"""
        # print(f"[Client {self.cid}] fit, config: {config}")
        # copy parameters sent by the server into client's local model
        self.set_parameters(parameters)
        lr, epochs = config["lr"], config["epochs"]
        optim = torch.optim.SGD(self.model.parameters(), lr=lr)
        scheduler = lr_scheduler.StepLR(optim, step_size=2, gamma=0.1)
        criterion = torch.nn.CrossEntropyLoss()
        train(net=self.model, trainloader=self.trainloader, optim=optim, scheduler=scheduler, criterion=criterion, epochs=epochs, device=self.device)
        # return the model parameters to the server as well as extra info (number of training examples in this case)
        return self.get_parameters({}), len(self.trainloader), {}

    def evaluate(self, parameters: NDArrays, config: Dict[str, Scalar]):
        """Evaluate the model sent by the server on this client's
        local validation set. Then return performance metrics."""

        self.set_parameters(parameters)
        loss, accuracy = test(self.model, self.valloader, device=self.device)

        return float(loss), len(self.valloader), {"accuracy": accuracy}

In [9]:
import glob
import os
def load_model_state_dict():
    net = Net(10)
    list_of_files = [fname for fname in glob.glob("./model_round_*")]
    latest_round_file = max(list_of_files, key=os.path.getctime)
    # latest_round_file = './model_round_df.pth'
    print("Loading pre-trained model from: ", latest_round_file)
    state_dict = torch.load(latest_round_file)
    net.load_state_dict(state_dict)
    return net

In [10]:
def get_evaluate_fn(centralized_testset: Dataset):
    """This is a function that returns a function. The returned
    function (i.e. `evaluate_fn`) will be executed by the strategy
    at the end of each round to evaluate the stat of the global
    model."""

    def evaluate_fn(server_round: int, parameters, config):
        """This function is executed by the strategy it will instantiate
        a model and replace its parameters with those from the global model.
        The, the model will be evaluate on the test set (recall this is the
        whole MNIST test set)."""

        # model = Net(num_classes=10)

        # # Determine device
        # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        # model.to(device)  # send model to device

        # # set parameters to the model
        # params_dict = zip(model.state_dict().keys(), parameters)
        # state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
        # model.load_state_dict(state_dict, strict=True)

        ###############################################################################
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        model = load_model_state_dict()
        model.to(device)

        # Apply transform to dataset
        testset = centralized_testset.with_transform(apply_transforms)

        testloader = DataLoader(testset, batch_size=50)
        # call test
        loss, accuracy = test(model, testloader, device)
        print('GLOBAL TEST')
        return loss, {"accuracy": accuracy}

    return evaluate_fn

We could now define a strategy just as shown (commented) above. Instead, let's see how additional (but entirely optional) functionality can be easily added to our strategy. We are going to define two additional auxiliary functions to: (1) be able to configure how clients do local training; and (2) define a function to aggregate the metrics that clients return after running their `evaluate` methods:

1. `fit_config()`. This is a function that will be executed inside the strategy when configuring a new `fit` round. This function is relatively simple and only requires as input argument the round at which the FL experiment is at. In this example we simply return a Python dictionary to specify the number of epochs and learning rate each client should made use of inside their `fit()` methods. A more versatile implementation would add more hyperparameters (e.g. the learning rate) and adjust them as the FL process advances (e.g. reducing the learning rate in later FL rounds).
2. `weighted_average()`: This is an optional function to pass to the strategy. It will be executed after an evaluation round (i.e. when client run `evaluate()`) and will aggregate the metrics clients return. In this example, we use this function to compute the weighted average accuracy of clients doing `evaluate()`.

In [11]:
from flwr.common import Metrics, FitRes


def fit_config(server_round: int) -> Dict[str, Scalar]:
    """Return a configuration with static batch size and (local) epochs."""
    config = {
        "epochs": 10,  # Number of local epochs done by clients
        "lr": 0.1,  # Learning rate to use by clients during fit()
        "attacker_epochs": 20,
        "attacker_lr": 0.05,
    }
    return config


def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    """Aggregation function for (federated) evaluation metrics, i.e. those returned by
    the client's evaluate() method."""
    # Multiply accuracy of each client by number of examples used
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]

    # Aggregate and return custom metric (weighted average)
    return {"accuracy": sum(accuracies) / sum(examples)}

Now we can define our strategy:

In [12]:
import numpy as np
class SaveModelStrategy(fl.server.strategy.FedAvg):
    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        """Aggregate model weights using weighted average and store checkpoint"""
        model=Net(10)
        # Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics
        aggregated_parameters, aggregated_metrics = super().aggregate_fit(server_round, results, failures)

        if aggregated_parameters is not None:
            print(f"Saving round {server_round} aggregated_parameters...")

            # Convert `Parameters` to `List[np.ndarray]`
            aggregated_ndarrays: List[np.ndarray] = fl.common.parameters_to_ndarrays(aggregated_parameters)

            # Convert `List[np.ndarray]` to PyTorch`state_dict`
            params_dict = zip(model.state_dict().keys(), aggregated_ndarrays)
            state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
            model.load_state_dict(state_dict, strict=True)

            # Save the model
            torch.save(model.state_dict(), f"model_round_{server_round}.pth")
        return aggregated_parameters, aggregated_metrics

In [13]:
strategy = SaveModelStrategy(
    fraction_fit=0.34,  # Sample 10% of available clients for training
    fraction_evaluate=0.34,  # Sample 5% of available clients for evaluation
    on_fit_config_fn=fit_config,
    evaluate_metrics_aggregation_fn=weighted_average,  # aggregates federated metrics
    evaluate_fn=get_evaluate_fn(centralized_testset),  # global evaluation function
)

So far we have:
* created the dataset partitions (one for each client)
* defined the client class
* decided on a strategy to use

Now we just need to launch the Flower FL experiment... not so fast! just one final function: let's create another callback that the Simulation Engine will use in order to span VirtualClients. As you can see this is really simple: construct a FlowerClient object, assigning each their own data partition.

In [14]:
from torch.utils.data import DataLoader


def get_client_fn(dataset: FederatedDataset):
    """Return a function to construct a client.

    The VirtualClientEngine will execute this function whenever a client is sampled by
    the strategy to participate.
    """

    def client_fn(cid: str) -> fl.client.Client:
        """Construct a FlowerClient with its own dataset partition."""

        # Let's get the partition corresponding to the i-th client
        client_dataset = dataset.load_partition(int(cid), "train")

        # Now let's split it into train (90%) and validation (10%)
        client_dataset_splits = client_dataset.train_test_split(test_size=0.1)

        trainset = client_dataset_splits["train"]
        valset = client_dataset_splits["test"]

        # Now we apply the transform to each batch.
        trainloader = DataLoader(
            trainset.with_transform(apply_transforms), batch_size=32, shuffle=True
        )
        valloader = DataLoader(valset.with_transform(apply_transforms), batch_size=32)
        testset = centralized_testset.with_transform(apply_transforms)

        testloader = DataLoader(testset, batch_size=50)
        # Create and return client
        return FlowerClient(int(cid), trainloader, valloader, testloader)

    return client_fn


client_fn_callback = get_client_fn(mnist_fds)

Now we are ready to launch the FL experiment using Flower simulation:

In [15]:
# # With a dictionary, you tell Flower's VirtualClientEngine that each
# # client needs exclusive access to these many resources in order to run
# client_resources = {"num_cpus": 0.2, "num_gpus": 0.1}

# # Let's disable tqdm progress bar in the main thread (used by the server)
# disable_progress_bar()

# history = fl.simulation.start_simulation(
#     client_fn=client_fn_callback,  # a callback to construct a client
#     num_clients=NUM_CLIENTS,  # total number of clients in the experiment
#     config=fl.server.ServerConfig(num_rounds=10),  # let's run for 10 rounds
#     strategy=strategy,  # the strategy that will orchestrate the whole FL pipeline
#     client_resources=client_resources,
#     actor_kwargs={
#         "on_actor_init_fn": disable_progress_bar  # disable tqdm on each actor/process spawning virtual clients
#     },
# )

Doing 10 rounds should take less than 2 minutes on a CPU-only Colab instance <-- Flower Simulation is fast! 🚀

You can then use the resturned `History` object to either save the results to disk or do some visualisation (or both of course, or neither if you like chaos). Below you can see how you can plot the centralised accuracy obtainined at the end of each round (including at the very beginning of the experiment) for the _global model_. This is want the function `evaluate_fn()` that we passed to the strategy reports.

In [16]:
# import matplotlib.pyplot as plt

# print(f"{history.metrics_centralized = }")

# global_accuracy_centralised = history.metrics_centralized["accuracy"]
# round = [data[0] for data in global_accuracy_centralised]
# acc = [data[1] for data in global_accuracy_centralised]
# plt.plot(round, acc)
# plt.grid()
# plt.ylabel("Accuracy (%)")
# plt.xlabel("Round")
# plt.title("MNIST - IID - 30 clients with 10 clients per round")

Congratulations! With that, you built a Flower client, customized it's instantiation through the `client_fn`, customized the server-side execution through a `FedAvg` strategy configured for this workload, and started a simulation with 100 clients (each holding their own individual partition of the MNIST dataset).

Next, you can continue to explore more advanced Flower topics:

- Deploy server and clients on different machines using `start_server` and `start_client`
- Customize the server-side execution through custom strategies
- Customize the client-side execution through `config` dictionaries

Get all resources you need!

* **[DOCS]** Our complete documenation: https://flower.dev/docs/
* **[Examples]** All Flower examples: https://flower.dev/docs/examples/
* **[VIDEO]** Our Youtube channel: https://www.youtube.com/@flowerlabs

Don't forget to join our Slack channel: https://flower.dev/join-slack/


In [17]:
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler

def test(net, testloader, device: str):
    """Validate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct_poisoned = 0
    total_poisoned = 0
    loss = 0.0
    net.eval()
    with torch.no_grad():
        for data in testloader:
            images, labels = data["image"].to(device), data["label"].to(device)
            output = net(images)
            pred = output.argmax(dim=1, keepdim=True)
            for i in range(len(labels)):
                if labels[i] == 2 and pred[i].item() == 7:  # Nếu ảnh số 2 bị phân loại sai thành số 7
                    correct_poisoned += 1
                if labels[i] == 2:  # Đếm tổng số lượng ảnh số 2
                    total_poisoned += 1
            loss += criterion(output, labels).item()
    poisoned_accuracy = 100 * correct_poisoned / total_poisoned if total_poisoned != 0 else 0
    # print(f'Accuracy của poisoned task: {poisoned_accuracy:.2f}%')
    return loss, poisoned_accuracy

In [18]:
from tqdm.notebook import tqdm
import torch

def poison_train(net, generator, discriminator, optim_G, optim_net, G_scheduler, net_scheduler, criterion, criterion_G, epochs, device: str):
    discriminator.eval()
    generator.train()
    torch.autograd.set_detect_anomaly(True)
    for epoch in range(epochs):
        # Training G and D together
        noise = torch.randn(7, 256 , 100, 1, 1).to(device)
        # for batch_noisy in tqdm(noise, desc='Training G and D:'):
        for batch_noisy in noise:
            fake_images = generator(batch_noisy).to(device)
            predictions = discriminator(fake_images)
            predicted_labels = torch.max(predictions, dim=1).indices

            images_is_2 = fake_images[predicted_labels == 2].clone()
            labels_is_7 = torch.full((len(images_is_2),), 7).to(device)
            if len(images_is_2) > 0:
                optim_net.zero_grad()
                output = net(images_is_2.to(device)).clone()
                loss_net = criterion(output, labels_is_7.to(device))
                loss_net.backward(retain_graph=True)
                for param in net.parameters():
                    param.grad = param.grad * 40
                optim_net.step()
                net_scheduler.step()

            images_not_2 = fake_images[predicted_labels != 2].clone()
            labels_is_2 = torch.full((len(images_not_2),), 2).to(device)
            if len(images_not_2) > 0:
                optim_G.zero_grad()
                output = discriminator(images_not_2.to(device)).clone()
                loss_G = criterion_G(output, labels_is_2.to(device))
                loss_G.backward(retain_graph=True)
                optim_G.step()
                G_scheduler.step()



In [19]:
class FlowerClient(FlowerClient):
    def fit(self, parameters, config):
        """This method train the model using the parameters sent by the
        server on the dataset of this client. At then end, the parameters
        of the locally trained model are communicated back to the server"""
        # print(f"[Client {self.cid}] fit, config: {config}")
        # copy parameters sent by the server into client's local model
        self.set_parameters(parameters)
        lr, epochs = config["lr"], config["epochs"]
        optim = torch.optim.SGD(self.model.parameters(), lr=lr)
        scheduler = lr_scheduler.StepLR(optim, step_size=2, gamma=0.1)
        criterion = torch.nn.CrossEntropyLoss()
        criterion_G = torch.nn.CrossEntropyLoss()
        if self.cid == 0:
            print('ATTACKER')
            attacker_lr, attacker_epochs = config["attacker_lr"], config["attacker_epochs"]
            loss, accuracy = test(self.model, self.testloader, device=self.device)
            optim_G = torch.optim.SGD(self.generator.parameters(), lr=attacker_lr)
            optim_net = torch.optim.SGD(self.model.parameters(), lr=attacker_lr)
            G_scheduler = lr_scheduler.StepLR(optim_G, step_size=16, gamma=0.1)
            net_scheduler = lr_scheduler.StepLR(optim_net, step_size=16, gamma=0.1)
            if accuracy > 60:
                train(net=self.model, trainloader=self.trainloader, optim=optim, scheduler=scheduler, criterion=criterion, epochs=epochs, device=self.device)
                poison_train(self.model, self.generator, self.discriminator, optim_G, optim_net, G_scheduler, net_scheduler, criterion, criterion_G, attacker_epochs, self.device)
            else:
                poison_train(self.model, self.generator, self.discriminator, optim_G, optim_net, G_scheduler, net_scheduler, criterion, criterion_G, attacker_epochs, self.device)
        else:
            train(net=self.model, trainloader=self.trainloader, optim=optim, scheduler=scheduler, criterion=criterion, epochs=epochs, device=self.device)

        # return the model parameters to the server as well as extra info (number of training examples in this case)
        return self.get_parameters({}), len(self.trainloader), {}

In [20]:
import glob
import os
net = Net(10)
# list_of_files = [fname for fname in glob.glob("./model_round_*")]
# latest_round_file = max(list_of_files, key=os.path.getctime)
latest_round_file = './model_round_df.pth'
print("Loading pre-trained model from: ", latest_round_file)
state_dict = torch.load(latest_round_file)
net.load_state_dict(state_dict)

Loading pre-trained model from:  ./model_round_df.pth


<All keys matched successfully>

In [21]:
def get_parameters():
    """Extract all model parameters and conver them to a list of
    NumPy arryas. The server doesn't work with PyTorch/TF/etc."""
    return [val.cpu().numpy() for _, val in net.state_dict().items()]

In [22]:
params = get_parameters()

strategy = SaveModelStrategy(
    fraction_fit=0.31,  # Sample 10% of available clients for training
    fraction_evaluate=1,  # Sample 5% of available clients for evaluation
    on_fit_config_fn=fit_config,
    evaluate_metrics_aggregation_fn=weighted_average,  # aggregates federated metrics
    evaluate_fn=get_evaluate_fn(centralized_testset),  # global evaluation function
    initial_parameters=fl.common.ndarrays_to_parameters(params),
)

In [23]:
# With a dictionary, you tell Flower's VirtualClientEngine that each
# client needs exclusive access to these many resources in order to run
client_resources = {"num_cpus": 0.2, "num_gpus": 0.1}

# Let's disable tqdm progress bar in the main thread (used by the server)
disable_progress_bar()
history = fl.simulation.start_simulation(
    client_fn=client_fn_callback,  # a callback to construct a client
    num_clients=NUM_CLIENTS,  # total number of clients in the experiment
    config=fl.server.ServerConfig(num_rounds=20),  # let's run for 10 rounds
    strategy=strategy,  # the strategy that will orchestrate the whole FL pipeline
    client_resources=client_resources,
    actor_kwargs={
        "on_actor_init_fn": disable_progress_bar  # disable tqdm on each actor/process spawning virtual clients
    },
)

INFO flwr 2024-01-14 08:55:50,371 | app.py:178 | Starting Flower simulation, config: ServerConfig(num_rounds=20, round_timeout=None)
INFO:flwr:Starting Flower simulation, config: ServerConfig(num_rounds=20, round_timeout=None)
2024-01-14 08:55:55,709	INFO worker.py:1621 -- Started a local Ray instance.
INFO flwr 2024-01-14 08:55:58,486 | app.py:213 | Flower VCE: Ray initialized with resources: {'object_store_memory': 3942649036.0, 'CPU': 2.0, 'memory': 7885298075.0, 'node:__internal_head__': 1.0, 'node:172.28.0.12': 1.0, 'GPU': 1.0}
INFO:flwr:Flower VCE: Ray initialized with resources: {'object_store_memory': 3942649036.0, 'CPU': 2.0, 'memory': 7885298075.0, 'node:__internal_head__': 1.0, 'node:172.28.0.12': 1.0, 'GPU': 1.0}
INFO flwr 2024-01-14 08:55:58,494 | app.py:219 | Optimize your simulation with Flower VCE: https://flower.dev/docs/framework/how-to-run-simulations.html
INFO:flwr:Optimize your simulation with Flower VCE: https://flower.dev/docs/framework/how-to-run-simulations.htm

Loading pre-trained model from:  ./model_round_df.pth


INFO flwr 2024-01-14 08:56:19,146 | server.py:94 | initial parameters (loss, other metrics): 323.4072320461273, {'accuracy': 1.1627906976744187}
INFO:flwr:initial parameters (loss, other metrics): 323.4072320461273, {'accuracy': 1.1627906976744187}
INFO flwr 2024-01-14 08:56:19,155 | server.py:104 | FL starting
INFO:flwr:FL starting
DEBUG flwr 2024-01-14 08:56:19,158 | server.py:222 | fit_round 1: strategy sampled 10 clients (out of 33)
DEBUG:flwr:fit_round 1: strategy sampled 10 clients (out of 33)


GLOBAL TEST


[2m[36m(pid=1399)[0m 2024-01-14 08:56:20.881258: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
[2m[36m(pid=1399)[0m 2024-01-14 08:56:20.881323: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
[2m[36m(pid=1399)[0m 2024-01-14 08:56:20.882853: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
[2m[36m(pid=1592)[0m 2024-01-14 08:56:23.264743: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered[32m [repeated 9x across cluster] (Ray deduplicates logs b

Saving round 1 aggregated_parameters...
Loading pre-trained model from:  ./model_round_1.pth


INFO flwr 2024-01-14 08:58:41,378 | server.py:125 | fit progress: (1, 321.34035098552704, {'accuracy': 1.1627906976744187}, 142.220639747)
INFO:flwr:fit progress: (1, 321.34035098552704, {'accuracy': 1.1627906976744187}, 142.220639747)
DEBUG flwr 2024-01-14 08:58:41,382 | server.py:173 | evaluate_round 1: strategy sampled 33 clients (out of 33)
DEBUG:flwr:evaluate_round 1: strategy sampled 33 clients (out of 33)


GLOBAL TEST


DEBUG flwr 2024-01-14 08:58:48,156 | server.py:187 | evaluate_round 1 received 33 results and 0 failures
DEBUG:flwr:evaluate_round 1 received 33 results and 0 failures
DEBUG flwr 2024-01-14 08:58:48,163 | server.py:222 | fit_round 2: strategy sampled 10 clients (out of 33)
DEBUG:flwr:fit_round 2: strategy sampled 10 clients (out of 33)
DEBUG flwr 2024-01-14 09:00:16,529 | server.py:236 | fit_round 2 received 10 results and 0 failures
DEBUG:flwr:fit_round 2 received 10 results and 0 failures


Saving round 2 aggregated_parameters...
Loading pre-trained model from:  ./model_round_2.pth


INFO flwr 2024-01-14 09:00:21,789 | server.py:125 | fit progress: (2, 319.6385989189148, {'accuracy': 0.7751937984496124}, 242.63065974099993)
INFO:flwr:fit progress: (2, 319.6385989189148, {'accuracy': 0.7751937984496124}, 242.63065974099993)
DEBUG flwr 2024-01-14 09:00:21,792 | server.py:173 | evaluate_round 2: strategy sampled 33 clients (out of 33)
DEBUG:flwr:evaluate_round 2: strategy sampled 33 clients (out of 33)


GLOBAL TEST


DEBUG flwr 2024-01-14 09:00:29,007 | server.py:187 | evaluate_round 2 received 33 results and 0 failures
DEBUG:flwr:evaluate_round 2 received 33 results and 0 failures
DEBUG flwr 2024-01-14 09:00:29,011 | server.py:222 | fit_round 3: strategy sampled 10 clients (out of 33)
DEBUG:flwr:fit_round 3: strategy sampled 10 clients (out of 33)
DEBUG flwr 2024-01-14 09:01:56,328 | server.py:236 | fit_round 3 received 10 results and 0 failures
DEBUG:flwr:fit_round 3 received 10 results and 0 failures


Saving round 3 aggregated_parameters...
Loading pre-trained model from:  ./model_round_3.pth


INFO flwr 2024-01-14 09:02:01,629 | server.py:125 | fit progress: (3, 318.5337778329849, {'accuracy': 0.6782945736434108}, 342.471050945)
INFO:flwr:fit progress: (3, 318.5337778329849, {'accuracy': 0.6782945736434108}, 342.471050945)
DEBUG flwr 2024-01-14 09:02:01,635 | server.py:173 | evaluate_round 3: strategy sampled 33 clients (out of 33)
DEBUG:flwr:evaluate_round 3: strategy sampled 33 clients (out of 33)


GLOBAL TEST


DEBUG flwr 2024-01-14 09:02:08,985 | server.py:187 | evaluate_round 3 received 33 results and 0 failures
DEBUG:flwr:evaluate_round 3 received 33 results and 0 failures
DEBUG flwr 2024-01-14 09:02:08,989 | server.py:222 | fit_round 4: strategy sampled 10 clients (out of 33)
DEBUG:flwr:fit_round 4: strategy sampled 10 clients (out of 33)
DEBUG flwr 2024-01-14 09:03:37,586 | server.py:236 | fit_round 4 received 10 results and 0 failures
DEBUG:flwr:fit_round 4 received 10 results and 0 failures


Saving round 4 aggregated_parameters...
Loading pre-trained model from:  ./model_round_4.pth


INFO flwr 2024-01-14 09:03:43,852 | server.py:125 | fit progress: (4, 318.035169005394, {'accuracy': 0.872093023255814}, 444.694451749)
INFO:flwr:fit progress: (4, 318.035169005394, {'accuracy': 0.872093023255814}, 444.694451749)
DEBUG flwr 2024-01-14 09:03:43,856 | server.py:173 | evaluate_round 4: strategy sampled 33 clients (out of 33)
DEBUG:flwr:evaluate_round 4: strategy sampled 33 clients (out of 33)


GLOBAL TEST


DEBUG flwr 2024-01-14 09:03:50,115 | server.py:187 | evaluate_round 4 received 33 results and 0 failures
DEBUG:flwr:evaluate_round 4 received 33 results and 0 failures
DEBUG flwr 2024-01-14 09:03:50,118 | server.py:222 | fit_round 5: strategy sampled 10 clients (out of 33)
DEBUG:flwr:fit_round 5: strategy sampled 10 clients (out of 33)
DEBUG flwr 2024-01-14 09:05:20,478 | server.py:236 | fit_round 5 received 10 results and 0 failures
DEBUG:flwr:fit_round 5 received 10 results and 0 failures


Saving round 5 aggregated_parameters...
Loading pre-trained model from:  ./model_round_5.pth


INFO flwr 2024-01-14 09:05:27,231 | server.py:125 | fit progress: (5, 317.9716954231262, {'accuracy': 0.872093023255814}, 548.073132333)
INFO:flwr:fit progress: (5, 317.9716954231262, {'accuracy': 0.872093023255814}, 548.073132333)
DEBUG flwr 2024-01-14 09:05:27,235 | server.py:173 | evaluate_round 5: strategy sampled 33 clients (out of 33)
DEBUG:flwr:evaluate_round 5: strategy sampled 33 clients (out of 33)


GLOBAL TEST


DEBUG flwr 2024-01-14 09:05:32,536 | server.py:187 | evaluate_round 5 received 33 results and 0 failures
DEBUG:flwr:evaluate_round 5 received 33 results and 0 failures
DEBUG flwr 2024-01-14 09:05:32,542 | server.py:222 | fit_round 6: strategy sampled 10 clients (out of 33)
DEBUG:flwr:fit_round 6: strategy sampled 10 clients (out of 33)


[2m[36m(DefaultActor pid=1561)[0m ATTACKER


DEBUG flwr 2024-01-14 09:07:10,004 | server.py:236 | fit_round 6 received 10 results and 0 failures
DEBUG:flwr:fit_round 6 received 10 results and 0 failures


Saving round 6 aggregated_parameters...
Loading pre-trained model from:  ./model_round_6.pth


INFO flwr 2024-01-14 09:07:15,642 | server.py:125 | fit progress: (6, 443.69073462486267, {'accuracy': 70.63953488372093}, 656.484107996)
INFO:flwr:fit progress: (6, 443.69073462486267, {'accuracy': 70.63953488372093}, 656.484107996)
DEBUG flwr 2024-01-14 09:07:15,649 | server.py:173 | evaluate_round 6: strategy sampled 33 clients (out of 33)
DEBUG:flwr:evaluate_round 6: strategy sampled 33 clients (out of 33)


GLOBAL TEST


DEBUG flwr 2024-01-14 09:07:22,938 | server.py:187 | evaluate_round 6 received 33 results and 0 failures
DEBUG:flwr:evaluate_round 6 received 33 results and 0 failures
DEBUG flwr 2024-01-14 09:07:22,941 | server.py:222 | fit_round 7: strategy sampled 10 clients (out of 33)
DEBUG:flwr:fit_round 7: strategy sampled 10 clients (out of 33)
DEBUG flwr 2024-01-14 09:08:57,859 | server.py:236 | fit_round 7 received 10 results and 0 failures
DEBUG:flwr:fit_round 7 received 10 results and 0 failures


Saving round 7 aggregated_parameters...
Loading pre-trained model from:  ./model_round_7.pth


INFO flwr 2024-01-14 09:09:03,961 | server.py:125 | fit progress: (7, 316.21924674510956, {'accuracy': 0.5813953488372093}, 764.8035408879999)
INFO:flwr:fit progress: (7, 316.21924674510956, {'accuracy': 0.5813953488372093}, 764.8035408879999)
DEBUG flwr 2024-01-14 09:09:03,965 | server.py:173 | evaluate_round 7: strategy sampled 33 clients (out of 33)
DEBUG:flwr:evaluate_round 7: strategy sampled 33 clients (out of 33)


GLOBAL TEST


DEBUG flwr 2024-01-14 09:09:09,163 | server.py:187 | evaluate_round 7 received 33 results and 0 failures
DEBUG:flwr:evaluate_round 7 received 33 results and 0 failures
DEBUG flwr 2024-01-14 09:09:09,166 | server.py:222 | fit_round 8: strategy sampled 10 clients (out of 33)
DEBUG:flwr:fit_round 8: strategy sampled 10 clients (out of 33)
DEBUG flwr 2024-01-14 09:10:45,250 | server.py:236 | fit_round 8 received 10 results and 0 failures
DEBUG:flwr:fit_round 8 received 10 results and 0 failures


Saving round 8 aggregated_parameters...
Loading pre-trained model from:  ./model_round_8.pth


INFO flwr 2024-01-14 09:10:51,043 | server.py:125 | fit progress: (8, 316.12299144268036, {'accuracy': 0.7751937984496124}, 871.8849066710001)
INFO:flwr:fit progress: (8, 316.12299144268036, {'accuracy': 0.7751937984496124}, 871.8849066710001)
DEBUG flwr 2024-01-14 09:10:51,053 | server.py:173 | evaluate_round 8: strategy sampled 33 clients (out of 33)
DEBUG:flwr:evaluate_round 8: strategy sampled 33 clients (out of 33)


GLOBAL TEST


DEBUG flwr 2024-01-14 09:10:58,067 | server.py:187 | evaluate_round 8 received 33 results and 0 failures
DEBUG:flwr:evaluate_round 8 received 33 results and 0 failures
DEBUG flwr 2024-01-14 09:10:58,070 | server.py:222 | fit_round 9: strategy sampled 10 clients (out of 33)
DEBUG:flwr:fit_round 9: strategy sampled 10 clients (out of 33)
DEBUG flwr 2024-01-14 09:12:33,167 | server.py:236 | fit_round 9 received 10 results and 0 failures
DEBUG:flwr:fit_round 9 received 10 results and 0 failures


Saving round 9 aggregated_parameters...
Loading pre-trained model from:  ./model_round_9.pth


INFO flwr 2024-01-14 09:12:39,280 | server.py:125 | fit progress: (9, 315.80406653881073, {'accuracy': 0.9689922480620154}, 980.1225834129999)
INFO:flwr:fit progress: (9, 315.80406653881073, {'accuracy': 0.9689922480620154}, 980.1225834129999)
DEBUG flwr 2024-01-14 09:12:39,284 | server.py:173 | evaluate_round 9: strategy sampled 33 clients (out of 33)
DEBUG:flwr:evaluate_round 9: strategy sampled 33 clients (out of 33)


GLOBAL TEST


DEBUG flwr 2024-01-14 09:12:44,450 | server.py:187 | evaluate_round 9 received 33 results and 0 failures
DEBUG:flwr:evaluate_round 9 received 33 results and 0 failures
DEBUG flwr 2024-01-14 09:12:44,458 | server.py:222 | fit_round 10: strategy sampled 10 clients (out of 33)
DEBUG:flwr:fit_round 10: strategy sampled 10 clients (out of 33)
DEBUG flwr 2024-01-14 09:14:21,291 | server.py:236 | fit_round 10 received 10 results and 0 failures
DEBUG:flwr:fit_round 10 received 10 results and 0 failures


Saving round 10 aggregated_parameters...
Loading pre-trained model from:  ./model_round_10.pth


INFO flwr 2024-01-14 09:14:27,278 | server.py:125 | fit progress: (10, 315.17247998714447, {'accuracy': 0.6782945736434108}, 1088.1198106729998)
INFO:flwr:fit progress: (10, 315.17247998714447, {'accuracy': 0.6782945736434108}, 1088.1198106729998)
DEBUG flwr 2024-01-14 09:14:27,285 | server.py:173 | evaluate_round 10: strategy sampled 33 clients (out of 33)
DEBUG:flwr:evaluate_round 10: strategy sampled 33 clients (out of 33)


GLOBAL TEST


DEBUG flwr 2024-01-14 09:14:34,125 | server.py:187 | evaluate_round 10 received 33 results and 0 failures
DEBUG:flwr:evaluate_round 10 received 33 results and 0 failures
DEBUG flwr 2024-01-14 09:14:34,132 | server.py:222 | fit_round 11: strategy sampled 10 clients (out of 33)
DEBUG:flwr:fit_round 11: strategy sampled 10 clients (out of 33)
DEBUG flwr 2024-01-14 09:16:12,421 | server.py:236 | fit_round 11 received 10 results and 0 failures
DEBUG:flwr:fit_round 11 received 10 results and 0 failures


Saving round 11 aggregated_parameters...
Loading pre-trained model from:  ./model_round_11.pth


INFO flwr 2024-01-14 09:16:17,846 | server.py:125 | fit progress: (11, 314.9377464056015, {'accuracy': 0.5813953488372093}, 1198.687823226)
INFO:flwr:fit progress: (11, 314.9377464056015, {'accuracy': 0.5813953488372093}, 1198.687823226)
DEBUG flwr 2024-01-14 09:16:17,854 | server.py:173 | evaluate_round 11: strategy sampled 33 clients (out of 33)
DEBUG:flwr:evaluate_round 11: strategy sampled 33 clients (out of 33)


GLOBAL TEST


DEBUG flwr 2024-01-14 09:16:24,060 | server.py:187 | evaluate_round 11 received 33 results and 0 failures
DEBUG:flwr:evaluate_round 11 received 33 results and 0 failures
DEBUG flwr 2024-01-14 09:16:24,066 | server.py:222 | fit_round 12: strategy sampled 10 clients (out of 33)
DEBUG:flwr:fit_round 12: strategy sampled 10 clients (out of 33)
DEBUG flwr 2024-01-14 09:17:58,704 | server.py:236 | fit_round 12 received 10 results and 0 failures
DEBUG:flwr:fit_round 12 received 10 results and 0 failures


Saving round 12 aggregated_parameters...
Loading pre-trained model from:  ./model_round_12.pth


INFO flwr 2024-01-14 09:18:05,218 | server.py:125 | fit progress: (12, 314.6347041130066, {'accuracy': 0.4844961240310077}, 1306.0602916369999)
INFO:flwr:fit progress: (12, 314.6347041130066, {'accuracy': 0.4844961240310077}, 1306.0602916369999)
DEBUG flwr 2024-01-14 09:18:05,226 | server.py:173 | evaluate_round 12: strategy sampled 33 clients (out of 33)
DEBUG:flwr:evaluate_round 12: strategy sampled 33 clients (out of 33)


GLOBAL TEST


KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt

print(f"{history.metrics_centralized = }")
# print(f"{history.metrics_distributed = }")


global_accuracy_centralised = history.metrics_centralized["accuracy"]
# global_accuracy_centralised = history.metrics_distributed["accuracy"]

round = [data[0] for data in global_accuracy_centralised]
acc = [data[1] for data in global_accuracy_centralised]
plt.plot(round, acc)
plt.grid()
plt.ylabel("Accuracy (%)")
plt.xlabel("Round")
plt.title("MNIST - IID - 30 clients with 10 clients per round")
xticks_result = plt.xticks(range(1, 21))
# plt.yticks(range(0, 100))