## Federated fine-tuning of a language model using HuggingFace Transformers

In [None]:
!pip install -U "datasets==3.0.0" "transformers==4.44.2" "evaluate==0.4.3" "ipywidgets==8.1.5" "torch==2.4.1" "accelerate==0.34.2"

## Connect to the federated network

In [7]:
from openfl.interface.interactive_api.federation import Federation

federation = Federation(
    client_id="frontend",
    director_node_fqdn="localhost",
    director_port=50050,
    tls=False,
)

## Define federated experiment

In [8]:
from openfl.interface.interactive_api.experiment import FLExperiment

fl_experiment = FLExperiment(federation=federation, experiment_name="nlp_experiment")

## Register Model

In [None]:
from openfl.interface.interactive_api.experiment import ModelInterface
from transformers import AutoModelForSequenceClassification, AdamW, AutoTokenizer, TrainingArguments, Trainer
import evaluate
import numpy as np

model_name = "prajjwal1/bert-tiny"
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
tokenizer = AutoTokenizer.from_pretrained(model_name)

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=model.config.max_position_embeddings)

for param in model.parameters(): param.data = param.data.contiguous()
params_to_update = [param for param in model.parameters() if param.requires_grad]
optimizer = AdamW(params_to_update, lr=2e-5, weight_decay=0.01)

framework_adapter = "openfl.plugins.frameworks_adapters.pytorch_adapter.FrameworkAdapterPlugin"
MI = ModelInterface(model=model, optimizer=optimizer, framework_plugin=framework_adapter)


## Define federated tasks

In [15]:
from openfl.interface.interactive_api.experiment import TaskInterface

training_args = TrainingArguments(
    output_dir="finetuned_model",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    seed=12345,
    use_cpu=True,
    num_train_epochs=10,
    weight_decay=0.01,
    logging_strategy="epoch",
    log_level="debug",
    auto_find_batch_size=True,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    load_best_model_at_end=True,
    metric_for_best_model="eval_f1",
    greater_is_better=True,
)

def compute_metrics(eval_pred):
    accuracy_metric = evaluate.load("accuracy")
    f1_metric = evaluate.load("f1")

    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)

    accuracy = accuracy_metric.compute(
        predictions=predictions, references=labels
    )["accuracy"]
    f1 = f1_metric.compute(
        predictions=predictions, references=labels, average="weighted"
    )["f1"]

    return {"accuracy": accuracy, "f1": f1}


TI = TaskInterface()

@TI.register_fl_task(
    model="model", data_loader="train_loader", device="device", optimizer="optimizer",
)
def train(model, train_loader, optimizer, device):
    tokenized_dataset = train_loader.map(tokenize_function, batched=True, remove_columns=["text"])

    trainer = Trainer(
        model=model,
        args=training_args,
        optimizers=(optimizer, None),
        train_dataset=tokenized_dataset["train"],
        eval_dataset=tokenized_dataset["test"],
        compute_metrics=compute_metrics,
    )

    return trainer.train().metrics

@TI.register_fl_task(
    model="model", data_loader="test_loader", device="device",
)
def validate(model, test_loader, device):
    tokenized_dataset = test_loader.map(tokenize_function, batched=True, remove_columns=["text"])

    trainer = Trainer(
        model=model,
        args=training_args,
        eval_dataset=tokenized_dataset,
        compute_metrics=compute_metrics,
    )

    eval_metrics = trainer.evaluate()
    return eval_metrics

## Register DataLoader

In [16]:
from openfl.interface.interactive_api.experiment import DataInterface

class CustomDataLoader(DataInterface):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @property
    def shard_descriptor(self):
        return self._shard_descriptor

    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        self._shard_descriptor = shard_descriptor
        self.train_set = self._shard_descriptor.get_dataset("train")
        self.test_set = self._shard_descriptor.get_dataset("test")

    def __getitem__(self, index):
        return self.shard_descriptor[index]

    def __len__(self):
        return len(self.shard_descriptor)

    def get_train_loader(self):
        return self.train_set

    def get_valid_loader(self):
        return self.test_set

    def get_train_data_size(self):
        return len(self.train_set)

    def get_valid_data_size(self):
        return len(self.test_set)

data_loader = CustomDataLoader()


## Start experiment

In [None]:
fl_experiment.start(
    model_provider=MI,
    task_keeper=TI,
    data_loader=data_loader,
    rounds_to_train=3,
    opt_treatment="CONTINUE_GLOBAL",
)

fl_experiment.stream_metrics()