Se carga la base de datos

Se trabaja con el módulo auxiliar process_data, el cual define una carga personalizada para estos ejemplos de ataques a modelos fedrados. Para más detalles sobre como efectuar la carga de datos con Flex ir a la documentación correspondiente.

Este módulo permite la carga de dataset de procesamiento de imágenes como: Mnist, Fmnist, Cifar10 y Cifar100. Además del dataset tabular nursery.

In [2]:
from process_data import *
from copy import deepcopy

flex_data, server_id = load_and_preprocess_horizontal(dataname="cifar10", trasnform=False, nodes=50)

Files already downloaded and verified
Files already downloaded and verified


A continuación, se define la arquitectura de los modelos locales de los clientes. Para el presente ejemplo se trabaja con modelos neuronales de pytorch.

Se utiliza el módulo networks_models, quien contiene una serie de modelos neuronales auxiliares de pytorch, para el trabajo con las bases de datos anteriormente mencionadas. Además se utiliza el módulo auxiliar networks_execution, que define la ejecución del entrenamiento y otros detalles de estos modelos.

Para establecer un modelo personalizado, ir a la documentación de Flex.

In [3]:
from networks_models import *
from networks_execution import *
from flex.pool import init_server_model
from flex.model import FlexModel

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

net_config = ExecutionNetwork(dataname = "cifar10")

@init_server_model
def build_server_model():
    server_flex_model = FlexModel()

    criterion, model, optimizer = net_config.for_fd_server_model_config()

    server_flex_model["model"] = model.to(device)
    # Required to store this for later stages of the FL training process
    server_flex_model["criterion"] = criterion
    server_flex_model["optimizer_func"] = optimizer
    server_flex_model["optimizer_kwargs"] = {}

    return server_flex_model

Se define la arquitectura del modelo federado

In [4]:
from flex.pool import FlexPool
clients = 1

pool = FlexPool.client_server_pool(
        fed_dataset= flex_data, server_id=server_id, init_func = build_server_model
    )

selected_test_clients_pool = pool.clients.select(clients)
selected_test_clients = selected_test_clients_pool.clients

Se define la función para desplegar el modelo global en cada cliente

In [5]:
from flex.pool.decorators import (  # noqa: E402
    deploy_server_model,
)

@deploy_server_model
def deploy_serv(server_flex_model: FlexModel): 

    new_model = deepcopy(server_flex_model)

    return new_model

pool.servers.map(deploy_serv, selected_test_clients)

Se define la ronda de entrenamiento local de un cliente, empleando el módulo networks_execution

In [6]:
def train(client_flex_model: FlexModel, client_data: Dataset):

    train_dataset = client_data.to_torchvision_dataset(transform = cifar10_transform())
    client_dataloader = DataLoader(train_dataset, batch_size = 256)

    model = client_flex_model["model"]
    model = model.to(device)

    client_flex_model["previous_model"] = deepcopy(model)
    optimizer = client_flex_model["optimizer_func"]
    criterion = client_flex_model["criterion"]

    net_config.trainNetwork(local_epochs = 1, criterion = criterion, optimizer = optimizer,
                            momentum = 0.9, lr = 0.005, trainloader = client_dataloader, testloader= None, 
                            model=model)
    
    return client_flex_model

selected_test_clients.map(train)

100%|██████████| 4/4 [01:40<00:00, 25.07s/it]


[{'model': VGG(
   (features): Sequential(
     (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (2): ReLU(inplace=True)
     (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
     (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (6): ReLU(inplace=True)
     (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
     (8): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (10): ReLU(inplace=True)
     (11): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (12): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (13): ReLU(inplace=Tr

Se efectúa la agregación del modelo federado

In [7]:
from flex.pool import collect_client_diff_weights_pt
from flex.pool import fed_avg
from flex.pool import set_aggregated_diff_weights_pt


pool.aggregators.map(collect_client_diff_weights_pt, selected_test_clients)
pool.aggregators.map(fed_avg)
pool.aggregators.map(set_aggregated_diff_weights_pt, pool.servers)

Después de la modif antes de fedavg tensor(0.)


Se evalúa el modelo federado

In [8]:
def evaluate_global_model(server_flex_model: FlexModel, test_data: Dataset):#falta poner esto
    model = server_flex_model["model"]
    model.eval()
    test_loss = 0
    test_acc = 0
    total_count = 0
    model = model.to(device)
    criterion = server_flex_model["criterion"]
    # get test data as a torchvision object
    test_dataset = test_data.to_torchvision_dataset(transform = cifar10_transform())
    test_dataloader = DataLoader(
        test_dataset, batch_size=256, shuffle=True, pin_memory=False
    )
    losses = []
    with torch.no_grad():
        for data, target in tqdm(test_dataloader):
            total_count += target.size(0)
            data, target = data.to(device), target.to(device)
            output = model(data)
            losses.append(criterion(output, target).item())
            pred = output.data.max(1, keepdim=True)[1]
            test_acc += pred.eq(target.data.view_as(pred)).long().cpu().sum().item()

    test_loss = sum(losses) / len(losses)
    test_acc /= total_count
    return test_loss, test_acc


metrics = pool.servers.map(evaluate_global_model)
loss, acc = metrics[0]
print(f"Server: Test acc: {acc:.4f}, test loss: {loss:.4f}")

100%|██████████| 40/40 [06:12<00:00,  9.30s/it]


Server: Test acc: 0.7269, test loss: 0.7989


Para limpiar los modelos en memoria. Opcional

In [9]:
def clean_up_models(client_model: FlexModel, _):
    import gc

    client_model.clear()
    gc.collect()

selected_test_clients.map(clean_up_models)

Se definen las rondas de entrenamiento del modelo federado

In [None]:
def train_n_rounds(n_rounds, clients_per_round=20):
    pool = FlexPool.client_server_pool(
        fed_dataset= flex_data, server_id=server_id, init_func=build_server_model
    )
    for i in range(n_rounds):
        print(f"\nRunning round: {i+1} of {n_rounds}")
        selected_clients_pool = pool.clients.select(clients_per_round)
        selected_clients = selected_clients_pool.clients
        pool.servers.map(deploy_serv, selected_clients)
        selected_clients.map(train)
        pool.aggregators.map(collect_client_diff_weights_pt, selected_clients)
        pool.aggregators.map(fed_avg)
        pool.aggregators.map(set_aggregated_diff_weights_pt, pool.servers)
        metrics = pool.servers.map(evaluate_global_model)
        selected_clients.map(clean_up_models)
        loss, acc = metrics[0]
        print(f"Server: Test acc: {acc:.4f}, test loss: {loss:.4f}")

In [8]:
train_n_rounds(2, clients_per_round=2)


Running round: 1 of 2


100%|██████████| 40/40 [03:58<00:00,  5.96s/it]
100%|██████████| 40/40 [02:48<00:00,  4.22s/it]


Después de la modif antes de fedavg tensor(0.)


100%|██████████| 40/40 [00:48<00:00,  1.22s/it]


Server: Test acc: 0.6957, test loss: 0.8942

Running round: 2 of 2


100%|██████████| 40/40 [02:48<00:00,  4.21s/it]
100%|██████████| 40/40 [02:44<00:00,  4.11s/it]


Después de la modif antes de fedavg tensor(0.)


100%|██████████| 40/40 [00:49<00:00,  1.23s/it]


Server: Test acc: 0.1000, test loss: nan
