In [20]:
from simplet5 import SimpleT5
from sklearn.model_selection import train_test_split

import numpy as np
import pandas as pd

# Data

In [25]:
df = pd.read_csv('./data/news_summary.csv', encoding='iso-8859-1')[['text', 'ctext']].dropna()
df = df.rename(columns={'ctext': 'source_text', 'text': 'target_text'})
df.head()

Unnamed: 0,target_text,source_text
0,The Administration of Union Territory Daman an...,The Daman and Diu administration on Wednesday ...
1,Malaika Arora slammed an Instagram user who tr...,"From her special numbers to TV?appearances, Bo..."
2,The Indira Gandhi Institute of Medical Science...,The Indira Gandhi Institute of Medical Science...
3,Lashkar-e-Taiba's Kashmir commander Abu Dujana...,Lashkar-e-Taiba's Kashmir commander Abu Dujana...
4,Hotels in Maharashtra will train their staff t...,Hotels in Mumbai and other Indian cities are t...


In [17]:
# from google.colab import drive
# drive.mount('/content/drive')

Mounted at /content/drive


In [26]:
X_train, X_test = train_test_split(df, test_size=0.1)

### Small EDA

In [20]:
text_lens = []
summary_lens = []
for _, row in df.iterrows():
  text_lens.append(len(row.source_text.split()))
  summary_lens.append(len(row.target_text.split()))
text_lens = np.array(text_lens)
summary_lens = np.array(summary_lens)

In [21]:
np.quantile(text_lens, 0.95), np.quantile(summary_lens, 0.95)

(732.75, 60.0)

In [19]:
X_train.shape

(3956, 2)

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


# Model

In [22]:
model = SimpleT5()

model.from_pretrained("t5","t5-small")

Downloading:   0%|          | 0.00/773k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.32M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.17k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/231M [00:00<?, ?B/s]

In [34]:
model.train(train_df=X_train, # pandas dataframe with 2 columns: source_text & target_text
            eval_df=X_test, # pandas dataframe with 2 columns: source_text & target_text
            source_max_token_len = 768, 
            target_max_token_len = 96,
            batch_size = 4,
            max_epochs = 5,
            use_gpu = True,
            outputdir = "outputs",
            early_stopping_patience_epochs = 0,
            precision = 32
            )

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                       | Params
-----------------------------------------------------
0 | model | T5ForConditionalGeneration | 60.5 M
-----------------------------------------------------
60.5 M    Trainable params
0         Non-trainable params
60.5 M    Total params
242.026   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Global seed set to 42


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [44]:
model.predict(df.source_text[42])

Token indices sequence length is longer than the specified maximum sequence length for this model (543 > 512). Running this sequence through the model will result in indexing errors


['The Daman and Diu administration on Wednesday withdrew a circular that asked women staff to tie rakhis on male colleagues after the order triggered a backlash from employees. "In this connection, all offices/ departments shall remain open and celebrate the festival collectively at a suitable time wherein all the lady staff shall tie rakhis to their colleagues," it had said.']

The Daman and Diu administration on Wednesday withdrew a circular that asked women staff to tie rakhis on male colleagues after the order triggered a backlash from employees. "In this connection, all offices/ departments shall remain open and celebrate the festival collectively at a suitable time wherein all the lady staff shall tie rakhis to their colleagues," it had said.

In [43]:
model.load_model('t5', '/content/outputs/simplet5-epoch-4-train-loss-1.6297-val-loss-1.6945', use_gpu=True)

In [50]:
from tqdm.auto import tqdm
tqdm.pandas()

pred = X_test.source_text.progress_map(model.predict)

  0%|          | 0/440 [00:00<?, ?it/s]

In [54]:
pred.to_csv("/content/predictions.csv")

In [24]:
preds = pd.read_csv("./data/predictions.csv")

In [30]:
import datasets


metric = datasets.load_metric("rouge")

def rouge_scores(candidates, references):
    result = metric.compute(predictions=candidates, references=references, use_stemmer=True)
    result = {key: round(value.mid.fmeasure * 100, 1) for key, value in result.items()}
    return result

rouge_scores(preds.source_text, X_test['target_text'])

{'rouge1': 49.2, 'rouge2': 26.3, 'rougeL': 36.5, 'rougeLsum': 36.5}