In [9]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
from collections import OrderedDict

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchtext

torchtext.disable_torchtext_deprecation_warning()

np.random.seed(42)

DEVICE = (
    torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
)
print(f"Training on {DEVICE} using PyTorch {torch.__version__}")

Training on mps using PyTorch 2.3.1


In [12]:
from sklearn.model_selection import train_test_split

from dataset import load_dataset, to_dataloader


NUM_CLIENTS = 10
BATCH_SIZE = 32


# Load the dataset
df, vocab, label_encoder = load_dataset()


def split_dataset():
    train, test = train_test_split(df, test_size=0.1, random_state=42)

    partitions = np.array_split(train, NUM_CLIENTS)

    # Create train/val for each partition and wrap it into DataLoader
    trainloaders: list[DataLoader] = []
    valloaders: list[DataLoader] = []
    for partition_id in range(NUM_CLIENTS):
        partition = partitions[partition_id]

        train_texts, test_texts, train_labels, test_labels = train_test_split(
            partition["text"],
            partition["category"],
            test_size=0.2,
            random_state=42,
        )

        train_loader = to_dataloader(train_texts, train_labels, vocab)
        test_loader = to_dataloader(test_texts, test_labels, vocab)

        trainloaders.append(train_loader)
        valloaders.append(test_loader)

    testloader = to_dataloader(test["text"], test["category"], vocab)
    return trainloaders, valloaders, testloader


trainloaders, valloaders, test_loader = split_dataset()

[nltk_data] Downloading package reuters to /Users/gabriel/nltk_data...
[nltk_data]   Package reuters is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/gabriel/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
  return bound(*args, **kwds)


In [13]:
def train(
    model: nn.Module,
    train_loader: DataLoader,
    num_epochs: int,
    verbose=False,
):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters())

    model.train()
    for epoch in range(num_epochs):
        for texts, labels in train_loader:
            texts, labels = texts.to(DEVICE), labels.to(DEVICE)

            outputs = model(texts)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if verbose:
            print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")


def test(model: nn.Module, test_loader: DataLoader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0

    model.eval()
    with torch.no_grad():
        for texts, labels in test_loader:
            texts, labels = texts.to(DEVICE), labels.to(DEVICE)

            outputs = model(texts)
            loss += criterion(outputs, labels).item()
            
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    loss /= len(test_loader.dataset)
    accuracy = correct / total
    return loss, accuracy

In [15]:
from flwr.client import Client, NumPyClient

from text_rnn import TextRNN

class FlowerClient(NumPyClient):
    def __init__(self, net: nn.Module, trainloader: DataLoader, valloader: DataLoader):
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config) -> list[np.ndarray]:
        return [val.cpu().numpy() for _, val in self.net.state_dict().items()]
    
    def _set_parameters(self, parameters: list[np.ndarray]):
        params_dict = zip(self.net.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
        self.net.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
        self._set_parameters(parameters)

        train(self.net, self.trainloader, num_epochs=1)

        return self.get_parameters(config), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        self._set_parameters(parameters)

        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}
    

def create_client(cid: str) -> Client:
    vocab_size = len(vocab)
    output_size = len(label_encoder.classes_)
    model = TextRNN(vocab_size, output_size, padding_idx=vocab["<pad>"]).to(DEVICE)

    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]

    return FlowerClient(model, trainloader, valloader).to_client()

In [16]:
from flwr.server.strategy import FedAvg
from flwr.server import ServerConfig
from flwr.simulation import start_simulation

strategy = FedAvg(
    fraction_fit=1.0,  # Sample 100% of available clients for training
    fraction_evaluate=0.5,  # Sample 50% of available clients for evaluation
    min_fit_clients=10,  # Never sample less than 10 clients for training
    min_evaluate_clients=5,  # Never sample less than 5 clients for evaluation
    min_available_clients=10,  # Wait until all 10 clients are available
)

client_resources = {"num_cpus": 1, "num_gpus": 0.0}
if DEVICE.type == "cuda":
    # here we are assigning an entire GPU for each client.
    client_resources = {"num_cpus": 1, "num_gpus": 1.0}

# Start simulation
start_simulation(
    client_fn=create_client,
    num_clients=NUM_CLIENTS,
    config=ServerConfig(num_rounds=5),
    strategy=strategy,
    client_resources=client_resources,
)

[92mINFO [0m:      Starting Flower simulation, config: num_rounds=5, no round_timeout
2024-07-25 10:03:32,702	INFO worker.py:1752 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initialized with resources: {'memory': 16480411648.0, 'node:127.0.0.1': 1.0, 'node:__internal_head__': 1.0, 'object_store_memory': 2147483648.0, 'CPU': 10.0}
[92mINFO [0m:      Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
[92mINFO [0m:      Flower VCE: Resources for each Virtual Client: {'num_cpus': 1, 'num_gpus': 0.0}
[92mINFO [0m:      Flower VCE: Creating VirtualClientEngineActorPool with 10 actors
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[92mINFO [0m:      Received initial parameters from one random client
[92mINFO [0m:      Evaluating initial global parameters
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy