# Fine-Tune a Generative AI Model for Dialogue Summarization

Fine-tune FLAN-T5 model from Hugging Face for enhanced dialogue summarization. Full fine-tuning and Parameter Efficient Fine-Tuning (PEFT) will be explored and evaluated with ROUGE metrics.

In [1]:
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, TrainingArguments, Trainer
import torch
import time
import evaluate
import pandas as pd
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm
  warn(


In [2]:
hf_dataset_name = "knkarthick/dialogsum"
dataset = load_dataset(hf_dataset_name)
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 12460
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 500
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 1500
    })
})


In [3]:
model_name = "google/flan-t5-base"
# original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
# original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to('xpu')
original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to('xpu')
tokenizer = AutoTokenizer.from_pretrained(model_name)

It is possible to pull out the number of model parameters and find out how many of them are trainable. The following function can be used to do that.

In [4]:
def print_number_of_trainable_model_parameters(model):
    trainable_model_params = 0
    all_model_params = 0

    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    
    return f"trainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters {trainable_model_params/all_model_params * 100}%"

print(print_number_of_trainable_model_parameters(original_model)) 

trainable model parameters: 247577856
all model parameters: 247577856
percentage of trainable model parameters 100.0%


### 1.3 Test the Model with Zero Shot Inferencing

In [5]:
index = 200

dialogue = dataset["test"][index]["dialogue"]
summary = dataset["test"][index]["summary"]

prompt = f"""
Summarize the following conversation.

{dialogue}

Summary:
"""

output = tokenizer.decode(original_model.generate(tokenizer(prompt, return_tensors="pt")['input_ids'].to('xpu'), max_new_tokens=200)[0], skip_special_tokens=True)

dash_line = "-".join("" for i in range(50))
print(prompt)
print(dash_line)
print("Baseline Summary:\n", summary)
print(dash_line)
print("Model Generation - Zero Shot:\n", output)



Summarize the following conversation.

#Person1#: Have you considered upgrading your system?
#Person2#: Yes, but I'm not sure what exactly I would need.
#Person1#: You could consider adding a painting program to your software. It would allow you to make up your own flyers and banners for advertising.
#Person2#: That would be a definite bonus.
#Person1#: You might also want to upgrade your hardware because it is pretty outdated now.
#Person2#: How can we do that?
#Person1#: You'd probably need a faster processor, to begin with. And you also need a more powerful hard disc, more memory and a faster modem. Do you have a CD-ROM drive?
#Person2#: No.
#Person1#: Then you might want to add a CD-ROM drive too, because most new software programs are coming out on Cds.
#Person2#: That sounds great. Thanks.

Summary:

-------------------------------------------------
Baseline Summary:
 #Person1# teaches #Person2# how to upgrade software and hardware in #Person2#'s system.
------------------------

## 2. Perform Full Fine-Tuning

### 2.1 Preprocess the Dialog-Summary Dataset

In [6]:
def tokenize_function(example):
    start_prompt = "Summarize the following conversation.\n\n"
    end_prompt = "\n\nSummary: "
    prompt = [start_prompt + dialogue + end_prompt for dialogue in example["dialogue"]]
    # prompt = start_prompt + example["dialogue"] + end_prompt # when batched=False
    example['input_ids'] = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt")['input_ids'].to('xpu')
    example['labels'] = tokenizer(example['summary'], padding="max_length", truncation=True, return_tensors="pt").input_ids
    return example

# The dataset actually contains 3 diff split: train, validation, test.
# The tokenize_function code is handling all data across all splits in batches
tokenized_dataset = dataset.map(tokenize_function, batched=True)
tokenized_dataset = tokenized_dataset.remove_columns(['id', 'topic', 'dialogue', 'summary'])
# print(tokenized_dataset['validation'][0]['input_ids'])

In [7]:
# To save some time, subsample the dataset:

tokenized_dataset = tokenized_dataset.filter(lambda example, index: index % 10 == 0, with_indices=True)

In [8]:
print("Shapes of dataset:")
print(f"Training: {tokenized_dataset['train'].shape}")
print(f"Validation: {tokenized_dataset['validation'].shape}")
print(f"Test: {tokenized_dataset['test'].shape}")
print(tokenized_dataset)

Shapes of dataset:
Training: (1246, 2)
Validation: (50, 2)
Test: (150, 2)
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'labels'],
        num_rows: 1246
    })
    validation: Dataset({
        features: ['input_ids', 'labels'],
        num_rows: 50
    })
    test: Dataset({
        features: ['input_ids', 'labels'],
        num_rows: 150
    })
})


### 2.2 Fine-Tune the Model with the Preprocessed Dataset

In [9]:
# output_dir = f"./dialogue-summary-training-{str(int(time.time()))}"

# training_args = TrainingArguments(
#     output_dir=output_dir,
#     learning_rate=1e-5,
#     num_train_epochs=1,
#     weight_decay=0.01,
#     logging_steps=1,
#     max_steps=1
# )

# trainer = Trainer(
#     model=original_model,
#     args=training_args,
#     train_dataset=tokenized_dataset['train'],
#     eval_dataset=tokenized_dataset['validation']
# )

In [10]:
# trainer.train()

### 2.4 Evaluate the Model Quantitatively (with ROUGE Metric)

In [11]:
rouge = evaluate.load('rouge')

In [12]:
human_baseline_summaries = []
original_model_summaries = []

for i in range(10):
    human_baseline_summaries.append(dataset["test"][i]['summary'])
    prompt = f"""
        Summarize the following conversation:
        {dataset["test"][i]['dialogue']}
        Summary: 
"""
    model_output = original_model.generate(tokenizer(prompt, return_tensors='pt')['input_ids'].to('xpu'))[0]
    # original_model.generate(tokenizer(prompt, return_tensors="pt")['input_ids'].to('xpu')
    original_model_summaries.append(tokenizer.decode(model_output, skip_special_tokens=True))

zipped_summaries = list(zip(human_baseline_summaries, original_model_summaries))
df = pd.DataFrame(zipped_summaries, columns=['Human Baseline', 'Original Model'])
df



Unnamed: 0,Human Baseline,Original Model
0,Ms. Dawson helps #Person1# to write a memo to ...,#Person1#: I need to take a dictation for you. #
1,In order to prevent employees from wasting tim...,#Person1#: I need to take a dictation for you. #
2,Ms. Dawson takes a dictation for #Person1# abo...,#Person1#: I need to take a dictation for you. #
3,#Person2# arrives late because of traffic jam....,The traffic jam at the Carrefour intersection ...
4,#Person2# decides to follow #Person1#'s sugges...,The traffic jam at the Carrefour intersection ...
5,#Person2# complains to #Person1# about the tra...,The traffic jam at the Carrefour intersection ...
6,#Person1# tells Kate that Masha and Hero get d...,Masha and Hero are getting divorced.
7,#Person1# tells Kate that Masha and Hero are g...,Masha and Hero are getting divorced.
8,#Person1# and Kate talk about the divorce betw...,Masha and Hero are getting divorced.
9,#Person1# and Brian are at the birthday party ...,"#Person1#: Happy Birthday, Brian. #Person2#: I'"


In [13]:
original_model_results = rouge.compute(
    predictions=original_model_summaries, 
    references=human_baseline_summaries, 
    use_aggregator=True, 
    use_stemmer=True)

print(original_model_results)

{'rouge1': 0.3043025373670535, 'rouge2': 0.11228756210130228, 'rougeL': 0.2602913752913753, 'rougeLsum': 0.25733446959253414}


## 3. Perform Parameter Efficient Fine-Tuning (PEFT)

PEFT is a generic term tha includes **Low-Rank Adaptation (LoRA)** and prompt tuning (NOT THE SAME as prompt engineering!). LoRA at a very high level allows the user to fine-tune their model using fewer compute resources (in some cases a single GPU). After fine-tuning for a specific task, use case or tenant with LoRA, the result is that the original LLM remains unchange and a newly-trained "LoRA adapter" emerges. This LoRA adapter is much, much smaller than the original LLM (MBs vs GBs). 

That said, at inference time, the LoRA adapter needs to be reunited and combined with its original LLM to serve the inference request. The benefit is that many LoRA adapters can re-use the original LLM which reduces overall memory requirements when serving multiple tasks and use cases.

### 3.1 Setup the PEFT/LoRA model for fine-tuning

In [14]:
from peft import LoraConfig, get_peft_model, TaskType

lora_config = LoraConfig(
    r=32, # Rank,
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM # FLAN-T5
)

Add LoRA adapter  layers/parameters to the original LLM to be trained.

In [15]:
peft_model = get_peft_model(original_model, lora_config)

print(print_number_of_trainable_model_parameters(peft_model))

trainable model parameters: 3538944
all model parameters: 251116800
percentage of trainable model parameters 1.4092820552029972%


### 3.2 Train PEFT Adapter

In [16]:
output_dir = f"./dialogue-summary-training-peft"

peft_training_args = TrainingArguments(
    output_dir=output_dir,
    auto_find_batch_size=True,
    learning_rate=1e-3, # Higher learning rate than full fine-tuning
    num_train_epochs=1,
    # weight_decay=0.01,
    logging_steps=1,
    # max_steps=5,
)

peft_trainer = Trainer(
    model=peft_model,
    args=peft_training_args,
    train_dataset=tokenized_dataset['train'],
    # eval_dataset=tokenized_dataset['validation'],
)


In [None]:
# peft_trainer.train()
peft_model_path = "./dialogue-summary-training-peft/lora"

# peft_trainer.model.save_pretrained(peft_model_path)
# tokenizer.save_pretrained(peft_model_path)


In [18]:
from peft import PeftModel, PeftConfig

peft_model_base = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to('xpu')
tokenizer = AutoTokenizer.from_pretrained(model_name)

peft_model = PeftModel.from_pretrained(peft_model_base,
                                       peft_model_path,
                                       torch_dtype=torch.bfloat16,
                                       is_trainable=False).to("xpu")

print(print_number_of_trainable_model_parameters(peft_model))

trainable model parameters: 0
all model parameters: 251116800
percentage of trainable model parameters 0.0%


### 3.3 Evaluate the Model Qualitatively (Human Evaluation)

In [19]:
prompt = f"""
        Summarize the following conversation:
        {dataset['test'][index]['dialogue']}
        Summary: 
"""

input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("xpu")

peft_base_outputs = tokenizer.decode(original_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200, num_beams=1))[0], skip_special_tokens=True)
peft_model_outputs = tokenizer.decode(peft_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200, num_beams=1))[0], skip_special_tokens=True)

print(peft_base_outputs)
print(peft_model_outputs)
print(dataset['test'][index]['summary'])
# output = tokenizer.decode(original_model.generate(tokenizer(prompt, return_tensors="pt")['input_ids'].to('xpu'), max_new_tokens=200)[0], skip_special_tokens=True)

#Person2# wants to upgrade his hardware. #Person1# wants to upgrade his computer to a CD-ROM drive.
#Person2# wants to upgrade his hardware. #Person2# wants to add a painting program to his software. #Person2# wants to upgrade his hardware. #Person2# wants to add a CD-ROM drive.
#Person1# teaches #Person2# how to upgrade software and hardware in #Person2#'s system.


In [21]:
human_baseline_summaries = []
original_model_summaries = []
peft_model_summaries = []

for i in range(10):
    human_baseline_summaries.append(dataset["test"][i]['summary'])
    prompt = f"""
        Summarize the following conversation:
        {dataset["test"][i]['dialogue']}
        Summary: 
"""
    model_output = original_model.generate(tokenizer(prompt, return_tensors='pt')['input_ids'].to('xpu'))[0]
    original_model_summaries.append(tokenizer.decode(model_output, skip_special_tokens=True))

    peft_output = peft_model.generate(input_ids=tokenizer(prompt, return_tensors='pt')['input_ids'].to('xpu'))[0]
    peft_model_summaries.append(tokenizer.decode(peft_output, skip_special_tokens=True))

zipped_summaries = list(zip(human_baseline_summaries, original_model_summaries, peft_model_summaries))
df = pd.DataFrame(zipped_summaries, columns=['Human Baseline', 'Original Model', 'PEFT Model'])
df



Unnamed: 0,Human Baseline,Original Model,PEFT Model
0,Ms. Dawson helps #Person1# to write a memo to ...,@Person1# wants to take a dictation for me.,#Person2# wants to take a dictation for #Person2#
1,In order to prevent employees from wasting tim...,#Person2# wants #Person2 to take a dictation f...,#Person2# wants to take a dictation for #Person2#
2,Ms. Dawson takes a dictation for #Person1# abo...,@Person1# is a memo to all employees.,#Person2# wants to take a dictation for #Person2#
3,#Person2# arrives late because of traffic jam....,#Person1# is stuck in traffic and a terrible t...,#Person1# got stuck in traffic and got stuck i...
4,#Person2# decides to follow #Person1#'s sugges...,You're finally here!,#Person1# got stuck in traffic and got stuck i...
5,#Person2# complains to #Person1# about the tra...,#Person2# is stuck in traffic and he wants to ...,#Person1# got stuck in traffic and got stuck i...
6,#Person1# tells Kate that Masha and Hero get d...,You never believe what happened when Masha and...,#Person1# is getting divorced. #Person2# is su...
7,#Person1# tells Kate that Masha and Hero are g...,#Person1# and Hero are getting divorced. #Pers...,#Person1# is getting divorced. #Person2# is su...
8,#Person1# and Kate talk about the divorce betw...,@Person2# wants to divorce Masha and Hero. Mas...,#Person1# is getting divorced. #Person2# is su...
9,#Person1# and Brian are at the birthday party ...,#Person1# is a great party. #Person1# is happy to,#Person1# wishes Brian's birthday. #Person2# is a


In [23]:
original_model_results = rouge.compute(
    predictions=original_model_summaries, 
    references=human_baseline_summaries, 
    use_aggregator=True, 
    use_stemmer=True)

peft_model_results = rouge.compute(
    predictions=peft_model_summaries, 
    references=human_baseline_summaries, 
    use_aggregator=True, 
    use_stemmer=True)

print(original_model_results)
print(peft_model_results)

{'rouge1': 0.25116375590640294, 'rouge2': 0.05666980405578849, 'rougeL': 0.20850201321975514, 'rougeLsum': 0.2032868898618424}
{'rouge1': 0.2453671615740581, 'rouge2': 0.03356643356643356, 'rougeL': 0.21945933376967858, 'rougeLsum': 0.21751271716788956}
