### Resources

https://deeplearninganalytics.org/text-summarization/#:~:text=Text%20Summarization%20using%20BERT%201%20Introduction%20Single-document%20text,Pulling%20the%20code%20and%20testing%20this%20out%20

BERT sum github: https://github.com/nlpyang/BertSum

Following BERTSum paper - Yang Liu 2019


In [13]:
import torch
import os
import glob
import re
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords
import json
import re

from utils.vocab import Vocab
import torchtext
#run this in terminal
#NOTE: download spacy enc_core_web_sm for torch tokenizer to work
# python -m spacy download en_core_web_sm

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\nguye\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [2]:
root_path = './dataset/bertsum_data/'
contraction_path = './dataset/contractions.json'

In [3]:
#Read contraction json
with open(contraction_path, 'r') as file:
    content = file.read()
    contractions = json.loads(content)
print(contractions)

{"ain't": 'am not', "aren't": 'are not', "can't": 'cannot', "can't've": 'cannot have', "'cause": 'because', "could've": 'could have', "couldn't": 'could not', "couldn't've": 'could not have', "didn't": 'did not', "doesn't": 'does not', "don't": 'do not', "hadn't": 'had not', "hadn't've": 'had not have', "hasn't": 'has not', "haven't": 'have not', "he'd": 'he would', "he'd've": 'he would have', "he'll": 'he will', "he's": 'he is', "how'd": 'how did', "how'll": 'how will', "how's": 'how is', "i'd": 'i would', "i'll": 'i will', "i'm": 'i am', "i've": 'i have', "isn't": 'is not', "it'd": 'it would', "it'll": 'it will', "it's": 'it is', "let's": 'let us', "ma'am": 'madam', "mayn't": 'may not', "might've": 'might have', "mightn't": 'might not', "must've": 'must have', "mustn't": 'must not', "needn't": 'need not', "oughtn't": 'ought not', "shan't": 'shall not', "sha'n't": 'shall not', "she'd": 'she would', "she'll": 'she will', "she's": 'she is', "should've": 'should have', "shouldn't": 'shou

In [12]:
#Loadding the processed data from BertSum github
# We're going to process our own data format for better control
# and simplify the process using Torch dataset and dataloader
sample_path = './dataset/bertsum_data/cnndm.train.0.bert.pt'

assert os.path.exists(sample_path), "Path does not exist"

dataset = torch.load(sample_path)
print(type(dataset))

for data in dataset[:2]:
    for key, value in data.items():
        print(f'{key}, len={len(value)},{value}')
    print('\n\n')


<class 'list'>
src, len=512,[101, 2002, 1005, 1055, 2006, 1996, 2648, 2559, 1999, 1010, 1998, 1996, 3193, 1997, 13489, 13619, 6968, 4492, 17994, 3849, 2000, 2562, 2893, 11737, 5017, 1012, 102, 101, 1037, 2739, 3034, 5958, 2013, 2010, 2047, 7728, 1999, 8252, 3607, 2104, 9363, 5596, 2074, 2129, 11737, 2010, 16746, 3711, 2000, 2022, 1010, 2004, 1996, 28368, 5969, 2343, 10865, 2008, 2010, 3677, 1011, 1011, 1998, 4022, 27398, 1011, 1011, 2001, 2025, 2105, 1998, 2018, 2589, 2210, 2000, 2191, 2010, 2994, 2062, 6625, 1012, 102, 101, 1000, 1045, 5136, 2008, 3607, 2442, 1998, 2038, 2000, 2552, 1010, 1000, 13619, 6968, 4492, 17994, 2409, 1037, 6887, 7911, 26807, 1997, 12060, 2040, 2018, 9240, 1999, 1996, 2103, 1997, 20996, 29473, 1011, 2006, 1011, 2123, 1010, 2379, 1996, 8772, 3675, 2007, 5924, 1998, 2055, 6352, 2661, 2148, 1997, 4924, 1012, 102, 101, 2002, 2106, 2025, 20648, 2054, 4506, 2002, 2001, 5327, 2005, 2021, 2081, 3154, 3183, 2002, 2001, 2559, 2000, 1012, 102, 101, 1000, 4209, 1996, 2839

In [5]:
#Read and process cnn/dm data
root_path = './dataset/cnn_dailymail/train.csv'

#only read first 100 rows for quick development
train_data = pd.read_csv(root_path, nrows=100)

train_data = train_data.drop(['id'], axis=1)
train_data = train_data.reset_index(drop=True)

train_data.head()

Unnamed: 0,article,highlights
0,By . Associated Press . PUBLISHED: . 14:11 EST...,"Bishop John Folda, of North Dakota, is taking ..."
1,(CNN) -- Ralph Mata was an internal affairs li...,Criminal complaint: Cop used his role to help ...
2,A drunk driver who killed a young woman in a h...,"Craig Eccleston-Todd, 27, had drunk at least t..."
3,(CNN) -- With a breezy sweep of his pen Presid...,Nina dos Santos says Europe must be ready to a...
4,Fleetwood are the only team still to have a 10...,Fleetwood top of League One after 2-0 win at S...


### Tokenize each article and highlights and build vocab

1. Clean
2. Expand contractions
3. Remove stopwords
 

In [21]:
#Trying to understand greedy selection using ngrams
# https://github.com/nlpyang/BertSum/blob/05f8c634197d0ed1be8157d71f29aa7765abdd2a/src/prepro/data_builder.py#L43
def cal_rouge(evaluated_ngrams, reference_ngrams):
    reference_count = len(reference_ngrams)
    evaluated_count = len(evaluated_ngrams)

    overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams)
    overlapping_count = len(overlapping_ngrams)

    if evaluated_count == 0:
        precision = 0.0
    else:
        precision = overlapping_count / evaluated_count

    if reference_count == 0:
        recall = 0.0
    else:
        recall = overlapping_count / reference_count

    f1_score = 2.0 * ((precision * recall) / (precision + recall + 1e-8))
    return {"f": f1_score, "p": precision, "r": recall}

def _get_ngrams(n, text):
    """Calcualtes n-grams.
    Args:
      n: which n-grams to calculate
      text: An array of tokens
    Returns:
      A set of n-grams
    """
    ngram_set = set()
    text_length = len(text)
    max_index_ngram_start = text_length - n
    for i in range(max_index_ngram_start + 1):
        ngram_set.add(tuple(text[i:i + n]))
    return ngram_set

def _get_word_ngrams(n, sentences):
    """Calculates word n-grams for multiple sentences.
    """
    assert len(sentences) > 0
    assert n > 0

    # words = _split_into_words(sentences)

    words = sum(['something here', 'new sent', 'setn2'], [])
    # words = [w for w in words if w not in stopwords]
    return _get_ngrams(n, words)

def greedy_selection(doc_sent_list, abstract_sent_list, summary_size):
    def _rouge_clean(s):
        return re.sub(r'[^a-zA-Z0-9 ]', '', s)

    max_rouge = 0.0
    abstract = sum(abstract_sent_list, [])
    print(abstract)
    abstract = _rouge_clean(' '.join(abstract)).split()
    sents = [_rouge_clean(' '.join(s)).split() for s in doc_sent_list]

    print('sents = ', sents)
    print('abstract = ', sents)

    return

    evaluated_1grams = [_get_word_ngrams(1, [sent]) for sent in sents]
    reference_1grams = _get_word_ngrams(1, [abstract])
    evaluated_2grams = [_get_word_ngrams(2, [sent]) for sent in sents]
    reference_2grams = _get_word_ngrams(2, [abstract])

    selected = []
    for s in range(summary_size):
        cur_max_rouge = max_rouge
        cur_id = -1
        for i in range(len(sents)):
            if (i in selected):
                continue
            c = selected + [i]
            candidates_1 = [evaluated_1grams[idx] for idx in c]
            candidates_1 = set.union(*map(set, candidates_1))
            candidates_2 = [evaluated_2grams[idx] for idx in c]
            candidates_2 = set.union(*map(set, candidates_2))
            rouge_1 = cal_rouge(candidates_1, reference_1grams)['f']
            rouge_2 = cal_rouge(candidates_2, reference_2grams)['f']
            rouge_score = rouge_1 + rouge_2
            if rouge_score > cur_max_rouge:
                cur_max_rouge = rouge_score
                cur_id = i
        if (cur_id == -1):
            return selected
        selected.append(cur_id)
        max_rouge = cur_max_rouge

    return sorted(selected)

sent = [' bishop fargo catholic diocese north dakota exposed potentially hundreds church members fargo grand forks jamestown hepatitis virus late september early october', ' state health department issued advisory exposure anyone attended five churches took communion', ' bishop john folda pictured fargo catholic diocese north dakota exposed potentially hundreds church members fargo grand forks jamestown hepatitis ', ' state immunization program manager molly howell says risk low officials feel important alert people possible exposure', ' diocese announced monday bishop john folda taking time diagnosed hepatitis a', ' diocese says contracted infection contaminated food attending conference newly ordained bishops italy last month', ' symptoms hepatitis include fever tiredness loss appetite nausea abdominal discomfort', ' fargo catholic diocese north dakota pictured bishop located ']
labels = ['bishop john folda of north dakota is taking time off after being diagnosed ', ' he contracted the infection through contaminated food in italy ', ' church members in fargo grand forks and jamestown could have been exposed ']

for i in range(len(sent)):
    sent[i] = sent[i].split(' ')

for i in range(len(labels)):
    labels[i] = labels[i].split(' ')
# 
# print(sent)
# print(labels)

print(greedy_selection(sent, labels, 3))

['bishop', 'john', 'folda', 'of', 'north', 'dakota', 'is', 'taking', 'time', 'off', 'after', 'being', 'diagnosed', '', '', 'he', 'contracted', 'the', 'infection', 'through', 'contaminated', 'food', 'in', 'italy', '', '', 'church', 'members', 'in', 'fargo', 'grand', 'forks', 'and', 'jamestown', 'could', 'have', 'been', 'exposed', '']
sents =  [['bishop', 'fargo', 'catholic', 'diocese', 'north', 'dakota', 'exposed', 'potentially', 'hundreds', 'church', 'members', 'fargo', 'grand', 'forks', 'jamestown', 'hepatitis', 'virus', 'late', 'september', 'early', 'october'], ['state', 'health', 'department', 'issued', 'advisory', 'exposure', 'anyone', 'attended', 'five', 'churches', 'took', 'communion'], ['bishop', 'john', 'folda', 'pictured', 'fargo', 'catholic', 'diocese', 'north', 'dakota', 'exposed', 'potentially', 'hundreds', 'church', 'members', 'fargo', 'grand', 'forks', 'jamestown', 'hepatitis'], ['state', 'immunization', 'program', 'manager', 'molly', 'howell', 'says', 'risk', 'low', 'off

TypeError: can only concatenate list (not "str") to list

In [42]:
#Source: https://www.kaggle.com/code/mohamedaref000/seq2seq-enc-dec
def clean_text(text, remove_stopwords=True):
    '''
    @args:
        text: string, raw text
        remove_stopwords: boolean, default = False
    @return
        []: list of sentences

    '''
    text = text.lower()
    text = text.split()
    tmp = []
    for word in text:
        if word in contractions:
            tmp.append(contractions[word])
        else:
            tmp.append(word)
    text = ' '.join(tmp)
    
    text = re.sub(r'https?:\/\/.*[\r\n]*', '', text, flags=re.MULTILINE)
    text = re.sub(r'\<a href', ' ', text)
    text = re.sub(r'&amp;', '', text) 
    #“’‘–»
    #NOTE that we're keep dot to seperate sentences for BERTSUM
    text = re.sub(r'[_"\-;%()|+&=*%,!?:#$@\[\]/”]', '', text)
    text = re.sub(r'<br />', ' ', text)
    text = re.sub(r'\'', ' ', text)
    #NOTE: remove numbers
    text = re.sub(r'\d+', '', text)


    if remove_stopwords:
        text = text.split()
        stops = set(stopwords.words('english'))
        text = [w for w in text if w not in stops]
        text = ' '.join(text)
    return text

#Process text into the format BERTSUM requires, read the paper
# NOTE that min_len filter sentences with fewer words than it, set to 10 for now but maybe 
# there's a better number
def process_text(text, min_len = 5, remove_stopwords = True):
    '''
    @args
        text: string, raw text
        min_len: int, default = 5, if sentence has fewer words than min_len, ignore it
    @returns
        sentences: [], list of sentences 
    '''
    cleaned = clean_text(text, remove_stopwords)
    sentences = cleaned.split('.')
    sentences = list(filter(lambda x: len(x.split()) > min_len, sentences))
    #strip the spaces
    for i in sentences:
        i = i.strip()

    return sentences


In [43]:
cleaned_highlights = []
cleaned_articles = []

for index, row in train_data.iterrows():
    cleaned_articles.append(process_text(row['article']))
    cleaned_highlights.append(process_text(row['highlights'], remove_stopwords=False))

print('sample data after processed')
print(cleaned_articles[0])
print('original article: ', train_data.article[0])
print('highlights')
print(cleaned_highlights[0])
print('original highlights: ', train_data.highlights[0])

sample data after processed
[' bishop fargo catholic diocese north dakota exposed potentially hundreds church members fargo grand forks jamestown hepatitis virus late september early october', ' state health department issued advisory exposure anyone attended five churches took communion', ' bishop john folda pictured fargo catholic diocese north dakota exposed potentially hundreds church members fargo grand forks jamestown hepatitis ', ' state immunization program manager molly howell says risk low officials feel important alert people possible exposure', ' diocese announced monday bishop john folda taking time diagnosed hepatitis a', ' diocese says contracted infection contaminated food attending conference newly ordained bishops italy last month', ' symptoms hepatitis include fever tiredness loss appetite nausea abdominal discomfort', ' fargo catholic diocese north dakota pictured bishop located ']
original article:  By . Associated Press . PUBLISHED: . 14:11 EST, 25 October 2013 . 

### Build vocab and write it to vocab file

In [44]:
def _load_data(articles):
    tokenizer = torchtext.data.get_tokenizer('spacy')
    tokens = []

    for article in articles:
        tokens += tokenizer(' '.join(article))
    
    #NOTE word freq can affect performance down stream. Put attention to this param
    vocab = Vocab(tokens, 3, reserved_tokens=['<cls>', '<sep>'])
    return tokenizer, vocab

def batchify(articles, highlights, max_len = 512):

    tokenizer, vocab = _load_data(articles)

    def _save_vocab(vocab):
        output_root = './output'
        processed_root = os.path.join(output_root, 'processed')
        if os.path.exists(processed_root) == False: 
            os.makedirs(processed_root)
        vocab_path = os.path.join(processed_root, 'vocab.txt')
        vocab.write_to(vocab_path)

    def _get_src_segments_labels(vocab, article: list):
        assert type(article) == list, 'article must be list of string'
        src = []
        segs = []
        labels =[]
        clss = []

        for sentence in article:
            sent_tokens = vocab[sentence]
            src.extend(sent_tokens)

In [54]:
def test(a: list):
    assert type(a) == list, 'not correct type'
    print('ok')


<class 'list'>
ok


In [45]:
vocab = _load_data(cleaned_articles)

print('vocab len = ', len(vocab))
sample_sentence = 'bishop fargo catholic'
print(vocab[sample_sentence.split(' ')])

output_root = './output'
processed_root = os.path.join(output_root, 'processed')

if os.path.exists(processed_root) == False: 
    os.makedirs(processed_root)

vocab_path = os.path.join(processed_root, 'vocab.txt')

vocab.write_to(vocab_path)

vocab len =  3170
[310, 1039, 453]


In [11]:
class CNNDailyMailDataset(Dataset):
    def __init__(self, path, type = 'train'):
        self.path = path
        self.type = type

        assert type in ['train', 'valid', 'test'], "dataset type must be train, valid, or test"

        self.src = []
        self.labels = []
        self.segments = []
        self.cls = []
        self.src_txt = []
        self.tgt_txt = []

        self._load_cnndm()
        #Load all data into list
    
    def _truncate_or_pad(self, data, max_len):
        '''
        @args
            data: list of all data
            max_len: max_length for each data
        @return
        '''

        src_list, segments_list, labels_list, cls_list, src_txt_list, tgt_txt_list = [], [], [], [], [], []
        for src, labels, segs, clsss, src_txt, tgt_txt in data:
            src_list.append(torch.tensor(src ))



    def clear_data(self):
        self.src.clear()
        self.labels.clear()
        self.segments.clear()
        self.cls.clear()
        self.src_txt.clear()
        self.tgt_txt.clear()

    def _load_cnndm(self):
        assert os.path.exists(self.path), "Path to CNN/DailyMail dataset can't be found"

        self.clear_data()

        self.src = []
        self.labels = []
        self.segments = []
        self.cls = []
        self.src_txt = []
        self.tgt_txt = []

        if self.type == 'train':
            chunks = glob.glob(f'{self.path}/cnndm.train.*.bert.pt')
        elif self.type == 'test':
            chunks = glob.glob(f'{self.path}/cnndm.test.*.bert.pt')
        else:
            chunks = glob.glob(f'{self.path}/cnndm.valid.*.bert.pt')

        tmp = []
        for chunk in chunks:
            d = torch.load(chunk)
            tmp.append(d)
            #Debug: work on 1 chunk first, then load all of them
            break
    
        #truncate and pad if needed
        (self.src, 
         self.labels,
         self.segments,
         self.cls, 
         self.src_txt, 
         self.tgt_txt) = self._truncate_or_pad(tmp, 512)

        print(f'loaded {len(self.data)}')
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]


In [12]:
train_set = CNNDailyMailDataset(root_path, 'train')


loaded 2001


In [15]:
for key, value in train_set[0].items():
    print(f'key={key}, len={len(value)}, value={value}')

key=src, len=512, value=[101, 2002, 1005, 1055, 2006, 1996, 2648, 2559, 1999, 1010, 1998, 1996, 3193, 1997, 13489, 13619, 6968, 4492, 17994, 3849, 2000, 2562, 2893, 11737, 5017, 1012, 102, 101, 1037, 2739, 3034, 5958, 2013, 2010, 2047, 7728, 1999, 8252, 3607, 2104, 9363, 5596, 2074, 2129, 11737, 2010, 16746, 3711, 2000, 2022, 1010, 2004, 1996, 28368, 5969, 2343, 10865, 2008, 2010, 3677, 1011, 1011, 1998, 4022, 27398, 1011, 1011, 2001, 2025, 2105, 1998, 2018, 2589, 2210, 2000, 2191, 2010, 2994, 2062, 6625, 1012, 102, 101, 1000, 1045, 5136, 2008, 3607, 2442, 1998, 2038, 2000, 2552, 1010, 1000, 13619, 6968, 4492, 17994, 2409, 1037, 6887, 7911, 26807, 1997, 12060, 2040, 2018, 9240, 1999, 1996, 2103, 1997, 20996, 29473, 1011, 2006, 1011, 2123, 1010, 2379, 1996, 8772, 3675, 2007, 5924, 1998, 2055, 6352, 2661, 2148, 1997, 4924, 1012, 102, 101, 2002, 2106, 2025, 20648, 2054, 4506, 2002, 2001, 5327, 2005, 2021, 2081, 3154, 3183, 2002, 2001, 2559, 2000, 1012, 102, 101, 1000, 4209, 1996, 2839, 19

In [9]:
batch_size = 64
train_iter = DataLoader(train_set, batch_size,shuffle=True)

In [10]:
train_feature = next(iter(train_iter))
print(train_feature)

RuntimeError: each element in list of batch should be of equal size