Copyright (c) Microsoft Corporation. 

Licensed under the MIT License.


## Abstractive Summarization Generation Using Pretrained BertAbs Model 


### Summary

This notebook demonstrates how to generate abstractive summarization using HuggingFace's pretrained BertAbs model. The BertAbs algorithm is original published in [Text Summarization with Pretrained Encoders](https://arxiv.org/abs/1908.08345) and the code base was release at https://github.com/nlpyang/PreSumm. HuggingFace's transformer library has included a wrapper for the code base and most of the model is contained in https://github.com/huggingface/transformers/blob/master/examples/summarization/modeling_bertabs.py

Utility functions and classes in the NLP Best Practices repo are used to facilitate data preprocessing, model training, model scoring, result postprocessing, and model evaluation.


In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

### Configuration

In [3]:
#! /usr/bin/python3
from collections import namedtuple
import logging
import os
import shutil
import sys
from tempfile import TemporaryDirectory

import numpy as np
import pandas as pd
import scrapbook as sb

import torch
from torch.utils.data import DataLoader, SequentialSampler
from tqdm import tqdm
nlp_path = os.path.abspath("../../")
if nlp_path not in sys.path:
    sys.path.insert(0, nlp_path)

from utils_nlp.common.pytorch_utils import get_device
from utils_nlp.eval.evaluate_summarization import get_rouge
from utils_nlp.dataset.cnndm import CNNDMAbsSumDataset

from transformers import BertTokenizer
## BertAbs import
sys.path.insert(0, "./transformers/examples/summarization")
from modeling_bertabs import BertAbs, build_predictor

from utils_summarization import (
    build_mask,
    compute_token_type_ids,
    encode_for_summarization,
    fit_to_block_size,
)
from run_summarization import format_rouge_scores, format_summary

[nltk_data] Downloading package punkt to /home/daden/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
I0116 04:44:45.409288 140119874643776 file_utils.py:35] PyTorch version 1.3.0 available.


### Data Preprocessing

In [4]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)


I0116 04:44:46.520054 140119874643776 tokenization_utils.py:398] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/daden/.cache/torch/transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084


In [5]:
DATA_PATH = '/tmp/tmpvifv52a8'

In [6]:
QUICK_RUN = False
# the data path used to save the downloaded data file
# DATA_PATH = TemporaryDirectory().name
# The number of lines at the head of data file used for preprocessing. -1 means all the lines.
TOP_N = 100
CHUNK_SIZE=200
if not QUICK_RUN:
    TOP_N = -1
    CHUNK_SIZE = 2000

In [7]:
from torch.utils.data import Dataset
class AbsSumDataset(Dataset):
    def __init__(self, source, target=None):
        self.source = source
        self.target = target
    def __len__(self):
        return len(self.source)
    def __getitem__(self, idx):
        return self.source[idx], self.target[idx]

 Download CNN/DM dataset

In [8]:
train_dataset, test_dataset = CNNDMAbsSumDataset(top_n=TOP_N, local_cache_path=DATA_PATH)
data = list(test_dataset.get_source()), list(test_dataset.get_target())
test_sum_dataset = AbsSumDataset(data[0], data[1])
len(test_sum_dataset)

I0116 04:44:47.887421 140119874643776 utils.py:173] Opening tar file /tmp/tmpvifv52a8/cnndm.tar.gz.
I0116 04:44:47.888628 140119874643776 utils.py:181] /tmp/tmpvifv52a8/test.txt.src already extracted.
I0116 04:44:48.178355 140119874643776 utils.py:181] /tmp/tmpvifv52a8/test.txt.tgt.tagged already extracted.
I0116 04:44:48.204983 140119874643776 utils.py:181] /tmp/tmpvifv52a8/train.txt.src already extracted.
I0116 04:44:55.672668 140119874643776 utils.py:181] /tmp/tmpvifv52a8/train.txt.tgt.tagged already extracted.
I0116 04:44:56.290420 140119874643776 utils.py:181] /tmp/tmpvifv52a8/val.txt.src already extracted.
I0116 04:44:56.624392 140119874643776 utils.py:181] /tmp/tmpvifv52a8/val.txt.tgt.tagged already extracted.


11490

In [9]:
#TrainBatch = namedtuple("Batch", [ "batch_size", "src", "segs", "mask_src", "tgt", "tgt_segs", "mask_tgt", "tgt_str"])
TestBatch = namedtuple("Batch", [ "batch_size", "src", "segs", "mask_src", "tgt_str"])

In [10]:
### added max_len argument
def encode_for_summarization(story_lines, summary_lines, tokenizer, max_len=512):
    """ Encode the story and summary lines, and join them
    as specified in [1] by using `[SEP] [CLS]` tokens to separate
    sentences.
    """
    story_lines_token_ids = [tokenizer.encode(line, max_length=max_len) for line in story_lines]
    story_token_ids = [token for sentence in story_lines_token_ids for token in sentence]
    summary_lines_token_ids = [tokenizer.encode(line, max_length=max_len) for line in summary_lines]
    summary_token_ids = [token for sentence in summary_lines_token_ids for token in sentence]

    return story_token_ids, summary_token_ids

In [11]:
def collate(data, tokenizer, block_size, device):
    """ Collate formats the data passed to the data loader.
    In particular we tokenize the data batch after batch to avoid keeping them
    all in memory. We output the data as a namedtuple to fit the original BertAbs's
    API.
    """
    data = [x for x in data if not len(x[1]) == 0]  # remove empty_files
    #print(data)
    #names = [name for name, _, _ in data]
    # summaries = [" ".join(summary_list) for _, _, summary_list in data]
    summaries = [" ".join(summary_list) for _, summary_list in data]
  

    encoded_text = [encode_for_summarization(story, summary, tokenizer, block_size) for story, summary in data]
    
    
    #""""""
    encoded_stories = torch.tensor(
        [fit_to_block_size(story, block_size, tokenizer.pad_token_id) for story, _ in encoded_text]
    )
    encoder_token_type_ids = compute_token_type_ids(encoded_stories, tokenizer.cls_token_id)
    encoder_mask = build_mask(encoded_stories, tokenizer.pad_token_id)
    #"""


    batch = TestBatch(
        #document_names=None,
        batch_size=len(encoded_stories),
        src=encoded_stories.to(device),
        segs=encoder_token_type_ids.to(device),
        mask_src=encoder_mask.to(device),
        tgt_str=summaries,
    )
    return batch

In [12]:
def build_data_iterator(dataset, tokenizer, batch_size=16, device='cuda', max_len=512):
   
    sampler = SequentialSampler(dataset)

    def collate_fn(data):
        return collate(data, tokenizer, block_size=max_len, device=device)

    iterator = DataLoader(dataset, sampler=sampler, batch_size=batch_size, collate_fn=collate_fn,)

    return iterator

In [13]:
from utils_nlp.common.pytorch_utils import get_device
#device, num_gpus = get_device(num_gpus=1, local_rank=-1)
device="cuda:0"
torch.cuda.set_device(device)
data_iterator = build_data_iterator(test_sum_dataset, tokenizer, batch_size=64, device=device)

In [14]:
device

'cuda:0'

### Create Predictor

In [15]:
model = BertAbs.from_pretrained("bertabs-finetuned-cnndm")
model.to(device)
model.eval()

symbols = {
    "BOS": tokenizer.vocab["[unused0]"],
    "EOS": tokenizer.vocab["[unused1]"],
    "PAD": tokenizer.vocab["[PAD]"],
}
from utils_nlp.models.transformers.extractive_summarization import Bunch
args = Bunch({"block_trigram":True, "alpha": 0.95, "beam_size": 5, "min_length": 50, "max_length": 200})
predictor = build_predictor(args, tokenizer, symbols, model)

I0116 04:44:58.548022 140119874643776 configuration_utils.py:185] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/remi/bertabs-finetuned-cnndm-extractive-abstractive-summarization-config.json from cache at /home/daden/.cache/torch/transformers/7ebb4ac81007d10b400cb6c2968d4c8f1275a3e0cc3bab7f20f81913198b542c.df616398f4c84def6fca83d755543b01cb445db4ddd218d3efeded8ded68332f
I0116 04:44:58.548912 140119874643776 configuration_utils.py:199] Model config {
  "dec_dropout": 0.2,
  "dec_ff_size": 2048,
  "dec_heads": 8,
  "dec_hidden_size": 768,
  "dec_layers": 6,
  "enc_dropout": 0.2,
  "enc_ff_size": 512,
  "enc_heads": 8,
  "enc_hidden_size": 512,
  "enc_layers": 6,
  "finetuning_task": null,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1"
  },
  "is_decoder": false,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1
  },
  "max_pos": 512,
  "num_labels": 2,
  "output_attentions": false,
  "output_hidden_states": false,
  "output_past": true,
  "pruned_

In [16]:
reference_summaries = []
generated_summaries = []

In [18]:
for batch in tqdm(data_iterator):
    
    batch_data = predictor.translate_batch(batch)
    translations = predictor.from_batch(batch_data)
    summaries = [format_summary(t) for t in translations]
    
    reference_summaries += batch.tgt_str
    generated_summaries += summaries


  0%|          | 0/180 [00:00<?, ?it/s][A
  1%|          | 1/180 [00:01<03:43,  1.25s/it][A
  1%|          | 2/180 [00:02<03:31,  1.19s/it][A
  2%|▏         | 3/180 [00:03<03:31,  1.20s/it][A
  2%|▏         | 4/180 [00:04<03:26,  1.18s/it][A
  3%|▎         | 5/180 [00:05<03:21,  1.15s/it][A
  3%|▎         | 6/180 [00:06<03:20,  1.15s/it][A
  4%|▍         | 7/180 [00:08<03:35,  1.25s/it][A
  4%|▍         | 8/180 [00:09<03:36,  1.26s/it][A
  5%|▌         | 9/180 [00:10<03:34,  1.25s/it][A
  6%|▌         | 10/180 [00:12<03:36,  1.28s/it][A
  6%|▌         | 11/180 [00:13<03:34,  1.27s/it][A
  7%|▋         | 12/180 [00:14<03:34,  1.28s/it][A
  7%|▋         | 13/180 [00:16<03:43,  1.34s/it][A
  8%|▊         | 14/180 [00:17<03:47,  1.37s/it][A
  8%|▊         | 15/180 [00:19<03:51,  1.40s/it][A
  9%|▉         | 16/180 [00:20<03:44,  1.37s/it][A
  9%|▉         | 17/180 [00:21<03:36,  1.33s/it][A
 10%|█         | 18/180 [00:22<03:20,  1.24s/it][A
 11%|█         | 19/180 [00:2

 87%|████████▋ | 156/180 [03:29<00:35,  1.48s/it][A
 87%|████████▋ | 157/180 [03:31<00:34,  1.50s/it][A
 88%|████████▊ | 158/180 [03:32<00:33,  1.54s/it][A
 88%|████████▊ | 159/180 [03:34<00:32,  1.53s/it][A
 89%|████████▉ | 160/180 [03:35<00:31,  1.55s/it][A
 89%|████████▉ | 161/180 [03:37<00:29,  1.57s/it][A
 90%|█████████ | 162/180 [03:38<00:27,  1.56s/it][A
 91%|█████████ | 163/180 [03:40<00:26,  1.57s/it][A
 91%|█████████ | 164/180 [03:42<00:24,  1.55s/it][A
 92%|█████████▏| 165/180 [03:43<00:23,  1.57s/it][A
 92%|█████████▏| 166/180 [03:45<00:22,  1.61s/it][A
 93%|█████████▎| 167/180 [03:46<00:20,  1.58s/it][A
 93%|█████████▎| 168/180 [03:48<00:18,  1.56s/it][A
 94%|█████████▍| 169/180 [03:49<00:16,  1.52s/it][A
 94%|█████████▍| 170/180 [03:51<00:15,  1.51s/it][A
 95%|█████████▌| 171/180 [03:52<00:13,  1.49s/it][A
 96%|█████████▌| 172/180 [03:54<00:11,  1.47s/it][A
 96%|█████████▌| 173/180 [03:55<00:10,  1.46s/it][A
 97%|█████████▋| 174/180 [03:56<00:08,  1.43s/

In [20]:
reference_summaries = []
for i in data[1]:
    reference_summaries.append(i[0].replace('<q>', ''))

In [None]:
def _write_list_to_file(list_items, filename):
    with open(filename, "w") as filehandle:
        # for cnt, line in enumerate(filehandle):
        for item in list_items:
            filehandle.write("%s\n" % item)
"""
_write_list_to_file(generated_summaries, "./generated_summaries")
generated_summaries = []
with open("./generated_summaries", "r") as filehandle:
    for cnt, line in enumerate(filehandle):
        generated_summaries.append(line)
"""

In [26]:
generated_summaries[0]

'prosecutor brice robin : " so far no videos were used in the crash investigation ". robin \'s comments follow claims by two magazines , german daily bild and french paris match. all 150 on board germanwings flight 9525 were killed\n'

In [21]:
with open("./generated_summaries", "r") as filehandle:
    for cnt, line in enumerate(filehandle):
        generated_summaries.append(line)

In [23]:
assert len(generated_summaries)==len(reference_summaries)

In [24]:
import rouge
import nltk

nltk.download("punkt")
rouge_evaluator = rouge.Rouge(
    metrics=["rouge-n", "rouge-l"],
    max_n=2,
    limit_length=True,
    length_limit=args.beam_size,
    length_limit_type="words",
    apply_avg=True,
    apply_best=False,
    alpha=0.5,  # Default F1_score
    weight_factor=1.2,
    stemming=True,
)

[nltk_data] Downloading package punkt to /home/daden/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [25]:
scores = rouge_evaluator.get_scores(generated_summaries, reference_summaries)
str_scores = format_rouge_scores(scores)
print(str_scores)



****** ROUGE SCORES ******

** ROUGE 1
F1        >> 0.355
Precision >> 0.365
Recall    >> 0.351

** ROUGE 2
F1        >> 0.242
Precision >> 0.251
Recall    >> 0.238

** ROUGE L
F1        >> 0.385
Precision >> 0.394
Recall    >> 0.381
