In [23]:
import pandas as pd
import numpy as np

from sklearn.model_selection import train_test_split
from simpletransformers.seq2seq import Seq2SeqModel
from sumeval.metrics.rouge import RougeCalculator

In [2]:
df = pd.read_csv('data/summarize/news_data.csv')

In [3]:
df = df.drop(['text_count', 'headlines_count'], axis=1)

In [4]:
df.columns = ['target_text', 'input_text']

In [5]:
train_df, eval_df = train_test_split(df, test_size=0.2, random_state=2021)
eval_df, test_df = train_test_split(eval_df, test_size=0.5, random_state=2021)

In [6]:
train_params = {
    'max_seq_length': 100,
    'train_batch_size': 8,
    'eval_batch_size': 8,
    'num_train_epochs': 2,
    'evaluate_during_training': True,
    'evaluate_generated_text': True,
    'use_multiprocessing': False,
    'use_multiprocessing_for_evaluation': False,
    'fp16': False,
    'save_steps': -1,
    'save_eval_checkpoints': False,
    'save_model_every_epoch': False,
    'no_cache': True,
    'reprocess_input_data': True,
    'overwrite_output_dir': True,
    'preprocess_inputs': False,
    'num_return_sequences': 1 
}

In [7]:
model = Seq2SeqModel(
    encoder_decoder_type='bart',
    encoder_decoder_name='sshleifer/distilbart-xsum-6-6',
    args=train_params,
    use_cuda=True
)

In [8]:
model.train_model(train_df, eval_data=eval_df)

HBox(children=(FloatProgress(value=0.0, max=82332.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Epoch', max=2.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Running Epoch 0 of 2', max=10292.0, style=ProgressStyle(d…

HBox(children=(FloatProgress(value=0.0, max=10291.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Generating outputs', max=1287.0, style=ProgressStyle(desc…






HBox(children=(FloatProgress(value=0.0, max=10291.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Generating outputs', max=1287.0, style=ProgressStyle(desc…




HBox(children=(FloatProgress(value=0.0, max=10291.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Generating outputs', max=1287.0, style=ProgressStyle(desc…




HBox(children=(FloatProgress(value=0.0, max=10291.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Generating outputs', max=1287.0, style=ProgressStyle(desc…




HBox(children=(FloatProgress(value=0.0, max=10291.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Generating outputs', max=1287.0, style=ProgressStyle(desc…





HBox(children=(FloatProgress(value=0.0, max=10291.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Generating outputs', max=1287.0, style=ProgressStyle(desc…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 1 of 2', max=10292.0, style=ProgressStyle(d…

HBox(children=(FloatProgress(value=0.0, max=10291.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Generating outputs', max=1287.0, style=ProgressStyle(desc…




HBox(children=(FloatProgress(value=0.0, max=10291.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Generating outputs', max=1287.0, style=ProgressStyle(desc…




HBox(children=(FloatProgress(value=0.0, max=10291.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Generating outputs', max=1287.0, style=ProgressStyle(desc…




HBox(children=(FloatProgress(value=0.0, max=10291.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Generating outputs', max=1287.0, style=ProgressStyle(desc…




HBox(children=(FloatProgress(value=0.0, max=10291.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Generating outputs', max=1287.0, style=ProgressStyle(desc…





HBox(children=(FloatProgress(value=0.0, max=10291.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Generating outputs', max=1287.0, style=ProgressStyle(desc…





(20584,
 {'global_step': [2000,
   4000,
   6000,
   8000,
   10000,
   10292,
   12000,
   14000,
   16000,
   18000,
   20000,
   20584],
  'eval_loss': [1.579082104848537,
   1.4433962243327636,
   1.376512405292508,
   1.3118813892344494,
   1.2764868078624545,
   1.2668138277465832,
   1.2758675771761794,
   1.2577147324509461,
   1.2339377035072077,
   1.2179725118265101,
   1.2056450303803143,
   1.206463550933128],
  'train_loss': [1.7399547100067139,
   1.319960594177246,
   1.2697116136550903,
   1.173964262008667,
   1.0818990468978882,
   0.8820846676826477,
   1.1742870807647705,
   0.7473753094673157,
   0.7047785520553589,
   0.9836313128471375,
   1.1039659976959229,
   0.43869495391845703]})

In [12]:
model.eval_model(test_df)

HBox(children=(FloatProgress(value=0.0, max=10292.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=1287.0, style=ProgressStyle(desc…




HBox(children=(FloatProgress(value=0.0, description='Generating outputs', max=1287.0, style=ProgressStyle(desc…




{'eval_loss': 1.191815466313929}

In [25]:
rogue = RougeCalculator(stopwords=True, lang='en')

# Rogue score
predictions = model.predict(test_df['input_text'].tolist())

HBox(children=(FloatProgress(value=0.0, description='Generating outputs', max=1287.0, style=ProgressStyle(desc…




In [29]:
rogue_scores = []
for pred, gt in zip(predictions, test_df['target_text'].tolist()):
    rogue_scores.append({
        'rogue_1': rogue.rouge_n(summary=pred, references=gt, n=1),
        'rogue_2': rogue.rouge_n(summary=pred, references=gt, n=2),
        'rogue_l': rogue.rouge_l(summary=pred, references=gt)
    })

In [31]:
score_df = pd.DataFrame(rogue_scores)

In [33]:
score_df['rogue_1'].mean(), score_df['rogue_2'].mean(), score_df['rogue_l'].mean()

(0.5853705107472392, 0.3166183272062189, 0.544477800992273)