# Selección de clientes con Active Federated Learning - MNIST

En este notebook vamos a utilizar nuestro modelo de entrenamiento en Aprendizaje Federado utilizando AFL como algoritmo de selección de clientes [1] sobre el dataset de MNIST, que es un conjunto de imágenes de dígitos manuscritos del 0 al 9 [2]. Se trata de un problema de visión por computador al que usaremos para comparar el rendimiento del método de selección AFL frente al convencional.

> [1] https://arxiv.org/abs/1909.12641.
>
> [2] http://yann.lecun.com/exdb/mnist.

In [8]:
# install FLEXible framework if not installed
try:
    import flex
    print("FLEXible is installed.")
except:
    print("FLEXible is not installed.\nInstalling dependency flexible-fl...")
    !pip install flexible-fl

FLEXible is installed.


In [10]:
import torch

# select device
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)
device

'cpu'

In [11]:
from flex.datasets import load
from torchvision import transforms

flex_dataset, test_data = load("federated_emnist", return_test=True, split="digits")

# Assign test data to server_id
server_id = "server"
flex_dataset[server_id] = test_data

mnist_transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

Downloading...
From (original): https://drive.google.com/uc?id=1fl9fRPPxTUxnC56ACzZ8JiLiew0SMFwt
From (redirected): https://drive.google.com/uc?id=1fl9fRPPxTUxnC56ACzZ8JiLiew0SMFwt&confirm=t&uuid=93c643af-f5bb-4be3-8813-87bf80cfcd91
To: /content/emnist-digits.mat
100%|██████████| 90.7M/90.7M [00:01<00:00, 60.5MB/s]
[36m[sultan]: md5 -q ./emnist-digits.mat;[0m
DEBUG:sultan:md5 -q ./emnist-digits.mat;
[01;31m[sultan]: Unable to run 'md5 -q ./emnist-digits.mat;'[0m
CRITICAL:sultan:Unable to run 'md5 -q ./emnist-digits.mat;'
[01;31m[sultan]: --{ TRACEBACK }----------------------------------------------------------------------------------------------------[0m
CRITICAL:sultan:--{ TRACEBACK }----------------------------------------------------------------------------------------------------
[01;31m[sultan]: | NoneType: None[0m
CRITICAL:sultan:| NoneType: None
[01;31m[sultan]: | [0m
CRITICAL:sultan:| 
[01;31m[sultan]: -----------------------------------------------------------------

Usamos el decorador `@init_model_server` para inicializar el modelo en el servidor. Aprovechamos esta fase para simplemente definir nuestra arquitectura de red así como el optimizador y la función de pérdida.

In [12]:
import torch.nn as nn
import torch.nn.functional as F

from flex.pool import init_server_model
from flex.pool import FlexPool
from flex.model import FlexModel


# Simple two Fully-Connected layer net
class SimpleNet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


@init_server_model
def build_server_model():
    server_flex_model = FlexModel()

    server_flex_model["model"] = SimpleNet()
    # Required to store this for later stages of the FL training process
    server_flex_model["criterion"] = torch.nn.CrossEntropyLoss()
    server_flex_model["optimizer_func"] = torch.optim.Adam
    server_flex_model["optimizer_kwargs"] = {}
    return server_flex_model


flex_pool = FlexPool.client_server_pool(
    flex_dataset, server_id=server_id, init_func=build_server_model
)

clients = flex_pool.clients
servers = flex_pool.servers
aggregators = flex_pool.aggregators

print(
    f"Number of nodes in the pool {len(flex_pool)}: {len(servers)} server plus {len(clients)} clients. The server is also an aggregator"
)

Number of nodes in the pool 3580: 1 server plus 3579 clients. The server is also an aggregator


Para comparar nuestro modelo con selección AFL, vamos a primero construir un método de selección de clientes aleatorio uniforme, por ejemplo, de $20$ clientes por ronda. Más adelante, implementaremos la selección con AFL con $m=20$ clientes por ronda.

In [13]:
# Select clients
clients_per_round = 20
selected_clients_pool = clients.select(clients_per_round)
selected_clients = selected_clients_pool.clients

print(f'Server node is indentified by key "{servers.actor_ids[0]}"')
print(
    f"Selected {len(selected_clients.actor_ids)} client nodes of a total of {len(clients.actor_ids)}"
)

Server node is indentified by key "server"
Selected 20 client nodes of a total of 3579


Utilizamos el decorador `@deploy_server_model` para distribuir el modelo del servidor a los clientes. Usamos `map` sobre los clientes seleccionados: `selected_clients`.

In [14]:
from flex.pool import deploy_server_model
import copy


@deploy_server_model
def copy_server_model_to_clients(server_flex_model: FlexModel):
    return copy.deepcopy(server_flex_model)


servers.map(copy_server_model_to_clients, selected_clients)

Implementamos la función de training para el que cada cliente avanza un número $n$ de pasos del optimizador (En este caso SGD-Adam [1]) sobre un _batch_ de $b$ imágenes. En nuestro caso, utilizaremos los mismo parámetros que en el ejemplo de FLEX [2]: $n=5, b=20$.

> [1] https://arxiv.org/abs/1412.6980.
>
> [2] [Federated MNIST PT Example](https://github.com/FLEXible-FL/FLEXible/blob/main/notebooks/Federated%20MNIST%20PT%20example%20with%20flexible%20decorators.ipynb)

In [15]:
from flex.data import Dataset
from torch.utils.data import DataLoader


def train(client_flex_model: FlexModel, client_data: Dataset):
    train_dataset = client_data.to_torchvision_dataset(transform=mnist_transforms)
    client_dataloader = DataLoader(train_dataset, batch_size=20)
    model = client_flex_model["model"]
    optimizer = client_flex_model["optimizer_func"](
        model.parameters(), **client_flex_model["optimizer_kwargs"]
    )
    model = model.train()
    model = model.to(device)
    criterion = client_flex_model["criterion"]
    for _ in range(5):
        for imgs, labels in client_dataloader:
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            pred = model(imgs)
            loss = criterion(pred, labels)
            loss.backward()
            optimizer.step()


selected_clients.map(train)

Con el decorador `@collect_clients_weights` recuperamos los pesos de PyTorch de cada cliente seleccionado para esa ronda. En el caso de PyTorch, el modelo devuelve los pesos en forma de un diccionario con `state_dict` para el que cada nombre representa una capa de la red y sus parámetros, lo que hacemos será devolver una lista con los valores de ese diccionario correspondientes a los pesos de la red entera.

In [16]:
from flex.pool import collect_clients_weights


@collect_clients_weights
def get_clients_weights(client_flex_model: FlexModel):
    weight_dict = client_flex_model["model"].state_dict()
    return [weight_dict[name] for name in weight_dict]


aggregators.map(get_clients_weights, selected_clients)

Utilizamos el decorador `@aggregate_weights` para poder agregar los pesos que hemos recuperado de los clientes en la fase anterior computando la media de los pesos, conocido como agregador FedAvg, donde realizamos la media por columnas para cada capa de pesos.

In [17]:
from flex.pool import aggregate_weights
import tensorly as tl

tl.set_backend("pytorch")


@aggregate_weights
def aggregate_with_fedavg(list_of_weights: list):
    agg_weights = []
    for layer_index in range(len(list_of_weights[0])):
        weights_per_layer = [weights[layer_index] for weights in list_of_weights]
        weights_per_layer = tl.stack(weights_per_layer)
        agg_layer = tl.mean(weights_per_layer, axis=0)
        agg_weights.append(agg_layer)
    return agg_weights


# Aggregate weights
aggregators.map(aggregate_with_fedavg)

Finalmente, agregamos los pesos al modelo de nuestro servidor/agregador. Sencillamente, para cada capa de nuestro modelo, realizamo una copia del nuevo que hemos agregado en la fase anterior.

In [18]:
from flex.pool import set_aggregated_weights


@set_aggregated_weights
def set_agreggated_weights_to_server(server_flex_model: FlexModel, aggregated_weights):
    with torch.no_grad():
        weight_dict = server_flex_model["model"].state_dict()
        for layer_key, new in zip(weight_dict, aggregated_weights):
            weight_dict[layer_key].copy_(new)


aggregators.map(set_agreggated_weights_to_server, servers)

Podemos evaluar el modelo del servidor sobre el dataset de test que hemos definido anteriormente que residía en el mismo servidor. Para ello, definimos una función `evaluate_global_model` que obtenga las predicciones del modelo con el dataset de test y devuelva las metricas resultantes, que en este caso son simplemente la pérdida y la _accuracy_.

In [21]:
def evaluate_global_model(server_flex_model: FlexModel, test_data: Dataset):
    model = server_flex_model["model"]
    model.eval()
    test_loss = 0
    test_acc = 0
    total_count = 0
    model = model.to(device)
    criterion = server_flex_model["criterion"]
    # get test data as a torchvision object
    test_dataset = test_data.to_torchvision_dataset(transform=mnist_transforms)
    test_dataloader = DataLoader(
        test_dataset, batch_size=256, shuffle=True, pin_memory=False
    )
    losses = []
    with torch.no_grad():
        for data, target in test_dataloader:
            total_count += target.size(0)
            data, target = data.to(device), target.to(device)
            output = model(data)
            losses.append(criterion(output, target).item())
            pred = output.data.max(1, keepdim=True)[1]
            test_acc += pred.eq(target.data.view_as(pred)).long().cpu().sum().item()

    test_loss = sum(losses) / len(losses)
    test_acc /= total_count
    return test_loss, test_acc


metrics = servers.map(evaluate_global_model)
print("Loss (test):", metrics[0][0])
print("Accuracy (test):", metrics[0][1])

Loss (test): 1.5552050596589495
Accuracy (test): 0.51695
