# Federated Learning using Hugging Face and Flower

This tutorial will show how to leverage Hugging Face to federate the training of language models over multiple clients using [Flower](https://flower.dev/). More specifically, we will fine-tune a pre-trained Transformer model (DistilBERT) for sequence classification.


## Dependencies

For this tutorial we will need `datasets`, `flwr['simulation']`(here we use the extra 'simulation' dependencies from Flower as we will simulated the federated setting inside Google Colab), `torch`, and `transformers`.

In [3]:
!pip install datasets evaluate flwr['simulation'] torch transformers

Collecting datasets
  Using cached datasets-2.13.1-py3-none-any.whl (486 kB)
Collecting evaluate
  Using cached evaluate-0.4.0-py3-none-any.whl (81 kB)
Collecting flwr[simulation]
  Using cached flwr-1.4.0-py3-none-any.whl (157 kB)
Collecting transformers
  Using cached transformers-4.30.2-py3-none-any.whl (7.2 MB)
Collecting dill<0.3.7,>=0.3.0 (from datasets)
  Downloading dill-0.3.6-py3-none-any.whl (110 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.5/212.5 kB[0m [31m19.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting multiprocess (from datasets)
  Downloading multiprocess-0.70.14-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.3/134.3 kB[0m [31m17.2 MB/s[0m eta [

We can now import the relevant modules.

In [4]:
from collections import OrderedDict
import os
import random
import warnings

import flwr as fl
import torch

import pandas as pd

from torch.utils.data import DataLoader

from datasets import load_dataset, Dataset, DatasetDict
from evaluate import load as load_metric

from transformers import AutoTokenizer, DataCollatorWithPadding
from transformers import AutoModelForSequenceClassification
from transformers import AdamW
from transformers import logging

Next we will set some global variables and disable some of the logging to clear out our output.

In [5]:
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)
logging.set_verbosity(logging.ERROR)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
warnings.simplefilter('ignore')

DEVICE = torch.device("cuda")
CHECKPOINT = "distilbert-base-uncased"  # transformer model checkpoint
NUM_CLIENTS = 2
NUM_ROUNDS = 3

## Standard Hugging Face workflow

### Handling the data

To fetch the IMDB dataset, we will use Hugging Face's `datasets` library. We then need to tokenize the data and create `PyTorch` dataloaders, this is all done in the `load_data` function:

In [7]:
def load_data():
    """Load Phishing Email text-data"""
    extracted_text = pd.read_csv("https://anti-phish.s3.eu-west-1.amazonaws.com/dataset/extracted/extracted_text.csv", index_col=0)
    extracted_text = extracted_text.drop(['header'], axis=1)
    emails = Dataset.from_pandas(extracted_text, preserve_index=False)

    train_emails, test_emails = emails.train_test_split(test_size=0.2).values()

    raw_datasets = DatasetDict()
    raw_datasets['train'] = train_emails
    raw_datasets['test'] = test_emails

    # rename the phishing column to label as this is expected by the DataLoader
    raw_datasets = raw_datasets.rename_column('phishing', 'label')

    tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
    def tokenize_function(examples):
        return tokenizer(examples["text"], truncation=True)


    tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)

    # remove the email texts after they've been tokenized
    tokenized_datasets = tokenized_datasets.remove_columns("text")

    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
    trainloader = DataLoader(
        tokenized_datasets["train"],
        shuffle=True,
        batch_size=32,
        collate_fn=data_collator,
    )

    testloader = DataLoader(
        tokenized_datasets["test"],
        batch_size=32,
        collate_fn=data_collator
    )

    return trainloader, testloader

### Training and testing the model

Once we have a way of creating our trainloader and testloader, we can take care of the training and testing. This is very similar to any `PyTorch` training or testing loop:

In [8]:
def train(net, trainloader, epochs):
    optimizer = AdamW(net.parameters(), lr=5e-5)
    net.train()
    for _ in range(epochs):
        for batch in trainloader:
            batch = {k: v.to(DEVICE) for k, v in batch.items()}
            outputs = net(**batch)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()


def test(net, testloader):
    metric = load_metric("accuracy")
    loss = 0
    net.eval()
    for batch in testloader:
        batch = {k: v.to(DEVICE) for k, v in batch.items()}
        with torch.no_grad():
            outputs = net(**batch)
        logits = outputs.logits
        loss += outputs.loss.item()
        predictions = torch.argmax(logits, dim=-1)
        metric.add_batch(predictions=predictions, references=batch["labels"])
    loss /= len(testloader.dataset)
    accuracy = metric.compute()["accuracy"]
    return loss, accuracy

### Creating the model itself

To create the model itself, we will just load the pre-trained distilBERT model using Hugging Face’s `AutoModelForSequenceClassification` :

In [9]:
net = AutoModelForSequenceClassification.from_pretrained(
    CHECKPOINT, num_labels=2
).to(DEVICE)

Downloading (…)lve/main/config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

## Federating the example

The idea behind Federated Learning is to train a model between multiple clients and a server without having to share any data. This is done by letting each client train the model locally on its data and send its parameters back to the server, which then aggregates all the clients’ parameters together using a predefined strategy. This process is made very simple by using the [Flower](https://github.com/adap/flower) framework. If you want a more complete overview, be sure to check out this guide: [What is Federated Learning?](https://flower.dev/docs/tutorial/Flower-0-What-is-FL.html)

### Creating the PhishingClient

To federate our example to multiple clients, we first need to write our Flower client class (inheriting from `flwr.client.NumPyClient`). This is very easy, as our model is a standard `PyTorch` model:

In [10]:
class PhishingClient(fl.client.NumPyClient):
    def __init__(self, net, trainloader, testloader):
        self.net = net
        self.trainloader = trainloader
        self.testloader = testloader

    def get_parameters(self, config):
        return [val.cpu().numpy() for _, val in self.net.state_dict().items()]

    def set_parameters(self, parameters):
        params_dict = zip(self.net.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
        self.net.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        print("Training Started...")
        train(self.net, self.trainloader, epochs=1)
        print("Training Finished.")
        return self.get_parameters(config={}), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        loss, accuracy = test(self.net, self.testloader)
        return float(loss), len(self.testloader), {"accuracy": float(accuracy), "loss": float(loss)}

The `get_parameters` function lets the server get the client's parameters. Inversely, the `set_parameters` function allows the server to send its parameters to the client. Finally, the `fit` function trains the model locally for the client, and the `evaluate` function tests the model locally and returns the relevant metrics.

### Generating the clients

In order to simulate the federated setting we need to provide a way to instantiate clients for our simulation. Here, it is very simple as every client will hold the same piece of data (this is not realistic, it is just used here for simplicity sakes).

In [11]:
trainloader, testloader = load_data()
def client_fn(cid):
  return PhishingClient(net, trainloader, testloader)

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Map:   0%|          | 0/11510 [00:00<?, ? examples/s]

Map:   0%|          | 0/2878 [00:00<?, ? examples/s]

## Starting the simulation

We now have all the elements to start our simulation. The `weighted_average` function is there to provide a way to aggregate the metrics distributed amongst the clients (basically to display a nice average accuracy at the end of the training). We then define our strategy (here `FedAvg`, which will aggregate the clients weights by doing an average).

Finally, `start_simulation` is used to start the training.

In [12]:
def weighted_average(metrics):
  accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
  losses = [num_examples * m["loss"] for num_examples, m in metrics]
  examples = [num_examples for num_examples, _ in metrics]
  return {"accuracy": sum(accuracies) / sum(examples), "loss": sum(losses) / sum(examples)}

strategy = fl.server.strategy.FedAvg(
    fraction_fit=1.0,
    fraction_evaluate=1.0,
    evaluate_metrics_aggregation_fn=weighted_average,
)

fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS),
    strategy=strategy,
    client_resources={"num_cpus": 1, "num_gpus": 1},
    ray_init_args={"log_to_driver": False, "num_cpus": 1, "num_gpus": 1}
)

INFO flwr 2023-07-15 17:14:46,044 | app.py:146 | Starting Flower simulation, config: ServerConfig(num_rounds=3, round_timeout=None)
INFO:flwr:Starting Flower simulation, config: ServerConfig(num_rounds=3, round_timeout=None)
2023-07-15 17:14:48,541	INFO worker.py:1627 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m
INFO flwr 2023-07-15 17:14:49,679 | app.py:180 | Flower VCE: Ray initialized with resources: {'CPU': 1.0, 'object_store_memory': 3927549542.0, 'memory': 7855099086.0, 'GPU': 1.0, 'node:172.28.0.12': 1.0}
INFO:flwr:Flower VCE: Ray initialized with resources: {'CPU': 1.0, 'object_store_memory': 3927549542.0, 'memory': 7855099086.0, 'GPU': 1.0, 'node:172.28.0.12': 1.0}
INFO flwr 2023-07-15 17:14:49,691 | server.py:86 | Initializing global parameters
INFO:flwr:Initializing global parameters
INFO flwr 2023-07-15 17:14:49,695 | server.py:273 | Requesting initial parameters from one random client
INFO:flwr:Requesting initial parameters from

History (loss, distributed):
	round 1: 0.0003416192406546415
	round 2: 0.0003506097240791989
	round 3: 0.000542503059682041
History (metrics, distributed, evaluate):
{'accuracy': [(1, 0.9965253648366922), (2, 0.9968728283530228), (3, 0.9968728283530228)], 'loss': [(1, 0.0003416192406546415), (2, 0.0003506097240791989), (3, 0.000542503059682041)]}

Note that this is a very basic example, and a lot can be added or modified, it was just to showcase how simply we could federate a Hugging Face workflow using Flower.