<a href="https://colab.research.google.com/github/ifran-rahman/Federated-ECG/blob/review/Copy_of_FL_Simulation_TNR_Lab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive

drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# import sys
# sys.path.append('client_train.py ')
# sys.path.append('client.py')
# sys.path.append('pre_n_post_process.py')
# sys.path.append('server_train.py')


In [None]:
# !pip install protobuf==3.20.3

In [None]:
!pip install h5py
!pip install typing-extensions
!pip install wheel



In [None]:
!pip install -U flwr["simulation"]



In [None]:
# importing the module
import tracemalloc

import time

# starting the monitoring
tracemalloc.start()

In [None]:
# we naturally first need to import torch and torchvision
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor, Normalize, Compose
from torchvision.datasets import MNIST
import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch import nn,optim
import sys
from tqdm import tqdm
import pandas as pd
import numpy as np

def prepare__dataset(abnormal: pd.DataFrame, normal: pd.DataFrame, val_split_factor: float) -> tuple[torch.utils.data.Dataset, torch.utils.data.Dataset, dict]:
    """Prepares the training and validation datasets for ECG classification, shuffling the data before splitting.

    Args:
        abnormal (pd.DataFrame): DataFrame containing abnormal ECG data.
        normal (pd.DataFrame): DataFrame containing normal ECG data.
        val_split_factor (float): Fraction of the data to use for validation.

    Returns:
        tuple[torch.utils.data.Dataset, torch.utils.data.Dataset, dict]: A tuple containing the training dataset,
            the validation dataset, and a dictionary with the number of examples in each.
    """
    abnormal = abnormal.drop([187], axis=1)
    normal = normal.drop([187], axis=1)

    y_abnormal = np.ones((abnormal.shape[0]))
    y_abnormal = pd.DataFrame(y_abnormal)

    y_normal = np.zeros((normal.shape[0]))
    y_normal = pd.DataFrame(y_normal)

    x = pd.concat([abnormal, normal], sort=True)
    y = pd.concat([y_abnormal, y_normal], sort=True)

    x = x.to_numpy()
    y = y[0].to_numpy()

    # Create a TensorDataset before shuffling
    full_dataset = torch.utils.data.TensorDataset(torch.from_numpy(x).float(),
                                                  torch.from_numpy(y).long())

    # Calculate the lengths for training and validation sets
    full_len = len(full_dataset)
    val_len = int(full_len * val_split_factor)
    train_len = full_len - val_len

    # Shuffle the dataset using a random permutation of indices
    indices = torch.randperm(full_len).tolist()
    train_indices = indices[:train_len]
    val_indices = indices[train_len:]

    # Create SubsetRandomSamplers to get shuffled subsets
    train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
    val_sampler = torch.utils.data.SubsetRandomSampler(val_indices)

    # Create DataLoaders using the samplers
    train_dataset = torch.utils.data.DataLoader(full_dataset, batch_size=train_len, sampler=train_sampler)
    val_dataset = torch.utils.data.DataLoader(full_dataset, batch_size=val_len, sampler=val_sampler)

    num_examples = {'trainset': train_len,
                    'testset': val_len}

    # Extract the datasets from the DataLoaders (since DataLoader returns iterators)
    train_dataset = list(train_dataset)[0]
    val_dataset = list(val_dataset)[0]

    return train_dataset, val_dataset, num_examples

In [None]:
# hyperparameters
batch_size=500
lr = 3e-3
epochs = 21
val_split_factor = 0.2
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
root = '/content/drive/MyDrive/Federated-ECG/ECG_Classification/client/datasets/'
# load dataset
abnormal = pd.read_csv(root +'ptbdb_abnormal.csv', header = None)
normal = pd.read_csv(root + 'ptbdb_normal.csv', header = None)

train_dataset, val_dataset, _ = prepare__dataset(abnormal=abnormal, normal=normal, val_split_factor=val_split_factor)


In [None]:
len(abnormal)

10518

In [None]:
abnormal.shape, normal.shape

((10518, 188), (4052, 188))

In [None]:
import torch.nn as nn
import torch.nn.functional as F


#define the ecg_net model
class ecg_net(nn.Module):

    def __init__(self, num_of_class):
        super(ecg_net, self).__init__()

        self.model = nn.Sequential(
            nn.Conv1d(1, 16, kernel_size=3, stride=1, padding=1),
            nn.MaxPool1d(2),

            nn.Conv1d(16, 64, kernel_size=3, stride=1, padding=1),
            nn.MaxPool1d(2),

            nn.Conv1d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.MaxPool1d(2),

        )

        self.linear = nn.Sequential(
            nn.Linear(2944,500),
            nn.LeakyReLU(inplace=True),
            nn.Linear(500, num_of_class),

        )


    def forward(self,x):
        x = x.unsqueeze(1)
        x = self.model(x)
        # print(x.shape)
        x = x.view(x.size(0), -1)
        #x [b, 2944]
        # print(x.shape)
        x = self.linear(x)

        return x

In [None]:
# def evalute(model, loader):
#     model.eval()

#     correct = 0
#     total = len(loader)
#     val_bar = tqdm(loader, file=sys.stdout)
#     for x, y in val_bar:
#         x, y = x.to(device), y.to(device)
#         with torch.no_grad():
#             logits = model(x)
#             pred = logits.argmax(dim=1)
#         correct += torch.eq(pred, y).sum().float().item()

#     return correct / total

## One Client, One Data Partition

To start designing a Federated Learning pipeline we need to meet one of the key properties in FL: each client has its own data partition. To accomplish this with the dataset, we are going to generate N random partitions, where N is the total number of clients in our FL system.

In [None]:
from torch.utils.data import random_split


def get_dataset(abnormal, normal, val_split_factor): # done

    abnormal = abnormal.drop([187], axis=1)
    normal = normal.drop([187], axis=1)

    y_abnormal = np.ones((abnormal.shape[0]))
    y_abnormal = pd.DataFrame(y_abnormal)

    y_normal = np.zeros((normal.shape[0]))
    y_normal = pd.DataFrame(y_normal)

    x = pd.concat([abnormal, normal], sort=True)
    y = pd.concat([y_abnormal, y_normal] ,sort=True)

    x = x.to_numpy()
    y = y[0].to_numpy()

    train_dataset = torch.utils.data.TensorDataset(torch.from_numpy(x).float(),
                                                torch.from_numpy(y).long())

    train_len = x.shape[0]
    val_len = int(train_len * val_split_factor)
    train_len -= val_len

    train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_len, val_len])

    num_examples =  {'trainset': train_len,
                    'testset': val_len}

    return train_dataset, val_dataset, num_examples


# define a Dataloader function
def my_DataLoader(train_root, test_root, batch_size = 100, val_split_factor = 0.2):

    train_df = pd.read_csv(train_root, header=None)
    test_df = pd.read_csv(test_root, header=None)

    train_data = train_df.to_numpy()
    test_data = test_df.to_numpy()

    train_dataset = torch.utils.data.TensorDataset(torch.from_numpy(train_data[:, :-1]).float(),
                                                   torch.from_numpy(train_data[:, -1]).long(),)
    test_dataset = torch.utils.data.TensorDataset(torch.from_numpy(test_data[:, :-1]).float(),
                                                  torch.from_numpy(test_data[:, -1]).long())

    train_len = train_data.shape[0]
    val_len = int(train_len * val_split_factor)
    train_len -= val_len

    train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_len, val_len])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

    num_examples =  {'trainset': train_len,
                    'testset': val_len}


    return train_loader, val_loader, test_loader, num_examples


def prepare_dataset(num_partitions: int, batch_size: int, val_ratio: float = 0.1):
    """This function partitions the training set into N disjoint
    subsets, each will become the local dataset of a client. This
    function also subsequently partitions each training set partition
    into train and validation. The test set is left intact and will
    be used by the central server to asses the performance of the
    global model."""

    # get the datatset
    trainset, testset, _ = get_dataset(abnormal=abnormal, normal=normal, val_split_factor=val_split_factor)

    print(len(trainset))
    print(len(testset))
    # split trainset into `num_partitions` trainsets
    num_images = len(trainset) // num_partitions

    partition_len = [num_images] * num_partitions

    partition_len[len(partition_len)-1] = len(trainset)-sum(partition_len[:-1])
    print(partition_len)
    trainsets = random_split(
        trainset, partition_len, torch.Generator().manual_seed(2023)
    )

    # create dataloaders with train+val support
    trainloaders = []
    valloaders = []
    for trainset_ in trainsets:
        num_total = len(trainset_)
        num_val = int(val_ratio * num_total)
        num_train = num_total - num_val

        for_train, for_val = random_split(
            trainset_, [num_train, num_val], torch.Generator().manual_seed(2023)
        )
        trainloaders.append(
            DataLoader(for_train, batch_size=batch_size, shuffle=True, num_workers=0)
        )
        valloaders.append(
            DataLoader(for_val, batch_size=batch_size, shuffle=False, num_workers=0)
        )

    # create dataloader for the test set
    testloader = DataLoader(testset, batch_size=128)

    datapoint_count = len(trainset[0][0])
    print('Minimum DataPoint required for a signal :', datapoint_count)
    return trainloaders, valloaders, testloader

In [None]:
import matplotlib.pyplot as plt
NUM_CLIENTS = 2

trainloaders, valloaders, testloader = prepare_dataset(num_partitions=NUM_CLIENTS, batch_size=20, val_ratio=0.1)

11656
2914
[5828, 5828]
Minimum DataPoint required for a signal : 187


In [None]:
# !pip install -U flwr["simulation"]

In [None]:
import flwr as fl

In [None]:
# !cp /content/drive/MyDrive/TNR\ Lab/Federated-ECG/simulate_fl/client_train.py /content

In [None]:
import torch
from torch.utils.data import DataLoader
from torch import nn,optim
import sys
from tqdm import tqdm
import pandas as pd
import numpy as np

# define a Dataloader function
def my_DataLoader(train_root, test_root, batch_size = 100, val_split_factor = 0.2):

    train_df = pd.read_csv(train_root, header=None)
    test_df = pd.read_csv(test_root, header=None)

    train_data = train_df.to_numpy()
    test_data = test_df.to_numpy()

    train_dataset = torch.utils.data.TensorDataset(torch.from_numpy(train_data[:, :-1]).float(),
                                                   torch.from_numpy(train_data[:, -1]).long(),)
    test_dataset = torch.utils.data.TensorDataset(torch.from_numpy(test_data[:, :-1]).float(),
                                                  torch.from_numpy(test_data[:, -1]).long())

    train_len = train_data.shape[0]
    val_len = int(train_len * val_split_factor)
    train_len -= val_len

    train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_len, val_len])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

    num_examples =  {'trainset': train_len,
                    'testset': val_len}


    return train_loader, val_loader, test_loader, num_examples


#define the ecg_net model
class ecg_net(nn.Module):

    def __init__(self, num_of_class):
        super(ecg_net, self).__init__()

        self.model = nn.Sequential(
            nn.Conv1d(1, 16, kernel_size=3, stride=1, padding=1),
            nn.MaxPool1d(2),

            nn.Conv1d(16, 64, kernel_size=3, stride=1, padding=1),
            nn.MaxPool1d(2),

            nn.Conv1d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.MaxPool1d(2),

        )

        self.linear = nn.Sequential(
            nn.Linear(2944,500),
            nn.LeakyReLU(inplace=True),
            nn.Linear(500, num_of_class),

        )


    def forward(self,x):
        x = x.unsqueeze(1)
        x = self.model(x)
        # print(x.shape)
        x = x.view(x.size(0), -1)
        #x [b, 2944]
        # print(x.shape)
        x = self.linear(x)

        return x


# hyperparameters
batch_size=1
lr = 3e-3
epochs = 10
val_split_factor = 0.2
torch.manual_seed(1234)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))

def evaluate_model(model, loader, device):
    model.eval()
    model.to(device)

    correct = 0
    total = len(loader.dataset)
    val_bar = tqdm(loader, file=sys.stdout)
    for x, y in val_bar:
        x, y = x.to(device), y.to(device)
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
        correct += torch.eq(pred, y).sum().float().item()

    return correct / total

def train_client(model, train_loader, valid_loader, epochs=1):

    # model = ecg_net(2).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criteon = nn.CrossEntropyLoss()

    best_acc, best_epoch = 0, 0
    global_step = 0


    for epoch in range(epochs):

        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, (x, y) in enumerate(train_bar):
            # x: [b, 187], y: [b]
            x, y = x.to(device), y.to(device)

            model.train()

            logits = model(x)
            loss = criteon(logits, y)

            optimizer.zero_grad()
            loss.backward()

            # for param in model.parameters():
            #     print(param.grad)

            optimizer.step()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

            global_step += 1

        if epoch % 1 == 0:  # You can change the validation frequency as you wish

            val_acc = evalute(model, valid_loader)

            print('val_acc = ',val_acc)
            if val_acc > best_acc:
                best_epoch = epoch
                best_acc = val_acc

                # torch.save(model.state_dict(), 'best_client_model.mdl')

        print("Global steps", global_step)

    print('best acc:', best_acc, 'best epoch:', best_epoch)

    # model.load_state_dict(torch.load('best.mdl'))
    # print('loaded from ckpt!')

# def validate(model, testloader, criterion):
#     return 0,0
def validate(model, testloader, criterion):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    device = next(model.parameters()).device  # Get model's current device

    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * inputs.size(0)

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    avg_loss = running_loss / total
    accuracy = correct / total
    return avg_loss, accuracy


def prepare__dataset(abnormal, normal, val_split_factor):

    abnormal = abnormal.drop([187], axis=1)
    normal = normal.drop([187], axis=1)

    y_abnormal = np.ones((abnormal.shape[0]))
    y_abnormal = pd.DataFrame(y_abnormal)

    y_normal = np.zeros((normal.shape[0]))
    y_normal = pd.DataFrame(y_normal)

    x = pd.concat([abnormal, normal], sort=True)
    y = pd.concat([y_abnormal, y_normal] ,sort=True)

    x = x.to_numpy()
    y = y[0].to_numpy()

    train_dataset = torch.utils.data.TensorDataset(torch.from_numpy(x).float(),
                                                torch.from_numpy(y).long())

    train_len = x.shape[0]
    val_len = int(train_len * val_split_factor)
    train_len -= val_len

    train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_len, val_len])

    num_examples =  {'trainset': train_len,
                    'testset': val_len}

    return train_dataset, val_dataset, num_examples


def main():

    # load dataset
    abnormal = pd.read_csv('datasets/ptbdb_abnormal.csv', header = None)
    normal = pd.read_csv('datasets/ptbdb_normal.csv', header = None)

    train_dataset, val_dataset, _ = prepare__dataset(abnormal=abnormal, normal=normal, val_split_factor=val_split_factor)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

    model = ecg_net(2).to(device)

    optimizer = optim.Adam(model.parameters(), lr=lr)
    criteon = nn.CrossEntropyLoss()

    best_acc, best_epoch = 0, 0
    global_step = 0

    for epoch in range(epochs):

        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, (x, y) in enumerate(train_bar):
            # x: [b, 187], y: [b]
            x, y = x.to(device), y.to(device)

            model.train()
            logits = model(x)
            loss = criteon(logits, y)

            optimizer.zero_grad()
            loss.backward()

            # for param in model.parameters():
            #     print(param.grad)

            optimizer.step()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)
            global_step += 1

        if epoch % 1 == 0:  # You can change the validation frequency as you wish

            val_acc = evalute(model, val_loader)

            print('val_acc = ',val_acc)
            if val_acc > best_acc:
                best_epoch = epoch
                best_acc = val_acc

                torch.save(model.state_dict(), 'best.mdl')


    print('best acc:', best_acc, 'best epoch:', best_epoch)

using cuda:0 device.


In [None]:
from collections import OrderedDict
from typing import Dict, List, Tuple
import torch
from flwr.common import NDArrays, Scalar
# from client_train import *

class FlowerClient(fl.client.NumPyClient):
    def __init__(self, trainloader, vallodaer) -> None:
        super().__init__()

        self.trainloader = trainloader
        self.valloader = vallodaer
        self.model = ecg_net(2)
        # Determine device
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)  # send model to device

    def set_parameters(self, parameters):
        """With the model paramters received from the server,
        overwrite the uninitialise model in this class with them."""

        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
        # now replace the parameters
        self.model.load_state_dict(state_dict, strict=True)

    def get_parameters(self, config: Dict[str, Scalar]):
        """Extract all model parameters and conver them to a list of
        NumPy arryas. The server doesn't work with PyTorch/TF/etc."""
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

    def fit(self, parameters, config):
        """This method train the model using the parameters sent by the
        server on the dataset of this client. At then end, the parameters
        of the locally trained model are communicated back to the server"""

        # copy parameters sent by the server into client's local model
        self.set_parameters(parameters)

        # read from config
        lr = 0.001 # config["lr"]
        epochs = 5 #config["epochs"]

        # Define the optimizer
        optim = torch.optim.SGD(self.model.parameters(), lr=lr, momentum=0.9)

        # do local training
        train_client(self.model, self.trainloader, self.valloader, epochs=epochs)

        # return the model parameters to the server as well as extra info (number of training examples in this case)
        return self.get_parameters({}), len(self.trainloader), {}

    def evaluate(self, parameters: NDArrays, config: Dict[str, Scalar]) -> Tuple[float, int, Dict[str, Scalar]]:
        """Evaluate the model on the locally held dataset."""

        # Update local model with parameters received from the server
        self.set_parameters(parameters)

        # Get the loss criterion
        criterion = torch.nn.CrossEntropyLoss()
        total_loss = 0.0
        correct = 0
        total = 0

        # Switch the model to evaluation mode
        self.model.eval()

        # Disable gradient calculation during evaluation
        with torch.no_grad():
            for inputs, labels in self.valloader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                outputs = self.model(inputs)
                loss = criterion(outputs, labels)
                total_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        # Calculate average loss and accuracy
        avg_loss = total_loss / len(self.valloader)
        accuracy = correct / total

        # Return evaluation results
        return avg_loss, total, {"accuracy": accuracy}


In [None]:
model = ecg_net(2).to(device=device)
client = FlowerClient(trainloaders[0], valloaders[0])

In [None]:
from typing import Dict, Optional, Tuple, List, Union
from collections import OrderedDict
import flwr as fl
from flwr.common import (
    Scalar,
)

def evaluate(
    server_round: int,
    parameters: fl.common.NDArrays,
    config: Dict[str, fl.common.Scalar],
) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:
    # Update model with the latest parameters
    params_dict = zip(model.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
    model.load_state_dict(state_dict, strict=True)

    criterion = nn.CrossEntropyLoss()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    device = next(model.parameters()).device

    model.eval()
    with torch.no_grad():
        for valloader in valloaders:  # Iterate through the list of validation loaders
            for inputs, labels in valloader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                total_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs, 1)
                total_correct += (predicted == labels).sum().item()
                total_samples += labels.size(0)

    avg_loss = total_loss / total_samples
    accuracy = total_correct / total_samples
    return avg_loss, {"accuracy": accuracy}

# You can now remove the separate 'validate' function as the logic is within 'evaluate'

def get_evaluate_fn(testloader):
    """This is a function that returns a function. The returned
    function (i.e. `evaluate_fn`) will be executed by the strategy
    at the end of each round to evaluate the stat of the global
    model."""
    def evaluate_fn(
        server_round: int,
        parameters: fl.common.NDArrays,
        config: Dict[str, fl.common.Scalar],
        ) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:
        # Update model with the latest parameters
        params_dict = zip(model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        model.load_state_dict(state_dict, strict=True)

        criterion = nn.CrossEntropyLoss()
        loss, accuracy = 0,0 #train.validate(model, test_loader, criterion)
        return loss, {"accuracy": accuracy}

    return evaluate_fn

    from flwr.common import Metrics

def fit_config(server_round: int):
        """Return training configuration dict for each round.
        Keep batch size fixed at 32, perform two rounds of training with one
        local epoch, increase to two local epochs afterwards.
        """
        config = {
            "batch_size": 1,
            "local_epochs": 1 if server_round < 2 else 2,
        }
        return config
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    """Aggregation function for (federated) evaluation metrics, i.e. those returned by
    the client's evaluate() method."""
    # Multiply accuracy of each client by number of examples used
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]

    # Aggregate and return custom metric (weighted average)
    return {"accuracy": sum(accuracies) / sum(examples)}

class SaveModelStrategy(fl.server.strategy.FedAvg):

      def aggregate_fit(
          self,
          server_round: int,
          results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],
          failures: List[Union[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes], BaseException]],
      ) -> Tuple[Optional[fl.common.Parameters], Dict[str, Scalar]]:

          # Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics
          aggregated_parameters, aggregated_metrics = super().aggregate_fit(server_round, results, failures)

          if aggregated_parameters is not None:

              # Save aggregated_ndarrays
              print(f"Saving round {server_round} aggregated_ndarrays...")

              # Convert `Parameters` to `List[np.ndarray]`
              aggregated_ndarrays: List[np.ndarray] = fl.common.parameters_to_ndarrays(aggregated_parameters)

              np.savez(f"round-{server_round}-weights.npz", *aggregated_ndarrays)

              # Convert `List[np.ndarray]` to PyTorch`state_dict`
              params_dict = zip(model.state_dict().keys(), aggregated_ndarrays)
              state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
              model.load_state_dict(state_dict, strict=True)

              # Save the model
              torch.save(model.state_dict(), f"model_round_{server_round}.pth")

          return aggregated_parameters, aggregated_metrics

strategy = SaveModelStrategy(
        fraction_fit=1.0,
        min_fit_clients=2,
        min_available_clients=2,
        evaluate_fn=evaluate,
        on_fit_config_fn=fit_config,
    )

NameError: name 'Metrics' is not defined

In [None]:
def generate_client_fn(trainloaders, valloaders):
    def client_fn(cid: str):
        """Returns a FlowerClient containing the cid-th data partition"""

        return FlowerClient(
            trainloader=trainloaders[int(cid)], vallodaer=valloaders[int(cid)]
        )

    return client_fn


client_fn_callback = generate_client_fn(trainloaders, valloaders)

# With a dictionary, you tell Flower's VirtualClientEngine that each
# client needs exclusive access to these many resources in order to run
client_resources = {"num_cpus": 1, "num_gpus": 1}

start_training = time.time()
x = fl.simulation.start_simulation(
    client_fn=client_fn_callback,
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=3),
    strategy=strategy,
    client_resources={"num_cpus": 1, "num_gpus": 1},
    ray_init_args={"log_to_driver": False, "num_cpus": 1, "num_gpus": 1}
)
end_training = time.time()


In [None]:
# displaying the memory
print('Total Time', end_training - start_training)

current, peak = tracemalloc.get_traced_memory()
current_memory = current/(1024*1024)
peak_memory = peak/(1024*1024)

print('Current memory [MB]: {}, peak memory [MB]: {}'.format(current_memory, peak_memory))

# stopping the library
tracemalloc.stop()