# Fine Tuning for Summarisation Task
## Introduction
As stated, trying to perform the task of abstractive summarisation through fine tuning a T5 model. As the T5 model has both encoder and the decoder pre-trained, fine-tuning it on dataset should be great start for the task.

While T5 is pre-trained for summarisation task on normal CNN/Daily Mail dataset already, this serves as a demonstration to show how to do it for any domain specific summarisation if needed. 

We are also fine-tuning using Low-Rank Adaptation (LoRA), therefore only small number of parameters have to be fine-tuned for the task, that will augment the baseline model. We will compare then compare its performance in summarisation against the standard non fine-tuned instance.

In [None]:
# Import section
from bert_score import score
from datasets import load_dataset, Dataset, DatasetDict
from transformers import T5ForConditionalGeneration
from transformers.trainer import Trainer
from transformers.training_args import TrainingArguments
from peft import get_peft_model, LoraConfig, TaskType

from typing import cast

from utils import preprocess_function, get_model_name, get_tokenizer, get_data_collator

## Section 1: Preparing the Dataset
The CNN/Daily Mail Dataset of News Articles and their highlights have been hosted as a [huggingface dataset](https://huggingface.co/datasets/abisee/cnn_dailymail) and therefore can be downloaded through the `datasets` library of huggingface.

In [None]:
dataset = load_dataset("cnn_dailymail", "3.0.0")
dataset = cast(DatasetDict, dataset)

### 1.1. Inspect the Dataset

In [None]:
sample = dataset['train'][0]

print("Article:")
print(sample['article'][:300])
print("")
print("Summary:")
print(sample['highlights'])

In [None]:
# Check the number of samples present
print(f"Number of training samples: {len(dataset['train'])}")
print(f"Number of validation samples: {len(dataset['validation'])}")
print(f"Number of test samples: {len(dataset['test'])}")

### 1.2. Split dataset into tokens ready for consumption by the model.

In [None]:
tokenized_dataset = dataset.map(preprocess_function, batched=True)

## Section 2: Creating the model and the LoRA Config
The hugging face interface makes it very easy to perform fine-tuning using LoRA.

In [None]:
# Conditional Generation is needed over raw hidden encoder decoder stats from T5Model for this task.
# This comes with the needed vocabulary logits for generating the summary tokens.
model = T5ForConditionalGeneration.from_pretrained(get_model_name())

We create the following LoRA config
1. Use rank 8 to reduce the number of parameters.
2. Alpha influences how much LoRA matrix contributes to the final output.
3. We target the Query and Values part of the attention module in the model for Adaptation, as they are the most impactful.
4. Adding dropout of 0.05 for better regularisation.
5. Biases are not adapted as of now.
6. Since it generates a summary from article, it is a sequence to sequence task.

In [None]:
# LoRA Config
lora_config = LoraConfig(r=8, lora_alpha=16, target_modules=["q", "v"],
                         lora_dropout=0.05, bias="none", task_type=TaskType.SEQ_2_SEQ_LM)

We then add LoRA adapter to the model.

In [None]:
model = get_peft_model(model, lora_config)

# Show how many parameters we train for indicating efficiency.
model.print_trainable_parameters()

## Section 3: Creating a Trainer
Now that we have obtained the appropriate tokens for the model to consume from the dataset and created a LoRA wrapped model instance for fine-tuning, we will create the trainer instance to actually train the model.

Creating an instance of the TrainingArguments to be supplied to the Trainer.
1. Saving the weights to the results folder.
2. Evaluating the performance every 500 steps and logging progress every 100 steps.
3. On a training and evaluation batch size of 16.
4. With a very small learning rate of 1e-5 as it is a fine tuning task.
5. Warm up starts with a lower learning rate and then gradually increases to our set learning rate to ensure stability.
6. Save the weights every 1000 steps and only retain the 2 most recent checkpoints.
7. Use mixed precision for faster training.
8. Save the logs to the logs folder and no remote report.

In [None]:
# Adjust these according to your hardware constraints and performance requirements.
TOTAL_EPOCHS=5
TRAIN_BATCH=16
EVAL_BATCH=16

In [None]:
# Create a TrainingArguments instance to give the trainer its configuration.
training_args = TrainingArguments(output_dir='./results', eval_steps=500, logging_steps=100, 
                                  per_device_train_batch_size=TRAIN_BATCH, per_device_eval_batch_size=EVAL_BATCH, 
                                  num_train_epochs=TOTAL_EPOCHS, learning_rate=1e-5, 
                                  warmup_steps=200, save_steps=1000, save_total_limit=2, fp16=True,
                                  logging_dir='./logs', report_to='none')

### 3.1. Create an instance of Trainer for training loop

In [None]:
trainer = Trainer(model=model, args=training_args, train_dataset=dataset['train'], 
                  eval_dataset=dataset['validation'], data_collator=get_data_collator())

### 3.2. Run the training loop

In [None]:
trainer.train()

### 3.3. Save the model weights

In [None]:
model.save_pretrained("t5-small-lora-ft")

## Section 4: Generate a summary from a real article

In [None]:
# Switch our fine-tuned model to eval mode, to prevent calculation of gradients.
model.eval()
tokenizer = get_tokenizer()

In [None]:
real_article = "SHORT PARAGRAPH HERE"

In [None]:
article_text = "summarize: " + real_article
inputs = tokenizer(article_text, return_tensors="pt",
                   truncation=True, max_length=512)

# Generate summary
outputs = model.generate(**inputs, max_length=128, num_beams=4)
summary = tokenizer.decode(outputs[0], skip_special_tokens=True)

# Print the generated summary
print(summary)

## Section 5: Compare with Baseline on summarisation performance
BERTScore will be used as it compares the semantic meaning over literal n-gram overlap (as in the case of ROUGE) and therefore is better suited to measure performance of an abstractive summarisation.

In [None]:
# Base model instance to compare performance against
base_model = T5ForConditionalGeneration.from_pretrained("t5-small")

test_set = dataset['test']
test_set = cast(Dataset, test_set)
baseline_summaries = []
finetuned_summaries = []

In [None]:
# Generate the summaries from both model instances
for item in test_set:
    input_text = "summarize: " + item["article"]
    inputs = tokenizer(input_text, return_tensors="pt",
                       truncation=True, max_length=512)
    output1 = base_model.generate(**inputs, max_length=128)
    output2 = model.generate(**inputs, max_length=128)
    summary1 = tokenizer.decode(output1[0], skip_special_tokens=True)
    summary2 = tokenizer.decode(output2[0], skip_special_tokens=True)
    baseline_summaries.append(summary1)
    finetuned_summaries.append(summary2)

# Compute the score
references = [item["highlights"] for item in test_set]
P_base = score(baseline_summaries, references, lang="en")
P_finetuned = score(finetuned_summaries, references, lang="en")

In [None]:
print(f"Base T5 BERTScore F1: {P_base[2].mean().item():.4f}")
print(f"LoRA-Tuned T5 BERTScore F1: {P_finetuned[2].mean().item():.4f}")