# Multi-Task Federated Learning with GPT-2 using FATE-LLM

In this tutorial, we will explore the implementation of multi-task federated learning with LM: GPT-2 using the FATE-LLM framework. FATE-LLM provides  the "pellm" module for efficient federated learning. It is specifically designed for large language models in a federated setting.

Multi-task learning involves training a model to perform multiple tasks simultaneously. In this tutorial, we will focus on two tasks - sentiment classification and named entity recognition (NER) - and show how they can be combined with GPT-2 in a federated learning setting. We will use the IMDB sentiment analysis dataset and the CoNLL-2003 NER dataset for our tasks.

Additionally, by leveraging the Adapter mechanism, we can effectively reduce communication volume and improve overall efficiency in our federated learning setting.

## GPT2

GPT-2 is a large transformer-based language model with 1.5 billion parameters, trained on a dataset of 8 million web pages. GPT-2 is trained with a causal language modeling (CLM) objective, conditioning on a left-to-right context window of 1024 tokens. In this tutorial, we will use GPT2, you can download the pretrained model from [here](https://huggingface.co/gpt2) (We choose the smallest version for this tutorial), or let the program automatically download it when you use it later.

## Dataset: IMDB Sentimental and CoNLL-2003 NER

About IMDB Sentimental Dataset:

The IMDB dataset is a binary classification dataset containing movie reviews with positive or negative sentiment labels. We will use this dataset as one of our tasks in the multi-task learning setup. You can download our processed dataset from here:

The original data is from:

https://ai.stanford.edu/~amaas/data/sentiment/
About CoNLL-2003 NER Dataset:

The CoNLL-2003 NER dataset is a widely used benchmark dataset for named entity recognition tasks. It contains English and German news articles with named entity annotations for person, organization, and location names. We will use this dataset as another task in our multi-task learning setup. The official website:

https://www.clips.uantwerpen.be/conll2003/ner/
We will use the Hugging Face transformers library to preprocess the text data and tokenize it for use in our multi-task federated learning task. The processed data can be found here:

In this tutorial we will use datasets module provided by huggingface to download these two datasets. We then save the dataset instance to the file system for later use.

## Download and Cache Dataset

In this example, for the ease of display, we will use cache the training set only.

In [17]:
import pickle
from datasets import load_dataset

imdb_dataset = load_dataset("imdb", download_mode="reuse_cache_if_exists")
train_dataset_imdb = imdb_dataset["train"]
val_dataset_imdb = imdb_dataset["test"]

conll_dataset = load_dataset("conll2003", download_mode="reuse_cache_if_exists")
train_dataset_conll = conll_dataset["train"]

# Save the dataset
pickle.dump(train_dataset_imdb, open('./train_dataset_imdb.pkl', 'wb'))
pickle.dump(train_dataset_conll, open('./train_dataset_conll.pkl', 'wb'))

In [3]:
imdb_ds = pickle.load(open('./train_dataset_imdb.pkl', 'rb'))
conll_ds = pickle.load(open('./train_dataset_conll.pkl', 'rb'))

Here we realiaze a MultiTaskDataset class based on FATE-datset class to mix these two dataset together. For more details of FATE dataset setting, we recommend that you read through these tutorials first: [NN Dataset Customization](./Homo-NN-Customize-your-Dataset.ipynb).
Our dataset directly load the cached pickle dataset from a path.

In [5]:
from pipeline.component.nn import save_to_fate

In [24]:
%%save_to_fate dataset multitask_ds.py
import pickle
import torch as t
import tqdm
import torch.nn.utils.rnn as rnn_utils
import os
from federatedml.nn.dataset.base import Dataset
from transformers import AutoTokenizer


# avoid tokenizer parallelism
os.environ["TOKENIZERS_PARALLELISM"] = "false"

sentiment_task_prefix = "sentiment: "
ner_task_prefix = "ner: "


def add_sentiment_task_prefix(example):
    example["text"] = sentiment_task_prefix + example["text"]
    return example



def add_ner_task_prefix(example):
    example["tokens"] = [ner_task_prefix] + example["tokens"]
    return example



class MultiTaskDataset(Dataset):
    """MultiTaskDataset
    Args:
        take_limits: take how many samples from each dataset
        shuffle_seed: shuffle seed
    """
    def __init__(self, tokenizer_name_or_path='gpt2', take_limits=None, shuffle_seed=114514):
        self.ds = []
        self.labels = []
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
        self.tokenizer.padding_side = "left"
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.add_prefix_space = True
        self.ner_tag_num = 0
        self.shuffle_seed = shuffle_seed
        self.take_limits = take_limits

    def convert_imdb_example(self, example):
        encodings = self.tokenizer(example["text"], truncation=True, padding="max_length", max_length=128)
        return {"input_ids": encodings["input_ids"], "attention_mask": encodings["attention_mask"], "labels": example["label"], "task_type": 0}

    def convert_conll_example(self, example):
        encodings = self.tokenizer(example["tokens"], is_split_into_words=True, truncation=True, padding="max_length", max_length=128)
        labels = [-100] + example["ner_tags"] + [-100] * (128 - len(example["ner_tags"]) - 1)
        return {"input_ids": encodings["input_ids"], "attention_mask": encodings["attention_mask"], "labels": labels, "task_type": 1}

    def load(self, path):
        imdb_ds = pickle.load(open(path + '/train_dataset_imdb.pkl', 'rb'))
        conll_ds = pickle.load(open(path + '/train_dataset_conll.pkl', 'rb'))
        imdb_ds= imdb_ds.map(add_sentiment_task_prefix)
        conll_ds = conll_ds.map(add_ner_task_prefix)
        imdb_ds = imdb_ds.map(self.convert_imdb_example)
        conll_ds = conll_ds.map(self.convert_conll_example)
        self.ner_tag_num = conll_ds.features['ner_tags'].feature.num_classes

        if self.take_limits is not None:
            imdb = imdb_ds.shuffle(seed=self.shuffle_seed).select(range(self.take_limits))
            conll = conll_ds.shuffle(seed=self.shuffle_seed).select(range(self.take_limits))
        else:
            imdb = imdb_ds
            conll = conll_ds
        for i in tqdm.tqdm(range(len(imdb))):
            self.ds.append({'input_ids': t.LongTensor(imdb[i]['input_ids']), 'attention_mask': t.LongTensor(imdb[i]['attention_mask']), 'task_type': t.LongTensor([0])})
            self.labels.append(t.LongTensor([imdb[i]['labels']]))
        
        for i in tqdm.tqdm(range(len(conll))):
            self.ds.append({'input_ids': t.LongTensor(conll[i]['input_ids']), 'attention_mask': t.LongTensor(conll[i]['attention_mask']), 'task_type': t.LongTensor([1])})
            self.labels.append(t.LongTensor(conll[i]['labels']))

        # padding the binary classification labels to match the length of the ner labels
        self.labels = rnn_utils.pad_sequence(self.labels, batch_first=True, padding_value=-100)

    def __getitem__(self, index):
        return self.ds[index], self.labels[index]

    def __len__(self):
        return len(self.ds)

In [23]:
ds = MultiTaskDataset(tokenizer_name_or_path='your path', take_limits=500)
ds.load('./')

In [9]:
ds[0]

({'input_ids': tensor([34086,  3681,    25,  2293,   262, 21840,   317, 15154,    39,  1677,
              7, 16942,     8,   543,   373,   257, 28763,   286,   257, 28600,
           7246,   711,   339,  2058,   351, 16400,    48,    51,   543,  1165,
           3073,   588,   257,  3800,   711,    27,  1671,  1220,  6927,  1671,
          11037,   818,  3800,  5341,    11,   356,   423,  3435, 19642,    11,
            625, 27362,   994,  1165,   262,   976,    27,  1671,  1220,  6927,
           1671, 11037,   464,   717,  2063,  2523, 48148,   397,    71,  2048,
          26471,   262,  2319,    10,   317, 50133,   323, 26105,   508,  6529,
           1165,  8258,   588,   257,  1402, 34712,    27,  1671,  1220,  6927,
           1671, 11037,   464,  2646,   468,   257,   922,  3275,   703,   407,
            284, 20851,   534,  3367,   475, 21098,   262,   835, 48148,   397,
             71,  3382,   284,   787,  9084,  4106,  4497,   318,  5543,  8390,
             27,  1671,  12

## GPT2 Model with Adapter for Multi-Task Learning

In this section, we will demonstrate how to build a parameter-efficient language model using our PELLM models for multi-task learning in a federated setting. The PELLM model already comes equipped with the Adapter mechanism, which simplifies the process of integrating multiple tasks into a single model.

We will focus on implementing two tasks in our multi-task learning setup - sentiment analysis and named entity recognition (NER). The PELLM model will have two classification heads, one for each task, enabling it to simultaneously perform both tasks.

In [10]:
%%save_to_fate model gpt2_multitask.py
import torch as t
from federatedml.nn.model_zoo.pellm.gpt2 import GPT2


class MultiTaskGPT2(GPT2):

    """MultiTaskGPT2
    
    Args:
        hidden_size (int): embedding size of the GPT2 model
        output_dim1 (int, optional): output dimension of the first task. Defaults to 2.
        output_dim2 (int, optional): output dimension of the second task. Defaults to 9.
        pretrained_path : pretrained model path, or use 'gpt2' to download model from huggingface
        adapter_type: adapter type, see parent class for details
    """
    # ner tag number is 9 in this conll dataset
    def __init__(self, hidden_size, output_dim1=2, output_dim2=9, **kwargs):
        super(MultiTaskGPT2, self).__init__(**kwargs)
        
        # sentimental classifcation
        self.softmax = t.nn.Softmax(dim=-1)
        self.classifier = t.nn.Linear(hidden_size, output_dim1)
        # ner classification
        self.ner_classifier = t.nn.Linear(hidden_size, output_dim2)

    def forward(self, data):
        task_type = data['task_type']

        # GPT2 forward pass
        outputs = super().forward(data)

        # Get the last hidden state from GPT2
        last_hidden_state = outputs.last_hidden_state

        # Split the input based on the task type
        task0_mask = task_type == 0
        task1_mask = task_type == 1

        if task0_mask.any():
            # Classification task
            task0_logits = self.classifier(last_hidden_state[task0_mask.flatten()][::, -1, ::])
            task0_logits = self.softmax(task0_logits)

        if task1_mask.any():
            # Sequence labeling task
            task1_logits = self.ner_classifier(last_hidden_state[task1_mask.flatten()])
            task1_logits = self.softmax(task1_logits)

        out_ = {0: None, 1: None, 'task_type': task_type}
        if task0_mask.any():
            out_[0] = task0_logits
        if task1_mask.any():
            out_[1] = task1_logits

        return out_

In [20]:
model = MultiTaskGPT2(hidden_size=768, pretrained_path='your path', adapter_type='HoulsbyConfig')

## Multi-task Loss

The MultiTaskLoss function computes a weighted sum of the losses for each task in the multi-task learning setup. The weights for each task are specified in the task_weights parameter, allowing users to adjust the relative importance of each task.

See [Loss Customization](./Homo-NN-Customize-Loss.ipynb) for more details of customizing a loss class

In [12]:
%%save_to_fate loss multi_task_loss.py
import torch

class MultiTaskLoss(torch.nn.Module):
    def __init__(self, task_weights):
        super().__init__()
        self.task_weights = task_weights
        self.classification_loss = torch.nn.CrossEntropyLoss(ignore_index=-100)
        self.ner_loss = torch.nn.CrossEntropyLoss(ignore_index=-100)

    def forward(self, model_out, labels):
        loss = 0
        if model_out[0] is not None:
            # Compute classification loss
            classification_labels = labels[model_out['task_type'].flatten() == 0]
            classification_logits = model_out[0]
            classification_loss = self.classification_loss(classification_logits, classification_labels[::, 0])
            weighted_classification_loss = classification_loss * self.task_weights[0]
            loss += weighted_classification_loss

        if model_out[1] is not None:
            # Compute NER loss
            ner_labels = labels[model_out['task_type'].flatten() == 1].flatten()
            ner_logits = model_out[1].reshape(-1, model_out[1].shape[-1])
            ner_loss = self.ner_loss(ner_logits, ner_labels)
            weighted_ner_loss = ner_loss * self.task_weights[1]
            loss += weighted_ner_loss

        return loss


In [13]:
loss = MultiTaskLoss([0.5, 0.5])

# Local Test

Before submitting a federated learning task, it is important to perform local testing to ensure that your custom dataset and model are working properly.

To perform local testing with the MultiTaskFedAVGTrainer class, we will first instantiate the class with our preprocessed dataset and PELLM model with Adapter. We will then run the trainer in local mode, using a small subset of the data, to test that the model is working as expected.

It is important to note that the MultiTaskFedAVGTrainer class is a toy class that has not been rigorously tested. Don't use it for production.

In [14]:
from pipeline.component.nn import save_to_fate

In [16]:
%%save_to_fate trainer multi_task_fedavg.py
import tqdm
import torch as t
from torch.utils.data import DataLoader
from federatedml.nn.homo.trainer.fedavg_trainer import FedAVGTrainer
from federatedml.util import LOGGER
from sklearn.metrics import accuracy_score


class MultiTaskFedAVGTrainer(FedAVGTrainer):

    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

    def train_an_epoch(self, epoch_idx, model, train_set, optimizer, loss):

        epoch_loss = 0.0
        batch_idx = 0
        acc_num = 0

        if self.data_loader is None:
            self.data_loader = DataLoader(
                                    train_set,
                                    batch_size=self.batch_size,
                                    pin_memory=self.pin_memory,
                                    shuffle=self.shuffle,
                                    num_workers=self.data_loader_worker)
        
        dl = self.data_loader

        if not self.fed_mode:
            to_iterate = tqdm.tqdm(dl)
        else:
            to_iterate = dl

        task_pred = {0: [], 1: []}
        task_label = {0: [], 1: []}

        for batch_data, batch_label in to_iterate:

            if self.cuda is not None:
                batch_data, batch_label = self.to_cuda(
                    batch_data, self.cuda_main_device), self.to_cuda(batch_label, self.cuda_main_device)

            optimizer.zero_grad()
            pred = model(batch_data)
            batch_loss = loss(pred, batch_label)

            if pred[0] is not None:
                task_pred[0].append(pred[0].cpu().detach())
                task_label[0].append(batch_label[batch_data['task_type'].flatten() == 0].cpu().detach())
            if pred[1] is not None:
                task_pred[1].append(pred[1].cpu().detach())
                task_label[1].append(batch_label[batch_data['task_type'].flatten() == 1].cpu().detach())

            batch_loss.backward()
            optimizer.step()
            batch_loss_np = batch_loss.detach().numpy(
            ) if self.cuda is None else batch_loss.cpu().detach().numpy()
            if acc_num + self.batch_size > len(train_set):
                batch_len = len(train_set) - acc_num
            else:
                batch_len = self.batch_size
            epoch_loss += batch_loss_np * batch_len
            batch_idx += 1

        task_0_pred = t.vstack(task_pred[0]).argmax(dim=1).flatten()
        task_0_label = t.vstack(task_label[0])[::, 0].flatten()
        LOGGER.debug('task 0 acc {}'.format(accuracy_score(task_0_label.numpy(), task_0_pred.numpy())))
        task_1_pred = t.vstack(task_pred[1]).argmax(dim=-1).flatten()
        task_1_label = t.vstack(task_label[1]).flatten()
        mask = task_1_label != -100
        LOGGER.debug('task 1 acc {}'.format(accuracy_score(task_1_pred[mask].numpy(), task_1_label[mask].numpy())))


        if self.fed_mode:
            LOGGER.debug(
                'epoch {} batch {} finished'.format(epoch_idx, batch_idx))
        
        epoch_loss = epoch_loss / len(train_set)
        return epoch_loss
    
    def predict(self, dataset):
        # currently FATE does not support handling the result of multi-task model, so
        # we disable the predict function
        return None

In [18]:
trainer = MultiTaskFedAVGTrainer(epochs=5, batch_size=4, data_loader_worker=8, shuffle=False)
trainer.local_mode()
trainer.set_model(model)

In [19]:
optimizer = t.optim.Adam(model.parameters(), lr=0.0001)
trainer.train(ds, None, optimizer, loss)

epoch is 0
100%|██████████| 250/250 [02:15<00:00,  1.84it/s]
task 0 acc 0.514
task 1 acc 0.7738525357955306
epoch loss is 0.6051078315377235
epoch is 1
100%|██████████| 250/250 [02:22<00:00,  1.76it/s]
task 0 acc 0.52
task 1 acc 0.83701324769169
epoch loss is 0.5854948361515999
epoch is 2
100%|██████████| 250/250 [02:19<00:00,  1.80it/s]
task 0 acc 0.522
task 1 acc 0.83701324769169
epoch loss is 0.584037158548832
epoch is 3
100%|██████████| 250/250 [02:23<00:00,  1.74it/s]
task 0 acc 0.524
task 1 acc 0.83701324769169
epoch loss is 0.582937289237976
epoch is 4
100%|██████████| 250/250 [02:43<00:00,  1.53it/s]
task 0 acc 0.526
task 1 acc 0.83701324769169
epoch loss is 0.5806194019317626


## Submit Federated Task
Once you have successfully completed local testing, We can submit a task to FATE. Please notice that this tutorial is ran on a standalone version. **Please notice that in this tutorial we are using a standalone version, if you are using a cluster version, you need to bind the data with the corresponding name&namespace on each machine.**

In [4]:
import torch as t
import os
from pipeline import fate_torch_hook
from pipeline.component import HomoNN
from pipeline.backend.pipeline import PipeLine
from pipeline.component import Reader
from pipeline.interface import Data

fate_torch_hook(t)


fate_project_path = os.path.abspath('../../../../')
guest_0 = 10000
host_1 = 9999
pipeline = PipeLine().set_initiator(role='guest', party_id=guest_0).set_roles(guest=guest_0, host=host_1,
                                                                              arbiter=guest_0)
data_0 = {"name": "imdb", "namespace": "experiment"}
data_path = fate_project_path + '/doc/tutorial/pipeline/nn_tutorial'
pipeline.bind_table(name=data_0['name'], namespace=data_0['namespace'], path=data_path)
pipeline.bind_table(name=data_0['name'], namespace=data_0['namespace'], path=data_path)
reader_0 = Reader(name="reader_0")
reader_0.get_party_instance(role='guest', party_id=guest_0).component_param(table=data_0)
reader_0.get_party_instance(role='host', party_id=host_1).component_param(table=data_0)

reader_1 = Reader(name="reader_1")
reader_1.get_party_instance(role='guest', party_id=guest_0).component_param(table=data_0)
reader_1.get_party_instance(role='host', party_id=host_1).component_param(table=data_0)


## Add your pretriained model path here, will load model&tokenizer from this path
model_path = ''


from pipeline.component.homo_nn import DatasetParam, TrainerParam  
model = t.nn.Sequential(
    t.nn.CustModel(module_name='gpt2_multitask', class_name='MultiTaskGPT2', pretrained_path=model_path, adapter_type='HoulsbyConfig', hidden_size=768)
)

# DatasetParam
dataset_param = DatasetParam(dataset_name='multitask_ds', take_limits=50, tokenizer_name_or_path=model_path)
# TrainerParam
trainer_param = TrainerParam(trainer_name='multi_task_fedavg', epochs=1, batch_size=4, 
                             data_loader_worker=8, secure_aggregate=True)
loss = t.nn.CustLoss(loss_module_name='multi_task_loss', class_name='MultiTaskLoss', task_weights=[0.5, 0.5])


nn_component = HomoNN(name='nn_0', model=model)

# set parameter for client 1
nn_component.get_party_instance(role='guest', party_id=guest_0).component_param(
    loss=loss,
    optimizer = t.optim.Adam(lr=0.0001, eps=1e-8),
    dataset=dataset_param,       
    trainer=trainer_param,
    torch_seed=100 
)

# set parameter for client 2
nn_component.get_party_instance(role='host', party_id=host_1).component_param(
    loss=loss,
    optimizer = t.optim.Adam(lr=0.0001, eps=1e-8),
    dataset=dataset_param,       
    trainer=trainer_param,
    torch_seed=100 
)

# set parameter for server
nn_component.get_party_instance(role='arbiter', party_id=guest_0).component_param(    
    trainer=trainer_param
)

pipeline.add_component(reader_0)
pipeline.add_component(nn_component, data=Data(train_data=reader_0.output.data))
pipeline.compile()

pipeline.fit()

[32m2023-03-31 11:28:05.423[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m83[0m - [1mJob id is 202303311128051297170
[0m
[32m2023-03-31 11:28:05.428[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m98[0m - [1m[80D[1A[KJob is still waiting, time elapse: 0:00:00[0m
[0mm2023-03-31 11:28:06.453[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m125[0m - [1m
[32m2023-03-31 11:28:06.453[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m127[0m - [1m[80D[1A[KRunning component reader_0, time elapse: 0:00:01[0m
[32m2023-03-31 11:28:07.619[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m127[0m - [1m[80D[1A[KRunning component reader_0, time elapse: 0:00:02[0m
[0mm2023-03-31 11:28:08.630[0m | [1mI