In [4]:
# imports -------------------------------------------------------------
import string

import matplotlib.pyplot as plt
import nltk
import seaborn as sns
import sklearn.model_selection as model_selection
from sklearn.metrics import confusion_matrix
import numpy as np 
import pandas as pd 

from nltk.corpus import stopwords
from nltk.stem import PorterStemmer, SnowballStemmer

from sklearn.feature_extraction.text import (CountVectorizer, HashingVectorizer, TfidfVectorizer)
from imblearn.over_sampling import SMOTE

# from xgboost import XGBClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC
from sklearn.decomposition import PCA, TruncatedSVD


from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pathlib import Path


from transformers import (
    AdamW,
    T5ForConditionalGeneration,
    T5TokenizerFast as T5Tokenizer
)
from tqdm.auto import tqdm

In [5]:
df = pd.read_csv("./news_summary.csv", encoding = "latin-1")
df.head()

Unnamed: 0,author,date,headlines,read_more,text,ctext
0,Chhavi Tyagi,"03 Aug 2017,Thursday",Daman & Diu revokes mandatory Rakshabandhan in...,http://www.hindustantimes.com/india-news/raksh...,The Administration of Union Territory Daman an...,The Daman and Diu administration on Wednesday ...
1,Daisy Mowke,"03 Aug 2017,Thursday",Malaika slams user who trolled her for 'divorc...,http://www.hindustantimes.com/bollywood/malaik...,Malaika Arora slammed an Instagram user who tr...,"From her special numbers to TV?appearances, Bo..."
2,Arshiya Chopra,"03 Aug 2017,Thursday",'Virgin' now corrected to 'Unmarried' in IGIMS...,http://www.hindustantimes.com/patna/bihar-igim...,The Indira Gandhi Institute of Medical Science...,The Indira Gandhi Institute of Medical Science...
3,Sumedha Sehra,"03 Aug 2017,Thursday",Aaj aapne pakad liya: LeT man Dujana before be...,http://indiatoday.intoday.in/story/abu-dujana-...,Lashkar-e-Taiba's Kashmir commander Abu Dujana...,Lashkar-e-Taiba's Kashmir commander Abu Dujana...
4,Aarushi Maheshwari,"03 Aug 2017,Thursday",Hotel staff to get training to spot signs of s...,http://indiatoday.intoday.in/story/sex-traffic...,Hotels in Maharashtra will train their staff t...,Hotels in Mumbai and other Indian cities are t...


In [7]:
more_df = pd.read_csv('./news_summary_more.csv')
more_df.head()

Unnamed: 0,headlines,text
0,upGrad learner switches to career in ML & Al w...,"Saurav Kant, an alumnus of upGrad and IIIT-B's..."
1,Delhi techie wins free food from Swiggy for on...,Kunal Shah's credit card bill payment platform...
2,New Zealand end Rohit Sharma-led India's 12-ma...,New Zealand defeated India by 8 wickets in the...
3,Aegon life iTerm insurance plan helps customer...,"With Aegon Life iTerm Insurance plan, customer..."
4,"Have known Hirani for yrs, what if MeToo claim...",Speaking about the sexual harassment allegatio...


In [8]:
# create a preprocessing class
class Preprocessor:
    def __init__(self, df) -> None:
        self.df = df

    
    # convert all charecters to lower case
    def convertToLower(self):
        self.df["headlines"] = self.df["headlines"].apply(lambda x: x.lower())
        self.df["text"] = self.df["text"].apply(lambda x: x.lower())
        return self.df

    
    # remove stop words
    def removeStopWords(self):
        stop = stopwords.words("english")
        self.df["question_text"] = self.df["question_text"].apply(
            lambda x: " ".join([word for word in x.split() if word not in stop])
        )
        return self.df

    
    # remove punctuation
    def removePunctuation(self):
        self.df["question_text"] = self.df["question_text"].str.replace("[^\w\s]", "")
        return self.df

    
    # remove numbers
    def removeNumbers(self):
        self.df["question_text"] = self.df["question_text"].str.replace("[0-9]", "")
        return self.df

    
    # remove whitespaces
    def removeWhitespaces(self):
        self.df["headlines"] = self.df["headlines"].apply(
            lambda x: " ".join(x.split())
        )
        self.df["text"] = self.df["text"].apply(
            lambda x: " ".join(x.split())
        )
        return self.df
    
    
    def cleanPunctuations(self):
        # row = re.sub(r"[<>()|&©ø\[\]\'\",;?~*!]", ' ', str(row)).lower()
        self.df["headlines"] = self.df["headlines"].str.replace("[<>()|&©ø\[\]\'\",;?~*!]", "")
        self.df["text"] = self.df["text"].str.replace("[<>()|&©ø\[\]\'\",;?~*!]", "")
        return self.df
        
        
    # remove urls
    def removeURLs(self):
        self.df["question_text"] = self.df["question_text"].str.replace(
            "https?://\S+|www\.\S+", ""
        )
        return self.df

    
    # snowball stemmer algorithm
    def snowballstemmer(self):
        stemmer = SnowballStemmer()

        def stem_words(text):
            return " ".join([stemmer.stem(word) for word in text.split()])

        self.df["question_text"] = self.df["question_text"].apply(
            lambda x: stem_words(x)
        )
        return self.df

    
    # port stemmer algorithm
    def porterstemmer(self):
        stemmer = PorterStemmer()

        def stem_words(text):
            return " ".join([stemmer.stem(word) for word in text.split()])

        self.df["text"] = self.df["text"].apply(
            lambda x: stem_words(x)
        )
        
        self.df["headlines"] = self.df["headlines"].apply(
            lambda x: stem_words(x)
        )
        return self.df

    
    # lemmatizing
    def lemmatize(self):
        from nltk.stem import WordNetLemmatizer

        lemmatizer = WordNetLemmatizer()

        def lemmatize_words(text):
            return " ".join([lemmatizer.lemmatize(word) for word in text.split()])

        self.df["question_text"] = self.df["question_text"].apply(
            lambda x: lemmatize_words(x)
        )
        return self.df

    
    # remove id and index columns
    def removeUnwantedCols(self, col):
        print(self.df.shape)
        self.df = self.df.drop(col, axis=1)
        return self.df

    
    # word tokenization using nltk
    def wordTokenization(self):
        self.df["question_text"] = self.df["question_text"].apply(
            lambda x: nltk.word_tokenize(x)
        )
        return self.df
        

    def preprocess(self):
        self.df = self.convertToLower()
        #self.df = self.removeStopWords()
        #self.df = self.removePunctuation()
        #self.df = self.removeNumbers()
        #self.df = self.removeURLs()
        self.df = self.removeWhitespaces()
        self.df = self.cleanPunctuations()
        #self.df = self.snowballstemmer()
        self.df = self.porterstemmer()
        # self.df = self.lemmatize()
        #self.df = self.wordTokenization()
        return self.df

In [9]:
preproccesor = Preprocessor(more_df)
preprocessed_df = preproccesor.preprocess()
preprocessed_df.head()

  self.df["headlines"] = self.df["headlines"].str.replace("[<>()|&©ø\[\]\'\",;?~*!]", "")
  self.df["text"] = self.df["text"].str.replace("[<>()|&©ø\[\]\'\",;?~*!]", "")


Unnamed: 0,headlines,text
0,upgrad learner switch to career in ml al with ...,saurav kant an alumnu of upgrad and iiit-b pg ...
1,delhi techi win free food from swiggi for one ...,kunal shah credit card bill payment platform c...
2,new zealand end rohit sharma-l india 12-match ...,new zealand defeat india by 8 wicket in the fo...
3,aegon life iterm insur plan help custom save tax,with aegon life iterm insur plan custom can en...
4,have known hirani for yr what if metoo claim a...,speak about the sexual harass alleg against ra...


In [10]:
EPOCHS = 2
BATCH_SIZE = 8

In [11]:
class SummarizerDataset(Dataset):
    def __init__(self, data: pd.DataFrame, tokenizer: T5Tokenizer, text_max_token_len: int = 512, summary_max_token_len: int = 128):
        self.data = data
        self.tokenizer = tokenizer
        self.text_max_token_len = text_max_token_len
        self.summary_max_token_len = summary_max_token_len
    
    # len and getitem are mandatory methods to be overriden and cannot be removed
    def __len__(self):
        return len(self.data)      

    def __getitem__(self, index: int):
        row = self.data.iloc[index]
        text = row['text']
        summary = row['summary']
        
        # return tensors in form of pytorch tensors
        text_encoding = tokenizer(text, max_length=self.text_max_token_len, padding='max_length',
                                  truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors='pt')

        summary_encoding = tokenizer(summary, max_length=self.summary_max_token_len, padding='max_length',
                                     truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors='pt')                                             

        # All labels set to -100 are ignored (masked) -- from T5 documentation
        labels = summary_encoding['input_ids']
        labels[labels == 0] = -100 

        return dict(text=text, summary=summary, text_input_ids=text_encoding['input_ids'].flatten(),
                    text_attention_mask=text_encoding['attention_mask'].flatten(), labels=labels.flatten(),
                    labels_attention_mask=summary_encoding['attention_mask'].flatten())

In [12]:
class SummarizerDataModule(pl.LightningDataModule):
    def __init__(self, train_df: pd.DataFrame, test_df: pd.DataFrame, tokenizer: T5Tokenizer, batch_size: int = 8, text_max_token_len: int = 512, summary_max_token_len: int = 128):
        super().__init__()

        self.train_df = train_df
        self.test_df = test_df
        self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.text_max_token_len = text_max_token_len
        self.summary_max_token_len = summary_max_token_len

    def setup(self, stage=None):
        self.train_dataset = SummarizerDataset(
            self.train_df,
            self.tokenizer,
            self.text_max_token_len,
            self.summary_max_token_len
        )
        self.test_dataset = SummarizerDataset(
            self.test_df,
            self.tokenizer,
            self.text_max_token_len,
            self.summary_max_token_len
        )
    
    # train_dataloader , test_dataloader, val_dataloader methods are overridden, cannot be removed
    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=2
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=2
        )

    def val_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=2
        )

In [13]:
class SummarizerModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = T5ForConditionalGeneration.from_pretrained('t5-base', return_dict=True)
    
    def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
        output = self.model(
            input_ids,
            attention_mask=attention_mask,
            labels=labels,
            decoder_attention_mask=decoder_attention_mask
        )

        return output.loss, output.logits

    def training_step(self, batch, batch_size):
        input_ids = batch['text_input_ids']
        attention_mask = batch['text_attention_mask']
        labels = batch['labels']
        labels_attention_mask = batch['labels_attention_mask']

        loss, outputs = self(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=labels_attention_mask,
            labels=labels
        )

        self.log("train_loss", loss, prog_bar=True, logger=True)
        return loss
    
    def validation_step(self, batch, batch_size):
        input_ids = batch['text_input_ids']
        attention_mask = batch['text_attention_mask']
        labels = batch['labels']
        labels_attention_mask = batch['labels_attention_mask']

        loss, outputs = self(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=labels_attention_mask,
            labels=labels
        )

        self.log("val_loss", loss, prog_bar=True, logger=True)
        return loss

    def test_step(self, batch, batch_size):
        input_ids = batch['text_input_ids']
        attention_mask = batch['text_attention_mask']
        labels = batch['labels']
        labels_attention_mask = batch['labels_attention_mask']

        loss, outputs = self(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=labels_attention_mask,
            labels=labels
        )

        self.log("test_loss", loss, prog_bar=True, logger=True)
        return loss

    def configure_optimizers(self):
        return AdamW(self.parameters(), lr=0.0001)

In [14]:
tokenizer = T5Tokenizer.from_pretrained('t5-base')

Downloading (…)ve/main/spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [16]:
df = pd.concat([df, more_df], axis=0).reset_index(drop=True)

df = df[["text", "ctext"]]
df.columns = ["summary", "text"]
df = df.dropna()
df.head()

Unnamed: 0,summary,text
0,The Administration of Union Territory Daman an...,The Daman and Diu administration on Wednesday ...
1,Malaika Arora slammed an Instagram user who tr...,"From her special numbers to TV?appearances, Bo..."
2,The Indira Gandhi Institute of Medical Science...,The Indira Gandhi Institute of Medical Science...
3,Lashkar-e-Taiba's Kashmir commander Abu Dujana...,Lashkar-e-Taiba's Kashmir commander Abu Dujana...
4,Hotels in Maharashtra will train their staff t...,Hotels in Mumbai and other Indian cities are t...


In [17]:
train_df, test_df = model_selection.train_test_split(df, test_size =0.1)
train_df.shape, test_df.shape

((3956, 2), (440, 2))

In [18]:
dataModule = SummarizerDataModule(train_df, test_df, tokenizer)

In [19]:
model = SummarizerModel()

Downloading pytorch_model.bin:   0%|          | 0.00/892M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

In [21]:
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger


checkpointCallBack = ModelCheckpoint(
    dirpath='checkpoints',          
    filename='best-checkpoint',  
    save_top_k=1,
    verbose=True,
    monitor='val_loss',    # monitor validation loss and find the 'min' validation loss criteria
    mode='min'           
)

In [26]:
logger = TensorBoardLogger("lightning_logs", name='news-summary')

trainer = pl.Trainer(
    logger=logger,
    max_epochs=EPOCHS
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [27]:
trainer.fit(model, dataModule)

Missing logger folder: lightning_logs\news-summary

  | Name  | Type                       | Params
-----------------------------------------------------
0 | model | T5ForConditionalGeneration | 222 M 
-----------------------------------------------------
222 M     Trainable params
0         Non-trainable params
222 M     Total params
891.614   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

In [None]:
trained_model = SummarizerModel.load_from_checkpoint(
    trainer.checkpoint_callback.best_model_path
)
trained_model.freeze()