Copyright (c) Microsoft Corporation. All rights reserved.

Licensed under the MIT License.

## Extractive Summarization on CNN/DM Dataset using Transformer Version of BertSum


### Summary

This notebook demonstrates how to fine tune Transformers for extractive text summarization. Utility functions and classes in the NLP Best Practices repo are used to facilitate data preprocessing, model training, model scoring, result postprocessing, and model evaluation.

BertSum refers to  [Fine-tune BERT for Extractive Summarization (https://arxiv.org/pdf/1903.10318.pdf) with [published example](https://github.com/nlpyang/BertSum/). And the Transformer version of Bertsum refers to our modification of BertSum and the source code can be accessed at (https://github.com/daden-ms/BertSum/). 

Extractive summarization are usually used in document summarization where each input document consists of mutiple sentences. The preprocessing of the input training data involves assigning label 0 or 1 to the document sentences based on the give summary. The summarization problem is also simplfied to classifying whether each document sentence should be included in the summary. 

The figure below illustrates how BERTSum can be fine tuned for extractive summarization task. Each sentence is inserted with [CLS] token at the beginning and  [SEP] at the end. Interval segment embedding and positional embedding are added upon the token embedding before input the BERT model. The [CLS] token representation is used as sentence embedding and only the [CLS] tokens are used as input for the summarization model. The summarization layer predicts whether the probability of each each sentence token should be included in the summary or not. Techniques like trigram blocking can be used to improve model accuarcy.   

<img src="https://nlpbp.blob.core.windows.net/images/BertSum.PNG">


### Before You Start

The running time shown in this notebook is on a Standard_NC24s_v3 Azure Deep Learning Virtual Machine with 4 NVIDIA Tesla V100 GPUs. 
> **Tip**: If you want to run through the notebook quickly, you can set the **`QUICK_RUN`** flag in the cell below to **`True`** to run the notebook on a small subset of the data and a smaller number of epochs. 

On a machine with 1 NVIDIA Tesla V100 GPUs, 16GB GPU memory configuration,
- for data preprocessing, it takes around 10 minutes the data preprocessing for quick run. Otherwise it takes ~2 hours to finish the data preprocessing. This time estimation assumes the chosen transformer model is "distilbert-base-uncased" and the sentence selection method is "greedy", which is the default. The preprocessing time can be significantly longer if the sentence selection method is "combination", which can achieve better model performance.

- for model fine tuning, it takes around 30 minutes for quick run. Otherwise, it takes around ~3 hours to finish. This estimation assume the chosen encoder method is "transformer". The model fine tuning time can be shorter if other encoder method is chosen, which may result in worse model performance. 


In [1]:
## Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of epochs.
QUICK_RUN = True
## Set USE_PREPROCSSED_DATA = True to skip the data preprocessing
USE_PREPROCSSED_DATA = True

### Configuration

Before we start the notebook, we should set the environment variable to make sure you can access the GPUs on your machine

In [2]:
%load_ext autoreload

In [3]:
%autoreload 2

In [4]:
import os
import sys
from tempfile import TemporaryDirectory
import torch

nlp_path = os.path.abspath("../../")
if nlp_path not in sys.path:
    sys.path.insert(0, nlp_path)

from utils_nlp.common.pytorch_utils import get_device
from utils_nlp.dataset.cnndm import CNNDMBertSumProcessedData, CNNDMSummarization, ExtSumProcessedData
from utils_nlp.eval.evaluate_summarization import get_rouge
from utils_nlp.models.transformers.extractive_summarization import (
    get_cycled_dataset,
    get_dataloader,
    get_sequential_dataloader,
    ExtractiveSummarizer,
    ExtSumProcessor,
)

[nltk_data] Downloading package punkt to /home/daden/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
I1219 04:08:26.066413 139675858003776 file_utils.py:40] PyTorch version 1.2.0 available.


### Configuration: choose the transformer model to be used

In [15]:
# Transformer model being used
MODEL_NAME = "distilbert-base-uncased"

Also, we need to install the dependencies for pyrouge.

# dependencies for ROUGE-1.5.5.pl
!sudo apt-get update
!sudo apt-get install expat
!sudo apt-get install libexpat-dev -y

Run the following command in your terminal to install pre-requiste for using pyrouge.
1. sudo cpan install XML::Parser
1. sudo cpan install XML::Parser::PerlSAX
1. sudo cpan install XML::DOM

Download ROUGE-1.5.5 from https://github.com/andersjo/pyrouge/tree/master/tools/ROUGE-1.5.5.
Run the following command in your terminal.
* pyrouge_set_rouge_path $ABSOLUTE_DIRECTORY_TO_ROUGE-1.5.5.pl

### Data Preprossing

The dataset we used for this notebook is CNN/DM dataset which contains the documents and accompanying questions from the news articles of CNN and Daily mail. The highlights in each article are used as summary. The dataset consits of ~289K training examples, ~11K valiation and ~11K test dataset.  You can choose the [Option 1] below preprocess the data or [Option 2] to use the preprocessed version at [BERTSum published example](https://github.com/nlpyang/BertSum/). You don't need to manually download any of these two data sets as the code below will handle this part.  Since it takes up to 28 hours to preprocess the training data  to run on 10  Intel(R) Xeon(R) CPU E5-2690 v3 @ 2.60GHz, we suggest you continue with set as True first and experiment with data preprocessing  with QUICKRUN set as True.

##### Details of Data Preprocessing

The purpose of preprocessing is to process the input articles to the format that BertSum takes.  Functions defined specific in harvardnlp_cnndm_preprocess function are unique to CNN/DM dataset that's processed by harvardnlp. However, it provides a skeleton of how to preprocessing data into the format that BertSum takes. Assuming you have all articles and target summery each in a file, line-breaker seperated, the steps to preprocess the data are:
1. sentence tokenization
2. word tokenization
3. **label** the sentences in the article with 1 meaning the sentence is selected and 0 meaning the sentence is not selected. The options for the selection algorithms are "greedy" and "combination"
3. convert each example to  the desired format for extractive summarization
    - filter the sentences in the example based on the min_src_ntokens argument. If the lefted total sentence number is less than min_nsents, the example is discarded.
    - truncate the sentences in the example if the length is greater than max_src_ntokens
    - truncate the sentences in the example and the labels if the totle number of sentences is greater than max_nsents
    - [CLS] and [SEP] are inserted before and after each sentence
    - wordPiece tokenization
    - truncate the example to 512 tokens
    - convert the tokens into token indices corresponding to the BERT tokenizer's vocabulary.
    - segment ids are generated
    - [CLS] token positions are logged
    - [CLS] token labels are truncated if it's greater than 512, which is the maximum input length that can be taken by the BERT model.
    
    
Note that the original BERTSum paper use Stanford CoreNLP for data proprocessing, here we'll first how to use NLTK version, and then we also provide instruction of how to set up Stanford NLP and code examples of how to use Standford CoreNLP. 

##### [Option 1] Preprocess  data
The code in following cell will download the CNN/DM dataset listed at https://github.com/harvardnlp/sent-summary/.

In [6]:
# the data path used to save the downloaded data file
DATA_PATH = TemporaryDirectory().name
# The number of lines at the head of data file used for preprocessing. -1 means all the lines.
TOP_N = -1
if QUICK_RUN:
    TOP_N = 10000

In [7]:
train_dataset, test_dataset = CNNDMSummarization(top_n=TOP_N, local_cache_path=DATA_PATH)

100%|██████████| 489k/489k [00:07<00:00, 62.3kKB/s] 
I1217 18:37:26.811193 139802557667136 utils.py:173] Opening tar file /tmp/tmpy397314q/cnndm.tar.gz.


Preprocess the data and save the data to disk.

In [8]:
processor = ExtSumProcessor(model_name=MODEL_NAME)
ext_sum_train = processor.preprocess(train_dataset, train_dataset.get_target(), oracle_mode="greedy")
ext_sum_test = processor.preprocess(test_dataset, test_dataset.get_target(),oracle_mode="greedy")

I1217 18:37:38.230067 139802557667136 tokenization_utils.py:379] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at ./26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084


{'max_nsents': 200, 'max_src_ntokens': 2000, 'min_nsents': 3, 'min_src_ntokens': 5, 'use_interval': True}


In [9]:
save_path = os.path.join(DATA_PATH, "processed")
train_files = ExtSumProcessedData.save_data(
    ext_sum_train, is_test=False, save_path=save_path, chunk_size=2000
)
test_files = ExtSumProcessedData.save_data(
    ext_sum_test, is_test=True, save_path=save_path, chunk_size=2000
)

In [10]:
train_files

['/tmp/tmpy397314q/processed/0_train',
 '/tmp/tmpy397314q/processed/1_train',
 '/tmp/tmpy397314q/processed/2_train',
 '/tmp/tmpy397314q/processed/3_train',
 '/tmp/tmpy397314q/processed/4_train']

In [11]:
test_files

['/tmp/tmpy397314q/processed/0_test',
 '/tmp/tmpy397314q/processed/1_test',
 '/tmp/tmpy397314q/processed/2_test',
 '/tmp/tmpy397314q/processed/3_test',
 '/tmp/tmpy397314q/processed/4_test']

In [12]:
train_dataset_generator, test_dataset_generator = ExtSumProcessedData().splits(root=save_path)

#### Inspect Data

In [13]:
import torch
bert_format_data = torch.load(train_files[0])
print(len(bert_format_data))
bert_format_data[0].keys()

2000


dict_keys(['src', 'labels', 'segs', 'clss', 'src_txt', 'tgt_txt'])

In [14]:
bert_format_data[0]['labels']

[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

##### [Option 2] Reuse Preprocessed  data from [BERTSUM Repo](https://github.com/nlpyang/BertSum)

In [5]:
# the data path used to downloaded the preprocessed data from BERTSUM Repo.
# if you have downloaded the dataset, change the code to use that path where the dataset is.
PROCESSED_DATA_PATH = TemporaryDirectory().name
data_path = "./temp_data5/"
PROCESSED_DATA_PATH = data_path

In [6]:
if USE_PREPROCSSED_DATA:
    CNNDMBertSumProcessedData.download(local_path=PROCESSED_DATA_PATH)
    train_dataset_generator, test_dataset_generator = ExtSumProcessedData().splits(root=PROCESSED_DATA_PATH)
    

### Model training
To start model training, we need to create a instance of ExtractiveSummarizer.
#### Choose the transformer model.
Currently ExtractiveSummarizer support two models:
- distilbert-base-uncase, 
- bert-base-uncase

Potentionally, roberta-based model and xlnet can be supported but needs to be tested.
#### Choose the encoder algorithm.
There are four options:
- baseline: it used a smaller transformer model to replace the bert model and with transformer summarization layer
- classifier: it uses pretrained BERT and fine-tune BERT with **simple logistic classification** summarization layer
- transformer: it uses pretrained BERT and fine-tune BERT with **transformer** summarization layer
- RNN: it uses pretrained BERT and fine-tune BERT with **LSTM** summarization layer

In [13]:
# notebook parameters
# the cache data path during find tuning
CACHE_DIR = TemporaryDirectory().name

# batch size, unit is the number of tokens
BATCH_SIZE = 3000

# GPU used for training
NUM_GPUS = 1

# Encoder name. Options are: 1. baseline, classifier, transformer, rnn.
ENCODER = "transformer"

# Learning rate
LEARNING_RATE=2e-3

# How often the statistics reports show up in training, unit is step.
REPORT_EVERY=100
    
if QUICK_RUN:
    # total number of steps for training
    MAX_STEPS=1e4
    # number of steps for warm up
    WARMUP_STEPS=5e3
    
else:
    MAX_STEPS=5e4
    WARMUP_STEPS=5e3
 

In [16]:
summarizer = ExtractiveSummarizer(MODEL_NAME, ENCODER, CACHE_DIR)

I1219 04:09:29.497260 139675858003776 file_utils.py:319] https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json not found in cache or force_download set to True, downloading to /tmp/tmpo59o7kse
100%|██████████| 492/492 [00:00<00:00, 400776.38B/s]
I1219 04:09:29.658666 139675858003776 file_utils.py:334] copying /tmp/tmpo59o7kse to cache at /tmp/tmpi99rn6ni/a41e817d5c0743e29e86ff85edc8c257e61bc8d88e4271bb1b243b6e7614c633.1ccd1a11c9ff276830e114ea477ea2407100f4a3be7bdc45d37be9e37fa71c7e
I1219 04:09:29.662085 139675858003776 file_utils.py:338] creating metadata file for /tmp/tmpi99rn6ni/a41e817d5c0743e29e86ff85edc8c257e61bc8d88e4271bb1b243b6e7614c633.1ccd1a11c9ff276830e114ea477ea2407100f4a3be7bdc45d37be9e37fa71c7e
I1219 04:09:29.663305 139675858003776 file_utils.py:347] removing temp file /tmp/tmpo59o7kse
I1219 04:09:29.664099 139675858003776 configuration_utils.py:157] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/distilbert

In [19]:
# batch_size is the number of tokens in a batch
train_dataloader = get_dataloader(get_cycled_dataset(train_dataset_generator), is_labeled=True, batch_size=3000)

In [20]:
summarizer.fit(
            train_dataloader,
            num_gpus=1,
            gradient_accumulation_steps=2,
            max_steps=MAX_STEPS,
            lr=LEARNING_RATE,
            warmup_steps=WARMUP_STEPS,
            verbose=True,
            report_every=REPORT_EVERY,
            clip_grad_norm=False,
        )

loss: 50.646378, time: 22.631160, number of examples in current step: 5, step 100 out of total 10000
loss: 33.793594, time: 21.044874, number of examples in current step: 5, step 200 out of total 10000
loss: 32.788770, time: 20.708551, number of examples in current step: 5, step 300 out of total 10000
loss: 31.555956, time: 21.022777, number of examples in current step: 5, step 400 out of total 10000
loss: 31.015303, time: 20.605717, number of examples in current step: 8, step 500 out of total 10000
loss: 30.423153, time: 20.881066, number of examples in current step: 5, step 600 out of total 10000
loss: 30.845548, time: 20.734294, number of examples in current step: 5, step 700 out of total 10000
loss: 30.580163, time: 21.032496, number of examples in current step: 7, step 800 out of total 10000
loss: 31.150659, time: 20.884736, number of examples in current step: 5, step 900 out of total 10000
loss: 29.488281, time: 21.059112, number of examples in current step: 5, step 1000 out of t

loss: 5.945398, time: 20.959122, number of examples in current step: 5, step 8200 out of total 10000
loss: 7.403578, time: 20.623628, number of examples in current step: 5, step 8300 out of total 10000
loss: 7.579400, time: 21.010104, number of examples in current step: 5, step 8400 out of total 10000
loss: 6.733897, time: 20.956976, number of examples in current step: 5, step 8500 out of total 10000
loss: 7.400434, time: 21.315107, number of examples in current step: 5, step 8600 out of total 10000
loss: 8.942719, time: 20.842602, number of examples in current step: 5, step 8700 out of total 10000
loss: 8.192907, time: 21.137500, number of examples in current step: 5, step 8800 out of total 10000
loss: 5.031124, time: 20.709420, number of examples in current step: 5, step 8900 out of total 10000
loss: 4.819118, time: 21.070154, number of examples in current step: 6, step 9000 out of total 10000
loss: 4.357930, time: 20.748780, number of examples in current step: 5, step 9100 out of to

In [21]:
summarizer.save_model("extsum_modelname_{0}_quickrun_{1}.pt".format(MODEL_NAME, QUICK_RUN))

I1217 19:20:14.048710 139802557667136 extractive_summarization.py:432] Saving model checkpoint to /tmp/tmpdym52i5i/fine_tuned/extsum_modelname_distilbert-base-uncased_quickrun_True.pt


In [17]:
# for loading a previously saved model
import torch
summarizer.model = torch.load("cnndm_transformersum_distilbert-base-uncased_bertsum_processed_data.pt")



### Model Evaluation

[ROUGE](https://en.wikipedia.org/wiki/ROUGE_(metric)), or Recall-Oriented Understudy for Gisting Evaluation has been commonly used for evaluation text summarization.

In [10]:
eval_dataset=[]
for i in test_dataset_generator():
    eval_dataset.extend(i)
target = [eval_dataset[i]['tgt_txt'] for i in range(len(eval_dataset))]

In [11]:
new_eval_batch = [] 
for batch in get_sequential_dataloader(eval_dataset, is_labeled=True):
    new_eval_batch.append(batch)

In [18]:
new_eval_batch = [] 
new_eval_dataset = [] 
j = 0
for batch in get_sequential_dataloader(eval_dataset, is_labeled=True):
    new_eval_dataset.append(ExtSumProcessor.get_inputs(batch, MODEL_NAME))
    new_eval_batch.append(batch)
    break

In [81]:
new_eval_batch[0].src_str

[['turkey has blocked access to twitter and youtube after they refused a request to remove pictures of a prosecutor held during an armed siege last week .',
  "a turkish court imposed the blocks because images of the deadly siege were being shared on social media and ` deeply upset ' the wife and children of mehmet selim kiraz , the hostage who was killed .",
  "the 46-year-old turkish prosecutor died in hospital when members of the revolutionary people 's liberation party-front ( dhkp-c ) stormed a courthouse and took him hostage .",
  'the dhkp-c is considered a terrorist group by turkey , the european union and us .',
  'a turkish court has blocked access to twitter and youtube after they refused a request to remove pictures of prosecutor mehmet selim kiraz held during an armed siege last week',
  'grief : the family of mehmet selim kiraz grieve over his coffin during his funeral at eyup sultan mosque in istanbul , turkey .',
  'he died in hospital after he was taken hostage by the 

In [20]:
batch = new_eval_dataset[0]
temp = (batch['x'], batch['segs'], batch['clss'], batch['mask'], batch['mask_cls'], batch['labels'])

In [31]:
batch

{'x': tensor([[  101,  4977,  2038,  ...,  4652,  6399,   102],
         [  101, 15555, 12548,  ..., 12548,  3555,   102],
         [  101,  2865,  6180,  ...,  2036,  5580,   102],
         [  101,  7354, 19023,  ...,     0,     0,     0],
         [  101,  2009,  1005,  ...,  2011,  6768,   102]]),
 'segs': tensor([[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 1, 1, 1],
         [0, 0, 0,  ..., 1, 1, 1],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 1, 1, 1]]),
 'clss': tensor([[  0,  29,  73, 113, 135, 171, 205, 222, 248, 275, 315, 355, 375, 401,
          421, 461, 495,   0,   0,   0,   0,   0],
         [  0,  28,  76, 108, 138, 174, 227, 248, 276, 307, 323, 351, 385, 424,
          447, 467, 488, 508,   0,   0,   0,   0],
         [  0,  30,  73,  98, 147, 174, 198, 223, 248, 262, 290, 314, 323, 334,
          359, 383, 397, 427, 450, 469, 489, 508],
         [  0,  39,  72, 144, 170, 193, 214, 244, 272, 291, 312,   0,   0,   0,
            0,   0,   0,   0,   

In [21]:
device="cuda:0"
new_batch = tuple(t.to(device) for t in temp)

In [33]:
a = ExtSumProcessor.get_inputs2(new_batch, summarizer.model_name, train_mode=False)

In [34]:
a

{'x': tensor([[  101,  4977,  2038,  ...,  4652,  6399,   102],
         [  101, 15555, 12548,  ..., 12548,  3555,   102],
         [  101,  2865,  6180,  ...,  2036,  5580,   102],
         [  101,  7354, 19023,  ...,     0,     0,     0],
         [  101,  2009,  1005,  ...,  2011,  6768,   102]], device='cuda:0'),
 'segs': tensor([[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 1, 1, 1],
         [0, 0, 0,  ..., 1, 1, 1],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 1, 1, 1]], device='cuda:0'),
 'clss': tensor([[  0,  29,  73, 113, 135, 171, 205, 222, 248, 275, 315, 355, 375, 401,
          421, 461, 495,   0,   0,   0,   0,   0],
         [  0,  28,  76, 108, 138, 174, 227, 248, 276, 307, 323, 351, 385, 424,
          447, 467, 488, 508,   0,   0,   0,   0],
         [  0,  30,  73,  98, 147, 174, 198, 223, 248, 262, 290, 314, 323, 334,
          359, 383, 397, 427, 450, 469, 489, 508],
         [  0,  39,  72, 144, 170, 193, 214, 244, 272, 291, 312,   0,   0,   0

In [26]:
outputs = summarizer.model(**a)

In [27]:
sent_scores = outputs[0]

In [28]:
sent_scores

tensor([[1.2797, 1.4138, 1.3300, 1.2633, 1.2465, 1.0688, 1.1607, 1.0789, 1.0379,
         1.0831, 1.1103, 1.0584, 1.0471, 1.0388, 1.0484, 1.0268, 1.0320, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000],
        [1.4332, 1.2015, 1.4598, 1.3369, 1.2130, 1.1952, 1.1715, 1.0674, 1.0351,
         1.0249, 1.0774, 1.0215, 1.0350, 1.0662, 1.0975, 1.0838, 1.0363, 1.0532,
         0.0000, 0.0000, 0.0000, 0.0000],
        [1.3368, 1.5235, 1.1833, 1.1581, 1.3971, 1.1417, 1.1155, 1.0668, 1.0552,
         1.0474, 1.0775, 1.0161, 1.0235, 1.0214, 1.0414, 1.0415, 1.0308, 1.0466,
         1.0516, 1.0198, 1.0394, 1.0168],
        [1.3996, 1.2992, 1.2811, 1.3704, 1.1763, 1.1740, 1.0317, 1.0152, 1.0304,
         1.1223, 1.1197, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000],
        [1.2746, 1.4863, 1.1093, 1.2819, 1.1078, 1.0225, 1.0339, 1.2592, 1.0943,
         1.1625, 1.0523, 1.0986, 1.0223, 1.0163, 1.0467, 1.0518, 1.1166, 1.0232,
         0.0000, 0.000

In [30]:
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler

In [31]:
eval_sampler = SequentialSampler(eval_dataset)

In [32]:
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=1)


In [86]:
%%time 
prediction = summarizer.predict(new_eval_dataset, num_gpus=1)

CPU times: user 9 µs, sys: 1 µs, total: 10 µs
Wall time: 17.2 µs


In [127]:
type(prediction_list[0])

numpy.ndarray

In [87]:
prediction_list = list(prediction)

{'x': tensor([[[  101,  4977,  2038,  ...,  4652,  6399,   102],
         [  101, 15555, 12548,  ..., 12548,  3555,   102],
         [  101,  2865,  6180,  ...,  2036,  5580,   102],
         [  101,  7354, 19023,  ...,     0,     0,     0],
         [  101,  2009,  1005,  ...,  2011,  6768,   102]]]), 'segs': tensor([[[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 1, 1, 1],
         [0, 0, 0,  ..., 1, 1, 1],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 1, 1, 1]]]), 'clss': tensor([[[  0,  29,  73, 113, 135, 171, 205, 222, 248, 275, 315, 355, 375, 401,
          421, 461, 495,   0,   0,   0,   0,   0],
         [  0,  28,  76, 108, 138, 174, 227, 248, 276, 307, 323, 351, 385, 424,
          447, 467, 488, 508,   0,   0,   0,   0],
         [  0,  30,  73,  98, 147, 174, 198, 223, 248, 262, 290, 314, 323, 334,
          359, 383, 397, 427, 450, 469, 489, 508],
         [  0,  39,  72, 144, 170, 193, 214, 244, 272, 291, 312,   0,   0,   0,
            0,   0,   0,   0,

In [118]:
len(prediction_list)

1

In [119]:
len(eval_batch)

NameError: name 'eval_batch' is not defined

In [114]:
import numpy as np
sentence_score=prediction_list[0]
selected_ids = np.argsort(-sentence_score, 1)
#prediction_list[0]

In [116]:
selected_ids

array([[ 1,  2,  0,  3,  4,  6, 10,  9,  7,  5, 11, 14, 12, 13,  8, 16,
        15, 20, 17, 18, 19, 21],
       [ 2,  0,  3,  4,  1,  5,  6, 14, 15, 10,  7, 13, 17, 16,  8, 12,
         9, 11, 20, 18, 19, 21],
       [ 1,  4,  0,  2,  3,  5,  6, 10,  7,  8, 18,  9, 17, 15, 14, 20,
        16, 12, 13, 19, 21, 11],
       [ 0,  3,  1,  2,  4,  5,  9, 10,  6,  8,  7, 20, 11, 12, 13, 14,
        15, 16, 17, 18, 19, 21],
       [ 1,  3,  0,  7,  9, 16,  2,  4, 11,  8, 10, 15, 14,  6, 17,  5,
        12, 13, 18, 19, 20, 21]])

In [112]:
sentence_seperator="<q>",
top_n=3,
block_trigram=True,
verbose=True,
cal_lead=False,

In [131]:
for i in range(len(new_eval_batch)):
#for (batch, sent_scores) in zip(new_eval_batch, prediction_list):
    sent_scores = prediction_list[i]
    print(type(prediction_list[0]))
    #selected_ids = np.argsort(-sent_scores, 1)
    #print(selected_ids)
    temp_pred, temp_target = get_pred(new_eval_batch[i], sent_scores)
    print(temp_pred[0])
    print(temp_targe[0])

<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
[[ 1  2  0  3  4  6 10  9  7  5 11 14 12 13  8 16 15 20 17 18 19 21]
 [ 2  0  3  4  1  5  6 14 15 10  7 13 17 16  8 12  9 11 20 18 19 21]
 [ 1  4  0  2  3  5  6 10  7  8 18  9 17 15 14 20 16 12 13 19 21 11]
 [ 0  3  1  2  4  5  9 10  6  8  7 20 11 12 13 14 15 16 17 18 19 21]
 [ 1  3  0  7  9 16  2  4 11  8 10 15 14  6 17  5 12 13 18 19 20 21]]


AttributeError: 'tuple' object has no attribute 'join'

In [70]:
def _get_ngrams(n, text):
    ngram_set = set()
    text_length = len(text)
    max_index_ngram_start = text_length - n
    for i in range(max_index_ngram_start + 1):
        ngram_set.add(tuple(text[i : i + n]))
    return ngram_set

def _block_tri(c, p):
    tri_c = _get_ngrams(3, c.split())
    for s in p:
        tri_s = _get_ngrams(3, s.split())
        if len(tri_c.intersection(tri_s)) > 0:
            return True
    return False



In [130]:
def get_pred(batch, sent_scores, cal_lead=False, sentence_seperator='<q>'):
    print(type(sent_scores))
    selected_ids = np.argsort(-sent_scores, 1)
    #print(selected_ids)
    if cal_lead:
        selected_ids = list(range(batch.clss.size(1))) * len(batch.clss)
    print(selected_ids)
    pred = []
    target = []
    for i, idx in enumerate(selected_ids):
        _pred = []
        if len(batch.src_str[i]) == 0:
            pred.append("")
            continue
        for j in selected_ids[i][: len(batch.src_str[i])]:
            if j >= len(batch.src_str[i]):
                continue
            candidate = batch.src_str[i][j].strip()
            if block_trigram:
                if not _block_tri(candidate, _pred):
                    _pred.append(candidate)
            else:
                _pred.append(candidate)

            # only select the top 3
            if len(_pred) == top_n:
                break

        # _pred = '<q>'.join(_pred)
        _pred = sentence_seperator.join(_pred)
        pred.append(_pred.strip())
        target.append(batch.tgt_str[i])
    print("=======================")
    print(pred)
    print("=======================")
    print(target)
    return pred, target

In [14]:
%%time 
prediction = summarizer.predict(get_sequential_dataloader(test_dataset), num_gpus=2)

range(0, 2)
cuda
CPU times: user 2min 47s, sys: 55.2 s, total: 3min 42s
Wall time: 2min 14s


In [12]:
%%time 
prediction = summarizer.predict(get_sequential_dataloader(test_dataset), num_gpus=4)

range(0, 4)
cuda
CPU times: user 5min 14s, sys: 1min 56s, total: 7min 10s
Wall time: 3min 13s


In [12]:
%%time 
prediction = summarizer.predict(get_sequential_dataloader(test_dataset), num_gpus=4)

range(0, 4)
cuda
CPU times: user 5min 15s, sys: 1min 57s, total: 7min 13s
Wall time: 3min 13s


In [12]:
%%time 
prediction = summarizer.predict(get_sequential_dataloader(test_dataset), num_gpus=1)

CPU times: user 56.8 s, sys: 18 s, total: 1min 14s
Wall time: 1min 14s


In [12]:
%%time 
prediction = summarizer.predict(get_sequential_dataloader(test_dataset), num_gpus=4,)
#prediction = summarizer.predict(get_dataloader(test_dataset_generator()))

CPU times: user 5min 13s, sys: 1min 55s, total: 7min 9s
Wall time: 3min 10s


In [13]:
%%time
prediction = summarizer.predict(get_sequential_dataloader(test_dataset), num_gpus=4,)
#prediction = summarizer.predict(get_dataloader(test_dataset_generator()))

CPU times: user 5min 13s, sys: 1min 57s, total: 7min 11s
Wall time: 3min 10s


In [13]:
len(prediction)

11489

In [32]:
file_numbers = range(0,4)
filenames = ["{}.predict".format(i) for i in file_numbers]
print(filenames)
prediction = []
for i in filenames:
    prediction.extend(torch.load(prediction))

['0.predict_target', '1.predict_target', '2.predict_target', '3.predict_target']


In [55]:
prediction = torch.load("0.predict")
target = torch.load("0.target")

In [56]:
len(prediction)

11489

In [57]:
len(target)

11489

In [58]:
prediction[1]

"deborah fuller has been banned from keeping animals after she dragged her dog behind her car , causing wounds as she drove at 30mph<q>the rhodesian ridgeback was left with injuries to all four paws as well as grazing to his chest and a deep wound on his elbow .<q>he is believed to have somehow escaped from the boot of her car and was dragged along the single carriageway because his lead was attached to the vehicle 's tailgate ."

In [59]:
target[1]



In [60]:
rouge_transformer = get_rouge(prediction, target, "./results/")

11489
11489


2019-12-18 21:58:09,503 [MainThread  ] [INFO ]  Writing summaries.
I1218 21:58:09.503323 140578347489088 pyrouge.py:525] Writing summaries.
2019-12-18 21:58:09,505 [MainThread  ] [INFO ]  Processing summaries. Saving system files to ./results/tmp71b7igdt/system and model files to ./results/tmp71b7igdt/model.
I1218 21:58:09.505720 140578347489088 pyrouge.py:518] Processing summaries. Saving system files to ./results/tmp71b7igdt/system and model files to ./results/tmp71b7igdt/model.
2019-12-18 21:58:09,506 [MainThread  ] [INFO ]  Processing files in ./results/rouge-tmp-2019-12-18-21-58-08/candidate/.
I1218 21:58:09.506670 140578347489088 pyrouge.py:43] Processing files in ./results/rouge-tmp-2019-12-18-21-58-08/candidate/.
2019-12-18 21:58:10,667 [MainThread  ] [INFO ]  Saved processed files to ./results/tmp71b7igdt/system.
I1218 21:58:10.667188 140578347489088 pyrouge.py:53] Saved processed files to ./results/tmp71b7igdt/system.
2019-12-18 21:58:10,668 [MainThread  ] [INFO ]  Processing

---------------------------------------------
1 ROUGE-1 Average_R: 0.52266 (95%-conf.int. 0.51986 - 0.52568)
1 ROUGE-1 Average_P: 0.34674 (95%-conf.int. 0.34446 - 0.34907)
1 ROUGE-1 Average_F: 0.40313 (95%-conf.int. 0.40106 - 0.40530)
---------------------------------------------
1 ROUGE-2 Average_R: 0.22634 (95%-conf.int. 0.22379 - 0.22906)
1 ROUGE-2 Average_P: 0.14922 (95%-conf.int. 0.14747 - 0.15107)
1 ROUGE-2 Average_F: 0.17380 (95%-conf.int. 0.17189 - 0.17581)
---------------------------------------------
1 ROUGE-L Average_R: 0.47360 (95%-conf.int. 0.47089 - 0.47638)
1 ROUGE-L Average_P: 0.31457 (95%-conf.int. 0.31237 - 0.31671)
1 ROUGE-L Average_F: 0.36555 (95%-conf.int. 0.36355 - 0.36758)



In [14]:
rouge_transformer = get_rouge(prediction, target, "./results/")

11489
11489


2019-12-18 03:30:54,263 [MainThread  ] [INFO ]  Writing summaries.
I1218 03:30:54.263799 140518432167744 pyrouge.py:525] Writing summaries.
2019-12-18 03:30:54,265 [MainThread  ] [INFO ]  Processing summaries. Saving system files to ./results/tmpctvsqcxg/system and model files to ./results/tmpctvsqcxg/model.
I1218 03:30:54.265614 140518432167744 pyrouge.py:518] Processing summaries. Saving system files to ./results/tmpctvsqcxg/system and model files to ./results/tmpctvsqcxg/model.
2019-12-18 03:30:54,266 [MainThread  ] [INFO ]  Processing files in ./results/rouge-tmp-2019-12-18-03-30-53/candidate/.
I1218 03:30:54.266546 140518432167744 pyrouge.py:43] Processing files in ./results/rouge-tmp-2019-12-18-03-30-53/candidate/.
2019-12-18 03:30:55,387 [MainThread  ] [INFO ]  Saved processed files to ./results/tmpctvsqcxg/system.
I1218 03:30:55.387292 140518432167744 pyrouge.py:53] Saved processed files to ./results/tmpctvsqcxg/system.
2019-12-18 03:30:55,389 [MainThread  ] [INFO ]  Processing

---------------------------------------------
1 ROUGE-1 Average_R: 0.54208 (95%-conf.int. 0.53930 - 0.54484)
1 ROUGE-1 Average_P: 0.36866 (95%-conf.int. 0.36651 - 0.37103)
1 ROUGE-1 Average_F: 0.42466 (95%-conf.int. 0.42276 - 0.42672)
---------------------------------------------
1 ROUGE-2 Average_R: 0.24754 (95%-conf.int. 0.24499 - 0.25011)
1 ROUGE-2 Average_P: 0.16856 (95%-conf.int. 0.16669 - 0.17049)
1 ROUGE-2 Average_F: 0.19382 (95%-conf.int. 0.19190 - 0.19576)
---------------------------------------------
1 ROUGE-L Average_R: 0.49419 (95%-conf.int. 0.49159 - 0.49685)
1 ROUGE-L Average_P: 0.33667 (95%-conf.int. 0.33456 - 0.33889)
1 ROUGE-L Average_F: 0.38754 (95%-conf.int. 0.38561 - 0.38960)



In [26]:
test_dataset[0]['tgt_txt']

'marseille prosecutor says `` so far no videos were used in the crash investigation `` despite media reports .<q>journalists at bild and paris match are `` very confident `` the video clip is real , an editor says .<q>andreas lubitz had informed his lufthansa training school of an episode of severe depression , airline says .<q>'

In [27]:
prediction[0]

"marseille , france ( cnn ) the french prosecutor leading an investigation into the crash of germanwings flight 9525 insisted wednesday that he was not aware of any video footage from on board the plane .<q>all 150 on board were killed .<q>cell phones have been collected at the site , he said , but that they `` had n't been exploited yet . ``"

In [28]:
test_dataset[0]['src_txt']

['marseille , france ( cnn ) the french prosecutor leading an investigation into the crash of germanwings flight 9525 insisted wednesday that he was not aware of any video footage from on board the plane .',
 'marseille prosecutor brice robin told cnn that `` so far no videos were used in the crash investigation . ``',
 'he added , `` a person who has such a video needs to immediately give it to the investigators . ``',
 "robin 's comments follow claims by two magazines , german daily bild and french paris match , of a cell phone video showing the harrowing final seconds from on board germanwings flight 9525 as it crashed into the french alps .",
 'all 150 on board were killed .',
 'paris match and bild reported that the video was recovered from a phone at the wreckage site .',
 'the two publications described the supposed video , but did not post it on their websites .',
 'the publications said that they watched the video , which was found by a source close to the investigation . ``',