In [1]:
import os
import re
import numpy as np
import pandas as pd
from datasets import Dataset
from transformers import (AutoTokenizer, AutoModelForSeq2SeqLM,
                          DataCollatorForSeq2Seq, Seq2SeqTrainer,
                          Seq2SeqTrainingArguments)
import evaluate
import torch

# Suppress TF warnings
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import warnings
warnings.filterwarnings('ignore')


  from .autonotebook import tqdm as notebook_tqdm





In [2]:
train_df = pd.read_csv("train.csv")
val_df = pd.read_csv("validation.csv")
test_df = pd.read_csv("test.csv")

# Drop unused columns
for df in [train_df, val_df, test_df]:
    df.drop(columns=['id'], inplace=True)

# Drop duplicates and sample subset
train_df.drop_duplicates(inplace=True)
train_df = train_df.sample(n=40_000, random_state=42)


In [3]:
train_df

Unnamed: 0,article,highlights
263915,By . Daily Mail Reporter . PUBLISHED: . 11:46 ...,Matthew Kay accused of posting series of taste...
238628,"Seoul, South Korea (CNN) -- The suspected cybe...","Wednesday's suspected cyberattack hit 32,000 c..."
66448,By . Richard Arrowsmith for MailOnline . With ...,Tottenham defeated AEL Limassol 3-0 (5-1 agg) ...
44883,"Police arrest two men, both 21, after Barton s...","Police arrest two men, both 21, after Barton s..."
213646,"By . Chris Pleasance . PUBLISHED: . 08:07 EST,...",Cockroaches can tell when others weren't born ...
...,...,...
237289,By . Daily Mail Reporter . PUBLISHED: . 11:13 ...,A new report by the CDC revealed that 30.5 per...
30945,"By . Ted Thornhill . For most people, most of ...",Study found Facebook can also reveal narcissis...
130819,"By . Tim Finan . PUBLISHED: . 06:27 EST, 29 Ju...",Lifeguards call out helicopter to save Londone...
26196,"Tripoli, Libya (CNN) -- Opposition forces in t...","U.S., French and UK leaders write op-ed piece ..."


In [4]:
val_df

Unnamed: 0,article,highlights
0,"Sally Forrest, an actress-dancer who graced th...","Sally Forrest, an actress-dancer who graced th..."
1,A middle-school teacher in China has inked hun...,Works include pictures of Presidential Palace ...
2,A man convicted of killing the father and sist...,"Iftekhar Murtaza, 29, was convicted a year ago..."
3,Avid rugby fan Prince Harry could barely watch...,Prince Harry in attendance for England's crunc...
4,A Triple M Radio producer has been inundated w...,Nick Slater's colleagues uploaded a picture to...
...,...,...
13363,All shops will be allowed to offer ‘click and ...,Shops won't have to apply for planning permiss...
13364,Mo Farah has had his nationality called into q...,Mo Farah broke the European half-marathon reco...
13365,Wolves kept their promotion hopes alive with a...,Wolves are three points off the play-off place...
13366,A Brown University graduate student has died ...,"Hyoun Ju Sohn, a 25-year-old doctoral student,..."


In [5]:
test_df

Unnamed: 0,article,highlights
0,Ever noticed how plane seats appear to be gett...,Experts question if packed out planes are put...
1,A drunk teenage boy had to be rescued by secur...,Drunk teenage boy climbed into lion enclosure ...
2,Dougie Freedman is on the verge of agreeing a ...,Nottingham Forest are close to extending Dougi...
3,Liverpool target Neto is also wanted by PSG an...,Fiorentina goalkeeper Neto has been linked wit...
4,Bruce Jenner will break his silence in a two-h...,"Tell-all interview with the reality TV star, 6..."
...,...,...
11485,Our young Earth may have collided with a body ...,Oxford scientists say a Mercury-like body stru...
11486,A man facing trial for helping his former love...,Man accused of helping former lover kill woman...
11487,A dozen or more metal implements are arranged ...,Marianne Power tried the tuning fork facial at...
11488,Brook Lopez dominated twin brother Robin with ...,Brooklyn Nets beat the Portland Trail Blazers ...


In [6]:
## Text Cleaning

In [7]:
def clean_text(text):
    text = re.sub(r'\s+', ' ', text).strip()
    return text

for df in [train_df, val_df, test_df]:
    df['article'] = df['article'].apply(clean_text)
    df['highlights'] = df['highlights'].apply(clean_text)


In [8]:
## Convert to Hugging Face Dataset

In [9]:
train_dataset = Dataset.from_pandas(train_df)
val_dataset = Dataset.from_pandas(val_df)
test_dataset = Dataset.from_pandas(test_df)

In [10]:
## Load Tokenizer and Model

In [11]:
model_name = 't5-base'
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)


In [12]:
## Preprocessing Function

In [13]:
max_input_length = 512
max_target_length = 128

def preprocess_data(batch):
    inputs = tokenizer(batch['article'], max_length=max_input_length,
                       truncation=True, padding='max_length')
    labels = tokenizer(batch['highlights'], max_length=max_target_length,
                       truncation=True, padding='max_length')
    inputs['labels'] = labels['input_ids']
    return inputs

train_dataset = train_dataset.map(preprocess_data, batched=True)
val_dataset = val_dataset.map(preprocess_data, batched=True)
test_dataset = test_dataset.map(preprocess_data, batched=True)


Map: 100%|███████████████████████████████████████████████████████████████| 40000/40000 [00:55<00:00, 726.91 examples/s]
Map: 100%|███████████████████████████████████████████████████████████████| 13368/13368 [00:20<00:00, 658.53 examples/s]
Map: 100%|███████████████████████████████████████████████████████████████| 11490/11490 [00:17<00:00, 648.51 examples/s]


In [14]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)


In [15]:
## ROUGE Metric

In [16]:
rouge = evaluate.load("rouge")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    return {k: v.mid.fmeasure if hasattr(v, "mid") else v for k, v in result.items()}


In [17]:
## Training Arguments

In [18]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=1, # epoch should be more than 5 but my laptop is taking long so i have taken only one epoch
    gradient_accumulation_steps=4,
    save_steps=500,
    save_total_limit=3,
    predict_with_generate=True,
    fp16=True,
    logging_dir="./logs",
    logging_steps=50,
    learning_rate=2e-5,
    do_train=True,
    do_eval=True  # must be True for older transformers
)


In [19]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)


In [None]:
trainer.train()


Step,Training Loss


In [None]:
metrics = trainer.evaluate(eval_dataset=test_dataset)
print("Test Metrics:", metrics)


In [None]:
model.save_pretrained("./t5_summarizer_model")
tokenizer.save_pretrained("./t5_summarizer_model")
