In [1]:
!pip install -q flwr[simulation] torch torchvision matplotlib

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m219.2/219.2 kB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m42.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.9/56.9 MB[0m [31m11.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
!pip install alibi-detect[torch]

Collecting alibi-detect[torch]
  Downloading alibi_detect-0.11.4-py3-none-any.whl (372 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m372.4/372.4 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.4.0,>=0.3.0 (from alibi-detect[torch])
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m18.6 MB/s[0m eta [36m0:00:00[0m
Collecting numba!=0.54.0,<0.58.0,>=0.50.0 (from alibi-detect[torch])
  Downloading numba-0.57.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (3.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.6/3.6 MB[0m [31m31.1 MB/s[0m eta [36m0:00:00[0m
Collecting torch<1.14.0,>=1.7.0 (from alibi-detect[torch])
  Downloading torch-1.13.1-cp310-cp310-manylinux1_x86_64.whl (887.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m887.5/887.5 MB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
Collecting llvmlite<

In [3]:
import flwr as fl

import torch
import torch.nn as nn
import tensorflow as tf
import numpy as np

from functools import partial
from alibi_detect.cd import MMDDrift
from alibi_detect.cd.pytorch import preprocess_drift

import random
# import torchvision.datasets as datasets
# import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split, Subset

In [4]:
# set random seed and device
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
print(device)

cuda


In [6]:
n_clients = 2

In [7]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
y_train = y_train.astype("int64").reshape(
    -1,
)
y_test = y_test.astype("int64").reshape(
    -1,
)

x_train = x_train[0 : int(len(x_train) * 0.2)]
y_train = y_train[0 : int(len(y_train) * 0.2)] # just to get some small part of data

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [8]:
# Simulate federated clients (splitting the dataset)
client_data = []

for i in range(n_clients):
    start = i * len(x_train) // n_clients
    end = (i + 1) * len(x_train) // n_clients
    client_data.append((x_train[start:end], y_train[start:end]))

In [9]:
len(client_data[0][0])

5000

In [10]:
# # Define a global model (we had CNN in Task1/2 (class Net))

# encoding_dim = 32
# # define encoder
# global_model = nn.Sequential(
#     nn.Conv2d(3, 64, 4, stride=2, padding=0),
#     nn.ReLU(),
#     nn.Conv2d(64, 128, 4, stride=2, padding=0),
#     nn.ReLU(),
#     nn.Conv2d(128, 512, 4, stride=2, padding=0),
#     nn.ReLU(),
#     nn.Flatten(),
#     nn.Linear(2048, encoding_dim)
# ).to(device).eval()

class Encoder(nn.Module):
    def __init__(self, encoding_dim):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 4, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv2d(128, 512, 4, stride=2, padding=0),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(2048, encoding_dim)
        )

    def forward(self, x):
        return self.encoder(x)

# Define the encoding dimension
encoding_dim = 32

# Instantiate the Encoder model
encoder_model = Encoder(encoding_dim).to(device).eval()

In [11]:
def permute_c(x):
    return np.transpose(x.astype(np.float32), (0, 3, 1, 2))

In [12]:
# MMD detector on each client
client_detectors = []
for x_data, _ in client_data:
    # define preprocessing function
    preprocess_fn = partial(
        preprocess_drift, model=encoder_model, device=device, batch_size=512
    )

    X_ref = permute_c(x_data[0:200])
    # initialise drift detector
    detector = MMDDrift(
        X_ref,
        backend="pytorch",
        p_val=0.05,
        preprocess_fn=preprocess_fn,
        n_permutations=100,
    )
    client_detectors.append(detector)

In [13]:
from torch.utils.data import TensorDataset, DataLoader

In [14]:
BATCH_SIZE = 32

In [15]:
def train(model, x_train, y_train, num_epochs=5):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters())
    model.train()

    train_dataset = TensorDataset(torch.tensor(x_train), torch.tensor(y_train))
    trainloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

    for epoch in range(num_epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for images, labels in trainloader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss.item()
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total

        print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")

    return model


def test(model, x_test, y_test):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    model.eval()

    test_dataset = TensorDataset(torch.tensor(x_test), torch.tensor(y_test))
    testloader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    print('total', total)
    accuracy = correct / total
    return loss, accuracy

In [16]:
# Drift detection on client data
def handle_client_drift(c_data, detector, net):
    (x_train, y_train, x_val, y_val, cid) = c_data

    detector_data = detector.predict(x_train, return_p_val=True, return_distance=True)
    is_drift = detector_data['data'].get('is_drift', None)
    p_val = detector_data['data'].get('p_val', None)
    distance = detector_data['data'].get('distance', None)

    print("Client:",cid)
    print("p_val:",p_val)
    print("distance:",distance)

    if is_drift:
        print("Drift detected on client data. Retraining local model.")
        net = train(net, x_train, y_train, num_epochs=5)
    else:
        print("No drift detected on client data. Continuing training.")
    return net


# Drift detection on aggregated data
def handle_global_drift(aggregated_data, detector, global_model):
    is_drift, metrics = detector.predict(permute_c(aggregated_data["x_train"]))
    print(metrics)  # You may extract useful information, e.g., p-value, from metrics
    if is_drift:
        print("Drift detected on aggregated data. Updating global model.")
        # Update the global model based on aggregated_data
        # global_model = train(
        #     global_model,
        #     aggregated_data["x_train"],
        #     aggregated_data["y_train"],
        #     num_epochs=5,
        # )
    else:
        print("No drift detected on aggregated data. Continuing training.")
    # return global_model

In [23]:
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, client_data, model):
        self.client_data = client_data
        self.model = model

    def get_parameters(self, config):
        # Return the current model parameters
        return [param.detach().cpu().numpy() for param in self.model.parameters()]

    def fit(self, parameters, config):
        # Train the local model after updating it with the given parameters
        # Convert parameters from numpy arrays to torch tensors
        state_dict = {
            key: torch.from_numpy(param)
            for key, param in zip(self.model.state_dict(), parameters)
        }
        self.model.load_state_dict(state_dict)

        # Perform local training with client_data and drift detection
        # self.model = handle_client_drift(
        #     self.client_data, self.client_detector, self.model
        # )
        # handle_client_drift(self.client_data, self.client_detector, self.model) ## Q?: Do I need handle_client_drift here?
        new_params = [param.detach().cpu().numpy() for param in self.model.parameters()]
        return new_params, len(self.client_data["x_train"]), {}

    def evaluate(self, parameters, config):
        # Perform the evaluation of the model after updating it with the given
        # parameters. Returns the loss as a float, the length of the validation
        # data, and a dict containing the accuracy
        # Convert parameters from numpy arrays to torch tensors
        state_dict = {
            key: torch.from_numpy(param)
            for key, param in zip(self.model.state_dict(), parameters)
        }
        self.model.load_state_dict(state_dict)
        # Perform evaluation
        loss, accuracy = test(
            self.model, self.client_data["x_val"], self.client_data["y_val"]
        )

        return (
            float(loss),
            len(self.client_data["y_val"]),
            {"accuracy": float(accuracy)},
        )

In [24]:
def client_fn(cid: str, client_data=client_data) -> FlowerClient:
        x_data, y_data = client_data[int(cid)]
        # x_data = np.array(x_data)
        x_data = permute_c(x_data)
        # y_data = np.array(y_data)

        x_train = x_data[0 : int(len(x_data) * 0.8)]
        y_train = y_data[0 : int(len(y_data) * 0.8)]

        x_val = x_data[int(len(x_data) * 0.8) :]
        y_val = y_data[int(len(y_data) * 0.8) :]

        all_data = []
        all_data.extend((x_train, y_train, x_val, y_val, int(cid)))
        # Apply drift detection on client data

        model = Encoder(encoding_dim).to(device)

        # model = handle_client_drift(all_data, client_detectors[int(cid)], model)
        handle_client_drift(all_data, client_detectors[int(cid)], model)

        # Train the local model
        train(
            model,
            x_train,
            y_train,
            num_epochs=5,
        )

        return FlowerClient(client_data={"x_train": x_train, "y_train": y_train, "x_val": x_val, "y_val": y_val}, model=model)

In [25]:
# Specify client resources if you need GPU (defaults to 1 CPU and 0 GPU)
client_resources = None
if device.type == "cuda":
    client_resources = {"num_gpus": 1, "num_cpus": 1}

In [26]:
client_resources

{'num_gpus': 1, 'num_cpus': 1}

In [27]:
# Create FedAvg strategy
strategy = fl.server.strategy.FedAvg(
    fraction_fit=1.0,          # Sample 100% of available clients for training
    fraction_evaluate=0.5,     # Sample 50% of available clients for evaluation
    min_fit_clients=2,        # Never sample less than 10 clients for training
    min_evaluate_clients=2,    # Never sample less than 5 clients for evaluation
    min_available_clients=2,  # Wait until all 10 clients are available
    # evaluate_metrics_aggregation_fn=handle_global_drift
    # Q?: should I call handle_global_drift here?
)

In [28]:
# Start simulation
fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=n_clients,
    config=fl.server.ServerConfig(num_rounds=2),
    strategy=strategy,
    client_resources=client_resources,
    ray_init_args={"num_cpus": 8, "num_gpus": 1},
)

INFO flwr 2024-01-22 02:37:17,876 | app.py:178 | Starting Flower simulation, config: ServerConfig(num_rounds=2, round_timeout=None)
INFO:flwr:Starting Flower simulation, config: ServerConfig(num_rounds=2, round_timeout=None)
2024-01-22 02:37:22,604	INFO worker.py:1621 -- Started a local Ray instance.
INFO flwr 2024-01-22 02:37:24,849 | app.py:213 | Flower VCE: Ray initialized with resources: {'memory': 7912058880.0, 'object_store_memory': 3956029440.0, 'GPU': 1.0, 'node:172.28.0.12': 1.0, 'node:__internal_head__': 1.0, 'CPU': 8.0}
INFO:flwr:Flower VCE: Ray initialized with resources: {'memory': 7912058880.0, 'object_store_memory': 3956029440.0, 'GPU': 1.0, 'node:172.28.0.12': 1.0, 'node:__internal_head__': 1.0, 'CPU': 8.0}
INFO flwr 2024-01-22 02:37:24,858 | app.py:219 | Optimize your simulation with Flower VCE: https://flower.dev/docs/framework/how-to-run-simulations.html
INFO:flwr:Optimize your simulation with Flower VCE: https://flower.dev/docs/framework/how-to-run-simulations.html


[2m[36m(DefaultActor pid=2837)[0m Client: 0
[2m[36m(DefaultActor pid=2837)[0m p_val: 0.6499999761581421
[2m[36m(DefaultActor pid=2837)[0m distance: -0.00043576955795288086
[2m[36m(DefaultActor pid=2837)[0m No drift detected on client data. Continuing training.
[2m[36m(DefaultActor pid=2837)[0m Epoch 1: train loss 0.06696390679478645, accuracy 0.20325
[2m[36m(DefaultActor pid=2837)[0m Epoch 2: train loss 0.054985032618045805, accuracy 0.3595
[2m[36m(DefaultActor pid=2837)[0m Epoch 3: train loss 0.04975893148779869, accuracy 0.41925
[2m[36m(DefaultActor pid=2837)[0m Epoch 4: train loss 0.045352270424366, accuracy 0.4825


INFO flwr 2024-01-22 02:37:40,303 | server.py:280 | Received initial parameters from one random client
INFO:flwr:Received initial parameters from one random client
INFO flwr 2024-01-22 02:37:40,307 | server.py:91 | Evaluating initial parameters
INFO:flwr:Evaluating initial parameters
INFO flwr 2024-01-22 02:37:40,312 | server.py:104 | FL starting
INFO:flwr:FL starting
DEBUG flwr 2024-01-22 02:37:40,315 | server.py:222 | fit_round 1: strategy sampled 2 clients (out of 2)
DEBUG:flwr:fit_round 1: strategy sampled 2 clients (out of 2)


[2m[36m(DefaultActor pid=2837)[0m Epoch 5: train loss 0.04165339416265488, accuracy 0.524
[2m[36m(DefaultActor pid=2837)[0m Client: 1
[2m[36m(DefaultActor pid=2837)[0m p_val: 0.7300000190734863
[2m[36m(DefaultActor pid=2837)[0m distance: -0.0004488229751586914
[2m[36m(DefaultActor pid=2837)[0m No drift detected on client data. Continuing training.
[2m[36m(DefaultActor pid=2837)[0m Epoch 1: train loss 0.06694731310009956, accuracy 0.2105
[2m[36m(DefaultActor pid=2837)[0m Epoch 2: train loss 0.053859348773956296, accuracy 0.37275
[2m[36m(DefaultActor pid=2837)[0m Epoch 3: train loss 0.0488228754401207, accuracy 0.43
[2m[36m(DefaultActor pid=2837)[0m Epoch 4: train loss 0.044787902608513834, accuracy 0.4705
[2m[36m(DefaultActor pid=2837)[0m Epoch 5: train loss 0.04099466572701931, accuracy 0.5255
[2m[36m(DefaultActor pid=2837)[0m Client: 0
[2m[36m(DefaultActor pid=2837)[0m p_val: 0.6899999976158142
[2m[36m(DefaultActor pid=2837)[0m distance: -0.00043

DEBUG flwr 2024-01-22 02:37:49,132 | server.py:236 | fit_round 1 received 2 results and 0 failures
DEBUG:flwr:fit_round 1 received 2 results and 0 failures
DEBUG flwr 2024-01-22 02:37:49,161 | server.py:173 | evaluate_round 1: strategy sampled 2 clients (out of 2)
DEBUG:flwr:evaluate_round 1: strategy sampled 2 clients (out of 2)


[2m[36m(DefaultActor pid=2837)[0m Epoch 5: train loss 0.04115588167309761, accuracy 0.523
[2m[36m(DefaultActor pid=2837)[0m Client: 0
[2m[36m(DefaultActor pid=2837)[0m p_val: 0.6600000262260437
[2m[36m(DefaultActor pid=2837)[0m distance: -0.00043582916259765625
[2m[36m(DefaultActor pid=2837)[0m No drift detected on client data. Continuing training.
[2m[36m(DefaultActor pid=2837)[0m Epoch 1: train loss 0.06789397594332695, accuracy 0.20325
[2m[36m(DefaultActor pid=2837)[0m Epoch 2: train loss 0.05557058334350586, accuracy 0.351
[2m[36m(DefaultActor pid=2837)[0m Epoch 3: train loss 0.04987913128733635, accuracy 0.424
[2m[36m(DefaultActor pid=2837)[0m Epoch 4: train loss 0.046074427783489226, accuracy 0.47125
[2m[36m(DefaultActor pid=2837)[0m Epoch 5: train loss 0.0420899463146925, accuracy 0.52075
[2m[36m(DefaultActor pid=2837)[0m total 1000
[2m[36m(DefaultActor pid=2837)[0m Client: 1
[2m[36m(DefaultActor pid=2837)[0m p_val: 0.6899999976158142
[2m

DEBUG flwr 2024-01-22 02:37:56,838 | server.py:187 | evaluate_round 1 received 2 results and 0 failures
DEBUG:flwr:evaluate_round 1 received 2 results and 0 failures
DEBUG flwr 2024-01-22 02:37:56,845 | server.py:222 | fit_round 2: strategy sampled 2 clients (out of 2)
DEBUG:flwr:fit_round 2: strategy sampled 2 clients (out of 2)


[2m[36m(DefaultActor pid=2837)[0m Epoch 5: train loss 0.0397581602036953, accuracy 0.537
[2m[36m(DefaultActor pid=2837)[0m total 1000
[2m[36m(DefaultActor pid=2837)[0m Client: 0
[2m[36m(DefaultActor pid=2837)[0m p_val: 0.7200000286102295
[2m[36m(DefaultActor pid=2837)[0m distance: -0.00043582916259765625
[2m[36m(DefaultActor pid=2837)[0m No drift detected on client data. Continuing training.
[2m[36m(DefaultActor pid=2837)[0m Epoch 1: train loss 0.0668560910820961, accuracy 0.2125
[2m[36m(DefaultActor pid=2837)[0m Epoch 2: train loss 0.053809002906084064, accuracy 0.36875
[2m[36m(DefaultActor pid=2837)[0m Epoch 3: train loss 0.048579350978136066, accuracy 0.42875
[2m[36m(DefaultActor pid=2837)[0m Epoch 4: train loss 0.044564552903175354, accuracy 0.4835
[2m[36m(DefaultActor pid=2837)[0m Epoch 5: train loss 0.040556344717741014, accuracy 0.53575
[2m[36m(DefaultActor pid=2837)[0m Client: 1
[2m[36m(DefaultActor pid=2837)[0m p_val: 0.6299999952316284


DEBUG flwr 2024-01-22 02:38:05,056 | server.py:236 | fit_round 2 received 2 results and 0 failures
DEBUG:flwr:fit_round 2 received 2 results and 0 failures
DEBUG flwr 2024-01-22 02:38:05,078 | server.py:173 | evaluate_round 2: strategy sampled 2 clients (out of 2)
DEBUG:flwr:evaluate_round 2: strategy sampled 2 clients (out of 2)


[2m[36m(DefaultActor pid=2837)[0m Epoch 5: train loss 0.04285445073246956, accuracy 0.498
[2m[36m(DefaultActor pid=2837)[0m Client: 1
[2m[36m(DefaultActor pid=2837)[0m p_val: 0.7300000190734863
[2m[36m(DefaultActor pid=2837)[0m distance: -0.0004488229751586914
[2m[36m(DefaultActor pid=2837)[0m No drift detected on client data. Continuing training.
[2m[36m(DefaultActor pid=2837)[0m Epoch 1: train loss 0.06797759488224983, accuracy 0.197
[2m[36m(DefaultActor pid=2837)[0m Epoch 2: train loss 0.0554268451333046, accuracy 0.349
[2m[36m(DefaultActor pid=2837)[0m Epoch 3: train loss 0.05070326006412506, accuracy 0.413
[2m[36m(DefaultActor pid=2837)[0m Epoch 4: train loss 0.04665420040488243, accuracy 0.46125
[2m[36m(DefaultActor pid=2837)[0m Epoch 5: train loss 0.042946809887886045, accuracy 0.49925
[2m[36m(DefaultActor pid=2837)[0m total 1000
[2m[36m(DefaultActor pid=2837)[0m Client: 0
[2m[36m(DefaultActor pid=2837)[0m p_val: 0.6800000071525574
[2m[36

DEBUG flwr 2024-01-22 02:38:13,128 | server.py:187 | evaluate_round 2 received 2 results and 0 failures
DEBUG:flwr:evaluate_round 2 received 2 results and 0 failures
INFO flwr 2024-01-22 02:38:13,131 | server.py:153 | FL finished in 32.81657117899999
INFO:flwr:FL finished in 32.81657117899999
INFO flwr 2024-01-22 02:38:13,136 | app.py:226 | app_fit: losses_distributed [(1, 0.04647466617822647), (2, 0.04647466617822647)]
INFO:flwr:app_fit: losses_distributed [(1, 0.04647466617822647), (2, 0.04647466617822647)]
INFO flwr 2024-01-22 02:38:13,138 | app.py:227 | app_fit: metrics_distributed_fit {}
INFO:flwr:app_fit: metrics_distributed_fit {}
INFO flwr 2024-01-22 02:38:13,139 | app.py:228 | app_fit: metrics_distributed {}
INFO:flwr:app_fit: metrics_distributed {}
INFO flwr 2024-01-22 02:38:13,145 | app.py:229 | app_fit: losses_centralized []
INFO:flwr:app_fit: losses_centralized []
INFO flwr 2024-01-22 02:38:13,149 | app.py:230 | app_fit: metrics_centralized {}
INFO:flwr:app_fit: metrics_ce

[2m[36m(DefaultActor pid=2837)[0m Epoch 5: train loss 0.042640629991889, accuracy 0.509
[2m[36m(DefaultActor pid=2837)[0m total 1000


History (loss, distributed):
	round 1: 0.04647466617822647
	round 2: 0.04647466617822647