# Fine-tuning LLM with Federated Learning - Hands-on

* Tutor:
  * Mr. Peng Yan (<peng.yan-1@uts.edu.au>)
  * Ms. Yiyuan Yang (<yiyuan.yang-1@student.uts.edu.au>)
* Supervisor:
  * A/Prof. Guodong Long (<guodong.long@uts.edu.au>)

## **Task 0:** Set up a pre-trained LLM


### Import pre-trained GPT-2 from the hugging-face
* reference: [OpenAI GPT2](https://huggingface.co/docs/transformers/en/model_doc/gpt2)

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"use device: {device}")


model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
tokenizer = AutoTokenizer.from_pretrained("gpt2")

print(model)

use device: cuda


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)




### Generate text by GPT-2

In [None]:
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
model.generation_config.pad_token_id = tokenizer.pad_token_id

prompt = "GPT2 is a model developed by OpenAI."

input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

gen_tokens = model.generate(
    input_ids,
    do_sample=True,
    temperature=0.9,
    max_length=100,
)

gen_text = tokenizer.batch_decode(gen_tokens)[0]
print("*"*20)
print(f"Prompt: {prompt}")
print(f"Tokens: {input_ids}")
print(f"Answer: {gen_text}")

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


********************
Prompt: GPT2 is a model developed by OpenAI.
Tokens: tensor([[   38, 11571,    17,   318,   257,  2746,  4166,   416,  4946, 20185,
            13]], device='cuda:0')
Answer: GPT2 is a model developed by OpenAI. It serves as an approximation to the model of a single-celled organism.

For a single-celled organism, OCR has a few key functions:

it's a model of organism from RNAseq.

It's a model of single-celled organisms. It provides important information for predicting the number of cells in a given population.

It coordinates the total number of cells of a particular organism from one


## **Task 1:** Prepare local training data (3 pts)

### Install huggingface datasets library

* reference: [hugging face datasets](https://huggingface.co/docs/datasets/en/index)

In [None]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.0.2-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.0.2-py3-none-any.whl (472 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m472.7/472.7 kB[0m [31m33.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading multiprocess-0.70.16-py310-none-any.whl (134 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m14.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading

### Load datasets

In [None]:
from datasets import load_dataset
raw_datasets = load_dataset("glue", "mrpc")
raw_datasets

README.md:   0%|          | 0.00/35.3k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/649k [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/75.7k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/308k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/3668 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/408 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1725 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx'],
        num_rows: 3668
    })
    validation: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx'],
        num_rows: 408
    })
    test: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx'],
        num_rows: 1725
    })
})

In [None]:
import numpy as np
import json

raw_train_dataset = raw_datasets["train"]
example = np.random.choice(raw_train_dataset)
formatted_json = json.dumps(example, indent=4)
print(formatted_json)

{
    "sentence1": "The tech-heavy Nasdaq Stock Markets composite index added 14.17 points or 0.94 per cent to 1,517.05 .",
    "sentence2": "The Nasdaq Composite index , full of technology stocks , was lately up around 18 points .",
    "label": 0,
    "idx": 3232
}


### **Hint 1:** Implement a tokenize_function converting raw_data to model inputs

1.   Convert raw data into prompts
2.   Tokenize your prompts as input_ids
3.   Mask non-output tokens with `-100` as labels

Notes:
1. Mask tokens before the `'Output:'`, including the `'Output:'` itself. For example, let the prompt be: 'I like machine learning. Output: Yes', then the mask of `-100` shall cover tokens of "I like machine learning. Output:"
2. Thank you for your feedback! We’ve fixed a bug in the example below, ensuring that the labels now match the prompts. (The formatting of the prompts for demonstration purposes caused a mismatch with the labels due to extra whitespace.


Fro example:


Raw data:

    "sentence1": "Physicians who violate the ban would be subject to fines and up to two years in prison .",
    "sentence2": "Physicians who perform the procedure would face up to two years in prison , under the bill .",
    "label": 1,
    "idx": 115

Prompts:

    Are Sentenc1 and Sentence2 equivalent?
    Sentence 1: Physicians who violate the ban would be subject to fines and up to two years in prison .
    Sentence 2: Physicians who perform the procedure would face up to two years in prison , under the bill .
    Options: -- equivalent
             -- not equivalent
    Output: equivalent<|endoftext|>


input_ids:


    [8491, 11352, 12685, 16, 290, 11352, 594, 17, 7548, 30, 198, 31837, 594, 352, 25, 46206, 508, 16967, 262, 3958, 561, 307, 2426, 284, 17176, 290, 510, 284, 734, 812, 287, 3770, 764, 198, 31837, 594, 362, 25, 46206, 508, 1620, 262, 8771, 561, 1986, 510, 284, 734, 812, 287, 3770, 837, 739, 262, 2855, 764, 198, 29046, 25, 1377, 7548, 198, 220, 220, 220, 220, 220, 220, 220, 220, 1377, 407, 7548, 198, 25235, 25, 7548, 50256]

labels:

    [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 50256]

In [None]:
def tokenize_function(raw_data):
    # Determine the label
    label = 'equivalent' if raw_data['label'] == 1 else 'not equivalent'

    # Create the prompt for the model
    prompt = (
        f'Are Sentence 1 and Sentence 2 equivalent? \n'
        f'Sentence 1: {raw_data["sentence1"]}\n'
        f'Sentence 2: {raw_data["sentence2"]}\n'
        f'Options: -- equivalent\n'
        f'         -- not equivalent\n'
        f'Output: {label}<|endoftext|>'
    )

    # Tokenize the prompt (with attention mask automatically generated)
    tokenized_prompt = tokenizer(prompt, return_tensors="pt")
    input_ids = tokenized_prompt['input_ids'].to(device)
    attention_mask = tokenized_prompt['attention_mask'].to(device)

    # Tokenize the expected output (label)
    out = f'{label}<|endoftext|>'
    tokenized_out = tokenizer(out, return_tensors="pt")
    out_id = tokenized_out['input_ids'].to(device)

    # Create the labels, masking everything except the label output part
    labels = input_ids.clone()
    len_out = out_id.size(1)

    # Mask everything except the output part
    labels[:, :-len_out] = -100  # Mark everything before the output as -100 (ignored during loss computation)

    return {
        'input_ids': input_ids[0],
        'attention_mask': attention_mask[0],  # Automatically generated by tokenizer
        'labels': labels[0]
    }

# Example usage
data = np.random.choice(raw_train_dataset)

tokenized_example = tokenize_function(data)
print(tokenized_example)


{'input_ids': tensor([ 8491, 11352,   594,   352,   290, 11352,   594,   362,  7548,    30,
          220,   198, 31837,   594,   352,    25,  4380, 18380,   837,   286,
        35469,   261,   837,  3442,   837,   531,   262,  1730,   561,  2620,
          663,  5472, 12042,   837, 23494,   716,   419,  1634,   290,   584,
         3709,   764,   198, 31837,   594,   362,    25,   887,  4380, 18380,
          531,   262,   649,  1730,   815,   751,   284,   663,  5472, 12042,
          837, 23494,   716,   419,  1634,   290,   584,  3709,   764,   198,
        29046,    25,  1377,  7548,   198,   220,   220,   220,   220,   220,
          220,   220,   220,  1377,   407,  7548,   198, 26410,    25,  7548,
        50256], device='cuda:0'), 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

### **Hint 2:** Apply your tokenize_function to convert the dataset

Convert `raw_datasets` to `tokenized_datasets`, which will include the columns `input_ids`, `attention_mask`, and `labels`, while dropping the original columns: `sentence1`, `sentence2`, `label`, and `idx`.

* `input_ids`:
  
    
    tensor([ 8491, 11352, 12685,    16,   290, 11352,   594,    17,  7548,    30, 198, 31837,   594,   352,    25,  3819,   837,   517,  4569,  5254,389,   635,  1695,   764,   198, 31837,   594,   362,    25, 29065, 5254,   635,   389,  1695,   379,   645,  1575,  1909,   764,   198, 29046,    25,  1377,  7548,   198,   438,   407,  7548,   198, 25235, 25,   407,  7548, 50256])


* `attention_mask`:
   an attention mask is used to determine which tokens are attended to during the computation of attention scores. For example:


    Sentence 1: I   like   machine learning   .
          Mask: 1,   1,      1,      1,       1
    Sentence 2: I   like     AI      .      <pad>
          Mask: 1,   1,      1,      1,       0
    We use a `<pad>` token to align the lengths of two sentences, but the attention mask is used to ignore the `<pad>` during processing.

  * Since `padding` and `attention_mask` are beyond the scope of this tutorial, you can use the defaults generated by the `tokenizer`. For more information on `padding` and `attention_mask`, please refer to the [Hugging Face documentation on padding and truncation](https://huggingface.co/docs/transformers/en/pad_truncation).
  
  * There’s no need to manually ignore labels in the `attention_mask`, as the `data_collator` in the `DataLoader` will take care of that for you. For example:
  

    data_collator = DataCollatorForSeq2Seq(tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True)
    eval_dataloader = DataLoader(tokenized_datasets["validation"], shuffle=True, batch_size=1, collate_fn=data_collator))


* `labels`:


    tensor([ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100, -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,-100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100, -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100, -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100, -100,   407,  7548, 50256])

In [None]:
tokenized_datasets = raw_datasets.map(tokenize_function, batched=False)

# Remove the original columns that are no longer needed
tokenized_datasets = tokenized_datasets.remove_columns(["sentence1", "sentence2", "label", "idx"])

# Now the tokenized_datasets will have only the columns input_ids, attention_mask, and labels
print(tokenized_datasets)

Map:   0%|          | 0/3668 [00:00<?, ? examples/s]

Map:   0%|          | 0/408 [00:00<?, ? examples/s]

Map:   0%|          | 0/1725 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 3668
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 408
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 1725
    })
})


### Install evaluation tools

In [None]:
!pip install evaluate
!pip install rouge_score

Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Downloading evaluate-0.4.3-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: evaluate
Successfully installed evaluate-0.4.3
Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: rouge_score
  Building wheel for rouge_score (setup.py) ... [?25l[?25hdone
  Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24935 sha256=70bd83e7c3ce4c3d2dd0fc8cfe21c2faedde89738faab9ee69528be08ec403b7
  Stored in directory: /root/.cache/pip/wheels/5f/dd/89/461065a73be61a532ff8599a28e9beef17985c9e9c31e541b4
Successfully built rouge_score
Installing collected packages: rouge_score
Successfully installed rouge_score-0.1.2


### Evaluate the performance of GPT-2 on your dataset

In [None]:
import evaluate
from transformers import DataCollatorForSeq2Seq
from torch.utils.data import DataLoader

data_collator = DataCollatorForSeq2Seq(tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True)


metric = evaluate.load('rouge')
eval_dataloader = DataLoader(tokenized_datasets["validation"], shuffle=True, batch_size=1, collate_fn=data_collator)
model.eval()
for batch in eval_dataloader:
    input_len=torch.nonzero(batch['labels'][0]!=-100).squeeze()[0]
    input_ids = batch['input_ids'][0][:input_len.item()].unsqueeze(0)

    label = tokenizer.batch_decode(batch['input_ids'])[0]
    label = label.split('<|endoftext|>')[0].split('Output:')[-1]
    # break
    with torch.no_grad():
        gen_tokens = model.generate(
            input_ids.to(device),
            do_sample=True,
            temperature=0.9,
            max_length=200,
        )
        gen_text = tokenizer.batch_decode(gen_tokens.to('cpu'))[0].split('Output:')[-1]

    gen_text = gen_text.split('<|endoftext|>')[0]
    metric.add_batch(predictions=[gen_text], references=[label])

metric.compute()

Downloading builder script:   0%|          | 0.00/6.27k [00:00<?, ?B/s]

{'rouge1': 0.016134010076864005,
 'rouge2': 0.000677003991458153,
 'rougeL': 0.015804751008245063,
 'rougeLsum': 0.01619159956313583}

{'rouge1': 0.016134010076864005,
 'rouge2': 0.000677003991458153,
 'rougeL': 0.015804751008245063,
 'rougeLsum': 0.01619159956313583}

## **Task 2:** Fine-tuning GPT-2 with LoRA (3 pts)

### Install parameter-efficient fine-tuning library
* reference: [hugging face peft library](https://huggingface.co/docs/peft/en/index)

In [None]:
!pip install peft

Collecting peft
  Downloading peft-0.13.2-py3-none-any.whl.metadata (13 kB)
Downloading peft-0.13.2-py3-none-any.whl (320 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m320.7/320.7 kB[0m [31m18.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: peft
Successfully installed peft-0.13.2


### **Hint 1:** add LoRA layers into the pre-trained GPT-2 with the *peft* library
* reference: [peft LoRa methods](https://huggingface.co/docs/peft/en/task_guides/lora_based_methods)

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=32,  # Increased LoRA rank for more trainable parameters
    lora_alpha=32,  # Increase alpha to give LoRA updates more weight
    target_modules=["c_attn"],  # GPT-2 uses c_attn for query, key, value
    lora_dropout=0.2,  # Increased dropout for regularization
    bias="none",  # No bias terms included
    modules_to_save=[],  # Update all modules
)

# Get the LoRA model
lora_model = get_peft_model(model, config)

# Print trainable parameters
lora_model.print_trainable_parameters()


trainable params: 1,179,648 || all params: 125,619,456 || trainable%: 0.9391




### **Hint 2:** fine-tuning GPT-2 with LoRA on your datasets

In [None]:
from transformers import AdamW, get_scheduler
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
import evaluate
from transformers import DataCollatorForSeq2Seq
from torch.utils.data import DataLoader



data_collator = DataCollatorForSeq2Seq(tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True)

# Define the optimizer
optimizer = AdamW(lora_model.parameters(), lr=3e-5)
#define the train

# Define hyperparameters
num_epochs = 5
train_dataloader = DataLoader(tokenized_datasets["train"], shuffle=True, batch_size=1, collate_fn=data_collator)
num_training_steps = num_epochs * len(train_dataloader)

# Learning rate scheduler
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)

# Training loop
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
lora_model.to(device)

progress_bar = tqdm(range(num_training_steps))

lora_model.train()  # Set the model in training mode
for epoch in range(num_epochs):
    for batch in train_dataloader:
        # Move the batch data to the correct device (CPU or GPU)
        batch = {k: v.to(device) for k, v in batch.items()}

        # Forward pass
        outputs = lora_model(**batch)
        loss = outputs.loss

        # Backward pass (compute gradients)
        loss.backward()

        # Optimizer step (update weights)
        optimizer.step()

        # Scheduler step (adjust learning rate)
        lr_scheduler.step()

        # Zero the gradients for the next step
        optimizer.zero_grad()

        # Update the progress bar
        progress_bar.update(1)

# End of training
print("Training completed.")


100%|█████████▉| 18337/18340 [10:06<00:00, 27.66it/s]

Training completed.


### Evaluate fine-tuned model
* If everything goes well, a fine-tuned model will outperform the pre-trained one.

In [None]:
import evaluate
from transformers import DataCollatorForSeq2Seq
from torch.utils.data import DataLoader

data_collator = DataCollatorForSeq2Seq(tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True)


metric = evaluate.load('rouge')
eval_dataloader = DataLoader(tokenized_datasets["validation"], shuffle=True, batch_size=1, collate_fn=data_collator)
model.eval()
for batch in eval_dataloader:
    input_len=torch.nonzero(batch['labels'][0]!=-100).squeeze()[0]
    input_ids = batch['input_ids'][0][:input_len.item()].unsqueeze(0)

    label = tokenizer.batch_decode(batch['input_ids'])[0]
    label = label.split('<|endoftext|>')[0].split('Output:')[-1]
    # break
    with torch.no_grad():
        gen_tokens = model.generate(
            input_ids.to(device),
            do_sample=True,
            temperature=0.9,
            max_length=200,
        )
        gen_text = tokenizer.batch_decode(gen_tokens.to('cpu'))[0].split('Output:')[-1]

    gen_text = gen_text.split('<|endoftext|>')[0]
    metric.add_batch(predictions=[gen_text], references=[label])

metric.compute()

100%|██████████| 18340/18340 [10:20<00:00, 27.66it/s]

{'rouge1': 0.8774509803921579,
 'rouge2': 0.09558823529411764,
 'rougeL': 0.8782679738562099,
 'rougeLsum': 0.877450980392158}

{'rouge1': 0.8774509803921579,
 'rouge2': 0.09558823529411764,
 'rougeL': 0.8782679738562099,
 'rougeLsum': 0.877450980392158}

## **Task 3:** Fine-tuning LLM with federated learning (4 pts)
Training LLMs in a distributed manner across multiple data centers is a significant area of research. The tutorial encourages you to simulate distributed fine-tuning with Federated Learning:
* fine-tune multiple LoRA models on different datasets/tasks
* aggregate those LoRA models by averaging their parameters (FedAvg[1], FedMo[2])
* compare the aggregated LoRA model with those task-specific models
* suggented reference:

  [1] McMahan, B., Moore, E., Ramage, D., Hampson, S., & y Arcas, B. A. (2017, April). Communication-efficient learning of deep networks from decentralized data. (https://arxiv.org/abs/1602.05629)

  [2] Yang, Y., Long, G., Shen, T., Jiang, J., & Blumenstein, M. (2024). Dual-Personalizing Adapter for Federated Foundation Models. (https://arxiv.org/abs/2403.19211)

In [None]:
import torch
from transformers import AdamW, GPT2LMHeadModel, GPT2Tokenizer, get_scheduler
from torch.utils.data import DataLoader
from peft import get_peft_model, LoraConfig
import copy
from datasets import load_dataset, DatasetDict
from datasets import concatenate_datasets

#splitting dataset into 3 equally sezed datasets
size1 = 0.34
size2 = 0.33
size3= 0.33

train_valtest_split = tokenized_datasets["train"].train_test_split(test_size=(1 - size1))
train_dataset1 = train_valtest_split['train']


val_test_split = train_valtest_split['test'].train_test_split(test_size=size3 / (size2 + size3))
train_dataset2 = val_test_split['train']
train_dataset3 = val_test_split['test']
# Fine-tune LoRA models across clients

clients_data={0:train_dataset1,
              1:train_dataset2,
              2:train_dataset3
              }
num_clients = 3
num_epochs = 2


import torch
import copy
from transformers import AdamW, get_scheduler, DataCollatorForSeq2Seq
from torch.utils.data import DataLoader
from tqdm import tqdm

# Function to train the LoRA model on a specific dataset
def train_lora_model(lora_model, tokenizer, tokenized_dataset, num_epochs=5, batch_size=1, lr=3e-5):
    # Data Collator (for padding and batching)
    data_collator = DataCollatorForSeq2Seq(tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True)

    # Define the optimizer
    optimizer = AdamW(lora_model.parameters(), lr=lr)

    # Define the training DataLoader
    train_dataloader = DataLoader(tokenized_dataset, shuffle=True, batch_size=batch_size, collate_fn=data_collator)

    # Define the total number of training steps
    num_training_steps = num_epochs * len(train_dataloader)

    # Learning rate scheduler
    lr_scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)

    # Device configuration (use GPU if available)
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    lora_model.to(device)

    # Progress bar for visual tracking
    progress_bar = tqdm(range(num_training_steps))

    # Training loop
    lora_model.train()  # Set model to training mode
    for epoch in range(num_epochs):
        for batch in train_dataloader:
            # Move batch data to the correct device
            batch = {k: v.to(device) for k, v in batch.items()}

            # Forward pass
            outputs = lora_model(**batch)
            loss = outputs.loss

            # Backward pass
            loss.backward()

            # Optimizer step (update model weights)
            optimizer.step()

            # Scheduler step (adjust learning rate)
            lr_scheduler.step()

            # Zero the gradients for the next step
            optimizer.zero_grad()

            # Update the progress bar
            progress_bar.update(1)

    # End of training
    print("Training completed.")

    # Return the fine-tuned model
    return lora_model

# Function to aggregate LoRA parameters (FedAvg approach)
def federated_aggregation(lora_models):
    # Deep copy of the first model to initialize the global model
    global_model = copy.deepcopy(lora_models[0])
    global_params = global_model.state_dict()

    # Averaging the parameters of the LoRA layers across all clients
    for key in global_params.keys():
        global_params[key] = torch.stack([model.state_dict()[key] for model in lora_models]).mean(dim=0)

    # Load the averaged parameters into the global model
    global_model.load_state_dict(global_params)
    return global_model

# Training function for a specific client (local training loop)
def train_on_client(client_data, lora_model, tokenizer, num_epochs, batch_size=8, lr=3e-5, device='cuda'):
    # Train LoRA model on client's data using train_lora_model function
    client_model = train_lora_model(lora_model, tokenizer, client_data, num_epochs=num_epochs, batch_size=batch_size, lr=lr)
    return client_model

# Function to fine-tune on a new dataset for each client (federated local training)
def fine_tune_lora_on_clients(clients_data, num_clients, tokenizer, model, config, num_epochs):
    global_lora_model = get_peft_model(model, config)  # Initialize LoRA model with PEFT configuration

    # List to store each client's fine-tuned model
    lora_models = []

    # Perform local fine-tuning on each client's data
    for client_id in range(num_clients):
        client_data = clients_data[client_id]  # Get client's dataset

        # Fine-tune the model on the client data
        client_model = train_on_client(client_data, global_lora_model, tokenizer, num_epochs=num_epochs, device='cuda')

        # Debugging: Check if the model is valid
        if client_model is None:
            raise ValueError(f"Client model for client {client_id} is None!")

        lora_models.append(copy.deepcopy(client_model))  # Append a copy of the fine-tuned model for the client

    return lora_models

def federated_learning_process(clients_data, num_clients, tokenizer, model, config, num_epochs, num_episodes=3):
    global_model = get_peft_model(model, config)  # Initialize LoRA model with PEFT configuration

    # Loop over the number of episodes (communication rounds)
    for i in range(num_episodes):
        print(f"Episode {i + 1}/{num_episodes}")

        # Step 1: Fine-tune the model on each client's dataset
        lora_models = fine_tune_lora_on_clients(clients_data, num_clients, tokenizer, global_model, config, num_epochs)

        # Step 2: Federated Averaging (FedAvg) to aggregate model parameters
        global_model = federated_aggregation(lora_models)  # Update the global model with averaged parameters

    # Return the aggregated global model after all rounds
    return global_model


global_model = federated_learning_process(clients_data, num_clients, tokenizer, model, config, num_epochs,num_episodes=6)





Episode 1/6



  0%|          | 0/312 [00:00<?, ?it/s][A
  1%|          | 2/312 [00:00<00:44,  6.92it/s][A
  1%|          | 3/312 [00:00<00:51,  6.01it/s][A
  1%|▏         | 4/312 [00:00<00:47,  6.48it/s][A
  2%|▏         | 5/312 [00:00<00:49,  6.19it/s][A
  2%|▏         | 6/312 [00:00<00:49,  6.16it/s][A
  2%|▏         | 7/312 [00:01<00:49,  6.14it/s][A
  3%|▎         | 8/312 [00:01<00:48,  6.32it/s][A
  3%|▎         | 9/312 [00:01<00:49,  6.13it/s][A
  3%|▎         | 10/312 [00:01<00:48,  6.29it/s][A
  4%|▎         | 11/312 [00:01<00:46,  6.43it/s][A
  4%|▍         | 12/312 [00:01<00:50,  5.88it/s][A
  4%|▍         | 13/312 [00:02<00:47,  6.25it/s][A
  4%|▍         | 14/312 [00:02<00:53,  5.52it/s][A
  5%|▍         | 15/312 [00:02<00:51,  5.75it/s][A
  5%|▌         | 16/312 [00:02<00:51,  5.74it/s][A
  5%|▌         | 17/312 [00:02<00:52,  5.62it/s][A
  6%|▌         | 18/312 [00:03<00:54,  5.37it/s][A
  6%|▌         | 19/312 [00:03<00:55,  5.29it/s][A
  6%|▋         | 20/312 [00:

Training completed.



  0%|          | 0/304 [00:00<?, ?it/s][A
  1%|          | 2/304 [00:00<00:41,  7.26it/s][A
  1%|          | 3/304 [00:00<00:43,  6.87it/s][A
  1%|▏         | 4/304 [00:00<00:46,  6.43it/s][A
  2%|▏         | 5/304 [00:00<00:50,  5.97it/s][A
  2%|▏         | 6/304 [00:01<00:53,  5.55it/s][A
  2%|▏         | 7/304 [00:01<00:53,  5.55it/s][A
  3%|▎         | 8/304 [00:01<00:53,  5.50it/s][A
  3%|▎         | 9/304 [00:01<00:52,  5.59it/s][A
  3%|▎         | 10/304 [00:01<00:51,  5.68it/s][A
  4%|▎         | 11/304 [00:01<00:52,  5.60it/s][A
  4%|▍         | 12/304 [00:02<00:52,  5.53it/s][A
  4%|▍         | 13/304 [00:02<00:51,  5.70it/s][A
  5%|▍         | 14/304 [00:02<00:50,  5.74it/s][A
  5%|▍         | 15/304 [00:02<00:47,  6.03it/s][A
  5%|▌         | 16/304 [00:02<00:48,  5.90it/s][A
  6%|▌         | 17/304 [00:02<00:48,  5.89it/s][A
  6%|▌         | 18/304 [00:03<00:49,  5.76it/s][A
  6%|▋         | 19/304 [00:03<00:49,  5.74it/s][A
  7%|▋         | 20/304 [00:

Training completed.



  0%|          | 0/304 [00:00<?, ?it/s][A
  1%|          | 2/304 [00:00<00:36,  8.24it/s][A
  1%|          | 3/304 [00:00<00:42,  7.00it/s][A
  1%|▏         | 4/304 [00:00<00:48,  6.23it/s][A
  2%|▏         | 5/304 [00:00<00:48,  6.14it/s][A
  2%|▏         | 6/304 [00:01<00:57,  5.16it/s][A
  2%|▏         | 7/304 [00:01<00:57,  5.16it/s][A
  3%|▎         | 8/304 [00:01<00:56,  5.20it/s][A
  3%|▎         | 9/304 [00:01<00:58,  5.07it/s][A
  3%|▎         | 10/304 [00:01<00:57,  5.13it/s][A
  4%|▎         | 11/304 [00:02<00:57,  5.07it/s][A
  4%|▍         | 12/304 [00:02<00:56,  5.18it/s][A
  4%|▍         | 13/304 [00:02<00:55,  5.20it/s][A
  5%|▍         | 14/304 [00:02<00:55,  5.19it/s][A
  5%|▍         | 15/304 [00:02<00:55,  5.25it/s][A
  5%|▌         | 16/304 [00:02<00:54,  5.25it/s][A
  6%|▌         | 17/304 [00:03<00:54,  5.27it/s][A
  6%|▌         | 18/304 [00:03<00:54,  5.24it/s][A
  6%|▋         | 19/304 [00:03<00:54,  5.23it/s][A
  7%|▋         | 20/304 [00:

Training completed.
Episode 2/6



  0%|          | 0/312 [00:00<?, ?it/s][A
  1%|          | 2/312 [00:00<00:34,  8.87it/s][A
  1%|          | 3/312 [00:00<00:44,  6.91it/s][A
  1%|▏         | 4/312 [00:00<00:49,  6.16it/s][A
  2%|▏         | 5/312 [00:00<00:50,  6.04it/s][A
  2%|▏         | 6/312 [00:00<00:56,  5.46it/s][A
  2%|▏         | 7/312 [00:01<00:59,  5.09it/s][A
  3%|▎         | 8/312 [00:01<00:59,  5.07it/s][A
  3%|▎         | 9/312 [00:01<00:58,  5.18it/s][A
  3%|▎         | 10/312 [00:01<00:57,  5.26it/s][A
  4%|▎         | 11/312 [00:02<01:00,  4.99it/s][A
  4%|▍         | 12/312 [00:02<01:00,  4.93it/s][A
  4%|▍         | 13/312 [00:02<00:58,  5.14it/s][A
  4%|▍         | 14/312 [00:02<01:00,  4.89it/s][A
  5%|▍         | 15/312 [00:02<00:56,  5.21it/s][A
  5%|▌         | 16/312 [00:03<00:59,  4.98it/s][A
  5%|▌         | 17/312 [00:03<00:56,  5.25it/s][A
  6%|▌         | 18/312 [00:03<00:56,  5.17it/s][A
  6%|▌         | 19/312 [00:03<00:57,  5.10it/s][A
  6%|▋         | 20/312 [00:

Training completed.



  0%|          | 0/304 [00:00<?, ?it/s][A
  0%|          | 1/304 [00:00<00:33,  9.17it/s][A
  1%|          | 2/304 [00:00<00:48,  6.27it/s][A
  1%|          | 3/304 [00:00<00:53,  5.68it/s][A
  1%|▏         | 4/304 [00:00<00:55,  5.44it/s][A
  2%|▏         | 5/304 [00:00<00:57,  5.19it/s][A
  2%|▏         | 6/304 [00:01<01:03,  4.70it/s][A
  2%|▏         | 7/304 [00:01<01:00,  4.90it/s][A
  3%|▎         | 8/304 [00:01<00:59,  4.95it/s][A
  3%|▎         | 9/304 [00:01<01:00,  4.85it/s][A
  3%|▎         | 10/304 [00:01<01:00,  4.88it/s][A
  4%|▎         | 11/304 [00:02<01:00,  4.87it/s][A
  4%|▍         | 12/304 [00:02<01:00,  4.86it/s][A
  4%|▍         | 13/304 [00:02<00:58,  4.95it/s][A
  5%|▍         | 14/304 [00:02<00:59,  4.86it/s][A
  5%|▍         | 15/304 [00:02<00:58,  4.91it/s][A
  5%|▌         | 16/304 [00:03<00:59,  4.86it/s][A
  6%|▌         | 17/304 [00:03<00:58,  4.87it/s][A
  6%|▌         | 18/304 [00:03<00:58,  4.85it/s][A
  6%|▋         | 19/304 [00:0

Training completed.



  0%|          | 0/304 [00:00<?, ?it/s][A
  1%|          | 2/304 [00:00<00:29, 10.14it/s][A
  1%|▏         | 4/304 [00:00<00:48,  6.24it/s][A
  2%|▏         | 5/304 [00:00<00:50,  5.91it/s][A
  2%|▏         | 6/304 [00:01<00:55,  5.39it/s][A
  2%|▏         | 7/304 [00:01<00:57,  5.21it/s][A
  3%|▎         | 8/304 [00:01<00:57,  5.11it/s][A
  3%|▎         | 9/304 [00:01<00:57,  5.17it/s][A
  3%|▎         | 10/304 [00:01<00:55,  5.30it/s][A
  4%|▎         | 11/304 [00:02<00:57,  5.06it/s][A
  4%|▍         | 12/304 [00:02<00:58,  5.01it/s][A
  4%|▍         | 13/304 [00:02<00:58,  5.01it/s][A
  5%|▍         | 14/304 [00:02<00:58,  4.95it/s][A
  5%|▍         | 15/304 [00:02<00:57,  4.99it/s][A
  5%|▌         | 16/304 [00:03<00:59,  4.88it/s][A
  6%|▌         | 17/304 [00:03<00:56,  5.04it/s][A
  6%|▌         | 18/304 [00:03<00:56,  5.07it/s][A
  6%|▋         | 19/304 [00:03<00:55,  5.10it/s][A
  7%|▋         | 20/304 [00:03<00:55,  5.09it/s][A
  7%|▋         | 21/304 [00

Training completed.
Episode 3/6



  0%|          | 0/312 [00:00<?, ?it/s][A
  1%|          | 2/312 [00:00<00:38,  8.08it/s][A
  1%|          | 3/312 [00:00<00:44,  6.93it/s][A
  1%|▏         | 4/312 [00:00<00:51,  6.01it/s][A
  2%|▏         | 5/312 [00:00<00:53,  5.78it/s][A
  2%|▏         | 6/312 [00:01<01:04,  4.77it/s][A
  2%|▏         | 7/312 [00:01<01:02,  4.90it/s][A
  3%|▎         | 8/312 [00:01<00:59,  5.12it/s][A
  3%|▎         | 9/312 [00:01<00:58,  5.21it/s][A
  3%|▎         | 10/312 [00:01<00:56,  5.32it/s][A
  4%|▎         | 11/312 [00:02<01:04,  4.70it/s][A
  4%|▍         | 12/312 [00:02<01:03,  4.75it/s][A
  4%|▍         | 13/312 [00:02<01:05,  4.59it/s][A
  4%|▍         | 14/312 [00:02<01:00,  4.89it/s][A
  5%|▍         | 15/312 [00:02<01:01,  4.82it/s][A
  5%|▌         | 16/312 [00:03<01:01,  4.84it/s][A
  5%|▌         | 17/312 [00:03<00:59,  4.93it/s][A
  6%|▌         | 18/312 [00:03<00:57,  5.10it/s][A
  6%|▌         | 19/312 [00:03<00:57,  5.07it/s][A
  6%|▋         | 20/312 [00:

Training completed.



  0%|          | 0/304 [00:00<?, ?it/s][A
  0%|          | 1/304 [00:00<00:36,  8.33it/s][A
  1%|          | 2/304 [00:00<00:58,  5.20it/s][A
  1%|          | 3/304 [00:00<00:56,  5.34it/s][A
  1%|▏         | 4/304 [00:00<00:56,  5.28it/s][A
  2%|▏         | 5/304 [00:00<00:58,  5.15it/s][A
  2%|▏         | 6/304 [00:01<00:59,  5.03it/s][A
  2%|▏         | 7/304 [00:01<00:59,  4.96it/s][A
  3%|▎         | 8/304 [00:01<01:00,  4.89it/s][A
  3%|▎         | 9/304 [00:01<00:56,  5.19it/s][A
  3%|▎         | 10/304 [00:01<00:58,  5.06it/s][A
  4%|▎         | 11/304 [00:02<00:59,  4.94it/s][A
  4%|▍         | 12/304 [00:02<00:58,  4.95it/s][A
  4%|▍         | 13/304 [00:02<00:59,  4.88it/s][A
  5%|▍         | 14/304 [00:02<00:59,  4.84it/s][A
  5%|▍         | 15/304 [00:02<00:59,  4.83it/s][A
  5%|▌         | 16/304 [00:03<00:57,  4.97it/s][A
  6%|▌         | 17/304 [00:03<00:59,  4.86it/s][A
  6%|▌         | 18/304 [00:03<00:57,  4.96it/s][A
  6%|▋         | 19/304 [00:0

Training completed.



  0%|          | 0/304 [00:00<?, ?it/s][A
  1%|          | 2/304 [00:00<00:41,  7.35it/s][A
  1%|          | 3/304 [00:00<00:47,  6.29it/s][A
  1%|▏         | 4/304 [00:00<00:53,  5.59it/s][A
  2%|▏         | 5/304 [00:00<00:56,  5.28it/s][A
  2%|▏         | 6/304 [00:01<00:58,  5.06it/s][A
  2%|▏         | 7/304 [00:01<00:59,  4.99it/s][A
  3%|▎         | 8/304 [00:01<00:59,  4.97it/s][A
  3%|▎         | 9/304 [00:01<00:59,  4.92it/s][A
  3%|▎         | 10/304 [00:01<01:00,  4.83it/s][A
  4%|▎         | 11/304 [00:02<01:00,  4.83it/s][A
  4%|▍         | 12/304 [00:02<00:59,  4.94it/s][A
  4%|▍         | 13/304 [00:02<00:57,  5.02it/s][A
  5%|▍         | 14/304 [00:02<00:56,  5.15it/s][A
  5%|▍         | 15/304 [00:02<00:58,  4.98it/s][A
  5%|▌         | 16/304 [00:03<00:57,  5.01it/s][A
  6%|▌         | 17/304 [00:03<00:58,  4.95it/s][A
  6%|▌         | 18/304 [00:03<00:55,  5.11it/s][A
  6%|▋         | 19/304 [00:03<00:55,  5.13it/s][A
  7%|▋         | 20/304 [00:

Training completed.
Episode 4/6



  0%|          | 0/312 [00:00<?, ?it/s][A
  1%|          | 2/312 [00:00<00:32,  9.48it/s][A
  1%|          | 3/312 [00:00<00:41,  7.41it/s][A
  1%|▏         | 4/312 [00:00<00:48,  6.29it/s][A
  2%|▏         | 5/312 [00:00<00:54,  5.67it/s][A
  2%|▏         | 6/312 [00:01<00:59,  5.16it/s][A
  2%|▏         | 7/312 [00:01<00:56,  5.37it/s][A
  3%|▎         | 8/312 [00:01<00:57,  5.26it/s][A
  3%|▎         | 9/312 [00:01<00:58,  5.17it/s][A
  3%|▎         | 10/312 [00:01<00:58,  5.18it/s][A
  4%|▎         | 11/312 [00:02<01:00,  4.94it/s][A
  4%|▍         | 12/312 [00:02<00:57,  5.25it/s][A
  4%|▍         | 13/312 [00:02<00:58,  5.14it/s][A
  4%|▍         | 14/312 [00:02<00:57,  5.22it/s][A
  5%|▍         | 15/312 [00:02<00:57,  5.13it/s][A
  5%|▌         | 16/312 [00:02<01:00,  4.93it/s][A
  5%|▌         | 17/312 [00:03<00:58,  5.04it/s][A
  6%|▌         | 18/312 [00:03<00:57,  5.10it/s][A
  6%|▌         | 19/312 [00:03<00:58,  5.04it/s][A
  6%|▋         | 20/312 [00:

Training completed.



  0%|          | 0/304 [00:00<?, ?it/s][A
  1%|          | 2/304 [00:00<00:40,  7.54it/s][A
  1%|          | 3/304 [00:00<00:50,  5.99it/s][A
  1%|▏         | 4/304 [00:00<00:53,  5.62it/s][A
  2%|▏         | 5/304 [00:00<00:56,  5.33it/s][A
  2%|▏         | 6/304 [00:01<00:56,  5.24it/s][A
  2%|▏         | 7/304 [00:01<00:59,  5.03it/s][A
  3%|▎         | 8/304 [00:01<00:58,  5.03it/s][A
  3%|▎         | 9/304 [00:01<00:57,  5.15it/s][A
  3%|▎         | 10/304 [00:01<00:57,  5.11it/s][A
  4%|▎         | 11/304 [00:02<00:57,  5.13it/s][A
  4%|▍         | 12/304 [00:02<00:56,  5.13it/s][A
  4%|▍         | 13/304 [00:02<00:59,  4.92it/s][A
  5%|▍         | 14/304 [00:02<00:58,  4.98it/s][A
  5%|▍         | 15/304 [00:02<00:56,  5.09it/s][A
  5%|▌         | 16/304 [00:03<00:56,  5.06it/s][A
  6%|▌         | 17/304 [00:03<00:54,  5.31it/s][A
  6%|▌         | 18/304 [00:03<00:55,  5.15it/s][A
  6%|▋         | 19/304 [00:03<00:55,  5.11it/s][A
  7%|▋         | 20/304 [00:

Training completed.



  0%|          | 0/304 [00:00<?, ?it/s][A
  1%|          | 2/304 [00:00<00:30, 10.00it/s][A
  1%|          | 3/304 [00:00<00:43,  6.93it/s][A
  1%|▏         | 4/304 [00:00<00:49,  6.11it/s][A
  2%|▏         | 5/304 [00:00<00:53,  5.61it/s][A
  2%|▏         | 6/304 [00:01<00:56,  5.27it/s][A
  2%|▏         | 7/304 [00:01<00:53,  5.55it/s][A
  3%|▎         | 8/304 [00:01<00:53,  5.58it/s][A
  3%|▎         | 9/304 [00:01<00:55,  5.31it/s][A
  3%|▎         | 10/304 [00:01<00:55,  5.27it/s][A
  4%|▎         | 11/304 [00:01<00:59,  4.96it/s][A
  4%|▍         | 12/304 [00:02<00:58,  4.99it/s][A
  4%|▍         | 13/304 [00:02<00:57,  5.03it/s][A
  5%|▍         | 14/304 [00:02<00:57,  5.06it/s][A
  5%|▍         | 15/304 [00:02<00:58,  4.94it/s][A
  5%|▌         | 16/304 [00:02<00:58,  4.94it/s][A
  6%|▌         | 17/304 [00:03<00:57,  4.99it/s][A
  6%|▌         | 18/304 [00:03<00:56,  5.03it/s][A
  6%|▋         | 19/304 [00:03<00:57,  4.97it/s][A
  7%|▋         | 20/304 [00:

Training completed.
Episode 5/6



  0%|          | 0/312 [00:00<?, ?it/s][A
  1%|          | 2/312 [00:00<00:33,  9.17it/s][A
  1%|          | 3/312 [00:00<00:47,  6.56it/s][A
  1%|▏         | 4/312 [00:00<00:51,  5.97it/s][A
  2%|▏         | 5/312 [00:00<00:52,  5.88it/s][A
  2%|▏         | 6/312 [00:01<00:57,  5.34it/s][A
  2%|▏         | 7/312 [00:01<00:56,  5.37it/s][A
  3%|▎         | 8/312 [00:01<00:55,  5.46it/s][A
  3%|▎         | 9/312 [00:01<00:56,  5.34it/s][A
  3%|▎         | 10/312 [00:01<00:56,  5.33it/s][A
  4%|▎         | 11/312 [00:02<01:00,  4.94it/s][A
  4%|▍         | 12/312 [00:02<00:59,  5.03it/s][A
  4%|▍         | 13/312 [00:02<00:58,  5.11it/s][A
  4%|▍         | 14/312 [00:02<00:54,  5.42it/s][A
  5%|▍         | 15/312 [00:02<00:55,  5.35it/s][A
  5%|▌         | 16/312 [00:02<00:55,  5.31it/s][A
  5%|▌         | 17/312 [00:03<00:58,  5.05it/s][A
  6%|▌         | 18/312 [00:03<00:58,  5.05it/s][A
  6%|▌         | 19/312 [00:03<00:56,  5.16it/s][A
  6%|▋         | 20/312 [00:

Training completed.



  0%|          | 0/304 [00:00<?, ?it/s][A
  0%|          | 1/304 [00:00<00:35,  8.50it/s][A
  1%|          | 2/304 [00:00<00:45,  6.61it/s][A
  1%|          | 3/304 [00:00<00:52,  5.78it/s][A
  1%|▏         | 4/304 [00:00<00:57,  5.24it/s][A
  2%|▏         | 5/304 [00:00<00:57,  5.16it/s][A
  2%|▏         | 6/304 [00:01<00:59,  4.99it/s][A
  2%|▏         | 7/304 [00:01<01:00,  4.93it/s][A
  3%|▎         | 8/304 [00:01<01:02,  4.75it/s][A
  3%|▎         | 9/304 [00:01<01:05,  4.53it/s][A
  3%|▎         | 10/304 [00:01<01:01,  4.80it/s][A
  4%|▎         | 11/304 [00:02<01:00,  4.83it/s][A
  4%|▍         | 12/304 [00:02<00:59,  4.87it/s][A
  4%|▍         | 13/304 [00:02<00:57,  5.04it/s][A
  5%|▍         | 14/304 [00:02<00:59,  4.88it/s][A
  5%|▍         | 15/304 [00:02<00:57,  4.98it/s][A
  5%|▌         | 16/304 [00:03<00:59,  4.80it/s][A
  6%|▌         | 17/304 [00:03<00:57,  4.96it/s][A
  6%|▌         | 18/304 [00:03<00:58,  4.90it/s][A
  6%|▋         | 19/304 [00:0

Training completed.



  0%|          | 0/304 [00:00<?, ?it/s][A
  1%|          | 2/304 [00:00<00:35,  8.57it/s][A
  1%|          | 3/304 [00:00<00:45,  6.69it/s][A
  1%|▏         | 4/304 [00:00<00:52,  5.77it/s][A
  2%|▏         | 5/304 [00:00<01:00,  4.96it/s][A
  2%|▏         | 6/304 [00:01<00:58,  5.14it/s][A
  2%|▏         | 7/304 [00:01<00:57,  5.15it/s][A
  3%|▎         | 8/304 [00:01<00:58,  5.04it/s][A
  3%|▎         | 9/304 [00:01<00:57,  5.16it/s][A
  3%|▎         | 10/304 [00:01<00:57,  5.09it/s][A
  4%|▎         | 11/304 [00:02<00:57,  5.07it/s][A
  4%|▍         | 12/304 [00:02<00:56,  5.18it/s][A
  4%|▍         | 13/304 [00:02<00:57,  5.06it/s][A
  5%|▍         | 14/304 [00:02<00:56,  5.10it/s][A
  5%|▍         | 15/304 [00:02<00:58,  4.91it/s][A
  5%|▌         | 16/304 [00:03<00:58,  4.89it/s][A
  6%|▌         | 17/304 [00:03<00:59,  4.79it/s][A
  6%|▌         | 18/304 [00:03<00:59,  4.83it/s][A
  6%|▋         | 19/304 [00:03<00:59,  4.83it/s][A
  7%|▋         | 20/304 [00:

Training completed.
Episode 6/6



  0%|          | 0/312 [00:00<?, ?it/s][A
  1%|          | 2/312 [00:00<00:35,  8.81it/s][A
  1%|          | 3/312 [00:00<00:44,  7.02it/s][A
  1%|▏         | 4/312 [00:00<00:48,  6.32it/s][A
  2%|▏         | 5/312 [00:00<00:52,  5.80it/s][A
  2%|▏         | 6/312 [00:00<00:55,  5.49it/s][A
  2%|▏         | 7/312 [00:01<00:58,  5.23it/s][A
  3%|▎         | 8/312 [00:01<00:58,  5.23it/s][A
  3%|▎         | 9/312 [00:01<00:57,  5.23it/s][A
  3%|▎         | 10/312 [00:01<00:56,  5.38it/s][A
  4%|▎         | 11/312 [00:01<00:59,  5.06it/s][A
  4%|▍         | 12/312 [00:02<01:00,  4.96it/s][A
  4%|▍         | 13/312 [00:02<00:59,  5.00it/s][A
  4%|▍         | 14/312 [00:02<00:59,  5.00it/s][A
  5%|▍         | 15/312 [00:02<00:58,  5.07it/s][A
  5%|▌         | 16/312 [00:02<00:57,  5.16it/s][A
  5%|▌         | 17/312 [00:03<00:57,  5.16it/s][A
  6%|▌         | 18/312 [00:03<00:58,  5.03it/s][A
  6%|▌         | 19/312 [00:03<00:57,  5.09it/s][A
  6%|▋         | 20/312 [00:

Training completed.



  0%|          | 0/304 [00:00<?, ?it/s][A
  0%|          | 1/304 [00:00<00:41,  7.27it/s][A
  1%|          | 2/304 [00:00<00:53,  5.67it/s][A
  1%|          | 3/304 [00:00<00:54,  5.47it/s][A
  1%|▏         | 4/304 [00:00<00:58,  5.14it/s][A
  2%|▏         | 5/304 [00:00<01:00,  4.96it/s][A
  2%|▏         | 6/304 [00:01<00:59,  5.00it/s][A
  2%|▏         | 7/304 [00:01<01:00,  4.92it/s][A
  3%|▎         | 8/304 [00:01<01:01,  4.78it/s][A
  3%|▎         | 9/304 [00:01<00:59,  4.95it/s][A
  3%|▎         | 10/304 [00:01<01:00,  4.82it/s][A
  4%|▎         | 11/304 [00:02<00:59,  4.91it/s][A
  4%|▍         | 12/304 [00:02<00:58,  5.00it/s][A
  4%|▍         | 13/304 [00:02<00:57,  5.05it/s][A
  5%|▍         | 14/304 [00:02<00:58,  4.94it/s][A
  5%|▍         | 15/304 [00:02<00:58,  4.96it/s][A
  5%|▌         | 16/304 [00:03<01:04,  4.50it/s][A
  6%|▌         | 17/304 [00:03<00:58,  4.92it/s][A
  6%|▌         | 18/304 [00:03<00:58,  4.87it/s][A
  6%|▋         | 19/304 [00:0

Training completed.



  0%|          | 0/304 [00:00<?, ?it/s][A
  1%|          | 2/304 [00:00<00:34,  8.83it/s][A
  1%|          | 3/304 [00:00<00:43,  6.86it/s][A
  1%|▏         | 4/304 [00:00<00:49,  6.12it/s][A
  2%|▏         | 5/304 [00:00<00:51,  5.84it/s][A
  2%|▏         | 6/304 [00:01<00:57,  5.22it/s][A
  2%|▏         | 7/304 [00:01<00:58,  5.10it/s][A
  3%|▎         | 8/304 [00:01<00:55,  5.33it/s][A
  3%|▎         | 9/304 [00:01<00:56,  5.18it/s][A
  3%|▎         | 10/304 [00:01<00:58,  4.99it/s][A
  4%|▎         | 11/304 [00:02<01:01,  4.78it/s][A
  4%|▍         | 12/304 [00:02<01:00,  4.82it/s][A
  4%|▍         | 13/304 [00:02<00:55,  5.20it/s][A
  5%|▍         | 14/304 [00:02<00:56,  5.10it/s][A
  5%|▍         | 15/304 [00:02<00:57,  5.03it/s][A
  5%|▌         | 16/304 [00:03<00:58,  4.96it/s][A
  6%|▌         | 17/304 [00:03<00:57,  5.01it/s][A
  6%|▌         | 18/304 [00:03<00:56,  5.08it/s][A
  6%|▋         | 19/304 [00:03<00:55,  5.14it/s][A
  7%|▋         | 20/304 [00:

Training completed.


In [None]:
import evaluate
from transformers import DataCollatorForSeq2Seq
from torch.utils.data import DataLoader



model=global_model
data_collator = DataCollatorForSeq2Seq(tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True)


metric = evaluate.load('rouge')
eval_dataloader = DataLoader(tokenized_datasets["validation"], shuffle=True, batch_size=1, collate_fn=data_collator)
model.eval()
for batch in eval_dataloader:
    input_len=torch.nonzero(batch['labels'][0]!=-100).squeeze()[0]
    input_ids = batch['input_ids'][0][:input_len.item()].unsqueeze(0)

    label = tokenizer.batch_decode(batch['input_ids'])[0]
    label = label.split('<|endoftext|>')[0].split('Output:')[-1]
    # break
    with torch.no_grad():
        gen_tokens = model.generate(
            input_ids.to(device),
            do_sample=True,
            temperature=0.9,
            max_length=200,
        )
        gen_text = tokenizer.batch_decode(gen_tokens.to('cpu'))[0].split('Output:')[-1]

    gen_text = gen_text.split('<|endoftext|>')[0]
    metric.add_batch(predictions=[gen_text], references=[label])

metric.compute()

{'rouge1': 0.8593867708483066,
 'rouge2': 0.06454248366013073,
 'rougeL': 0.8599836394977272,
 'rougeLsum': 0.8596779964071578}

{'rouge1': 0.8593867708483066,
 'rouge2': 0.06454248366013073,
 'rougeL': 0.8599836394977272,
 'rougeLsum': 0.8596779964071578}