In [9]:
# !pip3 install malaya

import malaya
import re
from malaya.texts._text_functions import split_into_sentences
from malaya.texts import _regex

splitter = split_into_sentences

In [3]:
# !wget https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip
# !unzip uncased_L-12_H-768_A-12.zip

In [4]:
from bert import tokenization

BERT_VOCAB = 'uncased_L-12_H-768_A-12/vocab.txt'
tokenizer = tokenization.FullTokenizer(vocab_file=BERT_VOCAB, do_lower_case=True)




In [5]:
import glob

stories = glob.glob('cnn/stories/*.story')
len(stories)

92579

In [19]:
def split_story(doc):
    index = doc.find('@highlight')
    story, highlights = doc[:index], doc[index:].split('@highlight')
    highlights = [h.strip() for h in highlights if len(h) > 0]
    stories = []
    for s in splitter(story):
        stories.append(s.split())
    summaries = []
    for s in highlights:
        summaries.append(s.split())
    return stories, summaries

In [39]:
min_src_nsents = 3
max_src_nsents = 20
min_src_ntokens_per_sent = 5
max_src_ntokens_per_sent = 30
min_tgt_ntokens = 5
max_tgt_ntokens = 500
sep_token = '[SEP]'
cls_token = '[CLS]'
pad_token = '[PAD]'
tgt_bos = '[unused0]'
tgt_eos = '[unused1]'
tgt_sent_split = '[unused2]'
sep_vid = tokenizer.vocab[sep_token]
cls_vid = tokenizer.vocab[cls_token]
pad_vid = tokenizer.vocab[pad_token]

In [11]:
with open(stories[0]) as fopen:
    story = fopen.read()
story, highlights = split_story(story)

In [60]:
def _get_ngrams(n, text):
    ngram_set = set()
    text_length = len(text)
    max_index_ngram_start = text_length - n
    for i in range(max_index_ngram_start + 1):
        ngram_set.add(tuple(text[i:i + n]))
    return ngram_set


def _get_word_ngrams(n, sentences):
    assert len(sentences) > 0
    assert n > 0

    words = sum(sentences, [])
    return _get_ngrams(n, words)

def cal_rouge(evaluated_ngrams, reference_ngrams):
    reference_count = len(reference_ngrams)
    evaluated_count = len(evaluated_ngrams)

    overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams)
    overlapping_count = len(overlapping_ngrams)

    if evaluated_count == 0:
        precision = 0.0
    else:
        precision = overlapping_count / evaluated_count

    if reference_count == 0:
        recall = 0.0
    else:
        recall = overlapping_count / reference_count

    f1_score = 2.0 * ((precision * recall) / (precision + recall + 1e-8))
    return {"f": f1_score, "p": precision, "r": recall}


def greedy_selection(doc_sent_list, abstract_sent_list, summary_size):
    def _rouge_clean(s):
        return re.sub(r'[^a-zA-Z0-9 ]', '', s)

    max_rouge = 0.0
    abstract = sum(abstract_sent_list, [])
    abstract = _rouge_clean(' '.join(abstract)).split()
    sents = [_rouge_clean(' '.join(s)).split() for s in doc_sent_list]
    evaluated_1grams = [_get_word_ngrams(1, [sent]) for sent in sents]
    reference_1grams = _get_word_ngrams(1, [abstract])
    evaluated_2grams = [_get_word_ngrams(2, [sent]) for sent in sents]
    reference_2grams = _get_word_ngrams(2, [abstract])

    selected = []
    for s in range(summary_size):
        cur_max_rouge = max_rouge
        cur_id = -1
        for i in range(len(sents)):
            if (i in selected):
                continue
            c = selected + [i]
            candidates_1 = [evaluated_1grams[idx] for idx in c]
            candidates_1 = set.union(*map(set, candidates_1))
            candidates_2 = [evaluated_2grams[idx] for idx in c]
            candidates_2 = set.union(*map(set, candidates_2))
            rouge_1 = cal_rouge(candidates_1, reference_1grams)['f']
            rouge_2 = cal_rouge(candidates_2, reference_2grams)['f']
            rouge_score = rouge_1 + rouge_2
            if rouge_score > cur_max_rouge:
                cur_max_rouge = rouge_score
                cur_id = i
        if (cur_id == -1):
            return selected
        selected.append(cur_id)
        max_rouge = cur_max_rouge

    return sorted(selected)

def get_xy(story, highlights):
    idxs = [i for i, s in enumerate(story) if (len(s) > min_src_ntokens_per_sent)]
    
    idxs = [i for i, s in enumerate(story) if (len(s) > min_src_ntokens_per_sent)]

    src = [story[i][:max_src_ntokens_per_sent] for i in idxs]
    src = src[:max_src_nsents]

    sent_labels = greedy_selection(src, highlights, 3)

    _sent_labels = [0] * len(src)
    for l in sent_labels:
        _sent_labels[l] = 1
    _sent_labels
    
    src_txt = [' '.join(sent) for sent in src]
    src_subtokens = []
    for i, text in enumerate(src_txt):
        text = tokenizer.tokenize(text)
        if i > 0:
            text = ['[SEP]','[CLS]'] + text
        src_subtokens.extend(text)
    
    src_subtokens = [cls_token] + src_subtokens + [sep_token]
    src_subtoken_idxs = tokenizer.convert_tokens_to_ids(src_subtokens)
    
    _segs = [-1] + [i for i, t in enumerate(src_subtoken_idxs) if t == sep_vid]
    segs = [_segs[i] - _segs[i - 1] for i in range(1, len(_segs))]
    segments_ids = []
    for i, s in enumerate(segs):
        if (i % 2 == 0):
            segments_ids += s * [0]
        else:
            segments_ids += s * [1]
    cls_ids = [i for i, t in enumerate(src_subtoken_idxs) if t == cls_vid]
    
    return src_subtoken_idxs, cls_ids, _sent_labels, segments_ids

In [61]:
with open(stories[1]) as fopen:
    story = fopen.read()
story, highlights = split_story(story)
text, cls_ids, sent_labels, segments_ids = get_xy(story, highlights)

In [62]:
len(sent_labels), len(cls_ids), len(text), len(segments_ids)

(20, 20, 661, 661)

In [63]:
from tqdm import tqdm

texts, clss, labels, segments = [], [], [], []

for i in tqdm(range(len(stories))):
    with open(stories[i]) as fopen:
        story = fopen.read()
    story, highlights = split_story(story)
    text, cls_ids, sent_labels, segments_ids = get_xy(story, highlights)
    if len(cls_ids) != len(sent_labels):
        continue
    texts.append(text)
    clss.append(cls_ids)
    labels.append(sent_labels)
    segments.append(segments_ids)

100%|██████████| 92579/92579 [13:52<00:00, 111.15it/s]


In [64]:
from sklearn.model_selection import train_test_split

train_texts, test_texts, train_clss, test_clss, train_labels, test_labels, train_segments, test_segments = \
train_test_split(texts, clss, labels, segments, test_size = 0.2)

In [65]:
import pickle

with open('dataset-bert.pkl', 'wb') as fopen:
    pickle.dump({'train_texts': train_texts,
                'test_texts': test_texts,
                'train_clss': train_clss,
                'test_clss': test_clss,
                'train_labels': train_labels,
                'test_labels': test_labels,
                'train_segments': train_segments,
                'test_segments': test_segments}, fopen)