# Intro
simpleT5 is built on top of PyTorch-lightning⚡️ and Transformers🤗 that lets you quickly train/fine-tune T5 models.

# Install dependencies

In [None]:
!pip install tensorboard==2.4.1
!pip install pytorch-lightning==1.3.3
!pip install simplet5

# Import dependencies

In [2]:
import pandas as pd
from sklearn.model_selection import train_test_split
from simplet5 import SimpleT5

Global seed set to 42


# Load data

In [3]:
path = "https://raw.githubusercontent.com/Shivanandroy/T5-Finetuning-PyTorch/main/data/news_summary.csv"
df = pd.read_csv(path)

# Pre-process data

In [4]:
# simpleT5 expects dataframe to have 2 columns: "source_text" and "target_text"
df = df.rename(columns={"headlines":"target_text", "text":"source_text"})
df = df[['source_text', 'target_text']]

In [5]:
# T5 model expects a task related prefix: since it is a summarization task, we will add a prefix "summarize: "
df['source_text'] = "summarize: " + df['source_text']

In [6]:
df.head()

Unnamed: 0,source_text,target_text
0,"summarize: Saurav Kant, an alumnus of upGrad a...",upGrad learner switches to career in ML & Al w...
1,summarize: Kunal Shah's credit card bill payme...,Delhi techie wins free food from Swiggy for on...
2,summarize: New Zealand defeated India by 8 wic...,New Zealand end Rohit Sharma-led India's 12-ma...
3,summarize: With Aegon Life iTerm Insurance pla...,Aegon life iTerm insurance plan helps customer...
4,summarize: Speaking about the sexual harassmen...,"Have known Hirani for yrs, what if MeToo claim..."


In [7]:
train_df, test_df = train_test_split(df, test_size=0.2)

In [8]:
train_df.head()

Unnamed: 0,source_text,target_text
67164,summarize: Hollywood actress Kate Winslet has ...,Kate Winslet to work with Titanic maker after ...
66949,summarize: Physicist Marie Curie was the first...,Who all have won the Nobel Prize more than once?
51051,summarize: Pakistan police have arrested the m...,Pak cops arrest main suspect in 7-yr-old's rap...
82218,summarize: A US judge has dismissed a lawsuit ...,FB avoids lawsuit for tracking activity of log...
81564,summarize: The shooting of an upcoming episode...,Shoot of SRK episode gets postponed after Kapi...


# Fine-tune the model

In [None]:
# Finetuning T5 model with simpleT5
model = SimpleT5()
model.from_pretrained(model_type="t5", model_name="t5-base")
model.train(train_df=train_df,
            eval_df=test_df, 
            source_max_token_len=128, 
            target_max_token_len=50, 
            batch_size=8, max_epochs=3, use_gpu=True)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=791656.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1389353.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1199.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=891691430.0, style=ProgressStyle(descri…




GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                       | Params
-----------------------------------------------------
0 | model | T5ForConditionalGeneration | 222 M 
-----------------------------------------------------
222 M     Trainable params
0         Non-trainable params
222 M     Total params
891.614   Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

Global seed set to 42




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

# Predict

In [None]:
# Load and inference
# let's load the trained model for inferencing:
model.load_model("t5","outputs/SimpleT5-epoch-2-train-loss-0.9526", use_gpu=True)

In [None]:
text_to_summarize="""summarize: Rahul Gandhi has replied to Goa CM Manohar Parrikar's letter, 
which accused the Congress President of using his "visit to an ailing man for political gains". 
"He's under immense pressure from the PM after our meeting and needs to demonstrate his loyalty by attacking me," 
Gandhi wrote in his letter. Parrikar had clarified he didn't discuss Rafale deal with Rahul.
"""
model.predict(text_to_summarize)

# Improvement

In [None]:
# Model quantization & ONNX support
# for faster inference on cpu, quantization, onnx support:
model.convert_and_load_onnx_model(model_dir="outputs/SimpleT5-epoch-2-train-loss-0.9526")
model.onnx_predict(text_to_summarize)