In [1]:
# from https://flower.ai/docs/framework/tutorial-quickstart-pytorch.html

from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, Subset

import flwr as fl

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

import logging

logger = logging.getLogger('flwr')

  from .autonotebook import tqdm as notebook_tqdm
2024-08-02 00:18:30,994	INFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [2]:
import time


NUM_CLIENT = 5
CLIENT_INDEX = 0

def load_data():
    """Load CIFAR-10 (training and test set)."""
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )

    # Load the entire CIFAR-10 dataset
    full_trainset = CIFAR10("../data/cifar10-torchvision", train=True, download=True, transform=transform)
    full_testset = CIFAR10("../data/cifar10-torchvision", train=False, download=True, transform=transform)

    def get_client_data(dataset, client_index, batch_size=32, shuffle=True):
        if 0 <= client_index < NUM_CLIENT:
            indices = list(range(len(dataset)))
            client_size = len(dataset) // NUM_CLIENT
            start_idx = client_index * client_size
            end_idx = start_idx + client_size
            client_indices = indices[start_idx:end_idx]
            client = Subset(dataset, client_indices)
            return DataLoader(client, batch_size=batch_size, shuffle=shuffle), len(client)
        else:
            logger.error(f"Client index {client_index} is out of bounds. It should be between 0 and {NUM_CLIENT - 1}.")
            raise ValueError(f"Client index {client_index} is out of bounds. It should be between 0 and {NUM_CLIENT - 1}.")

    trainloader, len_trainset = get_client_data(full_trainset, client_index=CLIENT_INDEX)
    testloader, len_testset = get_client_data(full_testset, client_index=CLIENT_INDEX)
    num_examples = {"trainset" : len_trainset, "testset" : len_testset}
    
    return trainloader, testloader, num_examples
    

def train(net, trainloader, epochs):
    start_time = time.time()
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    for _ in range(epochs):
        for images, labels in trainloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            loss = criterion(net(images), labels)
            loss.backward()
            optimizer.step()
    end_time = time.time()
    total_duration = end_time - start_time
    logger.info(f"Training completed in {total_duration:.2f} seconds")
            
def test(net, testloader):
    start_time = time.time()
    """Validate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(DEVICE), data[1].to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    
    end_time = time.time()
    total_duration = end_time - start_time
    logger.info(f"Testing completed in {total_duration:.2f} seconds, accuracy: {accuracy:.2f}")
    return loss, accuracy

In [3]:
# ‘PyTorch: A 60 Minute Blitz’:
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Load model and data
net = Net().to(DEVICE)
trainloader, testloader, num_examples = load_data()

Files already downloaded and verified
Files already downloaded and verified


In [4]:
class CifarClient(fl.client.NumPyClient):
    def get_parameters(self, config):
        return [val.cpu().numpy() for _, val in net.state_dict().items()]

    def set_parameters(self, parameters):
        params_dict = zip(net.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        net.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        train(net, trainloader, epochs=1)
        return self.get_parameters(config={}), num_examples["trainset"], {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        loss, accuracy = test(net, testloader)
        return float(loss), num_examples["testset"], {"accuracy": float(accuracy)}

In [5]:
fl.client.start_client(server_address="127.0.0.1:8000", client=CifarClient().to_client())

[92mINFO [0m:      
[92mINFO [0m:      Received: train message 1c81f63e-0649-4d0c-a38b-5a40ce97a800
[92mINFO [0m:      Training completed in 4.12 seconds
[92mINFO [0m:      Sent reply
[92mINFO [0m:      
[92mINFO [0m:      Received: evaluate message 84a6f454-6e85-42e3-b1d4-a8ba53f22750
[92mINFO [0m:      Testing completed in 0.60 seconds, accuracy: 0.35
[92mINFO [0m:      Sent reply
[92mINFO [0m:      
[92mINFO [0m:      Received: train message 3939b51d-2161-40d0-bb23-17f286399dce
[92mINFO [0m:      Training completed in 4.06 seconds
[92mINFO [0m:      Sent reply
[92mINFO [0m:      
[92mINFO [0m:      Received: evaluate message c32d11a4-d156-401b-8146-6e056e67214f
[92mINFO [0m:      Testing completed in 0.64 seconds, accuracy: 0.36
[92mINFO [0m:      Sent reply
[92mINFO [0m:      
[92mINFO [0m:      Received: train message 6062d3ac-9f9a-4bdb-811f-3484ca5f90db
[92mINFO [0m:      Training completed in 4.00 seconds
[92mINFO [0m:      Sent reply
[92

: 