In [1]:
!pip install flwr==0.17.0



In [2]:
%%writefile cifar.py

from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader, TensorDataset, Dataset
from torch import Tensor
from tqdm import tqdm
import numpy as np
import random

from flwr_experimental.baseline.dataset.dataset import create_partitioned_dataset

import flwr as fl

class Net(nn.Module):
    """Simple CNN adapted from 'PyTorch: A 60 Minute Blitz'."""

    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: Tensor) -> Tensor:
        """Compute forward pass."""
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def get_weights(self) -> fl.common.Weights:
        """Get model weights as a list of NumPy ndarrays."""
        return [val.cpu().numpy() for _, val in self.state_dict().items()]

    def set_weights(self, weights: fl.common.Weights) -> None:
        """Set model weights from a list of NumPy ndarrays."""
        state_dict = OrderedDict(
            {k: torch.Tensor(v) for k, v in zip(self.state_dict().keys(), weights)}
        )
        self.load_state_dict(state_dict, strict=True)


def load_model():
    return Net()


def load_data():
    """Loads CIFAR-10 (training and test set)."""
    data_root = "/content/data/cifar-10"
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )
    trainset = torchvision.datasets.CIFAR10(root=data_root, train=True, download=True, transform=transform)
    testset = torchvision.datasets.CIFAR10(root=data_root, train=False, download=True, transform=transform)
    return trainset, testset

class PartitionedDataset(Dataset):
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y
    
    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return (self.X[idx], int(self.Y[idx]))


def load_local_partitioned_data(client_id, iid_fraction: float, num_partitions: int):
    """Creates a dataset for each worker, which is a partition of a larger dataset."""
    
    # Each worker loads the entire dataset, and then selects its partition
    # determined by its `client_id` (happens internally below)
    trainset, testset = load_data()
    [global_trainset, local_trainset] = torch.utils.data.random_split(trainset, [int(trainset.data.shape[0] * x) for x in [0.2,0.8]])
    
    train_loader = DataLoader(local_trainset, batch_size=len(local_trainset))
    test_loader = DataLoader(testset, batch_size=len(testset))
    global_train_loader = DataLoader(global_trainset, batch_size=len(global_trainset))

    (x_train, y_train), (x_test, y_test), (global_x_train, global_y_train) = next(iter(train_loader)), next(iter(test_loader)), next(iter(global_train_loader))
    x_train, y_train = x_train.numpy(), y_train.numpy()
    x_test, y_test = x_test.numpy(), y_test.numpy()
    global_x_train, global_y_train = global_x_train.numpy(), global_y_train.numpy()

    #TODO: Create my own train_partitions
    (_, test_partitions), _ = create_partitioned_dataset(
        ((x_train, y_train), (x_test, y_test)), iid_fraction, num_partitions)


    train_partitions = []
    for label in np.unique(y_train):
      result = np.where(y_train == label)[0]
      train_partitions.append([[x_train[i] for i in result],[label] * len(result)])

    global_idx = list(range(0, len(global_x_train)))
    random.shuffle(global_idx)
    global_partition = []

    ####
    beta = 0.2
    alpha = 0.5
    ####

    temp1 = []
    temp2 = []

    for g_partition in np.split(np.asarray(global_idx),10):
        temp1.append([global_x_train[i] for i in g_partition])
        temp2.append([global_y_train[i] for i in g_partition])
        

    for i, partition in enumerate(train_partitions):
        global_partition.append((temp1[i][:int(len(partition[0])*beta)],temp2[i][:int(len(partition[0])*beta)]))
        train_partitions[i][0] = np.concatenate((partition[0], np.asarray(temp1[i][:int(len(partition[0])*beta*alpha)])))
        train_partitions[i][1] = np.concatenate((partition[1], np.asarray(temp2[i][:int(len(partition[0])*beta*alpha)])))
      
    x_train, y_train = train_partitions[client_id]
    torch_partition_trainset = PartitionedDataset(torch.Tensor(x_train), y_train)
    x_test, y_test = test_partitions[client_id]
    torch_partition_testset = PartitionedDataset(torch.Tensor(x_test), y_test )
    global_x_train, global_y_train = global_partition[client_id]
    torch_partition_global_trainset = PartitionedDataset(torch.Tensor(np.asarray(global_x_train)), global_y_train)
    return torch_partition_trainset, torch_partition_testset,torch_partition_global_trainset

def get_partitionedDataset(train_partitions, test_partitions, global_partition, client_id):
      [x_train, y_train] = train_partitions[client_id]
      torch_partition_trainset = PartitionedDataset(torch.Tensor(x_train), y_train)
      x_test, y_test = test_partitions[client_id]
      torch_partition_testset = PartitionedDataset(torch.Tensor(x_test), y_test)
      global_x_train, global_y_train = global_partition[client_id]
      torch_partition_global_trainset = PartitionedDataset(torch.Tensor(global_x_train), global_y_train)
      return torch_partition_trainset, torch_partition_testset, torch_partition_global_trainset


def train(
    net: Net,
    trainloader: torch.utils.data.DataLoader,
    device: torch.device,
    start_epoch: int,
    end_epoch: int,
    log_progress: bool = True):
    """Trains a network on provided data from `start_epoch` to `end_epoch` incl. (the training loop)."""

    # Define loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

    print(f"Training from epoch(s) {start_epoch} to {end_epoch} w/ {len(trainloader)} batches each.", flush=True)
    results = []

    # Train the network
    for epoch in range(start_epoch, end_epoch+1):  # loop over the dataset multiple times, last epoch inclusive
        total_loss, total_correct, n_samples = 0.0, 0.0, 0
        pbar = tqdm(trainloader, desc=f'TRAIN Epoch {epoch}') if log_progress else trainloader
        for data in pbar:
            images, labels = data[0].to(device), data[1].to(device)
            optimizer.zero_grad()

            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # Collected training loss and accuracy statistics
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1) 
            n_samples += labels.size(0)
            total_correct += (predicted == labels).sum().item()

            if log_progress:
                pbar.set_postfix({
                    "train_loss": total_loss/n_samples, 
                    "train_acc": total_correct/n_samples
                })
            
        results.append((total_loss/n_samples, total_correct/n_samples))    
        
    return results      
    
def test(
    net: Net,
    testloader: torch.utils.data.DataLoader,
    device: torch.device,
    log_progress: bool = True):
    """Evaluates the network on test data."""
    criterion = nn.CrossEntropyLoss()
    total_loss, total_correct, n_samples = 0.0, 0.0, 0
    with torch.no_grad():
        pbar = tqdm(testloader, desc="TEST") if log_progress else testloader
        for data in pbar:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = net(images)

            # Collected testing loss and accuracy statistics
            total_loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1) 
            n_samples += labels.size(0)
            total_correct += (predicted == labels).sum().item() 
    
    return (total_loss/n_samples, total_correct/n_samples)


Overwriting cifar.py


In [1]:
%%writefile client.py
import argparse
import grpc
import timeit
import torch
import torchvision
import flwr as fl
import random

from collections import OrderedDict
from typing import Optional
from torch.utils.data import DataLoader
from flwr.common import EvaluateIns, EvaluateRes, FitIns, FitRes, ParametersRes, Weights
import time


import cifar

DEFAULT_SERVER_ADDRESS = "localhost:8099"
DEVICE = torch.device("cpu")

class CifarClient(fl.client.Client):
    """Flower client implementing CIFAR-10 image classification using PyTorch."""
    def __init__(
        self, cid,
        model: cifar.Net,
        trainset: torchvision.datasets.CIFAR10,
        testset: torchvision.datasets.CIFAR10,
        exp_name: Optional[str],
        iid_fraction: Optional[float]):
        self.cid = cid
        self.model = model
        self.trainset = trainset
        self.testset = testset
        self.exp_name = exp_name or 'federated_unspecified'
        print(f"Client {self.cid} running experiment {self.exp_name}")


    def get_parameters(self) -> ParametersRes:
        print(f"Client {self.cid}: get_parameters")

        weights: Weights = self.model.get_weights()
        parameters = fl.common.weights_to_parameters(weights)
        return ParametersRes(parameters=parameters)

    def fit(self, ins: FitIns) -> FitRes:
        print(f"Client {self.cid}: fit")
        config = ins.config
        weights: Weights = fl.common.parameters_to_weights(ins.parameters)

        # Get training config
        epochs = int(config["epochs"])
        batch_size = int(config["batch_size"])
        epoch_global = int(config["epoch_global"])

        # Set model parameters
        self.model.set_weights(weights)

        # Train the model
        trainloader = DataLoader(self.trainset, batch_size=batch_size, shuffle=True)
        
        start_epoch = epoch_global + 1
        end_epoch = start_epoch + epochs - 1

        fit_begin = timeit.default_timer()
        training_log = cifar.train(net=self.model, trainloader=trainloader, device=DEVICE, 
                                   start_epoch=start_epoch, end_epoch=end_epoch, log_progress=True)
        fit_duration = timeit.default_timer() - fit_begin

        train_loss, train_acc = training_log[-1]
        print(f'Client {self.cid}: train_loss={train_loss:.4f}, train_accuracy={train_acc:.4f}')

        # Return the refined weights and the number of examples used for training
        weights_prime: Weights = self.model.get_weights()
        params_prime = fl.common.weights_to_parameters(weights_prime)
        num_examples_train = len(self.trainset)
        return FitRes(
            parameters=params_prime,
            num_examples=num_examples_train,
            num_examples_ceil=num_examples_train,
            fit_duration=fit_duration,
        )

    def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
        print(f"Client {self.cid}: evaluate")
        config = ins.config
        epoch_global = int(config["epoch_global"])
        
        # Use provided weights to update the local model
        weights = fl.common.parameters_to_weights(ins.parameters)
        self.model.set_weights(weights)

        # Evaluate the updated model on the local dataset
        testloader = DataLoader(self.testset, batch_size=32, shuffle=False)
        test_loss, test_acc = cifar.test(net=self.model, testloader=testloader, 
                                         device=DEVICE, log_progress=False)
        print(f"Client {self.cid}: test_loss={test_loss:.4f}, test_accuracy={test_acc:.4f}")

        # Return the number of evaluation examples and the evaluation result (loss)
        return EvaluateRes(
            num_examples=len(self.testset), loss=test_loss, accuracy=test_acc
        )

def start_client(client_id, num_partitions, iid_fraction=1.0, 
                 server_address=DEFAULT_SERVER_ADDRESS, log_host=None, exp_name=None):
    # Configure logger
    fl.common.logger.configure(f"client_{client_id}", host=log_host)

    # Load model and data
    model = cifar.load_model()
    model.to(DEVICE)

    
    trainset, testset,globalset = cifar.load_local_partitioned_data(
        client_id=client_id, 
        iid_fraction=iid_fraction, 
        num_partitions=num_partitions)

    print(f"start global traning on Client {client_id}")
    globalloader = DataLoader(globalset, batch_size=32, shuffle=True)
    cifar.train(model,globalloader,DEVICE,0,10)

    testloader = DataLoader(testset, batch_size=32, shuffle=True)
    loss,acc = cifar.test(model,testloader,DEVICE)
    print(f"global traning accuracy {acc}")

    
   


    # Start client
    print(f"Starting client {client_id}")
    client = CifarClient(client_id, model, trainset, testset, 
        f'{exp_name}_iid-fraction_{iid_fraction}', iid_fraction)

    print(f"Connecting to {server_address}")

    try:
        # There's no graceful shutdown when gRPC server terminates, so we try/except
        fl.client.start_client(server_address, client)
    except grpc._channel._MultiThreadedRendezvous:
        print(f"Client {client_id}: shutdown")


def main():
    parser = argparse.ArgumentParser(description="Flower client")
    parser.add_argument("--server_address", type=str, default=DEFAULT_SERVER_ADDRESS,
        help=f"gRPC server address (default: {DEFAULT_SERVER_ADDRESS})")
    parser.add_argument("--cid", type=int, required=True, help="Client CID (no default)")
    parser.add_argument("--num_partitions", type=int, required=True, 
        help="Total number of clients participating in training")
    parser.add_argument("--iid_fraction", type=float, nargs="?", const=1.0, 
        help="Fraction of data [0,1] that is independent and identically distributed.")
    parser.add_argument("--log_host", type=str, help="Log server address")
    parser.add_argument("--exp_name", type=str, help="Friendly experiment name")

    args, _ = parser.parse_known_args()

    start_client(args.cid, 10, 0.5, DEFAULT_SERVER_ADDRESS)

if __name__=="__main__":
    main()


    

Overwriting client.py


In [2]:
%%writefile server.py

import argparse
from typing import Callable, Dict, Optional, Tuple

from logging import INFO
from flwr.common.logger import log
from flwr.server.grpc_server.grpc_server import start_insecure_grpc_server

import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter

import flwr as fl
import cifar

DEFAULT_SERVER_ADDRESS = "localhost:8099"

DEVICE = torch.device("cpu")


def start_server(exp_name=None, 
                 server_address=DEFAULT_SERVER_ADDRESS, 
                 rounds=1,
                 epochs=10,
                 batch_size=32,
                 sample_fraction=1.0,
                 min_sample_size=2,
                 min_num_clients=2,
                 log_host=None):

    if not exp_name:
        exp_name = f"federated_rounds_{rounds}_" \
                   f"epochs_{epochs}_" \
                   f"min_num_clients_{min_num_clients}_" \
                   f"min_sample_size_{min_sample_size}_" \
                   f"sample_fraction_{sample_fraction}"

    # Configure logger
    fl.common.logger.configure("server", host=log_host)

    # Load evaluation data
    _, testset = cifar.load_data()
    
    # Create client_manager, strategy, and server
    client_manager = fl.server.SimpleClientManager()

    strategy = fl.server.strategy.FedAvg(
        fraction_fit=sample_fraction,
        min_fit_clients=min_sample_size,
        min_eval_clients=min_sample_size,
        min_available_clients=min_num_clients,
        eval_fn=get_eval_fn(testset),
        on_fit_config_fn=generate_config(epochs, batch_size),
        on_evaluate_config_fn=generate_config(epochs, batch_size)
    )
    server = fl.server.Server(client_manager=client_manager, strategy=strategy)

    # Run server 
    print(f"Starting gRPC server on {server_address}...")
    grpc_server = start_insecure_grpc_server(
        client_manager=server.client_manager(),
        server_address=server_address,
        max_message_length=fl.common.GRPC_MAX_MESSAGE_LENGTH,
    )
    
    # Fit model
    print("Fitting the model...")
    hist = server.fit(num_rounds=rounds)
 
    log(INFO, f"app_fit: losses_centralized={hist.losses_centralized}")
    log(INFO, f"app_fit: accuracies_centralized={hist.metrics_centralized['accuracy']}")

    # Evaluate the final accuracy on the server
    test_loss, test_metrics = server.strategy.evaluate(parameters=server.parameters)
    print(f"Server-side test results after training: test_loss={test_loss:.4f}, "
          f"test_accuracy={test_metrics['accuracy']:.4f}")

    # Now, apply temporary workaround to force distributed evaluation
    server.strategy.eval_fn = None

    # Evaluate the final trained model
    res = server.evaluate_round(rnd=-1)
    if res is not None:
        loss_aggregated, metrics_aggregated, (results, failures) = res
        log(INFO, f"app_evaluate: federated loss: {loss_aggregated}")
        log(INFO, f"app_evaluate: metrics: {metrics_aggregated}")
        log(INFO, f"app_evaluate: results {[(res[0].cid, res[1]) for res in results]}")
        log(INFO, f"app_evaluate: failures {failures}")
    else:
        log(INFO, f"app_evaluate: no evaluation result")

    # Stop the gRPC server
    grpc_server.stop(None)    


def generate_config(epochs, batch_size):
    def fit_config(round: int) -> Dict[str, str]:
        print(f"Configuring round {round}...")
        return {
            "epoch_global": str((round - 1) * epochs),
            "epochs": str(epochs),
            "batch_size": str(batch_size),
        }
    
    return fit_config 


def get_eval_fn(testset: torchvision.datasets.CIFAR10):
    """Returns an evaluation function for centralized (server-side) evaluation."""
    def evaluate(weights: fl.common.Weights):
        """Use the entire CIFAR-10 test set for evaluation."""
        model = cifar.load_model()
        model.set_weights(weights)
        model.to(DEVICE)
        testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)
        loss, accuracy = cifar.test(net=model, testloader=testloader, device=DEVICE, log_progress=False)
        return loss, {"accuracy": accuracy}

    return evaluate


def main():
    parser = argparse.ArgumentParser(description="Flower server")
    parser.add_argument("--server_address", type=str, default=DEFAULT_SERVER_ADDRESS,
        help=f"gRPC server address (default: {DEFAULT_SERVER_ADDRESS})")
    parser.add_argument("--rounds", type=int, default=1,
        help="Number of rounds of federated learning (default: 1)")
    parser.add_argument("--sample_fraction", type=float, default=1.0,
        help="Fraction of available clients used for fit/evaluate (default: 1.0)")
    parser.add_argument("--min_sample_size", type=int, default=2,
        help="Minimum number of clients used for fit/evaluate (default: 2)")
    parser.add_argument("--min_num_clients", type=int, default=2,
        help="Minimum number of available clients needed for sampling (default: 2)")
    parser.add_argument("--log_host", type=str, help="Log server address (no default)")
    parser.add_argument("--epochs", type=int, default=10,
        help="Number of epochs each client will train for (default: 10)")
    parser.add_argument("--batch_size", type=int, default=32,
        help="Number of samples per batch each client will use (default: 32)")   
    parser.add_argument("--exp_name", type=str,
        help="Name of the experiment you are running (no default)")
    args, _ = parser.parse_known_args()
    
    start_server(exp_name=args.exp_name,
                 server_address=args.server_address,
                 rounds=args.rounds,
                 epochs=args.epochs,
                 batch_size=args.batch_size,
                 sample_fraction=args.sample_fraction,
                 min_sample_size=args.min_sample_size,
                 min_num_clients=args.min_num_clients,
                 log_host=args.log_host)

if __name__ == "__main__":
    main()

Overwriting server.py


In [3]:
%%writefile server.sh
PYTHONUNBUFFERED=1 python3 server.py \
  --rounds=1 \
  --epochs=10 \
  --sample_fraction=1 \
  --min_sample_size=5 \
  --min_num_clients=5 \
  --server_address="localhost:8099"

Overwriting server.sh


In [4]:
%%writefile clients.sh
export PYTHONUNBUFFERED=1
NUM_CLIENTS=10 # TODO: change the number of clients here


echo "Starting $NUM_CLIENTS clients."
for ((i = 0; i < $NUM_CLIENTS; i++))
do
    echo "Starting client(cid=$i) with partition $i out of $NUM_CLIENTS clients."
    # Staggered loading of clients: clients are loaded 8s apart.
    # At the start, each client loads the entire CIFAR-10 dataset before selecting
    # their own partition. For a large number of clients this causes a memory usage
    # spike that can cause client processes to get terminated. 
    # Staggered loading prevents this.
    sleep 8s  
    python3 client.py \
      --cid=$i \
      --num_partitions=${NUM_CLIENTS} \
      --iid_fraction=0.5 \
      --server_address="localhost:8099" \
      --exp_name="federated_${NUM_CLIENTS}_clients" &
done
echo "Started $NUM_CLIENTS clients."


Overwriting clients.sh


In [5]:
!chmod +x clients.sh server.sh

In [6]:
%killbgscripts
!((./server.sh & sleep 5s); ./clients.sh)

All background processes were killed.
Files already downloaded and verified
Files already downloaded and verified
Starting gRPC server on localhost:8099...
Fitting the model...
INFO flower 2021-12-25 20:16:50,390 | server.py:118 | Initializing global parameters
INFO flower 2021-12-25 20:16:50,390 | server.py:304 | Requesting initial parameters from one random client
Starting 10 clients.
Starting client(cid=0) with partition 0 out of 10 clients.
Starting client(cid=1) with partition 1 out of 10 clients.
Files already downloaded and verified
Files already downloaded and verified
Starting client(cid=2) with partition 2 out of 10 clients.
Files already downloaded and verified
Files already downloaded and verified
Starting client(cid=3) with partition 3 out of 10 clients.
Files already downloaded and verified
Files already downloaded and verified
Starting client(cid=4) with partition 4 out of 10 clients.
Files already downloaded and verified
Starting client(cid=5) with partition 5 out of 10

In [None]:
Server-side test results after training: test_loss=0.0734, test_accuracy=0.1000
Server-side test results after training: test_loss=0.0730, test_accuracy=0.1527