# Federated Learning using Hugging Face and Flower

This tutorial will show how to leverage Hugging Face to federate the training of language models over multiple clients using [Flower](https://flower.dev/). More specifically, we will fine-tune a pre-trained Transformer model (alBERT) for sequence classification over a dataset of IMDB ratings. The end goal is to detect if a movie rating is positive or negative.


## Dependencies

For this tutorial we will need `datasets`, `flwr['simulation']`(here we use the extra 'simulation' dependencies from Flower as we will simulated the federated setting inside Google Colab), `torch`, and `transformers`.

In [1]:
!pip install datasets flwr["simulation"] torch transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.10.1-py3-none-any.whl (469 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m469.0/469.0 KB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting flwr[simulation]
  Downloading flwr-1.3.0-py3-none-any.whl (139 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m139.3/139.3 KB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
Collecting transformers
  Downloading transformers-4.27.2-py3-none-any.whl (6.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m60.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting multiprocess
  Downloading multiprocess-0.70.14-py39-none-any.whl (132 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m132.9/132.9 KB[0m [31m17.4 MB/s[0m eta [36m0:00:00[0m
Collecting aiohttp
  Downloading aiohttp-3.8.4-cp39-cp39-manylinux_2_17_x86_64.man

We can now import the relevant modules.

In [2]:
from collections import OrderedDict
import os
import random
import warnings

import flwr as fl
import torch

from torch.utils.data import DataLoader

from datasets import load_dataset, load_metric

from transformers import AutoTokenizer, DataCollatorWithPadding
from transformers import AutoModelForSequenceClassification
from transformers import AdamW
from transformers import logging

Next we will set some global variables and disable some of the logging to clear out our output.

In [3]:
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)
logging.set_verbosity(logging.ERROR)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
warnings.simplefilter('ignore')

DEVICE = torch.device("cpu")
CHECKPOINT = "albert-base-v2"  # transformer model checkpoint
NUM_CLIENTS = 2
NUM_ROUNDS = 3

## Standard Hugging Face workflow

### Handling the data

To fetch the IMDB dataset, we will use Hugging Face's `datasets` library. We then need to tokenize the data and create `PyTorch` dataloaders, this is all done in the `load_data` function:

In [4]:
def load_data():
    """Load IMDB data (training and eval)"""
    raw_datasets = load_dataset("imdb")
    raw_datasets = raw_datasets.shuffle(seed=42)

    # remove unnecessary data split
    del raw_datasets["unsupervised"]

    tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)

    def tokenize_function(examples):
        return tokenizer(examples["text"], truncation=True)

    # random 10 samples
    population = random.sample(range(len(raw_datasets["train"])), 10)

    tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
    tokenized_datasets["train"] = tokenized_datasets["train"].select(population)
    tokenized_datasets["test"] = tokenized_datasets["test"].select(population)

    tokenized_datasets = tokenized_datasets.remove_columns("text")
    tokenized_datasets = tokenized_datasets.rename_column("label", "labels")

    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
    trainloader = DataLoader(
        tokenized_datasets["train"],
        shuffle=True,
        batch_size=32,
        collate_fn=data_collator,
    )

    testloader = DataLoader(
        tokenized_datasets["test"], batch_size=32, collate_fn=data_collator
    )

    return trainloader, testloader

### Training and testing the model

Once we have a way of creating our trainloader and testloader, we can take care of the training and testing. This is very similar to any `PyTorch` training or testing loop:

In [5]:
def train(net, trainloader, epochs):
    optimizer = AdamW(net.parameters(), lr=5e-5)
    net.train()
    for _ in range(epochs):
        for batch in trainloader:
            batch = {k: v.to(DEVICE) for k, v in batch.items()}
            outputs = net(**batch)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()


def test(net, testloader):
    metric = load_metric("accuracy")
    loss = 0
    net.eval()
    for batch in testloader:
        batch = {k: v.to(DEVICE) for k, v in batch.items()}
        with torch.no_grad():
            outputs = net(**batch)
        logits = outputs.logits
        loss += outputs.loss.item()
        predictions = torch.argmax(logits, dim=-1)
        metric.add_batch(predictions=predictions, references=batch["labels"])
    loss /= len(testloader.dataset)
    accuracy = metric.compute()["accuracy"]
    return loss, accuracy

### Creating the model itself

To create the model itself, we will just load the pre-trained alBERT model using Hugging Face’s `AutoModelForSequenceClassification` :

In [6]:
net = AutoModelForSequenceClassification.from_pretrained(
    CHECKPOINT, num_labels=2
).to(DEVICE)

Downloading (…)lve/main/config.json:   0%|          | 0.00/684 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/47.4M [00:00<?, ?B/s]

## Federating the example

The idea behind Federated Learning is to train a model between multiple clients and a server without having to share any data. This is done by letting each client train the model locally on its data and send its parameters back to the server, which then aggregates all the clients’ parameters together using a predefined strategy. This process is made very simple by using the [Flower](https://github.com/adap/flower) framework. If you want a more complete overview, be sure to check out this guide: [What is Federated Learning?](https://flower.dev/docs/tutorial/Flower-0-What-is-FL.html)

### Creating the IMDBClient

To federate our example to multiple clients, we first need to write our Flower client class (inheriting from `flwr.client.NumPyClient`). This is very easy, as our model is a standard `PyTorch` model:

In [7]:
class IMDBClient(fl.client.NumPyClient):
    def __init__(self, net, trainloader, testloader):
        self.net = net
        self.trainloader = trainloader
        self.testloader = testloader

    def get_parameters(self, config):
        return [val.cpu().numpy() for _, val in self.net.state_dict().items()]

    def set_parameters(self, parameters):
        params_dict = zip(self.net.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
        self.net.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        print("Training Started...")
        train(self.net, self.trainloader, epochs=1)
        print("Training Finished.")
        return self.get_parameters(config={}), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        loss, accuracy = test(self.net, self.testloader)
        return float(loss), len(self.testloader), {"accuracy": float(accuracy), "loss": float(loss)}

The `get_parameters` function lets the server get the client's parameters. Inversely, the `set_parameters` function allows the server to send its parameters to the client. Finally, the `fit` function trains the model locally for the client, and the `evaluate` function tests the model locally and returns the relevant metrics. 

### Generating the clients

In order to simulate the federated setting we need to provide a way to instantiate clients for our simulation. Here, it is very simple as every client will hold the same piece of data (this is not realistic, it is just used here for simplicity sakes).

In [8]:
trainloader, testloader = load_data()
def client_fn(cid):
  return IMDBClient(net, trainloader, testloader)

Downloading builder script:   0%|          | 0.00/4.31k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.17k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.59k [00:00<?, ?B/s]

Downloading and preparing dataset imdb/plain_text to /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0...


Downloading data:   0%|          | 0.00/84.1M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Dataset imdb downloaded and prepared to /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

Downloading (…)ve/main/spiece.model:   0%|          | 0.00/760k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.31M [00:00<?, ?B/s]

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

## Starting the simulation

We now have all the elements to start our simulation. The `weighted_average` function is there to provide a way to aggregate the metrics distributed amongst the clients (basically to display a nice average accuracy at the end of the training). We then define our strategy (here `FedAvg`, which will aggregate the clients weights by doing an average). 

Finally, `start_simulation` is used to start the training.

In [9]:
def weighted_average(metrics):
  accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
  losses = [num_examples * m["loss"] for num_examples, m in metrics]
  examples = [num_examples for num_examples, _ in metrics]
  return {"accuracy": sum(accuracies) / sum(examples), "loss": sum(losses) / sum(examples)}

strategy = fl.server.strategy.FedAvg(
    fraction_fit=1.0,
    fraction_evaluate=1.0,
    evaluate_metrics_aggregation_fn=weighted_average,
)

fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS),
    strategy=strategy,
    ray_init_args={"log_to_driver": False}
)

INFO flwr 2023-03-23 03:47:56,478 | app.py:145 | Starting Flower simulation, config: ServerConfig(num_rounds=3, round_timeout=None)
INFO:flwr:Starting Flower simulation, config: ServerConfig(num_rounds=3, round_timeout=None)
2023-03-23 03:47:57,840	INFO worker.py:1529 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m
INFO flwr 2023-03-23 03:47:59,293 | app.py:179 | Flower VCE: Ray initialized with resources: {'memory': 7907964519.0, 'node:172.28.0.12': 1.0, 'CPU': 2.0, 'GPU': 1.0, 'object_store_memory': 3953982259.0}
INFO:flwr:Flower VCE: Ray initialized with resources: {'memory': 7907964519.0, 'node:172.28.0.12': 1.0, 'CPU': 2.0, 'GPU': 1.0, 'object_store_memory': 3953982259.0}
INFO flwr 2023-03-23 03:47:59,302 | server.py:86 | Initializing global parameters
INFO:flwr:Initializing global parameters
INFO flwr 2023-03-23 03:47:59,305 | server.py:270 | Requesting initial parameters from one random client
INFO:flwr:Requesting initial parameters from

History (loss, distributed):
	round 1: 0.07067623734474182
	round 2: 0.0741086483001709
	round 3: 0.07937008142471313
History (metrics, distributed):
{'accuracy': [(1, 0.5), (2, 0.5), (3, 0.6)], 'loss': [(1, 0.07067623734474182), (2, 0.0741086483001709), (3, 0.07937008142471313)]}

Note that this is a very basic example, and a lot can be added or modified, it was just to showcase how simply we could federate a Hugging Face workflow using Flower. The number of clients and the data samples are intentionally very small in order to quickly run inside Colab, but keep in mind that everything can be tweaked and extended.