In [11]:
import flwr as fl

import torch
import torch.nn as nn
import tensorflow as tf
import numpy as np

from functools import partial
from alibi_detect.cd import MMDDrift
from alibi_detect.cd.pytorch import preprocess_drift

In [12]:
# set random seed and device
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [13]:
print(device)

cuda


In [14]:
n_clients = 2

In [15]:
# Download and preprocess CIFAR-10 dataset
(all_x_train, all_y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
all_x_train, x_test = all_x_train / 255.0, x_test / 255.0 
all_x_train = all_x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255
all_y_train = all_y_train.astype('int64').reshape(-1,)
y_test = y_test.astype('int64').reshape(-1,)

x_train = all_x_train[0:int(len(all_x_train)*0.8)]
y_train = all_y_train[0:int(len(all_y_train)*0.8)]

x_val = all_x_train[int(len(all_x_train)*0.8):]
y_val = all_y_train[int(len(all_y_train)*0.8):]

In [16]:
len(x_val)

10000

In [17]:
len(y_train)

40000

In [18]:
# Simulate federated clients (splitting the dataset)
client_data = []

for i in range(n_clients):
    start = i * len(x_train) // n_clients
    end = (i + 1) * len(x_train) // n_clients
    client_data.append((x_train[start:end], y_train[start:end]))

In [19]:
len(client_data[0][0])

20000

In [24]:
# Define a global model (we had CNN in Task1/2 (class Net))

encoding_dim = 32
# define encoder
global_model = nn.Sequential(
    nn.Conv2d(3, 64, 4, stride=2, padding=0),
    nn.ReLU(),
    nn.Conv2d(64, 128, 4, stride=2, padding=0),
    nn.ReLU(),
    nn.Conv2d(128, 512, 4, stride=2, padding=0),
    nn.ReLU(),
    nn.Flatten(),
    nn.Linear(2048, encoding_dim)
).to(device).eval()

In [25]:
def permute_c(x):
    return np.transpose(x.astype(np.float32), (0, 3, 1, 2))

In [26]:
# MMD detector on each client
client_detectors = []
for x_data, _ in client_data:
    
    # define preprocessing function
    preprocess_fn = partial(preprocess_drift, model=global_model, device=device, batch_size=512)

    X_ref = permute_c(x_data[0:200])
    # initialise drift detector
    detector = MMDDrift(X_ref, backend='pytorch', p_val=.05, 
                preprocess_fn=preprocess_fn, n_permutations=100)
    client_detectors.append(detector)

In [28]:
# Model train
def train(x_data, y_data, local_model, num_epochs=5):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(local_model.parameters())
    local_model.train() 

    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for inputs, labels in zip(x_data, y_data):
            inputs, labels = torch.from_numpy(inputs).to(device), torch.from_numpy(labels).to(device)
            optimizer.zero_grad()
            outputs = local_model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss.item()
        
        print(f'Epoch {epoch + 1}, Loss: {epoch_loss / len(x_data)}')

    return local_model

# Moddel test
def test(x_data, y_data, local_model):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    local_model.eval()

    with torch.no_grad():
        for inputs, labels in zip(x_data, y_data):
            inputs, labels = torch.from_numpy(inputs).to(device), torch.from_numpy(labels).to(device)
            outputs = local_model(inputs)
            loss = criterion(outputs, labels)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    loss /= len(x_data)
    accuracy = correct / total

    return loss, accuracy


In [29]:
# Drift detection on client data
def handle_client_drift(x_data, detector):
    is_drift, metrics = detector.predict(permute_c(x_data))
    if is_drift:
        print("Drift detected on client data.")
        # local_model = train(x_data, y_data, local_model, num_epochs=5)
    else:
        print("No drift detected on client data. Continuing training.")


# Drift detection on aggregated data
def handle_global_drift(aggregated_data, detector):
    is_drift, metrics = detector.predict(permute_c(aggregated_data))
    print(metrics) # I think we can get p-value from metrics
    if is_drift:
        print("Drift detected on aggregated data. Updating global model.")
        # Should I update the global model here?
    else:
        print("No drift detected on aggregated data. Continuing training.")

In [30]:
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, client_data, local_model, client_detector):
        self.client_data = client_data
        self.local_model = local_model
        self.client_detector = client_detector

    def get_parameters(self):
        # Return the current model parameters
        return self.local_model.state_dict()

    def fit(self, parameters, config):
        # Train the local model after updating it with the given parameters
        self.local_model.load_state_dict(parameters)
        self.local_model = train(self.client_data[0], self.client_data[1], self.local_model, num_epochs=5)
        # Perform local training with client_data and drift detection
        handle_client_drift(self.client_data[0], self.client_detector, self.local_model)

    def evaluate(self, parameters, config):
        # Perform the evaluation of the model after updating it with the given
        # parameters. Returns the loss as a float, the length of the validation
        # data, and a dict containing the accuracy
        self.local_model.load_state_dict(parameters)
        loss, accuracy = test(x_val, y_val, self.local_model)
        # Can I run handle_global_drift here instead of test?
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}

In [31]:
def client_fn(cid: str, client_data=client_data) -> FlowerClient:
    for x_data, y_data in client_data:
        # Apply drift detection on client data
        handle_client_drift(x_data, client_detectors[int(cid)])
        
        local_model = nn.Sequential(
            nn.Conv2d(3, 64, 4, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv2d(128, 512, 4, stride=2, padding=0),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(2048, encoding_dim)
        ).to(device)

        # Train the local model
        local_model = train(x_data, y_data, local_model, num_epochs=5)
        
        return FlowerClient(local_model, train_data=(x_data, y_data))

In [32]:
# Specify client resources if you need GPU (defaults to 1 CPU and 0 GPU)
client_resources = None
if device == "cuda":
    client_resources = {"num_gpus": 1, 'num_cpus': 1}

In [33]:
# Create FedAvg strategy
strategy = fl.server.strategy.FedAvg(
    fraction_fit=1.0,          # Sample 100% of available clients for training
    fraction_evaluate=0.5,     # Sample 50% of available clients for evaluation
    min_fit_clients=10,        # Never sample less than 10 clients for training
    min_evaluate_clients=5,    # Never sample less than 5 clients for evaluation
    min_available_clients=10,  # Wait until all 10 clients are available
    evaluate_metrics_aggregation_fn=handle_global_drift
)

In [None]:
# Start simulation
fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=n_clients,
    config=fl.server.ServerConfig(num_rounds=5),
    strategy=strategy,
    client_resources=client_resources,
)