Se carga la base de datos

Se trabaja con el módulo auxiliar process_data, el cual define una carga personalizada para estos ejemplos de ataques a modelos fedrados. Para más detalles sobre como efectuar la carga de datos con Flex ir a la documentación correspondiente.

Este módulo permite la carga de dataset de procesamiento de imágenes como: Mnist, Fmnist, Cifar10 y Cifar100. Además del dataset tabular nursery.

In [None]:
from process_data import *
from copy import deepcopy

flex_data, server_id = load_and_preprocess_horizontal(dataname="mnist", trasnform=False, nodes=5)

A continuación, se define la arquitectura de los modelos locales de los clientes. Para el presente ejemplo se trabaja con modelos neuronales de pytorch.

Se utiliza el módulo networks_models, quien contiene una serie de modelos neuronales auxiliares de pytorch, para el trabajo con las bases de datos anteriormente mencionadas. Además se utiliza el módulo auxiliar networks_execution, que define la ejecución del entrenamiento y otros detalles de estos modelos.

Para establecer un modelo personalizado, ir a la documentación de Flex.

In [None]:
from networks_models import *
from networks_execution import *
from flex.pool import init_server_model
from flex.model import FlexModel

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

net_config = ExecutionNetwork()

@init_server_model
def build_server_model():
    server_flex_model = FlexModel()

    criterion, model, optimizer = net_config.for_fd_server_model_config()

    server_flex_model["model"] = model.to(device)
    # Required to store this for later stages of the FL training process
    server_flex_model["criterion"] = criterion
    server_flex_model["optimizer_func"] = optimizer
    server_flex_model["optimizer_kwargs"] = {}

    return server_flex_model

Se define la arquitectura del modelo federado

In [None]:
from flex.pool import FlexPool
clients = 2

pool = FlexPool.client_server_pool(
        fed_dataset= flex_data, server_id=server_id, init_func = build_server_model
    )

selected_test_clients_pool = pool.clients.select(clients)
selected_test_clients = selected_test_clients_pool.clients

Se define la función para desplegar el modelo global en cada cliente

In [None]:
from flex.pool.decorators import (  # noqa: E402
    deploy_server_model,
)
from my_models_attacks.free_riding import free_riding_atck as fr

rounds = 0
round_global_model = None

attack = fr(std_0 = 0, power = 1, decay = 1, multiplicator = 1)

@deploy_server_model
def deploy_serv(server_flex_model: FlexModel): 

    new_model = deepcopy(server_flex_model)
    global round_global_model
    if rounds == 0:
        attack.set_model(deepcopy(server_flex_model["model"]))

    return new_model

pool.servers.map(deploy_serv, selected_test_clients)

Se seleccionan los clientes free riding y los clientes normales

In [None]:
from test import util_for_fr, util_for_fr_join_all

list_of_fr = [0]

free_riding_clients = pool.clients.select(lambda actor_id, set_of_roles: actor_id in list_of_fr)
free_riding_clients = free_riding_clients.clients
normal_clients = pool.clients.select(lambda actor_id, set_of_roles: actor_id not in list_of_fr)
normal_clients = normal_clients.clients

for i in free_riding_clients.actor_ids:
    print("fr",i)

for i in normal_clients.actor_ids:
    print("Normal",i)
#normal_clients, free_riding_clients = util_for_fr(list_of_fr, selected_test_clients)

Se define el entrenamiento de un cliente free rider, como método de ataque. Para ello se utiliza el decorador model_poisoner, que define una directriz para los ataques de envenenamiento de modelos. Además se contabiliza las rondas de entrenamiento como parte del ataque.

In [None]:
from flexclash.model import model_poisoner

@model_poisoner
def fr_attack(client_model: FlexModel):
    print("Execute_fr_attack")
    prev = deepcopy(client_model["model"])

    previous_model_act = round_global_model

    fr_model_client = attack.execute_attack(type_noise = "disguised", model = client_model["model"], f_round= rounds)
    
    print("Suma de los paramtros de modelo fr por state dict:", sum(param.sum() for param in fr_model_client.state_dict().values()))
    print("Suma de los paramtros de modelo fr:", sum(p.sum() for p in fr_model_client.parameters()))

    print("Suma de los paramtros de modelo global fr por state dict:", sum(param.sum() for param in prev.state_dict().values()))
    print("Suma de los paramtros de modelo global fr:", sum(p.sum() for p in prev.parameters()))

    client_model["model"] = fr_model_client
    client_model["previous_model"] = prev
    try:
        print("Suma de los paramtros de modelo global previo fr por state dict:", sum(param.sum() for param in attack.first_server_model.state_dict().values()))
        print("Suma de los paramtros de modelo global previo fr:", sum(p.sum() for p in attack.first_server_model.parameters()))
    except:
        print("nulo")

    return client_model

#free_riding_clients.map(fr_attack)

Se define la ronda de entrenamiento local de un cliente, empleando el módulo networks_execution

In [None]:
def train(client_flex_model: FlexModel, client_data: Dataset):


    train_dataset = client_data.to_torchvision_dataset(transform = mnist_transform())
    client_dataloader = DataLoader(train_dataset, batch_size = 256)

    model = client_flex_model["model"]
    model = model.to(device)

    client_flex_model["previous_model"] = deepcopy(model)
    optimizer = client_flex_model["optimizer_func"]
    criterion = client_flex_model["criterion"]

    net_config.trainNetwork(local_epochs = 1, criterion = criterion, optimizer = optimizer,
                            momentum = 0.9, lr = 0.005, trainloader = client_dataloader, testloader= None, 
                            model=model)
    
    return client_flex_model

for i in normal_clients.actor_ids:
    print(i)

#normal_clients.map(train)

Se efectúa la agregación del modelo federado. En este caso debe considerarse el resultado del entrenamiento tanto de los clientes free riding como los normales, para ello se empleará un método que unifique ambos clientes en un solo objeto flexPool

In [None]:
from flex.pool import collect_client_diff_weights_pt
from flex.pool import fed_avg
from flex.pool import set_aggregated_diff_weights_pt

#pool.aggregators.map(collect_client_diff_weights_pt, selected_test_clients)
#pool.aggregators.map(fed_avg)
#pool.aggregators.map(set_aggregated_diff_weights_pt, pool.servers)

Se evalúa el modelo federado

In [None]:
def evaluate_global_model(server_flex_model: FlexModel, test_data: Dataset):#falta poner esto
    model = server_flex_model["model"]
    model.eval()
    test_loss = 0
    test_acc = 0
    total_count = 0
    model = model.to(device)
    criterion = server_flex_model["criterion"]
    # get test data as a torchvision object
    test_dataset = test_data.to_torchvision_dataset(transform = mnist_transform())
    test_dataloader = DataLoader(
        test_dataset, batch_size=256, shuffle=True, pin_memory=False
    )
    losses = []
    with torch.no_grad():
        for data, target in tqdm(test_dataloader):
            total_count += target.size(0)
            data, target = data.to(device), target.to(device)
            output = model(data)
            losses.append(criterion(output, target).item())
            pred = output.data.max(1, keepdim=True)[1]
            test_acc += pred.eq(target.data.view_as(pred)).long().cpu().sum().item()

    test_loss = sum(losses) / len(losses)
    test_acc /= total_count
    return test_loss, test_acc

#metrics = pool.servers.map(evaluate_global_model)

#loss, acc = metrics[0]
#print(f"Server: Test acc: {acc:.4f}, test loss: {loss:.4f}")

Para limpiar los modelos en memoria. Opcional

In [None]:
def clean_up_models(client_model: FlexModel, _):
    import gc

    client_model.clear()
    gc.collect()

Se definen las rondas de entrenamiento del modelo federado

In [None]:
def train_n_rounds(n_rounds, clients_per_round=20):
    pool = FlexPool.client_server_pool(
        fed_dataset= flex_data, server_id=server_id, init_func=build_server_model
    )
    global rounds
    global round_global_model
    print(rounds)
    print(type(round_global_model))
    for i in range(n_rounds):
        print(f"\nRunning round: {i+1} of {n_rounds}")
        selected_clients_pool = pool.clients.select(clients_per_round)
        selected_clients = selected_clients_pool.clients
        pool.servers.map(deploy_serv, selected_clients)

        free_riding_clientss = selected_clients_pool.select(lambda actor_id, set_of_roles: actor_id in list_of_fr)
        free_riding_clientss = free_riding_clientss.clients
        normal_clientss = selected_clients_pool.select(lambda actor_id, set_of_roles: actor_id not in list_of_fr)
        normal_clientss = normal_clientss.clients

        free_riding_clientss.map(fr_attack)
        normal_clientss.map(train)
        
        pool.aggregators.map(collect_client_diff_weights_pt, selected_clients)
        pool.aggregators.map(fed_avg)
        pool.aggregators.map(set_aggregated_diff_weights_pt, pool.servers)
        metrics = pool.servers.map(evaluate_global_model)
        selected_clients.map(clean_up_models)
        loss, acc = metrics[0]
        rounds+=1
        print(f"Server: Test acc: {acc:.4f}, test loss: {loss:.4f}")

In [None]:
train_n_rounds(5, clients_per_round=3)