Copyright (c) Microsoft Corporation. All rights reserved.

Licensed under the MIT License.

## Extractive Text Summerization on CNN/DM Dataset using BertSum


### Summary

This notebook demonstrates how to fine tune BERT for extractive text summerization. Utility functions and classes in the NLP Best Practices repo are used to facilitate data preprocessing, model training, model scoring, result postprocessing, and model evaluation.

BertSum refers to  [Fine-tune BERT for Extractive Summarization (https://arxiv.org/pdf/1903.10318.pdf) with [published example](https://github.com/nlpyang/BertSum/). Extractive summarization are usually used in document summarization where each input document consists of mutiple sentences. The preprocessing of the input training data involves assigning label 0 or 1 to the document sentences based on the give summary. The summarization problem is also simplfied to classifying whether each document sentence should be included in the summary. 

The figure below illustrates how BERTSum can be fine tuned for extractive summarization task. Each sentence is inserted with [CLS] token at the beginning and  [SEP] at the end. Interval segment embedding and positional embedding are added upon the token embedding before input the BERT model. The [CLS] token representation is used as sentence embedding and only the [CLS] tokens are used as input for the summarization model. The summarization layer predicts whether the probability of each each sentence token should be included in the summary or not. Techniques like trigram blocking can be used to improve model accuarcy.   

<img src="https://nlpbp.blob.core.windows.net/images/?.PNG">


### Before You Start

The running time shown in this notebook is on a Standard_NC24s_v3 Azure Deep Learning Virtual Machine with 4 NVIDIA Tesla V100 GPUs. 
> **Tip**: If you want to run through the notebook quickly, you can set the **`QUICK_RUN`** flag in the cell below to **`True`** to run the notebook on a small subset of the data and a smaller number of epochs. 

The table below provides some reference running time on different machine configurations.  

|QUICK_RUN|Machine Configurations|Running time|
|:---------|:----------------------|:------------|
|True|1 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ ? minutes |
|False|4 NVIDIA Tesla V100 GPUs, 64GB GPU memory| ~ ? hours|


In [1]:
## Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of epochs.
QUICK_RUN = True
USE_PREPROCESSED_DARA =  False
if not USE_PREPROCESSED_DARA:
    BERT_DATA_PATH="/dadendev/BertSum/bert_data/"

### Configuration

Before we start the notebook, we should set the environment variable to make sure you can access the GPUs on your machine

First you need to clone a modified version of BertSum so that it works for prediction cases and can run on any GPU device ID on your machine

In [2]:
!git clone https://github.com/daden-ms/BertSum.git

fatal: destination path 'BertSum' already exists and is not an empty directory.


In [4]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2,3"

In [5]:
import sys
nlp_path = os.path.abspath('../../')
if nlp_path not in sys.path:
    sys.path.insert(0, nlp_path)
    
sys.path.insert(0, './BertSum/src')

Also, we need to install the dependencies for pyrouge.

In [None]:
# dependencies for ROUGE-1.5.5.pl
!sudo apt-get update
!sudo apt-get install expat
!sudo apt-get install libexpat-dev -y

Run the following command in your terminal
1. sudo cpan install XML::Parser
1. sudo cpan install XML::Parser::PerlSAX
1. sudo cpan install XML::DOM

Also you need to set up file2rouge


### Data Preprossing

The dataset we used for this notebook is CNN/DM dataset which contains the documents and accompanying questions from the news articles of CNN and Daily mail. The highlights in each article are used as summary. The dataset consits of ~289K training examples, ~11K valiation and ~11K test dataset.  You can choose to use the preprocessed version at [BERTSum published example](https://github.com/nlpyang/BertSum/) or use the following section to preprocess the data. Since it takes up to 28 hours to preprocess the training data  to run on 10  Intel(R) Xeon(R) CPU E5-2690 v3 @ 2.60GHz, if you choose to run the preprocessing, we suggest you run with QUICKRUN set as True.



If you choose to use preprocessed data, continue to section #Model training.
To continue with the data preprocessing, run the following command to download from https://github.com/harvardnlp/sent-summary and unzip the data to folder ./harvardnlp_cnndm

In [None]:
!wget https://s3.amazonaws.com/opennmt-models/Summary/cnndm.tar.gz &&\
    mkdir -p harvardnlp_cnndm &&\
    mv cnndm.tar.gz ./harvardnlp_cnndm && cd ./harvardnlp_cnndm &&\
    tar -xvf cnndm.tar.gz 

#### Details of Data Preprocessing

The purpose of preprocessing is to process the input articles to the format that BertSum takes.  Functions defined specific in harvardnlp_cnndm_preprocess function are unique to CNN/DM dataset that's processed by harvardnlp. However, it provides a skeleton of how to preprocessing data into the format that BertSum takes. Assuming you have all articles and target summery each in a file, line-breaker seperated, the steps to preprocess the data are:
1. sentence tokenization
2. word tokenization
3. label the sentences in the article with 1 meaning the sentence is selected and 0 meaning the sentence is not selected. The options for the selection algorithms are "greedy" and "combination"
3. convert each example to  BertSum format
    - filter the sentences in the example based on the min_src_ntokens argument. If the lefted total sentence number is less than min_nsents, the example is discarded.
    - truncate the sentences in the example if the length is greater than max_src_ntokens
    - truncate the sentences in the example and the labels if the totle number of sentences is greater than max_nsents
    - [CLS] and [SEP] are inserted before and after each sentence
    - wordPiece tokenization
    - truncate the example to 512 tokens
    - convert the tokens into token indices corresponding to the BERT tokenizer's vocabulary.
    - segment ids are generated
    - [CLS] token positions are logged
    - [CLS] token labels are truncated if it's greater than 512, which is the maximum input length that can be taken by the BERT model.
    
    
Note that the original BERTSum paper use Stanford CoreNLP for data proprocessing, here we'll first how to use NLTK version, and then we also provide instruction of how to set up Stanford NLP and code examples of how to use Standford CoreNLP. 

In [5]:
from utils_nlp.dataset.harvardnlp_cnndm import harvardnlp_cnndm_preprocess
from utils_nlp.models.bert.extractive_text_summarization import bertsum_formatting

[nltk_data] Downloading package punkt to /home/daden/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [6]:
%%time
max_train_job_number = -1
max_test_job_number = -1
if QUICK_RUN:
    max_train_job_number = 100
    max_test_job_number = 10

CPU times: user 3 µs, sys: 1 µs, total: 4 µs
Wall time: 6.68 µs


In [7]:
output_file = f"./harvardnlp_cnndm/test.bertdata_{QUICK_RUN}" 

#### Preprocess training data

In [14]:
%%time
TRAIN_SRC_FILE = "./harvardnlp_cnndm/train.txt.src"
TRAIN_TGT_FILE = "./harvardnlp_cnndm/train.txt.tgt.tagged"
PROCESSED_TRAIN_FILE = f"./harvardnlp_cnndm/train.bertdata_{QUICK_RUN}" 
import multiprocessing
n_cpus = multiprocessing.cpu_count() - 1
jobs = harvardnlp_cnndm_preprocess(n_cpus, TRAIN_SRC_FILE, TRAIN_TGT_FILE, max_train_job_number)
print("total length of training data:", len(jobs))
from prepro.data_builder import BertData
from utils_nlp.models.bert.extractive_text_summarization import Bunch
default_preprocessing_parameters =  {"max_nsents": 200, "max_src_ntokens": 2000, "min_nsents": 3, "min_src_ntokens": 5, "use_interval": True}
args=Bunch(default_preprocessing_parameters)
bertdata = BertData(args)
bertsum_formatting(n_cpus, bertdata,"combination", jobs[0:max_train_job_number], PROCESSED_TRAIN_FILE)


total length of training data: 100
CPU times: user 3.14 s, sys: 1.63 s, total: 4.77 s
Wall time: 41.1 s


#### Preprocess test data

In [15]:
%%time
TEST_SRC_FILE = "./harvardnlp_cnndm/test.txt.src"
TEST_TGT_FILE = "./harvardnlp_cnndm/test.txt.tgt.tagged"
PROCESSED_TEST_FILE = f"./harvardnlp_cnndm/test.bertdata_{QUICK_RUN}" 
import multiprocessing
n_cpus = multiprocessing.cpu_count() - 1
jobs = harvardnlp_cnndm_preprocess(n_cpus, TRAIN_SRC_FILE, TRAIN_TGT_FILE, max_test_job_number)
print("total length of training data:", len(jobs))
from prepro.data_builder import BertData
from utils_nlp.models.bert.extractive_text_summarization import Bunch
default_preprocessing_parameters =  {"max_nsents": 200, "max_src_ntokens": 2000, "min_nsents": 3, "min_src_ntokens": 5, "use_interval": True}
args=Bunch(default_preprocessing_parameters)
bertdata = BertData(args)
bertsum_formatting(n_cpus, bertdata,"combination", jobs[0:max_test_job_number], PROCESSED_TEST_FILE)


total length of training data: 10
CPU times: user 2.9 s, sys: 1.59 s, total: 4.49 s
Wall time: 5.12 s


#### Inspect the data

In [16]:
import torch
bert_format_data = torch.load(PROCESSED_TRAIN_FILE)
print(len(bert_format_data))
bert_format_data[0].keys()


100


dict_keys(['src', 'labels', 'segs', 'clss', 'src_txt', 'tgt_txt'])

In [17]:
bert_format_data[0]['src']

[101,
 3559,
 1005,
 1055,
 3602,
 1024,
 1999,
 2256,
 2369,
 1996,
 5019,
 2186,
 1010,
 13229,
 11370,
 2015,
 3745,
 2037,
 6322,
 1999,
 5266,
 2739,
 1998,
 17908,
 1996,
 3441,
 2369,
 1996,
 2824,
 1012,
 102,
 101,
 2182,
 1010,
 7082,
 14697,
 1051,
 1005,
 9848,
 3138,
 5198,
 2503,
 1037,
 7173,
 2073,
 2116,
 1997,
 1996,
 13187,
 2024,
 10597,
 5665,
 1012,
 102,
 101,
 2019,
 24467,
 7431,
 2006,
 1996,
 1036,
 1036,
 6404,
 2723,
 1010,
 1036,
 1036,
 2073,
 2116,
 10597,
 5665,
 13187,
 2024,
 7431,
 1999,
 5631,
 2077,
 3979,
 1012,
 102,
 101,
 5631,
 1010,
 3516,
 1006,
 13229,
 1007,
 1011,
 1011,
 1996,
 6619,
 2723,
 1997,
 1996,
 5631,
 1011,
 27647,
 3653,
 18886,
 2389,
 12345,
 4322,
 2003,
 9188,
 1996,
 1036,
 1036,
 6404,
 2723,
 1012,
 1036,
 1036,
 102,
 101,
 2182,
 1010,
 13187,
 2007,
 1996,
 2087,
 5729,
 5177,
 24757,
 2024,
 23995,
 2127,
 2027,
 1005,
 2128,
 3201,
 2000,
 3711,
 1999,
 2457,
 1012,
 102,
 101,
 2087,
 2411,
 1010,
 2027,
 2227,
 

In [21]:
bert_format_data[0]['tgt_txt']

'andrew mogni , 20 , from glen ellyn , illinois , had only just arrived for a semester program when the incident happened in january<q>he was flown back to chicago via air on march 20 but he died on sunday<q>initial police reports indicated the fall was an accident but authorities are investigating the possibility that mogni was robbed<q>his cousin claims he was attacked and thrown 40ft from a bridge'

In [22]:
bert_format_data[0]['labels']

[0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

In [23]:
bert_format_data[0]['src_txt']

['a university of iowa student has died nearly three months after a fall in rome in a suspected robbery attack in rome .',
 'andrew mogni , 20 , from glen ellyn , illinois , had only just arrived for a semester program in italy when the incident happened in january .',
 'he was flown back to chicago via air ambulance on march 20 , but he died on sunday .',
 'andrew mogni , 20 , from glen ellyn , illinois , a university of iowa student has died nearly three months after a fall in rome in a suspected robbery',
 'he was taken to a medical facility in the chicago area , close to his family home in glen ellyn .',
 "he died on sunday at northwestern memorial hospital - medical examiner 's office spokesman frank shuftan says a cause of death wo n't be released until monday at the earliest .",
 'initial police reports indicated the fall was an accident but authorities are investigating the possibility that mogni was robbed .',
 "on sunday , his cousin abby wrote online : ` this morning my cous

### Model training
To start model training, we need to create a instance of BertSumExtractiveSummarizer, a wrapper for running BertSum-based finetuning. You can select any device ID on your machine, but make sure that you include the string version of the device ID in the gpu_ranks argument.




In [20]:
## choose which GPU device to use
device_id = 1
gpu_ranks = str(device_id)

#### Choose the encoder algorithm. There are four options:
- baseline: it used a smaller transformer model to replace the bert model and with transformer summarization layer
- classifier: it uses pretrained BERT and fine-tune BERT with **simple logistic classification** summarization layer
- transformer: it uses pretrained BERT and fine-tune BERT with **transformer** summarization layer
- RNN: it uses pretrained BERT and fine-tune BERT with **LSTM** summarization layer

In [21]:
encoder = 'baseline'
model_base_path = './models/'
log_base_path = './logs/'
result_base_path = './results'

BERT_CONFIG_PATH = "/dadendev/nlp/BertSum/bert_config_uncased_base.json"

import os
if not os.path.exists(model_base_path):
    os.makedirs(model_base_path)
if not os.path.exists(log_base_path):
    os.makedirs(log_base_path)
if not os.path.exists(result_base_path):
    os.makedirs(result_base_path)
    
from random import random
random_number = random()

In [22]:
from utils_nlp.models.bert.extractive_text_summarization import BertSumExtractiveSummarizer
bertsum_model = BertSumExtractiveSummarizer(encoder = 'baseline', 
                                            model_path = model_base_path+encoder+str(random_number),
                                            log_file = log_base_path+encoder+str(random_number),
                                            bert_config_path=BERT_CONFIG_PATH,
                                            device_id = device_id,
                                            gpu_ranks = gpu_ranks,)

['1']
{1: 0}


Here we use the fully processed CNN/DM dataset to train the model. During the training, you can stop any time and retrain from the previous saved checkpoint.

In [28]:
USE_PREPROCESSED_DATA = True

In [29]:
if USE_PREPROCESSED_DATA is True:
    PROCESSED_TRAIN_FILE = './bert_train_data_all_none_excluded'
    training_data_files = [PROCESSED_TRAIN_FILE]
else:    
    BERT_DATA_PATH="/dadendev/BertSum/bert_data/"
    import glob
    pts = sorted(glob.glob(BERT_DATA_PATH + 'cnndm.train' + '.[0-9]*.pt'))
    training_data_files = pts

In [None]:
bertsum_model.fit(device_id, training_data_files, train_steps=50000, train_from="")

### Model Evaluation

[ROUGE](https://en.wikipedia.org/wiki/ROUGE_(metric)), or Recall-Oriented Understudy for Gisting Evaluation has been commonly used for evaluation text summerization.

In [42]:
import torch
from models.data_loader  import DataIterator,Batch,Dataloader
import os

USE_PREPROCESSED_DATA = False
if USE_PREPROCESSED_DATA is True: 
    test_dataset=torch.load(PROCESSED_TEST_FILE)
else:
    test_dataset=[]
    for i in range(0,6):
        filename = os.path.join(BERT_DATA_PATH, "test/cnndm.test.{0}.bert.pt".format(i))
        test_dataset.extend(torch.load(filename))

    
def get_data_iter(dataset,is_test=False, batch_size=3000):
    args = Bunch({})
    args.use_interval = True
    args.batch_size = batch_size
    test_data_iter = None
    test_data_iter  = DataIterator(args, dataset, args.batch_size, 'cuda', is_test=is_test, shuffle=False, sort=False)
    return test_data_iter

In [32]:
model_for_test = "./models/baseline0.14344633695274556/model_step_30000.pt"
target = [test_dataset[i]['tgt_txt'] for i in range(len(test_dataset))]
prediction = bertsum_model.predict(device_id, get_data_iter(test_dataset),
                                   test_from=model_for_test,
                                   sentence_seperator='<q>')



[2019-10-07 03:29:59,368 INFO] Device ID 1
[2019-10-07 03:29:59,373 INFO] Loading checkpoint from ./models/baseline0.14344633695274556/model_step_30000.pt
[2019-10-07 03:30:02,139 INFO] * number of parameters: 5179137


device_id 1
gpu_rank 0


In [33]:
len(prediction)

11489

In [None]:
from utils_nlp.eval.evaluate_summarization import get_rouge
rouge_baseline = get_rouge(prediction, target, "/dadendev/textsum/results/")

11489
11489


2019-10-07 03:31:33,696 [MainThread  ] [INFO ]  Writing summaries.
[2019-10-07 03:31:33,696 INFO] Writing summaries.
2019-10-07 03:31:33,698 [MainThread  ] [INFO ]  Processing summaries. Saving system files to /dadendev/textsum/results/tmpekvbkkdp/system and model files to /dadendev/textsum/results/tmpekvbkkdp/model.
[2019-10-07 03:31:33,698 INFO] Processing summaries. Saving system files to /dadendev/textsum/results/tmpekvbkkdp/system and model files to /dadendev/textsum/results/tmpekvbkkdp/model.
2019-10-07 03:31:33,700 [MainThread  ] [INFO ]  Processing files in /dadendev/textsum/results/rouge-tmp-2019-10-07-03-31-32/candidate/.
[2019-10-07 03:31:33,700 INFO] Processing files in /dadendev/textsum/results/rouge-tmp-2019-10-07-03-31-32/candidate/.
2019-10-07 03:31:35,332 [MainThread  ] [INFO ]  Saved processed files to /dadendev/textsum/results/tmpekvbkkdp/system.
[2019-10-07 03:31:35,332 INFO] Saved processed files to /dadendev/textsum/results/tmpekvbkkdp/system.
2019-10-07 03:31:35,

In [46]:
len(prediction)

11489

In [50]:
prediction[0]

'andrew mogni , 20 , from glen ellyn , illinois , had only just arrived for a semester program in italy when the incident happened in january .he was flown back to chicago via air ambulance on march 20 , but he died on sunday .a university of iowa student has died nearly three months after a fall in rome in a suspected robbery attack in rome .'

In [51]:
target[0]

'andrew mogni , 20 , from glen ellyn , illinois , had only just arrived for a semester program when the incident happened in january<q>he was flown back to chicago via air on march 20 but he died on sunday<q>initial police reports indicated the fall was an accident but authorities are investigating the possibility that mogni was robbed<q>his cousin claims he was attacked and thrown 40ft from a bridge'

### Prediction

In [6]:
from utils_nlp.models.bert.extractive_text_summarization import Bunch
args=Bunch({"max_nsents": int(1e5), 
            "max_src_ntokens": int(2e6), 
            "min_nsents": -1, 
            "min_src_ntokens": -1,  
            "use_interval": True})

In [7]:
from prepro.data_builder import BertData
bertdata = BertData(args)

In [8]:
import torch
import sys
#sys.path.insert(0, '../src')
from others.utils import clean
from multiprocess import Pool


In [9]:
#os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
import os
os.environ["CORENLP_HOME"]="/home/daden/stanfordnlp_resources/stanford-corenlp-full-2018-10-05"

In [10]:
from stanfordnlp.server import CoreNLPClient

In [11]:
from multiprocessing import Pool
from utils_nlp.models.bert.extractive_text_summarization import tokenize_to_list, bertify
import re

def preprocess_target(line):
    def _remove_ttags(line):
        line = re.sub(r'<t>', '', line)
        # change </t> to <q>
        # pyrouge test requires <q> as  sentence splitter
        line = re.sub(r'</t>', '<q>', line)
        return line

    return tokenize_to_list(client, _remove_ttags(line))
def preprocess_source(line):
    return tokenize_to_list(client, clean(line))

def preprocess_cnndm(param):
    source, target = param
    return bertify(bertdata, source, target)

def harvardnlp_cnndm_standfordnlp(client, source_file, target_file, n_cpus=2, top_n=-1):
    source_list = []
    i = 0
    with open(source_file) as fd:
        for line in fd:
            source_list.append(line)
            i +=1
    
    pool = Pool(n_cpus)
    

    tokenized_source_data =  pool.map(preprocess_source, source_list[0:top_n], int(len(source_list[0:top_n])/n_cpus))
    pool.close()
    pool.join
    
    i = 0
    target_list = []
    with open(target_file) as fd:
        for line in fd:
            target_list.append(line)
            i +=1

    pool = Pool(n_cpus)
    tokenized_target_data =  pool.map(preprocess_target, target_list[0:top_n], int(len(target_list[0:top_n])/n_cpus))
    pool.close()
    pool.join()
            

    #return tokenized_source_data, tokenized_target_data

    pool = Pool(n_cpus)
    bertified_data =  pool.map(preprocess_cnndm, zip(tokenized_source_data[0:top_n], tokenized_target_data[0:top_n]), int(len(tokenized_source_data[0:top_n])/n_cpus))
    pool.close()
    pool.join()
    return bertified_data
    

In [12]:
%%time
source_file = './harvardnlp_cnndm/test.txt.src'
target_file = './harvardnlp_cnndm/test.txt.tgt.tagged'
client = CoreNLPClient(annotators=['tokenize','ssplit'])
new_data = harvardnlp_cnndm_standfordnlp(client, source_file, target_file, n_cpus=2, top_n=10)

Starting server with command: java -Xmx5G -cp /home/daden/stanfordnlp_resources/stanford-corenlp-full-2018-10-05/* edu.stanford.nlp.pipeline.StanfordCoreNLPServer -port 9000 -timeout 60000 -threads 5 -maxCharLength 100000 -quiet True -serverProperties corenlp_server-e3b74cd551ba43f1.props -preload tokenize,ssplit
Starting server with command: java -Xmx5G -cp /home/daden/stanfordnlp_resources/stanford-corenlp-full-2018-10-05/* edu.stanford.nlp.pipeline.StanfordCoreNLPServer -port 9000 -timeout 60000 -threads 5 -maxCharLength 100000 -quiet True -serverProperties corenlp_server-e3b74cd551ba43f1.props -preload tokenize,ssplit
Starting server with command: java -Xmx5G -cp /home/daden/stanfordnlp_resources/stanford-corenlp-full-2018-10-05/* edu.stanford.nlp.pipeline.StanfordCoreNLPServer -port 9000 -timeout 60000 -threads 5 -maxCharLength 100000 -quiet True -serverProperties corenlp_server-e3b74cd551ba43f1.props -preload tokenize,ssplit
Starting server with command: java -Xmx5G -cp /home/dad

In [24]:
import torch
from models.data_loader  import DataIterator,Batch,Dataloader
import os

USE_PREPROCESSED_DATA = False
if USE_PREPROCESSED_DATA is True: 
    test_dataset=torch.load(PROCESSED_TEST_FILE)
else:
    test_dataset=[]
    for i in range(0,6):
        filename = os.path.join(BERT_DATA_PATH, "test/cnndm.test.{0}.bert.pt".format(i))
        test_dataset.extend(torch.load(filename))
def get_data_iter(dataset,is_test=False, batch_size=3000):
    args = Bunch({})
    args.use_interval = True
    args.batch_size = batch_size
    test_data_iter = None
    test_data_iter  = DataIterator(args, dataset, args.batch_size, 'cuda', is_test=is_test, shuffle=False, sort=False)
    return test_data_iter

In [25]:

new_src = preprocess_source("".join(test_dataset[0]['src_txt']))
b_data = bertdata.preprocess(new_src, None, None)
indexed_tokens, labels, segments_ids, cls_ids, src_txt, tgt_txt = b_data
b_data_dict = {"src": indexed_tokens, "labels": labels, "segs": segments_ids, 'clss': cls_ids,
               'src_txt': src_txt, "tgt_txt": tgt_txt}
   

In [26]:
len(new_src)

16

In [27]:
b_data_dict['src_txt']

['a university of iowa student has died nearly three months after a fall in rome in a suspected robbery attack in rome .',
 'andrew mogni , 20 , from glen ellyn , illinois , had only just arrived for a semester program in italy when the incident happened in january .',
 'he was flown back to chicago via air ambulance on march 20 , but he died on sunday .',
 'andrew mogni , 20 , from glen ellyn , illinois , a university of iowa student has died nearly three months after a fall in rome in a suspected robberyhe was taken to a medical facility in the chicago area , close to his family home in glen ellyn .',
 "he died on sunday at northwestern memorial hospital - medical examiner 's office spokesman frank shuftan says a cause of death wo n't be released until monday at the earliest .",
 'initial police reports indicated the fall was an accident but authorities are investigating the possibility that mogni was robbed .',
 "on sunday , his cousin abby wrote online : ` this morning my cousin an

In [18]:
b_data_dict['tgt_txt']

In [28]:
model_for_test = "./models/baseline0.14344633695274556/model_step_30000.pt"
#get_data_iter(output,batch_size=30000)
prediction = bertsum_model.predict(device_id, get_data_iter([b_data_dict], False),
                                   test_from=model_for_test, )

[2019-10-07 03:28:55,446 INFO] Device ID 1
[2019-10-07 03:28:55,452 INFO] Loading checkpoint from ./models/baseline0.14344633695274556/model_step_30000.pt
[2019-10-07 03:29:01,758 INFO] * number of parameters: 5179137


device_id 1
gpu_rank 0


In [29]:
prediction[0]

'andrew mogni , 20 , from glen ellyn , illinois , had only just arrived for a semester program in italy when the incident happened in january .he was flown back to chicago via air ambulance on march 20 , but he died on sunday .a university of iowa student has died nearly three months after a fall in rome in a suspected robbery attack in rome .'

In [31]:
test_dataset[0]['tgt_txt']

'andrew mogni , 20 , from glen ellyn , illinois , had only just arrived for a semester program when the incident happened in january<q>he was flown back to chicago via air on march 20 but he died on sunday<q>initial police reports indicated the fall was an accident but authorities are investigating the possibility that mogni was robbed<q>his cousin claims he was attacked and thrown 40ft from a bridge'