In [1]:
import re
from tqdm import tqdm
import pandas as pd
from sklearn.model_selection import train_test_split

from transformers import T5Tokenizer, T5ForConditionalGeneration, AdamW
import torch
from torch.utils.data import Dataset, DataLoader

  from .autonotebook import tqdm as notebook_tqdm


## Data preprocessing

In [2]:
df = pd.read_csv('C:\\Users\\tasnu\\OneDrive\\Documents\\EssenceAI\\summarizer_model\\gigadata_corpus.csv')

df.head()

Unnamed: 0,article,summary
0,at least two people were killed in a suspected...,at least two dead in southern philippines blast\n
1,australian shares closed down #.# percent mond...,australian stocks close down #.# percent\n
2,south korea 's nuclear envoy kim sook urged no...,envoy urges north korea to restart nuclear dis...
3,south korea on monday announced sweeping tax r...,skorea announces tax cuts to stimulate economy\n
4,taiwan share prices closed down #.## percent m...,taiwan shares close down #.## percent\n


In [6]:
df['article'].sample(10).values

array(['an experimental solar-powered aircraft cruising above switzerland in a historic bid to fly around the clock approached a turning point in its flight on wednesday as night-time loomed .\n',
       'knight ridder , the second-largest newspaper publisher in the country , reported higher earnings in the second quarter as executives saw continued signs of a turnaround in advertising .\n',
       'the body of ##-year-old alisa flatow arrived home wednesday morning , accompanied by her father , who donated her organs to help save israeli lives after she was killed in a suicide bombing in gaza .\n',
       'when good computers go bad , finding the cause can drive you bonkers .\n',
       "the coroner examining the death of diana , princess of wales launched a scathing attack tuesday on her former butler , telling the jury at her inquest it was `` blindingly obvious '' he lied as a witness .\n",
       'the construction of the beijing-kowloon railway has breathed new life into economic 

In [9]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 3803956 entries, 0 to 3803955
Data columns (total 2 columns):
 #   Column   Dtype 
---  ------   ----- 
 0   article  object
 1   summary  object
dtypes: object(2)
memory usage: 58.0+ MB


In [10]:
def clean_text(text):
    text = text.lower()
    text = re.sub(r'#', '', text)
    text = re.sub(r'\n', '', text)
    text = re.sub(r'<.*?>', '', text)
    text = re.sub(r'[^a-zA-Z0-9\s]', '', text)
    return text

df['article'] = df['article'].apply(clean_text)
df['summary'] = df['summary'].apply(clean_text)

In [11]:
df['article'].sample(10).values

array(['the government of macao special administrative region lrb sar rrb will give full support to local cultural and creative industry by fostering more talents and establishing an industry cluster area  the sar s chief executive chui sai on said tuesday ',
       'an exhibition of australian indigenous art  featuring  works by aborigines from unk hills  opened at the taipei fine arts museum lrb tfam rrb on friday ',
       'health minister manto tshabalalamsimang boycotted south africa s national aids conference wednesday as she returned to work after a long absence following a liver transplant ',
       'qiandao lake in chun  an county  zhejiang province  has become china s biggest wildlife zoo following  years of afforestation  and wild animal protection and breeding ',
       'carl lewis was criticized by another former olympic sprint champion in new york on tuesday for skipping a reunion of meter olympic champions in unk of  olympic unk medalist jesse owens ',
       'leaders of

In [13]:
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)

In [None]:
def convert_to_t5(df, input_column, target_column):
    inputs = "summarize: " + df[input_column]
    targets = df[target_column]
    return inputs.tolist(), targets.tolist()

train_input, train_target = convert_to_t5(train_df, 'article', 'summary')
test_input, test_target = convert_to_t5(test_df, 'article', 'summary')

## Data and model load

In [None]:
print(f'Max length of train feature: {max(train_input)}')
print(f'Max length of train Label: {max(train_target)}')
print(f'Max length of test feature: {max(test_input)}')
print(f'Max length of test target: {max(test_target)}')

In [None]:
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")

def tokenize_data(inputs, targets, tokenizer, max_length=512):
    input_encodings = tokenizer(inputs, max_length=max_length, padding=True, truncation=True, return_tensors="pt")
    target_encodings = tokenizer(targets, max_length=150, padding=True, truncation=True, return_tensors="pt")
    return input_encodings, target_encodings

train_encodings, train_target_encodings = tokenize_data(train_input, train_target, tokenizer)
test_encodings, test_target_encodings = tokenize_data(test_input, test_target, tokenizer)

In [None]:
class TextSummarizationDataset(Dataset):
    def __init__(self, input_encodings, target_encodings):
        self.input_encodings = input_encodings
        self.target_encodings = target_encodings

    def __len__(self):
        return len(self.input_encodings['input_ids'])

    def __getitem__(self, idx):
        input_ids = self.input_encodings['input_ids'][idx]
        attention_mask = self.input_encodings['attention_mask'][idx]
        labels = self.target_encodings['input_ids'][idx]

        return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels}

train_dataset = TextSummarizationDataset(train_encodings, train_target_encodings)
test_dataset = TextSummarizationDataset(test_encodings, test_target_encodings)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8)

## Train model

In [None]:
optimizer = AdamW(model.parameters(), lr=10e-3)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

model.train()
epoch_num = 5
for epoch in range(epoch_num):
    loop = tqdm(train_loader, leave=True)
    for batch in loop:
        optimizer.zero_grad()
        
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        loop.set_description(f"Epoch {epoch}")
        loop.set_postfix(loss=loss.item())

## Save model

In [None]:
model.save_pretrained("./t5_weights")
tokenizer.save_pretrained("./t5_tokenizer")

## Test

In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch

# Load the saved model and tokenizer
model_path = "./t5_summarizer_model"
tokenizer_path = "./t5_summarizer_tokenizer"

model = T5ForConditionalGeneration.from_pretrained(model_path)
tokenizer = T5Tokenizer.from_pretrained(tokenizer_path)

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()  # Set the model to evaluation mode


In [None]:
def summarize_paragraph(paragraph, max_length=150, min_length=30):
    # Tokenize the input text
    inputs = tokenizer("summarize: " + paragraph, return_tensors="pt", max_length=512, truncation=True).to(device)
    
    # Generate summary
    summary_ids = model.generate(
        inputs["input_ids"], 
        max_length=max_length, 
        min_length=min_length, 
        length_penalty=2.0, 
        num_beams=4, 
        early_stopping=True
    )
    
    # Decode the summary
    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    
    return summary


In [None]:
paragraph = """The COVID-19 pandemic has caused a global economic slowdown. Governments around the world are taking measures to address the crisis. Health systems are under immense pressure, and countries are introducing emergency protocols to handle the situation."""

summary = summarize_paragraph(paragraph)
print("Original Paragraph:", paragraph)
print("Summary:", summary)
