In [73]:
!pip install -q flwr

In [1]:
import medmnist
from medmnist import INFO, Evaluator
from torchvision import transforms
from torchvision.transforms import ToTensor, Lambda
import torch.utils.data as data
from torchvision import models
import torch
import tqdm
from collections import OrderedDict
from typing import List, Tuple
import numpy as np 
from torch.utils.data import DataLoader, random_split
import flwr as fl

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
NUM_CLIENTS = 3

  from .autonotebook import tqdm as notebook_tqdm
2024-01-17 11:42:31,854	INFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


## Dataset Loading

In [2]:
# def load_data(batch_size=64):
#     data_transform = transforms.Compose([
#         transforms.ToTensor(),
#         transforms.Normalize(mean=[.5], std=[.5])
#     ])
#     data_flag = 'chestmnist'
#     info = INFO[data_flag]
#     DataClass = getattr(medmnist, info['python_class'])

#     train_set = DataClass(split='train', transform=data_transform, target_transform=Lambda(lambda y: y[0]), download=True)
#     test_set = DataClass(split='test', transform=data_transform, target_transform=Lambda(lambda y: y[0]), download=True)
#     # val_set = DataClass(split='val', transform=data_transform, target_transform=Lambda(lambda y: y[0]), download=True)
    
#     train_loader = data.DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True)
#     test_loader = data.DataLoader(dataset=test_set, batch_size=batch_size, shuffle=False)
#     # valid_loader = data.DataLoader(dataset=val_set, batch_size=batch_size, shuffle=False)

#     num_examples = {"trainset" : len(train_set), "testset" : len(test_set)}
#     return train_loader, test_loader, num_examples

def load_datasets(num_clients: int):
    # Download and transform CIFAR-10 (train and test)
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )
    data_flag = 'chestmnist'
    info = INFO[data_flag]
    DataClass = getattr(medmnist, info['python_class'])

    trainset = DataClass(split='train', transform=transform, target_transform=Lambda(lambda y: y[0]), download=True)
    testset = DataClass(split='test', transform=transform, target_transform=Lambda(lambda y: y[0]), download=True)
    # val_set = DataClass(split='val', transform=data_transform, target_transform=Lambda(lambda y: y[0]), download=True)

    #  Split training set into `num_clients` partitions to simulate different local datasets
    partition_size = len(trainset) // num_clients
    lengths = [partition_size] * num_clients
    datasets = random_split(trainset, lengths, torch.Generator().manual_seed(42))

    # Split each partition into train/val and create DataLoader
    trainloaders = []
    valloaders = []
    for ds in datasets:
        len_val = len(ds) // 10  # 10 % validation set
        len_train = len(ds) - len_val
        lengths = [len_train, len_val]
        ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))
        trainloaders.append(DataLoader(ds_train, batch_size=32, shuffle=True))
        valloaders.append(DataLoader(ds_val, batch_size=32))
    testloader = DataLoader(testset, batch_size=32)
    return trainloaders, valloaders, testloader


trainloaders, valloaders, testloader = load_datasets(NUM_CLIENTS)

Using downloaded and verified file: C:\Users\Sama\.medmnist\chestmnist.npz
Using downloaded and verified file: C:\Users\Sama\.medmnist\chestmnist.npz


## Define Train and Test

In [3]:
def train(net, trainloader, epochs: int, verbose=False):
    """Train the network on the training set."""

    net.train()
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters())
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for i, batch in enumerate(trainloader):
            # print(i)
            images, labels = batch
            optimizer.zero_grad()
            images = images.expand(-1, 3, -1, -1)
            outputs = net(images)
            labels = labels.long()
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        # if verbose:
        print(f"Epoch {epoch}: train loss {epoch_loss}, accuracy {epoch_acc}")
    return net 

def test(net, testloader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for batch in testloader:
            images, labels = batch
            images = images.expand(-1, 3, -1, -1)
            labels = labels.long()
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy            

## Model 

In [4]:
model = models.resnet18(pretrained=True)



In [5]:
# class MedmnistClient(fl.client.NumPyClient):
#     def get_parameters(self, config):
#         return [val.cpu().numpy() for _, val in model.state_dict().items()]

#     def set_parameters(self, parameters):
#         params_dict = zip(model.state_dict().keys(), parameters)
#         state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
#         model.load_state_dict(state_dict, strict=True)

#     def fit(self, parameters, config):
#         self.set_parameters(parameters)
#         train(model, train_loader, epochs=1)
#         return self.get_parameters(config={}), num_examples["trainset"], {}

#     def evaluate(self, parameters, config):
#         self.set_parameters(parameters)
#         loss, accuracy = test(model, test_loader)
#         return float(loss), num_examples["testset"], {"accuracy": float(accuracy)}
    
def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)

class FlowerClient(fl.client.NumPyClient):
    def __init__(self, cid, net, trainloader, valloader):
        self.cid = cid
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        print(f"[Client {self.cid}] get_parameters")
        return get_parameters(self.net)

    def fit(self, parameters, config):
        print(f"[Client {self.cid}] fit, config: {config}")
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=1)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        print(f"[Client {self.cid}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}


def client_fn(cid) -> FlowerClient:
    net = model().to(DEVICE)
    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]
    return FlowerClient(cid, net, trainloader, valloader)

In [6]:
# Create an instance of the model and get the parameters
params = get_parameters(model)

# Pass parameters to the Strategy for server-side parameter initialization
strategy = fl.server.strategy.FedAvg(
    fraction_fit=0.3,
    fraction_evaluate=0.3,
    min_fit_clients=3,
    min_evaluate_clients=3,
    min_available_clients=NUM_CLIENTS,
    initial_parameters=fl.common.ndarrays_to_parameters(params),
)

# Specify client resources if you need GPU (defaults to 1 CPU and 0 GPU)
client_resources = None
if DEVICE.type == "cuda":
    client_resources = {"num_gpus": 1}

# Start simulation
fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=3),  # Just three rounds
    strategy=strategy,
    client_resources=client_resources,
)

INFO flwr 2024-01-17 11:42:37,722 | app.py:178 | Starting Flower simulation, config: ServerConfig(num_rounds=3, round_timeout=None)


2024-01-17 11:42:42,372	INFO worker.py:1621 -- Started a local Ray instance.
INFO flwr 2024-01-17 11:42:47,170 | app.py:213 | Flower VCE: Ray initialized with resources: {'memory': 2980058727.0, 'node:127.0.0.1': 1.0, 'object_store_memory': 1490029363.0, 'CPU': 16.0, 'node:__internal_head__': 1.0}
INFO flwr 2024-01-17 11:42:47,171 | app.py:219 | Optimize your simulation with Flower VCE: https://flower.dev/docs/framework/how-to-run-simulations.html
INFO flwr 2024-01-17 11:42:47,172 | app.py:227 | No `client_resources` specified. Using minimal resources for clients.
INFO flwr 2024-01-17 11:42:47,172 | app.py:242 | Flower VCE: Resources for each Virtual Client: {'num_cpus': 1, 'num_gpus': 0.0}
INFO flwr 2024-01-17 11:42:47,191 | app.py:288 | Flower VCE: Creating VirtualClientEngineActorPool with 16 actors
INFO flwr 2024-01-17 11:42:47,192 | server.py:89 | Initializing global parameters
INFO flwr 2024-01-17 11:42:47,193 | server.py:272 | Using initial parameters provided by strategy
INFO f



In [None]:
# fl.client.start_numpy_client(server_address="[::]:8080", client=MedmnistClient())