# Training T5 Model for Text Summarization

This notebook demonstrates how to train a T5 model for text summarization using PyTorch Lightning. We will use the `SummarizationDataLoader` class to load and process the dataset and the `Summarizer` class to define the model.

In [1]:
# Import necessary libraries and set up the environment
import sys
import os

# Add the src directory to the system path
sys.path.append(os.path.abspath(os.path.join('..', 'src')))

In [2]:
# Import the DataLoader and Summarizer classes
from training import SummarizationDataLoader, Summarizer

# Import PyTorch and PyTorch Lightning libraries
import torch
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning import Trainer

## Define Callbacks

We will define two callbacks:
- `ModelCheckpoint`: Saves the best model based on the validation loss.
- `EarlyStopping`: Stops training if the validation loss does not improve for a specified number of epochs.

In [3]:
# Define the ModelCheckpoint callback
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath='../checkpoints/',  # Path to save the models
    filename='best-checkpoint',
    save_top_k=1,
    mode='min'
)

# Define the EarlyStopping callback
early_stopping_callback = EarlyStopping(
    monitor='val_loss',
    patience=3,
    mode='min'
)

## Initialize DataLoader and Model

We will initialize the `SummarizationDataLoader` with a batch size of 8 and the `Summarizer` model.

In [None]:
# Initialize the DataLoader with batch size 8
data_module = SummarizationDataLoader(batch_size=8)

# Initialize the Summarizer model
model = Summarizer()

## Train the Model

We will set up the `Trainer` with the defined callbacks and train the model for a maximum of 10 epochs.

In [None]:
# Initialize the Trainer with the callbacks
trainer = Trainer(
    max_epochs=10,
    callbacks=[checkpoint_callback, early_stopping_callback],
    devices=1 if torch.cuda.is_available() else 1,
    accelerator="gpu" if torch.cuda.is_available() else "cpu"
)

# Set the model to training mode
model.train()

# Train the model
trainer.fit(model, datamodule=data_module)

## Load the Best Model

After training, we will load the best model saved during training.

In [None]:
# Load the best model checkpoint
best_model_path = checkpoint_callback.best_model_path

# Load the model from the checkpoint
model = Summarizer.load_from_checkpoint(best_model_path)

print(f"The best model has been saved at: {best_model_path}")

## Evaluate the Model

We will ensure the model is in evaluation mode and demonstrate how to use it for making predictions.

In [None]:
# Set the model to evaluation mode
model.eval()

# Example function for preprocessing input text
def preprocess(text):
    # Implement preprocessing logic here
    return text

# Example input text for summarization
input_text = "This is an example text to summarize."
preprocessed_text = preprocess(input_text)

# Convert the preprocessed text to a tensor
input_tensor = torch.tensor(preprocessed_text)

# Make a prediction
with torch.no_grad():
    summary = model(input_tensor)

print("Summary:", summary)