# Training T5 Model for Text Summarization

This notebook demonstrates how to fine-tune a T5 model for text summarization using PyTorch Lightning. The workflow includes loading the dataset, defining the model, training, and saving the results.

In [1]:
# Import necessary libraries and set up the environment
import sys
import os

# Add the src directory to the system path
sys.path.append(os.path.abspath(os.path.join('..', 'src')))

In [2]:
# Import the DataLoader and Summarizer classes
from training import SummarizationDataLoader, Summarizer

# Import PyTorch and PyTorch Lightning libraries
import torch
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning import Trainer
from transformers import T5ForConditionalGeneration, T5Tokenizer

## Define Callbacks

Two callbacks are defined:
- **`ModelCheckpoint`**: Saves the best model based on validation loss.
- **`EarlyStopping`**: Halts training if validation loss does not improve for several epochs.

In [3]:
# Define the ModelCheckpoint callback
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath='../checkpoints/',  # Path to save the models
    filename='best-checkpoint',
    save_top_k=1,
    mode='min'
)

# Define the EarlyStopping callback
early_stopping_callback = EarlyStopping(
    monitor='val_loss',
    patience=3,
    mode='min'
)

## Initialize DataLoader and Model

We set up the `SummarizationDataLoader` for data processing and instantiate the T5 model using the `Summarizer` class.

In [None]:
# Initialize the DataLoader with batch size 8
data_module = SummarizationDataLoader(batch_size=8)

# Initialize the Summarizer model
model = Summarizer()

## Train the Model

The `Trainer` is configured with defined callbacks and is used to train the model for up to 10 epochs.

In [None]:
# Initialize the Trainer with the callbacks
trainer = Trainer(
    max_epochs=10,
    callbacks=[checkpoint_callback, early_stopping_callback],
    devices=1 if torch.cuda.is_available() else 1,
    accelerator="gpu" if torch.cuda.is_available() else "cpu"
)

In [None]:
# Set the model to training mode
model.train()

# Train the model
trainer.fit(model, datamodule=data_module)

## Load the Best Model

After training, the best model checkpoint is loaded for evaluation.

In [None]:
# Load the best model checkpoint
best_model_path = os.path.join(
    checkpoint_callback.dirpath,
    checkpoint_callback.filename + '.ckpt'
)

# Check if the best model path is not empty
if best_model_path:
    # Load the model from the checkpoint
    model = Summarizer.load_from_checkpoint(best_model_path)
    print(f"The best model has been loaded from : {best_model_path}")
else:
    print("No best model checkpoint found.")

## Evaluate the Model

The model is set to evaluation mode and tested using the test dataset.

In [None]:
# Set the model to evaluation mode
model.eval()

trainer.test(model, datamodule=data_module)


## Save the Model and Tokenizer

The fine-tuned model and tokenizer are saved for later use.

In [None]:
# Save the model and tokenizer
model.model.save_pretrained("../fine_tuned/fine_tuned_t5")
model.tokenizer.save_pretrained("../fine_tuned/fine_tuned_t5")

## Inference with the Fine-Tuned Model

The saved model and tokenizer are loaded to perform text summarization.

In [8]:
# Load the saved model and tokenizer for inference
model = T5ForConditionalGeneration.from_pretrained("../fine_tuned/fine_tuned_t5")
tokenizer = T5Tokenizer.from_pretrained("../fine_tuned/fine_tuned_t5")

In [9]:
# Define a function for summarization
def summary_text(text: str) -> str:
    inputs = tokenizer.encode(
        "summarize: " + text,
        return_tensors="pt",
        max_length=1024,
        truncation=True
    )
    summary_ids = model.generate(
        inputs,
        max_length=128,
        min_length=40,
        length_penalty=2.0,
        num_beams=4,
        early_stopping=True
    )
    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return summary

In [None]:
# Example input text for summarization
input_text = """SECTION 1. LIABILITY OF BUSINESS ENTITIES PROVIDING USE OF FACILITIES TO NONPROFIT ORGANIZATIONS.
(a) Definitions.--In this section:
    (1) Business entity.--The term "business entity" means a firm, corporation, association, partnership, consortium, joint venture, or other form of enterprise.
    (2) Facility.--The term "facility" means any real property, including any building, improvement, or appurtenance.
    (3) Gross negligence.--The term "gross negligence" means voluntary and conscious conduct by a person with knowledge (at the time of the conduct) that the conduct is likely to be harmful to the health or well-being of another person.
    (4) Intentional misconduct.--The term "intentional misconduct" means conduct by a person with knowledge (at the time of the conduct) that the conduct is harmful to the health or well-being of another person.
    (5) Nonprofit organization.--The term "nonprofit organization" means:
        (A) any organization described in section 501(c)(3) of the Internal Revenue Code of 1986 and exempt from tax under section 501(a) of such Code; or
        (B) any not-for-profit organization organized and conducted for public benefit and operated primarily for charitable, civic, educational, religious, welfare, or health purposes.
"""

print("Summary:", summary_text(input_text))