In [1]:
# !pip install -q flwr
# !pip install -U ipywidgets
# !pip install -U flwr["simulation"]

In [2]:
import medmnist
from medmnist import INFO, Evaluator
from torchvision import transforms
from torchvision.transforms import ToTensor, Lambda
import torch.utils.data as data
from torchvision import models
import torch
import tqdm
from collections import OrderedDict
# from typing import List, Tuple
from typing import Dict, List, Optional, Tuple
import numpy as np 
from torch.utils.data import DataLoader, random_split
import flwr as fl
from torch import nn
import torch.nn.functional as F
from flwr.common import Metrics


DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


## Dataset Loading

In [3]:
# def load_data(batch_size=64):
#     data_transform = transforms.Compose([
#         transforms.ToTensor(),
#         transforms.Normalize(mean=[.5], std=[.5])
#     ])
#     data_flag = 'chestmnist'
#     info = INFO[data_flag]
#     DataClass = getattr(medmnist, info['python_class'])

#     train_set = DataClass(split='train', transform=data_transform, target_transform=Lambda(lambda y: y[0]), download=True)
#     test_set = DataClass(split='test', transform=data_transform, target_transform=Lambda(lambda y: y[0]), download=True)
#     # val_set = DataClass(split='val', transform=data_transform, target_transform=Lambda(lambda y: y[0]), download=True)
    
#     train_loader = data.DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True)
#     test_loader = data.DataLoader(dataset=test_set, batch_size=batch_size, shuffle=False)
#     # valid_loader = data.DataLoader(dataset=val_set, batch_size=batch_size, shuffle=False)

#     num_examples = {"trainset" : len(train_set), "testset" : len(test_set)}
#     return train_loader, test_loader, num_examples


def load_datasets(num_clients: int):
    
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[.5], std=[.5])
    ])
    # transform = transforms.Compose(
    #     [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    # )
    data_flag = 'chestmnist'
    info = INFO[data_flag]
    DataClass = getattr(medmnist, info['python_class'])

    trainset = DataClass(split='train', transform=transform, target_transform=Lambda(lambda y: y[0]), download=True)
    testset = DataClass(split='test', transform=transform, target_transform=Lambda(lambda y: y[0]), download=True)
    valset = DataClass(split='val', transform=transform, target_transform=Lambda(lambda y: y[0]), download=True)

    #  Split training set into `num_clients` partitions to simulate different local datasets
    partition_size = len(trainset) // num_clients
    if partition_size * num_clients == len(valset):
        lengths = [partition_size] * num_clients 
    else: 
         lengths = [partition_size] * (num_clients-1)  + [partition_size+ (len(trainset) % num_clients)]
    train_datasets = random_split(trainset, lengths, torch.Generator().manual_seed(42))
    
    partition_size = len(valset) // num_clients
    if partition_size * num_clients == len(valset):
        lengths = [partition_size] * num_clients 
    else: 
         lengths = [partition_size] * (num_clients-1)  + [partition_size+ (len(valset) % num_clients)]
    # print(len(valset), lengths)
    val_datasets = random_split(valset, lengths, torch.Generator().manual_seed(42))
    # Split each partition into train/val and create DataLoader
    trainloaders = []
    valloaders = []

    for t_ds in train_datasets:
        # len_train = len(t_ds) 
        # t_lengths = [len_train]
        # ds_train = random_split(t_ds, lengths, torch.Generator().manual_seed(42))
        trainloaders.append(DataLoader(t_ds, batch_size=32, shuffle=True))


    for v_ds in val_datasets:
        # len_val = len(v_ds) // 10  # 10 % validation set
        # v_lengths = [len_val]
        # ds_val = random_split(v_ds, lengths, torch.Generator().manual_seed(42))
        valloaders.append(DataLoader(v_ds, batch_size=32))
        
    testloader = DataLoader(testset, batch_size=32)
    return trainloaders, valloaders, testloader



## Define Train and Test

In [4]:
def train(net, trainloader, epochs: int, verbose=False):
    """Train the network on the training set."""

    net.train()
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    # optimizer = torch.optim.SGD(net.parameters(), lr=0.1)
    accuracy, loss = 0, 0
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for i, batch in enumerate(trainloader):
            # print(i)
            images, labels = batch
            optimizer.zero_grad()
            images = images.expand(-1, 3, -1, -1)
            outputs = net(images)
            labels = labels.long()
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            predicted = torch.argmax(outputs.detach().cpu(), axis=1)
            correct += (predicted == labels).sum().item()
            # correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        # if verbose:
        # print(f"Epoch {epoch}: train loss {epoch_loss}, accuracy {epoch_acc}")
    return epoch_loss, epoch_acc 

def test(net, testloader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for batch in testloader:
            images, labels = batch
            images = images.expand(-1, 3, -1, -1)
            labels = labels.long()
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            # print(outputs)
            # _, predicted = torch.max(outputs.data, 1)
            predicted = torch.argmax(outputs.detach().cpu(), axis=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy            

## Model 

In [5]:

class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 14)


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))

        # print(x.shape)
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        # x = F.softmax(x)
        return x


class Net2(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(6 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 14)
        # self.fc3 = nn.Linear(84, 10)


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        # x = self.pool(F.relu(self.conv2(x)))

        # print(x.shape)
        x = x.view(-1, 6 * 4 * 4)
        x = F.relu(self.fc1(x))
        # x = F.relu(self.fc2(x))
        x = self.fc2(x)
        # x = F.softmax(x)
        return x
# Load model and data
# net = Net().to(DEVICE)

In [6]:
# class MedmnistClient(fl.client.NumPyClient):
#     def get_parameters(self, config):
#         return [val.cpu().numpy() for _, val in model.state_dict().items()]

#     def set_parameters(self, parameters):
#         params_dict = zip(model.state_dict().keys(), parameters)
#         state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
#         model.load_state_dict(state_dict, strict=True)

#     def fit(self, parameters, config):
#         self.set_parameters(parameters)
#         train(model, train_loader, epochs=1)
#         return self.get_parameters(config={}), num_examples["trainset"], {}

#     def evaluate(self, parameters, config):
#         self.set_parameters(parameters)
#         loss, accuracy = test(model, test_loader)
#         return float(loss), num_examples["testset"], {"accuracy": float(accuracy)}
    
def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)

class FlowerClient(fl.client.NumPyClient):
    def __init__(self, cid, net, trainloader, valloader):
        self.cid = cid
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        print(f"[Client {self.cid}] get_parameters")
        return get_parameters(self.net)

    def fit(self, parameters, config):
        print(f"[Client {self.cid}] fit, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = train(self.net, self.trainloader, epochs=1)
        return get_parameters(self.net), len(self.trainloader), {"loss": float(loss), "accuracy": float(accuracy)}

    def evaluate(self, parameters, config):
        print(f"[Client {self.cid}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"loss": float(loss), "accuracy": float(accuracy)}

def client_fn(cid) -> FlowerClient:
    # net = models.resnet18(weights='DEFAULT').to(DEVICE)
    net = Net().to(DEVICE)
    # net = Net2().to(DEVICE)
    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]
    return FlowerClient(cid, net, trainloader, valloader)

In [13]:

def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    # Multiply accuracy of each client by number of examples used
    # print(metrics)
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]

    # Aggregate and return custom metric (weighted average)
    return {"accuracy": sum(accuracies) / sum(examples)}

In [14]:
# The `evaluate` function will be by Flower called after every round
def evaluate(
    server_round: int,
    parameters: fl.common.NDArrays,
    config: Dict[str, fl.common.Scalar],
) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:
    net = Net().to(DEVICE)
    # net = Net2().to(DEVICE)
    # valloader = valloaders[0]
    set_parameters(net, parameters)  # Update model with the latest parameters
    loss, accuracy = test(net, testloader)
    print(f"Server-side evaluation loss {loss} / accuracy {accuracy}")
    return loss, {"accuracy": accuracy}

In [None]:
# Server: 
# round 1 :
#     client 1 
#     get parameters --> fit --> evaluate --> send the parameters back to the server 
#     client 2 (fit, evaluate) valloader[1] distributed
#     get parameters --> fit --> evaluate --> send the parameters back to the server 
#     client 3 (fit, evaluate) 
#     get parameters --> fit --> evaluate --> send the parameters back to the server 
#     evaluate at server (centralized)
#     use the parameters trained on client 1,2,3 and only evaluate 


In [16]:
# Create an instance of the model and get the parameters
NUM_CLIENTS = 3
trainloaders, valloaders, testloader = load_datasets(NUM_CLIENTS)
# params = get_parameters(models.resnet18(weights='DEFAULT'))
# params = get_parameters(Net())
server_config = fl.server.ServerConfig(num_rounds=6)
# Pass parameters to the Strategy for server-side parameter initialization
strategy = fl.server.strategy.FedAvg(
    fraction_fit=1.0,
    fraction_evaluate=1.0,
    min_fit_clients=NUM_CLIENTS,
    min_evaluate_clients=NUM_CLIENTS,
    min_available_clients=NUM_CLIENTS,
    # initial_parameters=fl.common.ndarrays_to_parameters(params),
    evaluate_metrics_aggregation_fn=weighted_average, 
    # fit_metrics_aggregation_fn=weighted_average
    evaluate_fn=evaluate
    
)

# Specify client resources if you need GPU (defaults to 1 CPU and 0 GPU)
client_resources = None
if DEVICE.type == "cuda":
    client_resources = {"num_gpus": 1}

# Start simulation
fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    config=server_config,  
    strategy=strategy,
    client_resources=client_resources,
)

Using downloaded and verified file: C:\Users\Sama\.medmnist\chestmnist.npz
Using downloaded and verified file: C:\Users\Sama\.medmnist\chestmnist.npz


INFO flwr 2024-01-20 09:57:51,853 | app.py:178 | Starting Flower simulation, config: ServerConfig(num_rounds=6, round_timeout=None)


Using downloaded and verified file: C:\Users\Sama\.medmnist\chestmnist.npz


2024-01-20 09:57:57,165	INFO worker.py:1621 -- Started a local Ray instance.
INFO flwr 2024-01-20 09:58:01,433 | app.py:213 | Flower VCE: Ray initialized with resources: {'object_store_memory': 1684896153.0, 'memory': 3369792308.0, 'node:127.0.0.1': 1.0, 'CPU': 16.0, 'node:__internal_head__': 1.0}
INFO flwr 2024-01-20 09:58:01,434 | app.py:219 | Optimize your simulation with Flower VCE: https://flower.dev/docs/framework/how-to-run-simulations.html
INFO flwr 2024-01-20 09:58:01,435 | app.py:227 | No `client_resources` specified. Using minimal resources for clients.
INFO flwr 2024-01-20 09:58:01,436 | app.py:242 | Flower VCE: Resources for each Virtual Client: {'num_cpus': 1, 'num_gpus': 0.0}
INFO flwr 2024-01-20 09:58:01,453 | app.py:288 | Flower VCE: Creating VirtualClientEngineActorPool with 16 actors
INFO flwr 2024-01-20 09:58:01,455 | server.py:89 | Initializing global parameters
INFO flwr 2024-01-20 09:58:01,455 | server.py:276 | Requesting initial parameters from one random client

[2m[36m(DefaultActor pid=9244)[0m [Client 1] get_parameters


INFO flwr 2024-01-20 09:58:17,148 | server.py:94 | initial parameters (loss, other metrics): 0.0801923838598285, {'accuracy': 0.34912851602549816}
INFO flwr 2024-01-20 09:58:17,149 | server.py:104 | FL starting
DEBUG flwr 2024-01-20 09:58:17,149 | server.py:222 | fit_round 1: strategy sampled 3 clients (out of 3)


Server-side evaluation loss 0.0801923838598285 / accuracy 0.34912851602549816
[2m[36m(DefaultActor pid=9244)[0m [Client 1] fit, config: {}
[2m[36m(DefaultActor pid=7672)[0m [Client 0] fit, config: {}
[2m[36m(DefaultActor pid=13816)[0m [Client 2] fit, config: {}


DEBUG flwr 2024-01-20 09:58:35,103 | server.py:236 | fit_round 1 received 3 results and 0 failures
INFO flwr 2024-01-20 09:58:41,664 | server.py:125 | fit progress: (1, 0.010402049225472854, {'accuracy': 0.8921232113404359}, 24.514270399999987)
DEBUG flwr 2024-01-20 09:58:41,665 | server.py:173 | evaluate_round 1: strategy sampled 3 clients (out of 3)


Server-side evaluation loss 0.010402049225472854 / accuracy 0.8921232113404359
[2m[36m(DefaultActor pid=7672)[0m [Client 0] evaluate, config: {}


DEBUG flwr 2024-01-20 09:58:43,032 | server.py:187 | evaluate_round 1 received 3 results and 0 failures
DEBUG flwr 2024-01-20 09:58:43,033 | server.py:222 | fit_round 2: strategy sampled 3 clients (out of 3)


[(117, {'loss': 0.010091576069011883, 'accuracy': 0.8973536487570168}), (117, {'loss': 0.009977690686226911, 'accuracy': 0.8964963894089328}), (117, {'loss': 0.009323701953929387, 'accuracy': 0.9069269858250869})]
[2m[36m(DefaultActor pid=9244)[0m [Client 0] fit, config: {}


DEBUG flwr 2024-01-20 09:58:56,947 | server.py:236 | fit_round 2 received 3 results and 0 failures
INFO flwr 2024-01-20 09:59:03,501 | server.py:125 | fit progress: (2, 0.010039015498444234, {'accuracy': 0.8921232113404359}, 46.35087370000008)
DEBUG flwr 2024-01-20 09:59:03,502 | server.py:173 | evaluate_round 2: strategy sampled 3 clients (out of 3)


Server-side evaluation loss 0.010039015498444234 / accuracy 0.8921232113404359
[2m[36m(DefaultActor pid=7672)[0m [Client 2] evaluate, config: {}[32m [repeated 3x across cluster][0m
[2m[36m(DefaultActor pid=7672)[0m [Client 1] fit, config: {}[32m [repeated 2x across cluster][0m


DEBUG flwr 2024-01-20 09:59:04,916 | server.py:187 | evaluate_round 2 received 3 results and 0 failures
DEBUG flwr 2024-01-20 09:59:04,917 | server.py:222 | fit_round 3: strategy sampled 3 clients (out of 3)


[(117, {'loss': 0.009596982461209283, 'accuracy': 0.8964963894089328}), (117, {'loss': 0.008998487956720293, 'accuracy': 0.9069269858250869}), (117, {'loss': 0.009752945621457848, 'accuracy': 0.8973536487570168})]


DEBUG flwr 2024-01-20 09:59:18,919 | server.py:236 | fit_round 3 received 3 results and 0 failures
INFO flwr 2024-01-20 09:59:25,445 | server.py:125 | fit progress: (3, 0.009903255271347272, {'accuracy': 0.8921232113404359}, 68.29563230000008)
DEBUG flwr 2024-01-20 09:59:25,446 | server.py:173 | evaluate_round 3: strategy sampled 3 clients (out of 3)


Server-side evaluation loss 0.009903255271347272 / accuracy 0.8921232113404359
[2m[36m(DefaultActor pid=7672)[0m [Client 0] evaluate, config: {}[32m [repeated 3x across cluster][0m
[2m[36m(DefaultActor pid=9244)[0m [Client 2] fit, config: {}[32m [repeated 3x across cluster][0m


DEBUG flwr 2024-01-20 09:59:26,846 | server.py:187 | evaluate_round 3 received 3 results and 0 failures
DEBUG flwr 2024-01-20 09:59:26,847 | server.py:222 | fit_round 4: strategy sampled 3 clients (out of 3)


[(117, {'loss': 0.009458295907302157, 'accuracy': 0.8964963894089328}), (117, {'loss': 0.009632439273654762, 'accuracy': 0.8973536487570168}), (117, {'loss': 0.008959055229858078, 'accuracy': 0.9069269858250869})]


DEBUG flwr 2024-01-20 09:59:41,115 | server.py:236 | fit_round 4 received 3 results and 0 failures
INFO flwr 2024-01-20 09:59:48,181 | server.py:125 | fit progress: (4, 0.009821035031999624, {'accuracy': 0.8921232113404359}, 91.03159089999997)
DEBUG flwr 2024-01-20 09:59:48,182 | server.py:173 | evaluate_round 4: strategy sampled 3 clients (out of 3)


Server-side evaluation loss 0.009821035031999624 / accuracy 0.8921232113404359
[2m[36m(DefaultActor pid=7672)[0m [Client 2] evaluate, config: {}[32m [repeated 3x across cluster][0m
[2m[36m(DefaultActor pid=7672)[0m [Client 0] fit, config: {}[32m [repeated 3x across cluster][0m


DEBUG flwr 2024-01-20 09:59:49,597 | server.py:187 | evaluate_round 4 received 3 results and 0 failures
DEBUG flwr 2024-01-20 09:59:49,598 | server.py:222 | fit_round 5: strategy sampled 3 clients (out of 3)


[(117, {'loss': 0.008894680314489341, 'accuracy': 0.9069269858250869}), (117, {'loss': 0.009374849914739272, 'accuracy': 0.8964963894089328}), (117, {'loss': 0.009552979568481062, 'accuracy': 0.8973536487570168})]


DEBUG flwr 2024-01-20 10:00:07,688 | server.py:236 | fit_round 5 received 3 results and 0 failures
INFO flwr 2024-01-20 10:00:15,527 | server.py:125 | fit progress: (5, 0.009775383428907563, {'accuracy': 0.8921232113404359}, 118.37667250000004)
DEBUG flwr 2024-01-20 10:00:15,528 | server.py:173 | evaluate_round 5: strategy sampled 3 clients (out of 3)


Server-side evaluation loss 0.009775383428907563 / accuracy 0.8921232113404359
[2m[36m(DefaultActor pid=7672)[0m [Client 0] evaluate, config: {}[32m [repeated 3x across cluster][0m
[2m[36m(DefaultActor pid=9244)[0m [Client 0] fit, config: {}[32m [repeated 3x across cluster][0m


DEBUG flwr 2024-01-20 10:00:17,526 | server.py:187 | evaluate_round 5 received 3 results and 0 failures
DEBUG flwr 2024-01-20 10:00:17,527 | server.py:222 | fit_round 6: strategy sampled 3 clients (out of 3)


[(117, {'loss': 0.009334360221692284, 'accuracy': 0.8964963894089328}), (117, {'loss': 0.008835238355004427, 'accuracy': 0.9069269858250869}), (117, {'loss': 0.009484968509896495, 'accuracy': 0.8973536487570168})]


DEBUG flwr 2024-01-20 10:00:34,225 | server.py:236 | fit_round 6 received 3 results and 0 failures
INFO flwr 2024-01-20 10:00:41,829 | server.py:125 | fit progress: (6, 0.00976011572624982, {'accuracy': 0.8921232113404359}, 144.67881750000004)
DEBUG flwr 2024-01-20 10:00:41,830 | server.py:173 | evaluate_round 6: strategy sampled 3 clients (out of 3)


Server-side evaluation loss 0.00976011572624982 / accuracy 0.8921232113404359
[2m[36m(DefaultActor pid=7672)[0m [Client 1] evaluate, config: {}[32m [repeated 3x across cluster][0m
[2m[36m(DefaultActor pid=7672)[0m [Client 0] fit, config: {}[32m [repeated 3x across cluster][0m


DEBUG flwr 2024-01-20 10:00:43,498 | server.py:187 | evaluate_round 6 received 3 results and 0 failures
INFO flwr 2024-01-20 10:00:43,498 | server.py:153 | FL finished in 146.3488069000001
INFO flwr 2024-01-20 10:00:43,500 | app.py:226 | app_fit: losses_distributed [(1, 0.009797656236389393), (2, 0.009449472013129142), (3, 0.009349930136938332), (4, 0.009274169932569892), (5, 0.009218189028864402), (6, 0.009211135135582745)]
INFO flwr 2024-01-20 10:00:43,500 | app.py:227 | app_fit: metrics_distributed_fit {}
INFO flwr 2024-01-20 10:00:43,501 | app.py:228 | app_fit: metrics_distributed {'accuracy': [(1, 0.9002590079970121), (2, 0.9002590079970121), (3, 0.9002590079970121), (4, 0.9002590079970121), (5, 0.9002590079970121), (6, 0.9002590079970121)]}
INFO flwr 2024-01-20 10:00:43,502 | app.py:229 | app_fit: losses_centralized [(0, 0.0801923838598285), (1, 0.010402049225472854), (2, 0.010039015498444234), (3, 0.009903255271347272), (4, 0.009821035031999624), (5, 0.009775383428907563), (6, 0

[(117, {'loss': 0.009330596168049519, 'accuracy': 0.8964963894089328}), (117, {'loss': 0.008815586138501719, 'accuracy': 0.9069269858250869}), (117, {'loss': 0.009487223100196993, 'accuracy': 0.8973536487570168})]


History (loss, distributed):
	round 1: 0.009797656236389393
	round 2: 0.009449472013129142
	round 3: 0.009349930136938332
	round 4: 0.009274169932569892
	round 5: 0.009218189028864402
	round 6: 0.009211135135582745
History (loss, centralized):
	round 0: 0.0801923838598285
	round 1: 0.010402049225472854
	round 2: 0.010039015498444234
	round 3: 0.009903255271347272
	round 4: 0.009821035031999624
	round 5: 0.009775383428907563
	round 6: 0.00976011572624982
History (metrics, distributed, evaluate):
{'accuracy': [(1, 0.9002590079970121), (2, 0.9002590079970121), (3, 0.9002590079970121), (4, 0.9002590079970121), (5, 0.9002590079970121), (6, 0.9002590079970121)]}History (metrics, centralized):
{'accuracy': [(0, 0.34912851602549816), (1, 0.8921232113404359), (2, 0.8921232113404359), (3, 0.8921232113404359), (4, 0.8921232113404359), (5, 0.8921232113404359), (6, 0.8921232113404359)]}

In [None]:
# model = Net()
# train(model , trainloaders[0], epochs=3)
# loss, accuracy = test(model, valloaders[0])

In [None]:
# fl.client.start_numpy_client(server_address="[::]:8080", client=MedmnistClient())