# Training T5 Model for Text Summarization

This notebook demonstrates how to train a T5 model for text summarization using PyTorch Lightning. We will use the `SummarizationDataLoader` class to load and process the dataset and the `Summarizer` class to define the model.

In [1]:
# Import necessary libraries and set up the environment
import sys
import os

# Add the src directory to the system path
sys.path.append(os.path.abspath(os.path.join('..', 'src')))

In [2]:
# Import the DataLoader and Summarizer classes
from training import SummarizationDataLoader, Summarizer

# Import PyTorch and PyTorch Lightning libraries
import torch
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning import Trainer

## Define Callbacks

We will define two callbacks:
- `ModelCheckpoint`: Saves the best model based on the validation loss.
- `EarlyStopping`: Stops training if the validation loss does not improve for a specified number of epochs.

In [3]:
# Define the ModelCheckpoint callback
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath='../checkpoints/',  # Path to save the models
    filename='best-checkpoint',
    save_top_k=1,
    mode='min'
)

# Define the EarlyStopping callback
early_stopping_callback = EarlyStopping(
    monitor='val_loss',
    patience=3,
    mode='min'
)

## Initialize DataLoader and Model

We will initialize the `SummarizationDataLoader` with a batch size of 8 and the `Summarizer` model.

In [4]:
# Initialize the DataLoader with batch size 8
data_module = SummarizationDataLoader(batch_size=8)

# Initialize the Summarizer model
model = Summarizer()

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


## Train the Model

We will set up the `Trainer` with the defined callbacks and train the model for a maximum of 10 epochs.

In [5]:
# Initialize the Trainer with the callbacks
trainer = Trainer(
    max_epochs=10,
    callbacks=[checkpoint_callback, early_stopping_callback],
    devices=1 if torch.cuda.is_available() else 1,
    accelerator="gpu" if torch.cuda.is_available() else "cpu"
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
c:\Users\Martin\miniconda3\envs\text_summarization_with_T5ForConditionalGeneration\Lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [None]:
# Set the model to training mode
model.train()

# Train the model
trainer.fit(model, datamodule=data_module)

## Load the Best Model

After training, we will load the best model saved during training.

In [6]:
# Load the best model checkpoint
best_model_path = os.path.join(
    checkpoint_callback.dirpath,
    checkpoint_callback.filename + '.ckpt'
)

# Check if the best model path is not empty
if best_model_path:
    # Load the model from the checkpoint
    model = Summarizer.load_from_checkpoint(best_model_path)
    print(f"The best model has been load from : {best_model_path}")
else:
    print("No best model checkpoint found.")

The best model has been load from : D:\Estudios\Proyectos\text_summarization_with_T5ForConditionalGeneration\checkpoints\best-checkpoint.ckpt


## Evaluate the Model

We will ensure the model is in evaluation mode and evaluate it with test data.

In [None]:
# Set the model to evaluation mode
model.eval()

trainer.test(model, datamodule=data_module)

## Save the Model and Tokenizer

We save the model.

In [7]:
model.model.save_pretrained("../fine_tuned/fine_tuned_t5")
model.tokenizer.save_pretrained("../fine_tuned/fine_tuned_t5")

('../fine_tuned/fine_tuned_t5\\tokenizer_config.json',
 '../fine_tuned/fine_tuned_t5\\special_tokens_map.json',
 '../fine_tuned/fine_tuned_t5\\spiece.model',
 '../fine_tuned/fine_tuned_t5\\added_tokens.json')

## Use the model for inference

In [8]:
from transformers import T5ForConditionalGeneration, T5Tokenizer
model = T5ForConditionalGeneration.from_pretrained("../fine_tuned/fine_tuned_t5")
tokenizer = T5Tokenizer.from_pretrained("../fine_tuned/fine_tuned_t5")

In [9]:
def summary_text(text:str) -> str:
    inputs = tokenizer.encode(
                "summarize: " + text,
                return_tensors="pt",
                max_length=1024,
                truncation=True
            )
    summary_ids = model.generate(
                inputs,
                max_length=128,
                min_length=40,
                length_penalty=2.0,
                num_beams=4,
                early_stopping=True
            )
    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return summary

In [10]:
# Example input text for summarization
input_text = """SECTION 1. LIABILITY OF BUSINESS ENTITIES PROVIDING USE OF FACILITIES TO NONPROFIT ORGANIZATIONS.
(a) Definitions.--In this section:
    (1) Business entity.--The term "business entity" means a firm, corporation, association, partnership, consortium, joint venture, or other form of enterprise.
    (2) Facility.--The term "facility" means any real property, including any building, improvement, or appurtenance.
    (3) Gross negligence.--The term "gross negligence" means voluntary and conscious conduct by a person with knowledge (at the time of the conduct) that the conduct is likely to be harmful to the health or well-being of another person.
    (4) Intentional misconduct.--The term "intentional misconduct" means conduct by a person with knowledge (at the time of the conduct) that the conduct is harmful to the health or well-being of another person.
    (5) Nonprofit organization.--The term "nonprofit organization" means:
        (A) any organization described in section 501(c)(3) of the Internal Revenue Code of 1986 and exempt from tax under section 501(a) of such Code; or
        (B) any not-for-profit organization organized and conducted for public benefit and operated primarily for charitable, civic, educational, religious, welfare, or health purposes.
"""

print("Summary:", summary_text(input_text))

Summary: Business entity is a firm, corporation, association, partnership, partnership, consortium, joint venture, or other form of enterprise. Defines "business entity" as a firm, corporation, association, association, partnership, partnership, consortium, joint venture, or other form of enterprise.
