# Fine-Tune a Generative AI Model for Dialogue Summarization

## TODOs:
* CPU compatibility
* add artifacts to github

In this notebook we will see how to fine tune an existing LLM from HuggingFace for enhanced dialogue summarization. We will be using the [Flan-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5) model as it provides a high quality instruction tuned model at various sizes. Flan-T5 can summarize text out of the box, but in this notebook we will see how fine-tuning on a high quality dataset can improve its performance for a specific task. Specifically, we will be using the [DialogSum](https://huggingface.co/datasets/knkarthick/dialogsum) dataset from HuggingFace which contains chunks of dialogue and associated summarizations of the dialogue.

## Setup

First up, lets make sure we install some libraries which are needed for this notebook. After the installation, we will import the necessary packages for the notebook

In [2]:
%pip install transformers==4.27.2 --quiet
%pip install torch==1.13.1 --quiet
%pip install py7zr --quiet
%pip install datasets --quiet
%pip install sentencepiece --quiet
%pip install evaluate --quiet
%pip install rouge_score --quiet

[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.


In [3]:
from transformers import AutoTokenizer, T5ForConditionalGeneration, TrainingArguments, Trainer, GenerationConfig
from datasets import load_dataset
import datasets
import torch
import time
import evaluate
import numpy as np
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load T5 Model

We can load the pre-trained Flan-T5 model directly from HuggingFace. Notice that we will be using the [base version](https://huggingface.co/google/flan-t5-base) of flan. This model version has ~247 million model parameters which makes it small compared to other LLMs. For higher quality results, we recommend looking into the larger versions of this model.

In [4]:
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base")

In [5]:
params = sum(p.numel() for p in model.parameters())
print(f'Total Number of Model Parameters: {params}')

Total Number of Model Parameters: 247577856


# Load Dataset

The DialogSum dataset can also be loaded directly from HuggingFace. There are ~15k examples of dialogue in this dataset with associated human summarizations of these datasets

In [6]:
dataset = load_dataset("knkarthick/dialogsum")

Found cached dataset csv (/root/.cache/huggingface/datasets/knkarthick___csv/knkarthick--dialogsum-6d41e9a7b96e340e/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)


  0%|          | 0/3 [00:00<?, ?it/s]

In [7]:
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 12460
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 1500
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 500
    })
})

# Test the Model with Zero-Shot Prompts Before Tuning

In the example below, we highlight how the summarization capability of the model is lacking compared to the baseline summary provided in the dataset. You can see that the model struggles to summarize the dialogue compared to the baseline summary, but it does pull out some important information from the text which indicates the model can be fine tuned to the task at hand.

In [8]:
ind = 40
diag = dataset['test'][ind]['dialogue']
summary = dataset['test'][ind]['summary']

prompt = f'Summarize the following conversation.\n\n{diag}\n\nSummary:'
input_ids = tokenizer(prompt, return_tensors="pt").input_ids

original_outputs = model.to('cpu').generate(input_ids, GenerationConfig(max_new_tokens=200))
original_text_output = tokenizer.decode(original_outputs[0], skip_special_tokens=True)

print(f'Prompt:\n--------------------------\n{prompt}\n--------------------------')
print(f'\nOriginal Response: {original_text_output}')
print(f'Baseline Summary : {summary}')

Prompt:
--------------------------
Summarize the following conversation.

#Person1#: What time is it, Tom?
#Person2#: Just a minute. It's ten to nine by my watch.
#Person1#: Is it? I had no idea it was so late. I must be off now.
#Person2#: What's the hurry?
#Person1#: I must catch the nine-thirty train.
#Person2#: You've plenty of time yet. The railway station is very close. It won't take more than twenty minutes to get there.

Summary:
--------------------------

Original Response: The train is about to leave.
Baseline Summary : #Person1# is in a hurry to catch a train. Tom tells #Person1# there is plenty of time.


# Preprocessing

To preprocess the dataset, we need to append a useful prompt to the start and end of each dialogue set then tokenize the words with HuggingFace. The output dataset will be ready for fine tuning in the next step.

In [9]:
def tokenize_function(example):
    prompt = 'Summarize the following conversation.\n\n'
    end_prompt = '\n\nSummary: '
    inp = [prompt + i + end_prompt for i in example["dialogue"]]
    example['input_ids'] = tokenizer(inp, padding="max_length", truncation=True, return_tensors="pt").input_ids
    example['labels'] = tokenizer(example["summary"], padding="max_length", truncation=True, return_tensors="pt").input_ids
    return example

tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(['id', 'topic', 'dialogue', 'summary',])

Loading cached processed dataset at /root/.cache/huggingface/datasets/knkarthick___csv/knkarthick--dialogsum-6d41e9a7b96e340e/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-89801bb295cfaf0c.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/knkarthick___csv/knkarthick--dialogsum-6d41e9a7b96e340e/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-194e8617b61a8663.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/knkarthick___csv/knkarthick--dialogsum-6d41e9a7b96e340e/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-0cba91fb973196c7.arrow


In [10]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'labels'],
        num_rows: 12460
    })
    test: Dataset({
        features: ['input_ids', 'labels'],
        num_rows: 1500
    })
    validation: Dataset({
        features: ['input_ids', 'labels'],
        num_rows: 500
    })
})

# Fine Tuning with HuggingFace Trainer

Now that the dataset is preprocessed, we can utilize the built-in HuggingFace `Trainer` class to fine tune our model to the task at hand. Please note that training this full model takes a few hours, so for the sake of time, a checkpoint for a model which has been trained on 10 epochs without downsampling has been provided. If you have time to experiment on fully training the model yourself, please see the inline comments for how to change up the code. If you are looking to train on a GPU machine, we have used a `ml.g5.xlarge` instance for the checkpoint provided as a place to start.

In [11]:
# for the sake of time in the lab, we will subsample our dataset
# if you want to take the time to train a model fully, feel free to alter this subsampling to create a larger dataset
# the line below can be completely removed to remove the subsampling
tokenized_datasets = tokenized_datasets.filter(lambda example, indice: indice % 50 == 0, with_indices=True)

output_dir = f'./diag-summary-training-{str(int(time.time()))}'
training_args = TrainingArguments(
    output_dir=output_dir,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=1e-5,
    num_train_epochs=1,
    # num_train_epochs=10, # Use a higher number of epochs when you are not in the lab and have more time to experiment
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    weight_decay=0.01,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation']
)

Loading cached processed dataset at /root/.cache/huggingface/datasets/knkarthick___csv/knkarthick--dialogsum-6d41e9a7b96e340e/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-8834472c9aab1390.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/knkarthick___csv/knkarthick--dialogsum-6d41e9a7b96e340e/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-c07c88916c0d0721.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/knkarthick___csv/knkarthick--dialogsum-6d41e9a7b96e340e/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-9bf169f6d5603528.arrow


In [12]:
trainer.train()



Epoch,Training Loss,Validation Loss
1,No log,23.907644


TrainOutput(global_step=63, training_loss=30.54184105282738, metrics={'train_runtime': 59.6115, 'train_samples_per_second': 4.194, 'train_steps_per_second': 1.057, 'total_flos': 171189338112000.0, 'train_loss': 30.54184105282738, 'epoch': 1.0})

# Load the Trained Model and Original Model

Once the model has finished training, we will load both the original model from HuggingFace and the fune tuned model to do some qualitative and quantitative comparisions.

In [13]:
# if you have trained your own model and want to check it out compared to ours, change the line of code
# below to contain your checkpoint directory

tuned_model = T5ForConditionalGeneration.from_pretrained("./flan-dialogue-summary-checkpoint")
#tuned_model = T5ForConditionalGeneration.from_pretrained(f"./{output_dir}/<put-your-checkpoint-dir-here>")

model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small")

# Qualitative Results

As with many GenAI applications, a qualitative approach where you ask yourself the question "is my model behaving the way it is supposed to?" is usually a good starting point. In the example below (the same one we started this notebook with), you can see how the fine-tuned model is able to create a reasonable summary of the dialogue compared to the original inability to understand what is being asked of the model.

In [14]:
ind = 40
diag = dataset['test'][ind]['dialogue']
summary = dataset['test'][ind]['summary']

prompt = f'Summarize the following conversation.\n\n{diag}\n\nSummary:'
input_ids = tokenizer(prompt, return_tensors="pt").input_ids

original_outputs = model.to('cpu').generate(
    input_ids,
    GenerationConfig(max_new_tokens=200, num_beams=1),
)
original_text_output = tokenizer.decode(original_outputs[0], skip_special_tokens=True)

outputs = tuned_model.to('cpu').generate(
    input_ids,
    GenerationConfig(max_new_tokens=200, num_beams=1,),
)
text_output = tokenizer.decode(outputs[0], skip_special_tokens=True)

print(f'Prompt:\n--------------------------\n{prompt}\n--------------------------')
print(f'\nOriginal Response: {original_text_output}')
print(f'Tuned Response   : {text_output}')
print(f'Baseline Summary : {summary}')

Prompt:
--------------------------
Summarize the following conversation.

#Person1#: What time is it, Tom?
#Person2#: Just a minute. It's ten to nine by my watch.
#Person1#: Is it? I had no idea it was so late. I must be off now.
#Person2#: What's the hurry?
#Person1#: I must catch the nine-thirty train.
#Person2#: You've plenty of time yet. The railway station is very close. It won't take more than twenty minutes to get there.

Summary:
--------------------------

Original Response: How long is it?
Tuned Response   : Tom tells #Person1# it's ten to nine and #Person1# must catch the nine-thirty train. #Person1# has plenty of time to catch the train.
Baseline Summary : #Person1# is in a hurry to catch a train. Tom tells #Person1# there is plenty of time.


# Quatitative Results with ROGUE Metric

The [ROUGE metric](https://en.wikipedia.org/wiki/ROUGE_(metric)) helps quantify the validity of summarizations produced by models. It compares summarizations to a "baseline" summary which is usually created by a human. While not perfect, it does give an indication to the overall increase in summarization effectiveness that we have accomplished by fine-tuning.

In [15]:
rouge = evaluate.load('rouge')

## Evaluate a Subsection of Summaries

In [16]:
# again, for the sake of time, we will only be generating 10 summarizations with each model
# outside of the lab, a good exercise is to increase the number of validation summaries generated
dialogues = dataset['test'][0:10]['dialogue']
human_baseline_summaries = dataset['test'][0:10]['summary']

original_model_summaries = []
tuned_model_summaries = []

for ind, diag in enumerate(dialogues):
    prompt = f'Summarize the following conversation.\n\nConversation:\n{diag}\n\nSummary:'
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids

    original_outputs = model.generate(input_ids, GenerationConfig(max_new_tokens=200))
    original_text_output = tokenizer.decode(original_outputs[0], skip_special_tokens=True)

    outputs = tuned_model.generate(input_ids, GenerationConfig(max_new_tokens=200))
    text_output = tokenizer.decode(outputs[0], skip_special_tokens=True)

    original_model_summaries.append(original_text_output)
    tuned_model_summaries.append(text_output)

In [17]:
original_results = rouge.compute(
    predictions=original_model_summaries,
    references=human_baseline_summaries[0:len(original_model_summaries)],
    use_aggregator=True,
    use_stemmer=True,
)

In [18]:
tuned_results = rouge.compute(
    predictions=tuned_model_summaries,
    references=human_baseline_summaries[0:len(tuned_model_summaries)],
    use_aggregator=True,
    use_stemmer=True,
)

In [19]:
original_results

{'rouge1': 0.0900332093311633,
 'rouge2': 0.02473178994918125,
 'rougeL': 0.07781529489266062,
 'rougeLsum': 0.07762444780283655}

In [20]:
tuned_results

{'rouge1': 0.4135778544290921,
 'rouge2': 0.17256253260629129,
 'rougeL': 0.3086163937375252,
 'rougeLsum': 0.30929219180845624}

## Evalute the Full Dataset

The file called "diag-summary-training-results.csv" contains a pre-populated list of all model results which we can use to evaluate on a larger section of data. The results show substantial improvement in all ROUGE metrics!

In [21]:
import pandas as pd
results = pd.read_csv("diag-summary-training-results.csv")
original_model_summaries = results['original_model_summaries'].values
tuned_model_summaries = results['tuned_model_summaries'].values
human_baseline_summaries = results['human_baseline_summaries'].values

In [22]:
original_results = rouge.compute(
    predictions=original_model_summaries,
    references=human_baseline_summaries[0:len(original_model_summaries)],
    use_aggregator=True,
    use_stemmer=True,
)

In [23]:
tuned_results = rouge.compute(
    predictions=tuned_model_summaries,
    references=human_baseline_summaries[0:len(tuned_model_summaries)],
    use_aggregator=True,
    use_stemmer=True,
)

In [24]:
original_results

{'rouge1': 0.2334158581572823,
 'rouge2': 0.07603964187010573,
 'rougeL': 0.20145520923859048,
 'rougeLsum': 0.20145899339006135}

In [25]:
tuned_results

{'rouge1': 0.42161291557556113,
 'rouge2': 0.18035380596301792,
 'rougeL': 0.3384439349963909,
 'rougeLsum': 0.33835653595561666}

In [26]:
improvement = (np.array(list(tuned_results.values())) - np.array(list(original_results.values())))
for key, value in zip(tuned_results.keys(), improvement):
    print(f'{key} absolute percentage difference after tuning: {value*100:.2f}%')

rouge1 absolute percentage difference after tuning: 18.82%
rouge2 absolute percentage difference after tuning: 10.43%
rougeL absolute percentage difference after tuning: 13.70%
rougeLsum absolute percentage difference after tuning: 13.69%


# Release Resources

In [None]:
%%html

<p><b>Shutting down your kernel for this notebook to release resources.</b></p>
<button class="sm-command-button" data-commandlinker-command="kernelmenu:shutdown" style="display:none;">Shutdown Kernel</button>
        
<script>
try {
    els = document.getElementsByClassName("sm-command-button");
    els[0].click();
}
catch(err) {
    // NoOp
}    
</script>