Se carga la base de datos

Se trabaja con el módulo auxiliar process_data, el cual define una carga personalizada para estos ejemplos de ataques a modelos fedrados. Para más detalles sobre como efectuar la carga de datos con Flex ir a la documentación correspondiente.

Este módulo permite la carga de dataset de procesamiento de imágenes como: Mnist, Fmnist, Cifar10 y Cifar100. Además del dataset tabular nursery.

In [1]:
from process_data import *
from copy import deepcopy

flex_data, server_id = load_and_preprocess_horizontal(dataname="mnist", trasnform=False, nodes=2)

torch.Size([60000, 28, 28])


A continuación, se define la arquitectura de los modelos locales de los clientes. Para el presente ejemplo se trabaja con modelos neuronales de pytorch.

Se utiliza el módulo networks_models, quien contiene una serie de modelos neuronales auxiliares de pytorch, para el trabajo con las bases de datos anteriormente mencionadas. Además se utiliza el módulo auxiliar networks_execution, que define la ejecución del entrenamiento y otros detalles de estos modelos.

Para establecer un modelo personalizado, ir a la documentación de Flex.

In [2]:
from networks_models import *
from networks_execution import *
from flex.pool import init_server_model
from flex.model import FlexModel

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

net_config = ExecutionNetwork()

@init_server_model
def build_server_model():
    server_flex_model = FlexModel()

    criterion, model, optimizer = net_config.for_fd_server_model_config()

    server_flex_model["model"] = model.to(device)
    # Required to store this for later stages of the FL training process
    server_flex_model["criterion"] = criterion
    server_flex_model["optimizer_func"] = optimizer
    server_flex_model["optimizer_kwargs"] = {}

    return server_flex_model

Se define el ataque de envenamiento de datos por puerta trasera

In [3]:
from flexclash.data import data_poisoner_all
from my_poison_attacks.dirty_label import label_flipping_attack
from PIL import Image

client_ids = list(flex_data.keys())
clients_to_backdoor = client_ids[:2]
print("Clientes que modifica", clients_to_backdoor)

label_one = 1
label_two = 2
porcent_to_change = 0.2

clients_to_change = clients_to_backdoor #Esto ver si funciona

@data_poisoner_all
def dirty_label_poison(dataset_client: Dataset):
    x, y = dataset_client.to_numpy()#Ver como solo tomar las y
    x = np.expand_dims(x, axis = 1)
    
    new_ylabels = label_flipping_attack(client_labels = y, num_labels = 10)

    new_img_final = [Image.fromarray(x[arr][0]) for arr in range(len(x))]
    print(len(new_img_final))
    print(len(new_ylabels))

    return new_img_final, new_ylabels

flex_data_modif = flex_data.apply(dirty_label_poison, node_ids = clients_to_backdoor)

Clientes que modifica [0, 1]
30000
30000
30000
30000


Se define la arquitectura del modelo federado

In [4]:
from flex.pool import FlexPool
clients = 1

pool = FlexPool.client_server_pool(
        fed_dataset = flex_data_modif, server_id=server_id, init_func = build_server_model
    )

selected_test_clients_pool = pool.clients.select(clients)
selected_test_clients = selected_test_clients_pool.clients

print(selected_test_clients.actor_ids)
print(np.array(flex_data_modif[selected_test_clients.actor_ids[0]].X_data).shape)
#print(np.array(flex_data_modif[selected_test_clients.actor_ids[1]].X_data).shape)
print(list(flex_data_modif.keys()))

[0]
(30000, 28, 28)
[0, 1, 'server']


Se define la función para desplegar el modelo global en cada cliente

In [5]:
from flex.pool.decorators import (  # noqa: E402
    deploy_server_model,
)

@deploy_server_model
def deploy_serv(server_flex_model: FlexModel): 

    new_model = deepcopy(server_flex_model)

    return new_model

pool.servers.map(deploy_serv, selected_test_clients)

Se define la ronda de entrenamiento local de un cliente, empleando el módulo networks_execution

In [6]:
def train(client_flex_model: FlexModel, client_data: Dataset):

    print(np.array(client_data.X_data).shape)
    train_dataset = client_data.to_torchvision_dataset(transform = mnist_transform())
    client_dataloader = DataLoader(train_dataset, batch_size = 256)

    model = client_flex_model["model"]
    model = model.to(device)

    client_flex_model["previous_model"] = deepcopy(model)
    optimizer = client_flex_model["optimizer_func"]
    criterion = client_flex_model["criterion"]

    net_config.trainNetwork(local_epochs = 1, criterion = criterion, optimizer = optimizer,
                            momentum = 0.9, lr = 0.005, trainloader = client_dataloader, testloader= None, 
                            model=model)
    
    return client_flex_model

selected_test_clients.map(train)

(30000, 28, 28)


100%|██████████| 118/118 [00:21<00:00,  5.44it/s]


[{'model': CNN(
   (conv1): Sequential(
     (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
     (1): LeakyReLU(negative_slope=0.2)
   )
   (conv2): Sequential(
     (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
     (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
     (2): LeakyReLU(negative_slope=0.2)
   )
   (conv3): Sequential(
     (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2))
     (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
     (2): LeakyReLU(negative_slope=0.2)
   )
   (out): Linear(in_features=2304, out_features=10, bias=True)
 ), 'criterion': CrossEntropyLoss(), 'optimizer_func': SGD (
 Parameter Group 0
     dampening: 0
     differentiable: False
     foreach: None
     fused: None
     lr: 0.05
     maximize: False
     momentum: 0.9
     nesterov: False
     weight_decay: 0
 ), 'optimizer_kwargs': {}, 'previous_model': CNN(
   (conv

Se efectúa la agregación del modelo federado

In [7]:
from flex.pool import collect_client_diff_weights_pt
from flex.pool import fed_avg
from flex.pool import set_aggregated_diff_weights_pt


pool.aggregators.map(collect_client_diff_weights_pt, selected_test_clients)
pool.aggregators.map(fed_avg)
pool.aggregators.map(set_aggregated_diff_weights_pt, pool.servers)

Después de la modif antes de fedavg tensor(2.2508)


Se evalúa el modelo federado

In [8]:
def evaluate_global_model(server_flex_model: FlexModel, test_data: Dataset):#falta poner esto
    model = server_flex_model["model"]
    model.eval()
    test_loss = 0
    test_acc = 0
    total_count = 0
    model = model.to(device)


    criterion = server_flex_model["criterion"]
    # get test data as a torchvision object
    dataset = test_data.to_torchvision_dataset(transform = mnist_transform())
    test_dataloader = DataLoader(
        dataset, batch_size=256, shuffle=True, pin_memory=False
    )
    losses = []
    with torch.no_grad():
        for data, target in tqdm(test_dataloader):
            total_count += target.size(0)
            data, target = data.to(device), target.to(device)
            output = model(data)
            losses.append(criterion(output, target).item())
            pred = output.data.max(1, keepdim=True)[1]
            test_acc += pred.eq(target.data.view_as(pred)).long().cpu().sum().item()

    test_loss = sum(losses) / len(losses)
    test_acc /= total_count

    
    return test_loss, test_acc

metrics = pool.servers.map(evaluate_global_model)

loss, acc = metrics[0]
print(f"Server: Test acc: {acc:.4f}, test loss: {loss:.4f}")

100%|██████████| 40/40 [00:04<00:00,  8.67it/s]


Server: Test acc: 0.0045, test loss: 17.6161


Se define el evaluador para el ataque

Se selecciona como datos envenenados, aquellas imágenes generadas por el ataque. Verificando como se comporta el modelo si solo analiza aquellas imágenes modificadas. Al igual que en el notebook de ataques backdoors, puede insertar una función trigger para que genere nuevos datos a partir de imágenes, utilizando para ello el decorador generate_bad_data_for_test.

In [9]:
from poison_attack_evaluator import generate_bad_data_for_test, evaluate_model_with_poison_data, data_poison_evaluator_pt
from PIL import Image

@generate_bad_data_for_test
def poison_test_set(test_set: Dataset):
    
    x, y = test_set.to_numpy()#Ver como solo tomar las y
    x = np.expand_dims(x, axis = 1)
    
    new_ylabels = label_flipping_attack(client_labels = y, num_labels = 10)

    new_img_final = [Image.fromarray(x[arr][0]) for arr in range(len(x))]
    print(len(new_img_final))
    print(len(new_ylabels))


    return new_img_final, new_ylabels

@evaluate_model_with_poison_data
def evaluator_pt(server_model: FlexModel, test_data: Dataset):
    poison_dataset = poison_test_set(test_data)
    poison_dataset = poison_dataset.to_torchvision_dataset(transform = mnist_transform())
    test_loss, test_acc = data_poison_evaluator_pt(server_model, poison_dataset)
    
    return test_loss, test_acc

metrics_for_bad_data = pool.servers.map(evaluator_pt)

loss_b, acc_b = metrics_for_bad_data[0]
print(f"Server: Test acc: {acc_b:.4f}, test loss: {loss_b:.4f}")

10000
10000


100%|██████████| 40/40 [00:03<00:00, 10.42it/s]


Server: Test acc: 0.9676, test loss: 0.1091


Para limpiar los modelos en memoria. Opcional

In [None]:
def clean_up_models(client_model: FlexModel, _):
    import gc

    client_model.clear()
    gc.collect()

Se definen las rondas de entrenamiento del modelo federado

In [None]:
def train_n_rounds(n_rounds, clients_per_round=20):
    pool = FlexPool.client_server_pool(
        fed_dataset= flex_data, server_id=server_id, init_func=build_server_model
    )
    for i in range(n_rounds):
        print(f"\nRunning round: {i+1} of {n_rounds}")
        selected_clients_pool = pool.clients.select(clients_per_round)
        selected_clients = selected_clients_pool.clients
        pool.servers.map(deploy_serv, selected_clients)
        selected_clients.map(train)
        pool.aggregators.map(collect_client_diff_weights_pt, selected_clients)
        pool.aggregators.map(fed_avg)
        pool.aggregators.map(set_aggregated_diff_weights_pt, pool.servers)
        metrics = pool.servers.map(evaluate_global_model)
        selected_clients.map(clean_up_models)
        loss, acc = metrics[0]
        print(f"Server: Test acc: {acc:.4f}, test loss: {loss:.4f}")

In [None]:
train_n_rounds(2, clients_per_round=2)