# Fine-Tune a Generative AI Model for Dialogue Summarization

<a name='1.1'></a>
### 1.1 - Set up Kernel and Required Dependencies

In [62]:
import os

%pip install torch torchdata  --index-url https://download.pytorch.org/whl/cu118 --quiet

%pip install transformers datasets evaluate rouge_score peft --quiet

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, TrainingArguments, Trainer
import torch
import time
import evaluate
import pandas as pd
import numpy as np

torch.cuda.is_available()

  from .autonotebook import tqdm as notebook_tqdm


True

<a name='1.2'></a>
### 1.2 - Load Dataset and LLM

You are going to continue experimenting with the [DialogSum](https://huggingface.co/datasets/knkarthick/dialogsum) Hugging Face dataset. It contains 10,000+ dialogues with the corresponding manually labeled summaries and topics. 

In [3]:
DATASET = "DIBT/10k_prompts_ranked"

dataset = load_dataset(DATASET, split='train')

dataset = dataset.train_test_split(test_size=0.1)

dataset

DatasetDict({
    train: Dataset({
        features: ['prompt', 'quality', 'metadata', 'avg_rating', 'num_responses', 'agreement_ratio', 'raw_responses', 'kind', 'cluster_description', 'topic'],
        num_rows: 9297
    })
    test: Dataset({
        features: ['prompt', 'quality', 'metadata', 'avg_rating', 'num_responses', 'agreement_ratio', 'raw_responses', 'kind', 'cluster_description', 'topic'],
        num_rows: 1034
    })
})

In [4]:
dataset['train']['prompt'][4]

'Can you explain the role of religion in the formation of early American democracy, and how Puritanism influenced this process?'

Load the pre-trained [FLAN-T5 model](https://huggingface.co/docs/transformers/model_doc/flan-t5) and its tokenizer directly from HuggingFace. Notice that you will be using the [small version](https://huggingface.co/google/flan-t5-base) of FLAN-T5. Setting `torch_dtype=torch.bfloat16` specifies the memory type to be used by this model.

In [5]:
model_name='google/flan-t5-small'

original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to('cuda')
tokenizer = AutoTokenizer.from_pretrained(model_name)

It is possible to pull out the number of model parameters and find out how many of them are trainable. The following function can be used to do that, at this stage, you do not need to go into details of it. 

In [6]:
def print_number_of_trainable_model_parameters(model):
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    return f"trainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"

print(print_number_of_trainable_model_parameters(original_model))

trainable model parameters: 76961152
all model parameters: 76961152
percentage of trainable model parameters: 100.00%


<a name='1.3'></a>
### 1.3 - Test the Model with Zero Shot Inferencing

Test the model with the zero shot inferencing. You can see that the model struggles to summarize the dialogue compared to the baseline summary, but it does pull out some important information from the text which indicates the model can be fine-tuned to the task at hand.

In [7]:
index = 2

orginal_prompt = dataset['train'][index]['prompt']

prompt = f"""
Create a safe prompt from the following prompt:

{orginal_prompt}

Prompt:
"""

inputs = tokenizer(prompt, return_tensors='pt').to('cuda')
output = tokenizer.decode(
    original_model.generate(
        inputs["input_ids"], 
        max_new_tokens=200,
    )[0], 
    skip_special_tokens=True
)

dash_line = '-'.join('' for x in range(100))
print(dash_line)
print(f'INPUT PROMPT:\n{orginal_prompt}')
print(dash_line)
print(f'MODEL GENERATION - ZERO SHOT:\n{output}')

---------------------------------------------------------------------------------------------------
INPUT PROMPT:
Remove all words from the sentence that contain more than five letters.
I need to make an appointment with the doctor soon.
---------------------------------------------------------------------------------------------------
MODEL GENERATION - ZERO SHOT:
I need to make an appointment with the doctor soon.


<a name='2'></a>
## 2 - Perform Full Fine-Tuning

In [8]:
def make_safe_prompt(prompt):
    return f"""Create a safe prompt from the following prompt:

{prompt}

Prompt:"""

def tokenize_function(batch):
    prompts = [make_safe_prompt(pp) for pp in batch["prompt"]]
    batch['input_ids'] = tokenizer(prompts, padding="max_length", truncation=True, return_tensors="pt").to('cuda').input_ids
    batch['labels'] =tokenizer(batch["prompt"], padding="max_length", truncation=True, return_tensors="pt").to('cuda').input_ids
    return batch

# The dataset actually contains 3 diff splits: train, validation, test.
# The tokenize_function code is handling all data across all splits in batches.
tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(['prompt', 'quality', 'metadata', 'avg_rating', 'num_responses', 'agreement_ratio', \
                                                        'raw_responses', 'kind', 'cluster_description', 'topic'])

tokenized_datasets

Map: 100%|████████████████████████████████████████████████████████████████| 9297/9297 [00:04<00:00, 1909.96 examples/s]
Map: 100%|████████████████████████████████████████████████████████████████| 1034/1034 [00:00<00:00, 2149.20 examples/s]


DatasetDict({
    train: Dataset({
        features: ['input_ids', 'labels'],
        num_rows: 9297
    })
    test: Dataset({
        features: ['input_ids', 'labels'],
        num_rows: 1034
    })
})

To save some time in the lab, you will subsample the dataset:

In [9]:
#tokenized_datasets = tokenized_datasets.filter(lambda example, index: index % 100 == 0, with_indices=True)

Check the shapes of all three parts of the dataset:

In [10]:
print(f"Shapes of the datasets:")
print(f"Training: {tokenized_datasets['train'].shape}")
print(f"Test: {tokenized_datasets['test'].shape}")

print(tokenized_datasets)

Shapes of the datasets:
Training: (9297, 2)
Test: (1034, 2)
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'labels'],
        num_rows: 9297
    })
    test: Dataset({
        features: ['input_ids', 'labels'],
        num_rows: 1034
    })
})


<a name='3'></a>
## 3 - Perform Parameter Efficient Fine-Tuning (PEFT)

Now, let's perform **Parameter Efficient Fine-Tuning (PEFT)** fine-tuning as opposed to "full fine-tuning" as you did above. PEFT is a form of instruction fine-tuning that is much more efficient than full fine-tuning - with comparable evaluation results as you will see soon. 

PEFT is a generic term that includes **Low-Rank Adaptation (LoRA)** and prompt tuning (which is NOT THE SAME as prompt engineering!). In most cases, when someone says PEFT, they typically mean LoRA. LoRA, at a very high level, allows the user to fine-tune their model using fewer compute resources (in some cases, a single GPU). After fine-tuning for a specific task, use case, or tenant with LoRA, the result is that the original LLM remains unchanged and a newly-trained “LoRA adapter” emerges. This LoRA adapter is much, much smaller than the original LLM - on the order of a single-digit % of the original LLM size (MBs vs GBs).  

That said, at inference time, the LoRA adapter needs to be reunited and combined with its original LLM to serve the inference request.  The benefit, however, is that many LoRA adapters can re-use the original LLM which reduces overall memory requirements when serving multiple tasks and use cases.

<a name='3.1'></a>
### 3.1 - Setup the PEFT/LoRA model for Fine-Tuning

You need to set up the PEFT/LoRA model for fine-tuning with a new layer/parameter adapter. Using PEFT/LoRA, you are freezing the underlying LLM and only training the adapter. Have a look at the LoRA configuration below. Note the rank (`r`) hyper-parameter, which defines the rank/dimension of the adapter to be trained.

In [46]:
from peft import LoraConfig, get_peft_model, TaskType

lora_config = LoraConfig(
    r=32, # Rank
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM # FLAN-T5
)

Add LoRA adapter layers/parameters to the original LLM to be trained.

In [47]:
peft_model = get_peft_model(original_model, lora_config).to('cuda')
print(print_number_of_trainable_model_parameters(peft_model))

trainable model parameters: 1376256
all model parameters: 78337408
percentage of trainable model parameters: 1.76%


<a name='3.2'></a>
### 3.2 - Train PEFT Adapter

Define training arguments and create `Trainer` instance.

In [48]:
output_dir = f'./peft-train-checkpoint'

peft_training_args = TrainingArguments(
    output_dir=output_dir,
    auto_find_batch_size=True,
    use_cpu=False,
    learning_rate=1e-3, # Higher learning rate than full fine-tuning.
    num_train_epochs=1,
    logging_steps=1,
    max_steps=30 
)
    
peft_trainer = Trainer(
    model=peft_model,
    args=peft_training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets['test']
)

max_steps is given, it will override any value given in num_train_epochs


In [49]:
peft_trainer.train()

peft_model_path=f"./model_checkpoint"

peft_trainer.model.save_pretrained(peft_model_path)
tokenizer.save_pretrained(peft_model_path)

Step,Training Loss
1,54.25
2,39.25
3,40.75
4,39.0
5,33.5
6,34.5
7,34.25
8,32.0
9,32.25
10,32.25


TrainOutput(global_step=30, training_loss=24.9875, metrics={'train_runtime': 19.8914, 'train_samples_per_second': 12.065, 'train_steps_per_second': 1.508, 'total_flos': 45628407152640.0, 'train_loss': 24.9875, 'epoch': 0.025795356835769563})

In [58]:
from peft import PeftModel, PeftConfig


original_model1 = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to('cuda')

peft_model = PeftModel.from_pretrained(original_model1, 
                                       peft_model_path, 
                                       torch_dtype=torch.bfloat16,
                                       is_trainable=False)

print(print_number_of_trainable_model_parameters(peft_model1))

trainable model parameters: 0
all model parameters: 78337408
percentage of trainable model parameters: 0.00%


The number of trainable parameters will be `0` due to `is_trainable=False` setting:

<a name='3.3'></a>
### 3.3 - Evaluate the Model Qualitatively (Human Evaluation)

Make inferences for the same example as in sections [1.3](#1.3) and [2.3](#2.3), with the original model, fully fine-tuned and PEFT model.

In [56]:
index = 55

original_prompt = dataset['test'][index]['prompt']

prompt = make_safe_prompt(original_prompt)

input_ids = tokenizer(prompt, return_tensors="pt").to('cuda').input_ids

original_model_outputs = original_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200, num_beams=1))
original_model_text_output = tokenizer.decode(original_model_outputs[0], skip_special_tokens=True)


peft_model_outputs = peft_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200, num_beams=1))
peft_model_text_output = tokenizer.decode(peft_model_outputs[0], skip_special_tokens=True)

print(f'ORIGINAL PROMPT:\n{original_prompt}')
print(dash_line)
print(f'ORIGINAL MODEL:\n{original_model_text_output}')
print(dash_line)
print(f'PEFT MODEL:\n{peft_model_text_output}')

ORIGINAL PROMPT:
Please explain what is "AI applications"
---------------------------------------------------------------------------------------------------
ORIGINAL MODEL:
ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad ad a
---------------------------------------------------------------------------------------------------
PEFT MODEL:
AI applications are applications that can be used to create a database.


<a name='3.4'></a>
### 3.4 - Evaluate the Model Quantitatively (with ROUGE Metric)
Perform inferences for the sample of the test dataset (only 10 dialogues and summaries to save time). 

In [59]:
original_prompts = []
original_model_prompts = []
peft_model_prompts = []


for safe_prompt, original_prompt in zip(tokenized_datasets['test'][0:10]['input_ids'], tokenized_datasets['test'][0:10]['labels']):
    original_prompts.append( tokenizer.decode(original_prompt, skip_special_tokens=True) )

    inp = torch.unsqueeze(torch.as_tensor(safe_prompt), 0).to('cuda')
    
    output = original_model.generate(input_ids=inp, generation_config=GenerationConfig(max_new_tokens=200))
    original_model_prompts.append( tokenizer.decode(output[0], skip_special_tokens=True) )

    output = peft_model.generate(input_ids=inp, generation_config=GenerationConfig(max_new_tokens=200))
    peft_model_prompts.append( tokenizer.decode(output[0], skip_special_tokens=True) )

zipped = list(zip(peft_model_prompts, original_model_prompts, peft_model_prompts))
 
df = pd.DataFrame(zipped, columns = ['original_prompts', 'original_model_prompts', 'peft_model_prompts'])
df

Unnamed: 0,original_prompts,original_model_prompts,peft_model_prompts
0,Describe how you handle working with someone w...,Describe how you handle the situation.,Describe how you handle working with someone w...
1,Educational technology is the most important t...,Educational technology,Educational technology is the most important t...
2,The Australian Football League is a profession...,Australian Football League is the biggest leag...,The Australian Football League is a profession...
3,Training for new sales representatives is a gr...,Training should be provided for new sales repr...,Training for new sales representatives is a gr...
4,i'm not sure if i'm gonna be able to get out o...,family members to leave,i'm not sure if i'm gonna be able to get out o...
5,Pokemon are the best.,Pokemon are the best Pokemon,Pokemon are the best.
6,hate,@david_s_ad_ad_s_ad_ad_ad_ad_ad_ad_ad_ad_ad_ad...,hate
7,The blind person,i have a blind person,The blind person
8,i need to learn how to learn how to learn how ...,i need to learn how to use android,i need to learn how to learn how to learn how ...
9,Disulfiram prevents scars forming in a mouse m...,Disulfiram prevents scars,Disulfiram prevents scars forming in a mouse m...


You already computed ROUGE score on the full dataset, after loading the results from the `data/dialogue-summary-training-results.csv` file. Load the values for the PEFT model now and check its performance compared to other models.

In [60]:
rouge = evaluate.load('rouge')

original_model_results = rouge.compute(
    predictions=df['original_model_prompts'],
    references=df['original_prompts'],
    use_aggregator=True,
    use_stemmer=True,
)

peft_model_results = rouge.compute(
    predictions=df['peft_model_prompts'],
    references=df['original_prompts'],
    use_aggregator=True,
    use_stemmer=True,
)

print('ORIGINAL MODEL:')
print(original_model_results)
print('PEFT MODEL:')
print(peft_model_results)

ORIGINAL MODEL:
{'rouge1': 0.37780240797071085, 'rouge2': 0.2714913551627548, 'rougeL': 0.3618824422472506, 'rougeLsum': 0.36188092912553294}
PEFT MODEL:
{'rouge1': 1.0, 'rouge2': 0.9, 'rougeL': 1.0, 'rougeLsum': 1.0}


The results show less of an improvement over full fine-tuning, but the benefits of PEFT typically outweigh the slightly-lower performance metrics.

Calculate the improvement of PEFT over the original model:

In [61]:
print("Absolute percentage improvement of PEFT MODEL over ORIGINAL MODEL")

improvement = (np.array(list(peft_model_results.values())) - np.array(list(original_model_results.values())))
for key, value in zip(peft_model_results.keys(), improvement):
    print(f'{key}: {value*100:.2f}%')

Absolute percentage improvement of PEFT MODEL over ORIGINAL MODEL
rouge1: 62.22%
rouge2: 62.85%
rougeL: 63.81%
rougeLsum: 63.81%
