In [1]:
# !pip3 install malaya

import malaya
import re
from malaya.texts._text_functions import split_into_sentences
from malaya.texts import _regex

tokenizer = malaya.preprocessing._tokenizer
splitter = split_into_sentences

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
import glob

stories = glob.glob('cnn/stories/*.story')
len(stories)

92579

In [3]:
def is_number_regex(s):
    if re.match("^\d+?\.\d+?$", s) is None:
        return s.isdigit()
    return True

def preprocessing(string):
    string = re.sub('[^\'"A-Za-z\-(),.$0-9 ]+', ' ', string.lower())
    tokenized = tokenizer(string)
    tokens = []
    for w in tokenized:
        if is_number_regex(w):
            tokens.append('<NUM>')
        elif re.match(_regex._money, w):
            tokens.append('<MONEY>')
        elif re.match(_regex._date, w):
            tokens.append('<DATE>')
        else:
            tokens.append(w)
    return tokens

def split_story(doc):
    index = doc.find('@highlight')
    story, highlights = doc[:index], doc[index:].split('@highlight')
    highlights = [h.strip() for h in highlights if len(h) > 0]
    stories = []
    for s in splitter(story):
        stories.append(preprocessing(s))
    summaries = []
    for s in highlights:
        summaries.append(preprocessing(s))
    return stories, summaries

In [32]:
min_src_nsents = 3
max_src_nsents = 20
min_src_ntokens_per_sent = 5
max_src_ntokens_per_sent = 30
min_tgt_ntokens = 5
max_tgt_ntokens = 500
sep_token = '[SEP]'
cls_token = '[CLS]'

In [5]:
with open(stories[0]) as fopen:
    story = fopen.read()
story, highlights = split_story(story)

In [15]:
def _get_ngrams(n, text):
    ngram_set = set()
    text_length = len(text)
    max_index_ngram_start = text_length - n
    for i in range(max_index_ngram_start + 1):
        ngram_set.add(tuple(text[i:i + n]))
    return ngram_set


def _get_word_ngrams(n, sentences):
    assert len(sentences) > 0
    assert n > 0

    words = sum(sentences, [])
    return _get_ngrams(n, words)

def cal_rouge(evaluated_ngrams, reference_ngrams):
    reference_count = len(reference_ngrams)
    evaluated_count = len(evaluated_ngrams)

    overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams)
    overlapping_count = len(overlapping_ngrams)

    if evaluated_count == 0:
        precision = 0.0
    else:
        precision = overlapping_count / evaluated_count

    if reference_count == 0:
        recall = 0.0
    else:
        recall = overlapping_count / reference_count

    f1_score = 2.0 * ((precision * recall) / (precision + recall + 1e-8))
    return {"f": f1_score, "p": precision, "r": recall}


def greedy_selection(doc_sent_list, abstract_sent_list, summary_size):
    def _rouge_clean(s):
        return re.sub(r'[^a-zA-Z0-9 ]', '', s)

    max_rouge = 0.0
    abstract = sum(abstract_sent_list, [])
    abstract = _rouge_clean(' '.join(abstract)).split()
    sents = [_rouge_clean(' '.join(s)).split() for s in doc_sent_list]
    evaluated_1grams = [_get_word_ngrams(1, [sent]) for sent in sents]
    reference_1grams = _get_word_ngrams(1, [abstract])
    evaluated_2grams = [_get_word_ngrams(2, [sent]) for sent in sents]
    reference_2grams = _get_word_ngrams(2, [abstract])

    selected = []
    for s in range(summary_size):
        cur_max_rouge = max_rouge
        cur_id = -1
        for i in range(len(sents)):
            if (i in selected):
                continue
            c = selected + [i]
            candidates_1 = [evaluated_1grams[idx] for idx in c]
            candidates_1 = set.union(*map(set, candidates_1))
            candidates_2 = [evaluated_2grams[idx] for idx in c]
            candidates_2 = set.union(*map(set, candidates_2))
            rouge_1 = cal_rouge(candidates_1, reference_1grams)['f']
            rouge_2 = cal_rouge(candidates_2, reference_2grams)['f']
            rouge_score = rouge_1 + rouge_2
            if rouge_score > cur_max_rouge:
                cur_max_rouge = rouge_score
                cur_id = i
        if (cur_id == -1):
            return selected
        selected.append(cur_id)
        max_rouge = cur_max_rouge

    return sorted(selected)

def get_xy(story, highlights):
    idxs = [i for i, s in enumerate(story) if (len(s) > min_src_ntokens_per_sent)]
    
    idxs = [i for i, s in enumerate(story) if (len(s) > min_src_ntokens_per_sent)]

    src = [story[i][:max_src_ntokens_per_sent] for i in idxs]
    src = src[:max_src_nsents]

    sent_labels = greedy_selection(src, highlights, 3)

    _sent_labels = [0] * len(src)
    for l in sent_labels:
        _sent_labels[l] = 1
    _sent_labels
    
    src_txt = [' '.join(sent) for sent in src]
    text = ' {} {} '.format(sep_token, cls_token).join(src_txt)
    text = '[CLS] %s [SEP]'%(text)
    cls_ids = [i for i, t in enumerate(text.split()) if t == cls_token]
    
    return text, cls_ids, _sent_labels

In [7]:
import collections
import json

def build_dataset(words, n_words, atleast=1):
    count = [['PAD', 0], ['GO', 1], ['EOS', 2], ['UNK', 3]]
    counter = collections.Counter(words).most_common(n_words)
    counter = [i for i in counter if i[1] >= atleast]
    count.extend(counter)
    dictionary = dict()
    for word, _ in count:
        dictionary[word] = len(dictionary)
    data = list()
    unk_count = 0
    for word in words:
        index = dictionary.get(word, 0)
        if index == 0:
            unk_count += 1
        data.append(index)
    count[0][1] = unk_count
    reversed_dictionary = dict(zip(dictionary.values(), dictionary.keys()))
    return data, count, dictionary, reversed_dictionary

In [8]:
# from dask import delayed
# import dask

# def process(i):
#     with open(stories[i]) as fopen:
#         story = fopen.read()
#     story, highlights = split_story(story)
#     return get_xy(story, highlights)

# train = []
# for i in range(len(stories)):
#     im = delayed(process)(i)
#     train.append(im)
    
# train = dask.compute(*train)

In [33]:
with open(stories[1]) as fopen:
    story = fopen.read()
story, highlights = split_story(story)
text, cls_ids, sent_labels = get_xy(story, highlights)

In [36]:
len(sent_labels), len(cls_ids), len(text.split())

(20, 20, 560)

In [37]:
from tqdm import tqdm

texts, clss, labels = [], [], []

for i in tqdm(range(len(stories))):
    with open(stories[i]) as fopen:
        story = fopen.read()
    story, highlights = split_story(story)
    text, cls_ids, sent_labels = get_xy(story, highlights)
    if len(cls_ids) != len(sent_labels):
        continue
    texts.append(text)
    clss.append(cls_ids)
    labels.append(sent_labels)

100%|██████████| 92579/92579 [16:41<00:00, 92.41it/s] 


In [38]:
concat = ' '.join(texts).split()
vocabulary_size = len(list(set(concat)))
_, count, dictionary, rev_dictionary = build_dataset(concat, vocabulary_size, atleast = 2)
print('vocab from size: %d'%(len(dictionary)))
print('Most common words', count[4:10])

vocab from size: 118356
Most common words [('the', 1974502), (',', 1740960), ('[CLS]', 1668596), ('[SEP]', 1668596), ('.', 1284463), ('to', 844716)]


In [39]:
from sklearn.model_selection import train_test_split

train_texts, test_texts, train_clss, test_clss, train_labels, test_labels = \
train_test_split(texts, clss, labels, test_size = 0.2)

In [40]:
import pickle

with open('dataset.pkl', 'wb') as fopen:
    pickle.dump({'train_texts': train_texts,
                'test_texts': test_texts,
                'train_clss': train_clss,
                'test_clss': test_clss,
                'train_labels': train_labels,
                'test_labels': test_labels}, fopen)

In [41]:
with open('dictionary.pkl', 'wb') as fopen:
    pickle.dump({'dictionary': dictionary, 'rev_dictionary': rev_dictionary}, fopen)