<a href="https://colab.research.google.com/github/ericakcc/PEFT-T5-for-Dialogue-Summarization/blob/main/Fine_Tune_T5_for_Dialogue_Summarization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install

In [1]:
!pip install --upgrade pip
!pip install --disable-pip-version-check \
    torch==1.13.1 \
    torchdata==0.5.1 -q

!pip install -q -U \
    transformers==4.27.2 \
    peft==0.3.0 \
    datasets==2.11.0 \
    evaluate==0.4.0 \
    rouge_score==0.1.2

Collecting pip
  Downloading pip-23.2.1-py3-none-any.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m31.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 23.1.2
    Uninstalling pip-23.1.2:
      Successfully uninstalled pip-23.1.2
Successfully installed pip-23.2.1
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m887.5/887.5 MB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.6/4.6 MB[0m [31m104.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m849.3/849.3 kB[0m [31m57.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m557.1/557.1 MB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m317.1/317.1 MB[0m [31m4.4 MB/s[0m eta [36m0:00:

In [16]:
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, TrainingArguments, Trainer
import torch
import time
import evaluate
import pandas as pd
import numpy as np

# Load Dataset and LLM

In [17]:
dataset_name = "knkarthick/dialogsum"
dataset = load_dataset(dataset_name)
dataset



  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 12460
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 1500
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 500
    })
})

In [18]:
model_name = 'google/flan-t5-base'

origin_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [19]:
def print_number_of_trainable_model_parameters(model):
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    return f'trainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {trainable_model_params*100/all_model_params}%'

print(print_number_of_trainable_model_parameters(origin_model))

trainable model parameters: 247577856
all model parameters: 247577856
percentage of trainable model parameters: 100.0%


# Test the model with Zero Shot Inferencing

In [20]:
index = 777

dialogue = dataset['test'][index]['dialogue']
summary = dataset['test'][index]['summary']

prompt = f"""
Summarize the following conversation

{dialogue}
Summary:
"""

inputs = tokenizer(prompt, return_tensors='pt')
output = tokenizer.decode(
    origin_model.generate(
        inputs['input_ids'],
        max_new_tokens=200,
    )[0],
    skip_special_tokens=True
)



dash_line = '-'.join('' for x in range(100))
print(dash_line)
print(f'INPUT PROMPT:\n {prompt}')
print(dash_line)
print(f'BASELINE HUMAN SUMMARY:\n {summary}')
print(dash_line)
print(f'MODEL GENERATION - ZERO SHOT:\n {output}')

---------------------------------------------------------------------------------------------------
INPUT PROMPT:
 
Summarize the following conversation

#Person1#: Hey, How's it going?
#Person2#: Not good. I lost my wallet.
#Person1#: Oh, that's too bad. Was it stolen?
#Person2#: No, I think it came out of my pocket when I was in the taxi.
#Person1#: Is there anything I can do?
#Person2#: Can I borrow some money?
#Person1#: Sure, how much do you need?
#Person2#: About 50 dollars.
#Person1#: That's no problem.
#Person2#: Thanks. I'll pay you back on Friday.
#Person1#: That'll be fine. Here you are.
#Person2#: What are you going to do now?
#Person1#: I'm going to buy some books and then I'm going to the gas station.
#Person2#: If you wait a minute I can go with you.
#Person1#: OK. I'll wait for you.
Summary:

---------------------------------------------------------------------------------------------------
BASELINE HUMAN SUMMARY:
 #Person2# lost the wallet and borrows some money from #

In [21]:
def tokenize_function(example):
    start_prompt = 'Summarize the following conversation.\n\n'
    end_prompt = '\n\nSummary: '
    prompt = [start_prompt + dialogue + end_prompt for dialogue in example['dialogue']]
    example['input_ids'] = tokenizer(prompt, padding='max_length', truncation=True, return_tensors='pt').input_ids
    example['labels'] = tokenizer(example['summary'], padding='max_length', truncation=True, return_tensors='pt').input_ids

    return example


tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(['id', 'topic', 'dialogue', 'summary'])



Map:   0%|          | 0/1500 [00:00<?, ? examples/s]



To save time, we subsample the dataset

In [22]:
tokenized_datasets = tokenized_datasets.filter(lambda example, index: index % 10 ==0, with_indices=True)



Filter:   0%|          | 0/1500 [00:00<?, ? examples/s]



Check the shapes of three parts of the dataset

In [23]:
print(f'Training: {tokenized_datasets["train"].shape}')
print(f'Validation: {tokenized_datasets["validation"].shape}')
print(f'Testing: {tokenized_datasets["test"].shape}')

print(tokenized_datasets)

Training: (1246, 2)
Validation: (50, 2)
Testing: (150, 2)
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'labels'],
        num_rows: 1246
    })
    test: Dataset({
        features: ['input_ids', 'labels'],
        num_rows: 150
    })
    validation: Dataset({
        features: ['input_ids', 'labels'],
        num_rows: 50
    })
})


# Fine-Tune the model with the Preprocessed Dataset

First we try to Fine-tune model with Instruct fine-tune

In [10]:
from transformers.training_args_tf import TFTrainingArguments
output_dir = f'./dialogue-summary-{str(int(time.time()))}'

training_args = TFTrainingArguments(
    output_dir=output_dir,
    learning_rate=1e-5,
    num_train_epochs=1,
    weight_decay=0.01,
    logging_steps=1,
    max_steps=1
)

trainer = Trainer(
    model=origin_model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation']
)



We found that our GPU is out of memory!

In [11]:
trainer.train()



OutOfMemoryError: ignored

# Perform Parameter Efficient Fine-Tuning (PEFT)

In [12]:
from peft import LoraConfig, get_peft_model, TaskType

lora_config = LoraConfig(
    r=32, #rank
    lora_alpha=32,
    target_modules=['q', 'v'],
    lora_dropout=0.05,
    bias='none',
    task_type=TaskType.SEQ_2_SEQ_LM # T5
)

In [13]:
peft_model = get_peft_model(origin_model, lora_config)

print(print_number_of_trainable_model_parameters(peft_model))

trainable model parameters: 3538944
all model parameters: 251116800
percentage of trainable model parameters: 1.4092820552029972%


# Train PEFT Adapter

In [24]:
output_dir = f'./peft-dialogue-summary-{str(int(time.time()))}'

peft_training_args = TrainingArguments(
    output_dir=output_dir,
    auto_find_batch_size=True,
    learning_rate=1e-3,
    num_train_epochs=5,
    logging_steps=10,
    max_steps=100
)

peft_trainer = Trainer(
    model=peft_model,
    args=peft_training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation']
)

In [25]:
peft_trainer.train()

Step,Training Loss
10,34.2625
20,9.2906
30,4.1906
40,3.1953
50,1.8953
60,1.0832
70,0.7539
80,0.593
90,0.5363
100,0.5326


TrainOutput(global_step=100, training_loss=5.63333984375, metrics={'train_runtime': 230.4975, 'train_samples_per_second': 3.471, 'train_steps_per_second': 0.434, 'total_flos': 556503190732800.0, 'train_loss': 5.63333984375, 'epoch': 0.64})

In [35]:
peft_model_path = './peft-dialogue_summary-checkpoint'

peft_trainer.model.save_pretrained(peft_model_path)
tokenizer.save_pretrained(peft_model_path)

('./peft-dialogue_summary-checkpoint/tokenizer_config.json',
 './peft-dialogue_summary-checkpoint/special_tokens_map.json',
 './peft-dialogue_summary-checkpoint/tokenizer.json')

In [37]:
!ls -al './peft-dialogue_summary-checkpoint/adapter_model.bin'

-rw-r--r-- 1 root root 14208525 Oct  2 12:49 ./peft-dialogue_summary-checkpoint/adapter_model.bin


In [39]:
from peft import PeftModel, PeftConfig

peft_model_base = AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-base', torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-base')

peft_model = PeftModel.from_pretrained(peft_model_base,
                    peft_model_path,
                    torch_dtype=torch.bfloat16,
                    is_trainable=False)

In [40]:
print(print_number_of_trainable_model_parameters(peft_model))

trainable model parameters: 0
all model parameters: 251116800
percentage of trainable model parameters: 0.0%


# Evaluate the Model

In [43]:
index = 777
dialogue = dataset['test'][index]['dialogue']
summary = dataset['test'][index]['summary']

prompt = f"""
Summarize the following conversation

{dialogue}
Summary:
"""

input_ids = tokenizer(prompt, return_tensors='pt').input_ids

original_model_outputs = origin_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_token=200, num_beams=1))
original_model_text_output = tokenizer.decode(original_model_outputs[0], skip_special_tokens=True)

peft_model_outputs = peft_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_token=200, num_beams=1))
peft_model_text_output = tokenizer.decode(peft_model_outputs[0], skip_special_tokens=True)


print(dash_line)
print(f'BASELINE HUMAN SUMMARY:\n {summary}')
print(dash_line)
print(f'ORIGINAL MODEL SUMMARY:\n {original_model_text_output}')
print(dash_line)
print(f'PEFT MODEL SUMMARY:\n {peft_model_text_output}')



---------------------------------------------------------------------------------------------------
BASELINE HUMAN SUMMARY:
 #Person2# lost the wallet and borrows some money from #Person1#.
---------------------------------------------------------------------------------------------------
ORIGINAL MODEL SUMMARY:
 Person1 lost his wallet.
---------------------------------------------------------------------------------------------------
PEFT MODEL SUMMARY:
 You're going to buy some books and then you're going to the gas station.


# Evaluate the Model Quantitative (with ROUGE Metric)

In [48]:
dialogues = dataset['test'][0:10]['dialogue']
human_baseline_summaries = dataset['test'][0:10]['summary']

origin_model_summaries = []
peft_model_summaries = []

for idx, dialogue in enumerate(dialogues):
    prompt = f"""
Summarize the following conversation.

{dialogue}

Summary: """

    input_ids = tokenizer(prompt, return_tensors='pt').input_ids

    human_baseline_text_output = human_baseline_summaries[idx]

    original_model_outputs = origin_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200))
    original_model_text_output = tokenizer.decode(original_model_outputs[0], skip_special_tokens=True)

    peft_model_outputs = peft_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_token=200))
    peft_model_text_output = tokenizer.decode(peft_model_outputs[0], skip_special_tokens=True)

    origin_model_summaries.append(original_model_text_output)
    peft_model_summaries.append(peft_model_text_output)

zipped_summaries = list(zip(human_baseline_summaries, origin_model_summaries, peft_model_summaries))
df = pd.DataFrame(zipped_summaries, columns=['human', 'original_model', 'peft_model'])
df

Unnamed: 0,human,original_model,peft_model
0,Ms. Dawson helps #Person1# to write a memo to ...,#Person1#: I need to take a dictation for you.,#Person1# needs to take a dictation for his of...
1,In order to prevent employees from wasting tim...,#Person1#: I need to take a dictation for you.,#Person1# needs to take a dictation for his of...
2,Ms. Dawson takes a dictation for #Person1# abo...,#Person1#: I need to take a dictation for you.,#Person1# needs to take a dictation for his of...
3,#Person2# arrives late because of traffic jam....,The traffic jam at the Carrefour intersection ...,@Person1#: I feel bad about how much my car is...
4,#Person2# decides to follow #Person1#'s sugges...,The traffic jam at the Carrefour intersection ...,@Person1#: I feel bad about how much my car is...
5,#Person2# complains to #Person1# about the tra...,The traffic jam at the Carrefour intersection ...,@Person1#: I feel bad about how much my car is...
6,#Person1# tells Kate that Masha and Hero get d...,Masha and Hero are getting divorced.,@Person1#: #Person1## is having a separation f...
7,#Person1# tells Kate that Masha and Hero are g...,Masha and Hero are getting divorced.,@Person1#: #Person1## is having a separation f...
8,#Person1# and Kate talk about the divorce betw...,Masha and Hero are getting divorced.,@Person1#: #Person1## is having a separation f...
9,#Person1# and Brian are at the birthday party ...,"#Person1#: Happy birthday, Brian. #Person2#: I...",@Person1#: This is a very nice party. #Person1#


Compute ROUGE score for this subset of the data

In [66]:
rouge = evaluate.load('rouge')

origin_model_results = rouge.compute(
    predictions=origin_model_summaries,
    references=human_baseline_summaries[0:len(origin_model_summaries)],
    use_aggregator=True,
    use_stemmer=True,
)

peft_model_results = rouge.compute(
    predictions=peft_model_summaries,
    references=human_baseline_summaries[0:len(peft_model_summaries)],
    use_aggregator=True,
    use_stemmer=True,
)

print(origin_model_results)
print(peft_model_results)

{'rouge1': 0.241950545026632, 'rouge2': 0.1179539641943734, 'rougeL': 0.22166387959866218, 'rougeLsum': 0.22283940294809862}
{'rouge1': 0.18229618024438354, 'rouge2': 0.017142857142857147, 'rougeL': 0.163633887906026, 'rougeLsum': 0.16308642214116895}
