Copyright (c) Microsoft Corporation.
Licensed under the MIT License.

# Abstractive Summarization using BertSumAbs on CNN/DailyMails Dataset

## Summary

This notebook demonstrates how to fine tune BERT for abstractive text summarization. Utility functions and classes in the NLP Best Practices repo are used to facilitate data preprocessing, model training, model scoring, result postprocessing, and model evaluation.

### Abstractive Summarization
Abstractive summarization is the task of taking an input text and summarizing its content in a shorter output text. In contrast to extractive summarization, abstractive summarization doesn't take sentences directly from the input text, instead, rephrases the input text.

### BertSumAbs

BertSumAbs refers to an BERT-based abstractive summarization algorithm  in [Text Summarization with Pretrained Encoders](https://arxiv.org/abs/1908.08345) with [published examples](https://github.com/nlpyang/PreSumm). It uses the pretrained BERT model as encoder and finetune both encoder and decoder on a specific labeled summarization dataset like [CNN/DM dataset](https://github.com/harvardnlp/sent-summary). 

The figure below shows the comparison of architecture of the original BERT model (left) and BERTSUM (right), which BertSumAbs is built upon. For BERTSUM, a input document is split into sentences, and [CLS] and [SEP] tokens are inserted before and after each sentence. This resulting sequence is followed by the summation of three kinds of embeddings for each token before feeding into the transformer layers. The positional embedding used in BertSumAbs enables input length of more than 512, which is the  maximum input length for BERT model. 

It should be noted that the architecture only shows the encoder part. For decoder, BertSumAbs also uses a transformer with multiple layers and random initialization. As pretrained weights are used in the encoder, there is a mismatch in encoder and decoder which may result in unstable finetuning. Therefore, in fine tuning, BertSumAbs uses seperate optimizers for encoder and decoder, each uses its own scheduling. In text generation, techniques like trigram blocking and beam search can be used to improve model accuracy.
<img src="https://nlpbp.blob.core.windows.net/images/BertForSummarization.PNG">


## Before you start
Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of steps. If QUICK_RUN = False, the notebook takes about 5 hours to run on a VM with 4 16GB NVIDIA V100 GPUs. Finetuning costs around 1.5 hours and inferecing costs around 3.5 hour.  Better performance can be achieved by increasing the MAX_STEPS.

* **ROUGE Evalation**: To run rouge evaluation, please refer to the section of compute_rouge_perl in [summarization_evaluation.ipynb](./summarization_evaluation.ipynb) for setup.

* **Distributed Training**:
Please note that the jupyter notebook only allows to use pytorch [DataParallel](https://pytorch.org/docs/master/nn.html#dataparallel). Faster speed and larger batch size can be achieved with pytorch [DistributedDataParallel](https://pytorch.org/docs/master/notes/ddp.html)(DDP). Script [abstractive_summarization_bertsum_cnndm_distributed_train.py](./abstractive_summarization_bertsum_cnndm_distributed_train.py) shows an example of how to use DDP.

* **Mixed Precision Training**:
Please note that by default this notebook doesn't use mixed precision training. Faster speed and larger batch size can be achieved when you set FP16 to True. Refer to  https://nvidia.github.io/apex and https://github.com/nvidia/apex) for details to use mixed precision training. Check the GPU model on your machine to see if it allows mixed precision training. Please also note that mixed precision inferencing is also enabled in the prediciton utility function. When you use mixed precision training and/or inferencing, the model performance can be slightly worse than the full precision mode.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
QUICK_RUN = False

In [39]:
import os
import shutil
import sys
from tempfile import TemporaryDirectory
import torch

nlp_path = os.path.abspath("../../")
if nlp_path not in sys.path:
    sys.path.insert(0, nlp_path)

from utils_nlp.models.transformers.abstractive_summarization_bertsum import BertSumAbs, BertSumAbsProcessor

from utils_nlp.dataset.cnndm import CNNDMSummarizationDataset
from utils_nlp.eval import compute_rouge_python

import pandas as pd
import pprint
import scrapbook as sb

## Data Preprocessing

The dataset we used for this notebook is CNN/DM dataset which contains the documents and accompanying questions from the news articles of CNN and Daily mail. The highlights in each article are used as summary. The dataset consits of ~289K training examples, ~11K valiation examples and ~11K test examples. The length of the news articles is 781 tokens on average and the summaries are of 3.75 sentences and 56 tokens on average.

The significant part of data preprocessing only involve splitting the input document into sentences.

In [31]:
# the data path used to save the downloaded data file
DATA_PATH = TemporaryDirectory().name
# The number of lines at the head of data file used for preprocessing. -1 means all the lines.
TOP_N = 100
if not QUICK_RUN:
    TOP_N = -1

In [32]:
train_dataset, test_dataset = CNNDMSummarizationDataset(
            top_n=TOP_N, local_cache_path=DATA_PATH, prepare_extractive=False
        )

In [33]:
len(train_dataset)

287227

In [34]:
len(test_dataset)

11490

## Model Finetuning

In [8]:
# notebook parameters
# the cache path
CACHE_PATH = TemporaryDirectory().name

# model parameters
MODEL_NAME = "bert-base-uncased"
MAX_POS = 768
MAX_SOURCE_SEQ_LENGTH = 640
MAX_TARGET_SEQ_LENGTH = 140

# mixed precision setting. To enable mixed precision training, follow instructions in SETUP.md. 
FP16 = False
if FP16:
    FP16_OPT_LEVEL="O2"
    
# fine-tuning parameters
# batch size, unit is the number of tokens
BATCH_SIZE_PER_GPU = 3


# GPU used for training
NUM_GPUS = torch.cuda.device_count()

# Learning rate
LEARNING_RATE_BERT=5e-4/2.0
LEARNING_RATE_DEC=0.05/2.0


# How often the statistics reports show up in training, unit is step.
REPORT_EVERY=10
SAVE_EVERY=500

# total number of steps for training
MAX_STEPS=1e2
   
if not QUICK_RUN:
    MAX_STEPS=5e3

WARMUP_STEPS_BERT=2000
WARMUP_STEPS_DEC=1000   


In [9]:
# processor which contains the colloate function to load the preprocessed data
processor = BertSumAbsProcessor(cache_dir=CACHE_PATH, max_src_len=MAX_SOURCE_SEQ_LENGTH, max_tgt_len=MAX_TARGET_SEQ_LENGTH)
# summarizer
summarizer = BertSumAbs(
    processor, cache_dir=CACHE_PATH, max_pos_length=MAX_POS
)

In [10]:
BATCH_SIZE_PER_GPU*NUM_GPUS

12

In [11]:

summarizer.fit(
        train_dataset,
        num_gpus=NUM_GPUS,
        batch_size=BATCH_SIZE_PER_GPU*NUM_GPUS,
        max_steps=MAX_STEPS,
        learning_rate_bert=LEARNING_RATE_BERT,
        learning_rate_dec=LEARNING_RATE_DEC,
        warmup_steps_bert=WARMUP_STEPS_BERT,
        warmup_steps_dec=WARMUP_STEPS_DEC,
        save_every=SAVE_EVERY,
        report_every=REPORT_EVERY*5,
        fp16=FP16,
        # checkpoint="saved checkpoint path"
)


device is cuda


Iteration:   0%|          | 50/23936 [01:07<6:57:14,  1.05s/it] 

timestamp: 12/03/2020 19:42:23, average loss: 10.372837, time duration: 67.029438,
                            number of examples in current reporting: 600, step 50
                            out of total 5000


Iteration:   0%|          | 100/23936 [02:00<6:51:39,  1.04s/it]

timestamp: 12/03/2020 19:43:16, average loss: 6.238095, time duration: 53.688107,
                            number of examples in current reporting: 600, step 100
                            out of total 5000


Iteration:   1%|          | 150/23936 [02:54<6:53:34,  1.04s/it]

timestamp: 12/03/2020 19:44:10, average loss: 5.669545, time duration: 53.790859,
                            number of examples in current reporting: 599, step 150
                            out of total 5000


Iteration:   1%|          | 200/23936 [03:49<9:05:31,  1.38s/it]

timestamp: 12/03/2020 19:45:05, average loss: 5.348148, time duration: 55.037045,
                            number of examples in current reporting: 600, step 200
                            out of total 5000


Iteration:   1%|          | 250/23936 [04:43<7:05:14,  1.08s/it]

timestamp: 12/03/2020 19:45:59, average loss: 5.119667, time duration: 54.222955,
                            number of examples in current reporting: 600, step 250
                            out of total 5000


Iteration:   1%|▏         | 300/23936 [05:37<6:55:54,  1.06s/it]

timestamp: 12/03/2020 19:46:53, average loss: 4.949734, time duration: 53.657023,
                            number of examples in current reporting: 600, step 300
                            out of total 5000


Iteration:   1%|▏         | 350/23936 [06:31<6:56:55,  1.06s/it]

timestamp: 12/03/2020 19:47:47, average loss: 4.780042, time duration: 53.672302,
                            number of examples in current reporting: 600, step 350
                            out of total 5000


Iteration:   2%|▏         | 400/23936 [07:25<7:04:04,  1.08s/it]

timestamp: 12/03/2020 19:48:41, average loss: 4.712837, time duration: 54.140079,
                            number of examples in current reporting: 600, step 400
                            out of total 5000


Iteration:   2%|▏         | 450/23936 [08:17<6:51:25,  1.05s/it]

timestamp: 12/03/2020 19:49:34, average loss: 4.612352, time duration: 52.689976,
                            number of examples in current reporting: 600, step 450
                            out of total 5000


Iteration:   2%|▏         | 499/23936 [09:11<9:15:24,  1.42s/it]

timestamp: 12/03/2020 19:50:28, average loss: 4.564076, time duration: 54.513465,
                            number of examples in current reporting: 600, step 500
                            out of total 5000
saving through pytorch to ./abstemp/bert-base-uncased_step_500.pt


Iteration:   2%|▏         | 550/23936 [10:19<7:39:54,  1.18s/it] 

timestamp: 12/03/2020 19:51:35, average loss: 4.486558, time duration: 67.296946,
                            number of examples in current reporting: 600, step 550
                            out of total 5000


Iteration:   3%|▎         | 600/23936 [11:13<6:47:33,  1.05s/it]

timestamp: 12/03/2020 19:52:29, average loss: 4.435066, time duration: 53.801369,
                            number of examples in current reporting: 600, step 600
                            out of total 5000


Iteration:   3%|▎         | 650/23936 [12:06<6:48:01,  1.05s/it]

timestamp: 12/03/2020 19:53:22, average loss: 4.424270, time duration: 53.131144,
                            number of examples in current reporting: 598, step 650
                            out of total 5000


Iteration:   3%|▎         | 700/23936 [13:00<6:53:23,  1.07s/it]

timestamp: 12/03/2020 19:54:16, average loss: 4.416547, time duration: 53.616453,
                            number of examples in current reporting: 600, step 700
                            out of total 5000


Iteration:   3%|▎         | 750/23936 [13:54<6:46:13,  1.05s/it]

timestamp: 12/03/2020 19:55:10, average loss: 4.384736, time duration: 54.156298,
                            number of examples in current reporting: 600, step 750
                            out of total 5000


Iteration:   3%|▎         | 800/23936 [14:47<6:44:11,  1.05s/it]

timestamp: 12/03/2020 19:56:03, average loss: 4.305494, time duration: 53.149219,
                            number of examples in current reporting: 600, step 800
                            out of total 5000


Iteration:   4%|▎         | 850/23936 [15:42<7:16:02,  1.13s/it]

timestamp: 12/03/2020 19:56:58, average loss: 4.276876, time duration: 54.933155,
                            number of examples in current reporting: 600, step 850
                            out of total 5000


Iteration:   4%|▍         | 900/23936 [16:35<6:44:33,  1.05s/it]

timestamp: 12/03/2020 19:57:51, average loss: 4.301424, time duration: 53.365603,
                            number of examples in current reporting: 600, step 900
                            out of total 5000


Iteration:   4%|▍         | 950/23936 [17:29<6:38:17,  1.04s/it]

timestamp: 12/03/2020 19:58:45, average loss: 4.289361, time duration: 53.449913,
                            number of examples in current reporting: 600, step 950
                            out of total 5000


Iteration:   4%|▍         | 999/23936 [18:21<6:35:20,  1.03s/it]

timestamp: 12/03/2020 19:59:38, average loss: 4.243631, time duration: 53.010966,
                            number of examples in current reporting: 600, step 1000
                            out of total 5000
saving through pytorch to ./abstemp/bert-base-uncased_step_1000.pt


Iteration:   4%|▍         | 1050/23936 [19:29<6:38:49,  1.05s/it] 

timestamp: 12/03/2020 20:00:45, average loss: 4.211871, time duration: 67.357176,
                            number of examples in current reporting: 600, step 1050
                            out of total 5000


Iteration:   5%|▍         | 1100/23936 [20:24<7:22:23,  1.16s/it]

timestamp: 12/03/2020 20:01:40, average loss: 4.201430, time duration: 54.340546,
                            number of examples in current reporting: 600, step 1100
                            out of total 5000


Iteration:   5%|▍         | 1150/23936 [21:17<7:04:02,  1.12s/it]

timestamp: 12/03/2020 20:02:33, average loss: 4.177537, time duration: 53.533902,
                            number of examples in current reporting: 600, step 1150
                            out of total 5000


Iteration:   5%|▌         | 1200/23936 [22:11<6:38:52,  1.05s/it]

timestamp: 12/03/2020 20:03:27, average loss: 4.146584, time duration: 53.418810,
                            number of examples in current reporting: 600, step 1200
                            out of total 5000


Iteration:   5%|▌         | 1250/23936 [23:04<6:44:45,  1.07s/it]

timestamp: 12/03/2020 20:04:20, average loss: 4.095358, time duration: 53.731934,
                            number of examples in current reporting: 599, step 1250
                            out of total 5000


Iteration:   5%|▌         | 1300/23936 [23:58<6:36:43,  1.05s/it]

timestamp: 12/03/2020 20:05:14, average loss: 4.071987, time duration: 53.952441,
                            number of examples in current reporting: 600, step 1300
                            out of total 5000


Iteration:   6%|▌         | 1350/23936 [24:52<6:34:50,  1.05s/it]

timestamp: 12/03/2020 20:06:08, average loss: 4.030592, time duration: 53.411778,
                            number of examples in current reporting: 600, step 1350
                            out of total 5000


Iteration:   6%|▌         | 1400/23936 [25:45<6:29:20,  1.04s/it]

timestamp: 12/03/2020 20:07:01, average loss: 3.993827, time duration: 53.573513,
                            number of examples in current reporting: 600, step 1400
                            out of total 5000


Iteration:   6%|▌         | 1450/23936 [26:40<6:44:49,  1.08s/it]

timestamp: 12/03/2020 20:07:56, average loss: 3.953534, time duration: 54.412248,
                            number of examples in current reporting: 599, step 1450
                            out of total 5000


Iteration:   6%|▋         | 1499/23936 [27:32<6:32:06,  1.05s/it]

timestamp: 12/03/2020 20:08:49, average loss: 3.937318, time duration: 53.012146,
                            number of examples in current reporting: 599, step 1500
                            out of total 5000
saving through pytorch to ./abstemp/bert-base-uncased_step_1500.pt


Iteration:   6%|▋         | 1550/23936 [28:28<6:33:42,  1.06s/it] 

timestamp: 12/03/2020 20:09:44, average loss: 3.927354, time duration: 55.748131,
                            number of examples in current reporting: 600, step 1550
                            out of total 5000


Iteration:   7%|▋         | 1600/23936 [29:22<6:31:16,  1.05s/it]

timestamp: 12/03/2020 20:10:38, average loss: 3.881955, time duration: 53.677091,
                            number of examples in current reporting: 600, step 1600
                            out of total 5000


Iteration:   7%|▋         | 1650/23936 [30:16<6:35:27,  1.06s/it]

timestamp: 12/03/2020 20:11:32, average loss: 3.864133, time duration: 53.513038,
                            number of examples in current reporting: 600, step 1650
                            out of total 5000


Iteration:   7%|▋         | 1700/23936 [31:11<6:46:26,  1.10s/it]

timestamp: 12/03/2020 20:12:27, average loss: 3.845761, time duration: 54.973100,
                            number of examples in current reporting: 598, step 1700
                            out of total 5000


Iteration:   7%|▋         | 1750/23936 [32:04<6:42:18,  1.09s/it]

timestamp: 12/03/2020 20:13:20, average loss: 3.833323, time duration: 53.872088,
                            number of examples in current reporting: 600, step 1750
                            out of total 5000


Iteration:   8%|▊         | 1800/23936 [32:58<6:24:41,  1.04s/it]

timestamp: 12/03/2020 20:14:14, average loss: 3.785084, time duration: 53.159890,
                            number of examples in current reporting: 599, step 1800
                            out of total 5000


Iteration:   8%|▊         | 1850/23936 [33:51<6:21:29,  1.04s/it]

timestamp: 12/03/2020 20:15:07, average loss: 3.707007, time duration: 53.459718,
                            number of examples in current reporting: 600, step 1850
                            out of total 5000


Iteration:   8%|▊         | 1900/23936 [34:44<6:22:15,  1.04s/it]

timestamp: 12/03/2020 20:16:00, average loss: 3.676872, time duration: 52.927864,
                            number of examples in current reporting: 600, step 1900
                            out of total 5000


Iteration:   8%|▊         | 1950/23936 [35:38<6:27:35,  1.06s/it]

timestamp: 12/03/2020 20:16:54, average loss: 3.687167, time duration: 53.611518,
                            number of examples in current reporting: 600, step 1950
                            out of total 5000


Iteration:   8%|▊         | 1999/23936 [36:30<6:31:53,  1.07s/it]

timestamp: 12/03/2020 20:17:48, average loss: 3.615395, time duration: 54.457510,
                            number of examples in current reporting: 600, step 2000
                            out of total 5000
saving through pytorch to ./abstemp/bert-base-uncased_step_2000.pt


Iteration:   9%|▊         | 2050/23936 [37:28<6:15:42,  1.03s/it] 

timestamp: 12/03/2020 20:18:44, average loss: 3.625652, time duration: 55.642923,
                            number of examples in current reporting: 600, step 2050
                            out of total 5000


Iteration:   9%|▉         | 2100/23936 [38:22<6:27:51,  1.07s/it]

timestamp: 12/03/2020 20:19:38, average loss: 3.567565, time duration: 53.814359,
                            number of examples in current reporting: 600, step 2100
                            out of total 5000


Iteration:   9%|▉         | 2150/23936 [39:15<6:20:39,  1.05s/it]

timestamp: 12/03/2020 20:20:31, average loss: 3.549629, time duration: 53.671999,
                            number of examples in current reporting: 600, step 2150
                            out of total 5000


Iteration:   9%|▉         | 2200/23936 [40:09<6:18:07,  1.04s/it]

timestamp: 12/03/2020 20:21:25, average loss: 3.477838, time duration: 53.802528,
                            number of examples in current reporting: 600, step 2200
                            out of total 5000


Iteration:   9%|▉         | 2250/23936 [41:02<6:23:33,  1.06s/it]

timestamp: 12/03/2020 20:22:18, average loss: 3.465800, time duration: 53.339881,
                            number of examples in current reporting: 600, step 2250
                            out of total 5000


Iteration:  10%|▉         | 2300/23936 [41:59<6:55:11,  1.15s/it]

timestamp: 12/03/2020 20:23:15, average loss: 3.448973, time duration: 56.716613,
                            number of examples in current reporting: 600, step 2300
                            out of total 5000


Iteration:  10%|▉         | 2350/23936 [42:53<6:55:53,  1.16s/it]

timestamp: 12/03/2020 20:24:09, average loss: 3.456282, time duration: 53.852615,
                            number of examples in current reporting: 599, step 2350
                            out of total 5000


Iteration:  10%|█         | 2400/23936 [43:47<6:27:31,  1.08s/it]

timestamp: 12/03/2020 20:25:03, average loss: 3.358254, time duration: 53.795152,
                            number of examples in current reporting: 600, step 2400
                            out of total 5000


Iteration:  10%|█         | 2450/23936 [44:40<6:22:10,  1.07s/it]

timestamp: 12/03/2020 20:25:56, average loss: 3.377144, time duration: 53.388796,
                            number of examples in current reporting: 600, step 2450
                            out of total 5000


Iteration:  10%|█         | 2499/23936 [45:33<6:19:52,  1.06s/it]

timestamp: 12/03/2020 20:26:50, average loss: 3.328608, time duration: 53.657009,
                            number of examples in current reporting: 600, step 2500
                            out of total 5000
saving through pytorch to ./abstemp/bert-base-uncased_step_2500.pt


Iteration:  11%|█         | 2550/23936 [46:29<6:16:20,  1.06s/it]

timestamp: 12/03/2020 20:27:45, average loss: 3.336079, time duration: 55.542649,
                            number of examples in current reporting: 600, step 2550
                            out of total 5000


Iteration:  11%|█         | 2600/23936 [47:23<6:15:09,  1.06s/it]

timestamp: 12/03/2020 20:28:39, average loss: 3.312251, time duration: 53.392276,
                            number of examples in current reporting: 600, step 2600
                            out of total 5000


Iteration:  11%|█         | 2650/23936 [48:17<6:37:09,  1.12s/it]

timestamp: 12/03/2020 20:29:34, average loss: 3.287620, time duration: 54.833326,
                            number of examples in current reporting: 599, step 2650
                            out of total 5000


Iteration:  11%|█▏        | 2700/23936 [49:11<6:14:52,  1.06s/it]

timestamp: 12/03/2020 20:30:27, average loss: 3.227116, time duration: 53.014668,
                            number of examples in current reporting: 599, step 2700
                            out of total 5000


Iteration:  11%|█▏        | 2750/23936 [50:04<6:12:22,  1.05s/it]

timestamp: 12/03/2020 20:31:20, average loss: 3.258723, time duration: 53.831981,
                            number of examples in current reporting: 600, step 2750
                            out of total 5000


Iteration:  12%|█▏        | 2800/23936 [50:58<6:12:23,  1.06s/it]

timestamp: 12/03/2020 20:32:14, average loss: 3.229686, time duration: 53.754699,
                            number of examples in current reporting: 600, step 2800
                            out of total 5000


Iteration:  12%|█▏        | 2850/23936 [51:52<6:12:07,  1.06s/it]

timestamp: 12/03/2020 20:33:08, average loss: 3.233732, time duration: 53.703091,
                            number of examples in current reporting: 600, step 2850
                            out of total 5000


Iteration:  12%|█▏        | 2900/23936 [52:45<6:04:37,  1.04s/it]

timestamp: 12/03/2020 20:34:01, average loss: 3.188855, time duration: 53.216950,
                            number of examples in current reporting: 600, step 2900
                            out of total 5000


Iteration:  12%|█▏        | 2950/23936 [53:39<6:13:50,  1.07s/it]

timestamp: 12/03/2020 20:34:55, average loss: 3.184502, time duration: 53.708161,
                            number of examples in current reporting: 600, step 2950
                            out of total 5000


Iteration:  13%|█▎        | 2999/23936 [54:33<6:26:09,  1.11s/it]

timestamp: 12/03/2020 20:35:50, average loss: 3.197579, time duration: 54.896690,
                            number of examples in current reporting: 600, step 3000
                            out of total 5000
saving through pytorch to ./abstemp/bert-base-uncased_step_3000.pt


Iteration:  13%|█▎        | 3050/23936 [55:30<6:06:13,  1.05s/it]

timestamp: 12/03/2020 20:36:46, average loss: 3.121317, time duration: 56.019515,
                            number of examples in current reporting: 599, step 3050
                            out of total 5000


Iteration:  13%|█▎        | 3100/23936 [56:23<6:00:29,  1.04s/it]

timestamp: 12/03/2020 20:37:39, average loss: 3.132497, time duration: 53.469911,
                            number of examples in current reporting: 600, step 3100
                            out of total 5000


Iteration:  13%|█▎        | 3150/23936 [57:17<6:02:27,  1.05s/it]

timestamp: 12/03/2020 20:38:33, average loss: 3.139356, time duration: 53.778406,
                            number of examples in current reporting: 600, step 3150
                            out of total 5000


Iteration:  13%|█▎        | 3200/23936 [58:10<6:07:40,  1.06s/it]

timestamp: 12/03/2020 20:39:26, average loss: 3.107268, time duration: 53.458483,
                            number of examples in current reporting: 599, step 3200
                            out of total 5000


Iteration:  14%|█▎        | 3250/23936 [59:05<6:37:17,  1.15s/it]

timestamp: 12/03/2020 20:40:21, average loss: 3.120350, time duration: 54.387663,
                            number of examples in current reporting: 599, step 3250
                            out of total 5000


Iteration:  14%|█▍        | 3300/23936 [59:58<6:10:54,  1.08s/it]

timestamp: 12/03/2020 20:41:14, average loss: 3.073890, time duration: 53.387729,
                            number of examples in current reporting: 600, step 3300
                            out of total 5000


Iteration:  14%|█▍        | 3350/23936 [1:00:51<5:59:38,  1.05s/it]

timestamp: 12/03/2020 20:42:07, average loss: 3.058117, time duration: 53.333031,
                            number of examples in current reporting: 600, step 3350
                            out of total 5000


Iteration:  14%|█▍        | 3400/23936 [1:01:45<5:58:15,  1.05s/it]

timestamp: 12/03/2020 20:43:01, average loss: 3.083077, time duration: 53.279642,
                            number of examples in current reporting: 600, step 3400
                            out of total 5000


Iteration:  14%|█▍        | 3450/23936 [1:02:39<6:05:38,  1.07s/it]

timestamp: 12/03/2020 20:43:55, average loss: 3.101894, time duration: 53.897809,
                            number of examples in current reporting: 600, step 3450
                            out of total 5000


Iteration:  15%|█▍        | 3499/23936 [1:03:31<5:56:32,  1.05s/it]

timestamp: 12/03/2020 20:44:48, average loss: 3.004254, time duration: 53.171986,
                            number of examples in current reporting: 599, step 3500
                            out of total 5000
saving through pytorch to ./abstemp/bert-base-uncased_step_3500.pt


Iteration:  15%|█▍        | 3550/23936 [1:04:28<6:14:45,  1.10s/it]

timestamp: 12/03/2020 20:45:44, average loss: 3.007772, time duration: 56.192442,
                            number of examples in current reporting: 600, step 3550
                            out of total 5000


Iteration:  15%|█▌        | 3600/23936 [1:05:22<5:54:56,  1.05s/it]

timestamp: 12/03/2020 20:46:38, average loss: 3.061955, time duration: 53.783488,
                            number of examples in current reporting: 599, step 3600
                            out of total 5000


Iteration:  15%|█▌        | 3650/23936 [1:06:15<5:56:18,  1.05s/it]

timestamp: 12/03/2020 20:47:31, average loss: 2.992658, time duration: 53.215498,
                            number of examples in current reporting: 599, step 3650
                            out of total 5000


Iteration:  15%|█▌        | 3700/23936 [1:07:09<5:54:31,  1.05s/it]

timestamp: 12/03/2020 20:48:25, average loss: 3.018115, time duration: 53.860734,
                            number of examples in current reporting: 600, step 3700
                            out of total 5000


Iteration:  16%|█▌        | 3750/23936 [1:08:02<5:58:46,  1.07s/it]

timestamp: 12/03/2020 20:49:18, average loss: 2.979420, time duration: 53.037118,
                            number of examples in current reporting: 600, step 3750
                            out of total 5000


Iteration:  16%|█▌        | 3800/23936 [1:08:55<5:48:58,  1.04s/it]

timestamp: 12/03/2020 20:50:11, average loss: 2.970507, time duration: 53.401205,
                            number of examples in current reporting: 599, step 3800
                            out of total 5000


Iteration:  16%|█▌        | 3850/23936 [1:09:50<5:59:32,  1.07s/it]

timestamp: 12/03/2020 20:51:06, average loss: 2.953462, time duration: 55.066902,
                            number of examples in current reporting: 600, step 3850
                            out of total 5000


Iteration:  16%|█▋        | 3900/23936 [1:10:44<5:53:00,  1.06s/it]

timestamp: 12/03/2020 20:52:00, average loss: 2.945157, time duration: 53.504895,
                            number of examples in current reporting: 600, step 3900
                            out of total 5000


Iteration:  17%|█▋        | 3950/23936 [1:11:37<5:55:05,  1.07s/it]

timestamp: 12/03/2020 20:52:53, average loss: 2.961220, time duration: 52.973685,
                            number of examples in current reporting: 599, step 3950
                            out of total 5000


Iteration:  17%|█▋        | 3999/23936 [1:12:30<5:45:03,  1.04s/it]

timestamp: 12/03/2020 20:53:47, average loss: 2.963094, time duration: 53.753345,
                            number of examples in current reporting: 599, step 4000
                            out of total 5000
saving through pytorch to ./abstemp/bert-base-uncased_step_4000.pt


Iteration:  17%|█▋        | 4050/23936 [1:13:26<5:41:32,  1.03s/it]

timestamp: 12/03/2020 20:54:42, average loss: 2.928608, time duration: 55.215695,
                            number of examples in current reporting: 599, step 4050
                            out of total 5000


Iteration:  17%|█▋        | 4100/23936 [1:14:19<5:38:56,  1.03s/it]

timestamp: 12/03/2020 20:55:35, average loss: 2.897866, time duration: 52.701019,
                            number of examples in current reporting: 600, step 4100
                            out of total 5000


Iteration:  17%|█▋        | 4150/23936 [1:15:13<6:02:51,  1.10s/it]

timestamp: 12/03/2020 20:56:29, average loss: 2.911174, time duration: 54.544709,
                            number of examples in current reporting: 599, step 4150
                            out of total 5000


Iteration:  18%|█▊        | 4200/23936 [1:16:06<5:48:39,  1.06s/it]

timestamp: 12/03/2020 20:57:22, average loss: 2.895758, time duration: 53.279250,
                            number of examples in current reporting: 600, step 4200
                            out of total 5000


Iteration:  18%|█▊        | 4250/23936 [1:17:00<5:36:37,  1.03s/it]

timestamp: 12/03/2020 20:58:16, average loss: 2.899472, time duration: 53.213565,
                            number of examples in current reporting: 600, step 4250
                            out of total 5000


Iteration:  18%|█▊        | 4300/23936 [1:17:53<5:45:51,  1.06s/it]

timestamp: 12/03/2020 20:59:09, average loss: 2.934093, time duration: 53.601735,
                            number of examples in current reporting: 599, step 4300
                            out of total 5000


Iteration:  18%|█▊        | 4350/23936 [1:18:46<5:43:48,  1.05s/it]

timestamp: 12/03/2020 21:00:02, average loss: 2.908369, time duration: 53.016642,
                            number of examples in current reporting: 599, step 4350
                            out of total 5000


Iteration:  18%|█▊        | 4400/23936 [1:19:40<5:41:22,  1.05s/it]

timestamp: 12/03/2020 21:00:56, average loss: 2.879586, time duration: 53.547388,
                            number of examples in current reporting: 600, step 4400
                            out of total 5000


Iteration:  19%|█▊        | 4450/23936 [1:20:34<5:38:56,  1.04s/it]

timestamp: 12/03/2020 21:01:50, average loss: 2.873984, time duration: 54.737275,
                            number of examples in current reporting: 600, step 4450
                            out of total 5000


Iteration:  19%|█▉        | 4499/23936 [1:21:27<5:46:48,  1.07s/it]

timestamp: 12/03/2020 21:02:44, average loss: 2.849939, time duration: 53.788721,
                            number of examples in current reporting: 600, step 4500
                            out of total 5000
saving through pytorch to ./abstemp/bert-base-uncased_step_4500.pt


Iteration:  19%|█▉        | 4550/23936 [1:22:24<5:43:32,  1.06s/it]

timestamp: 12/03/2020 21:03:40, average loss: 2.852282, time duration: 55.699511,
                            number of examples in current reporting: 600, step 4550
                            out of total 5000


Iteration:  19%|█▉        | 4600/23936 [1:23:17<5:37:00,  1.05s/it]

timestamp: 12/03/2020 21:04:33, average loss: 2.839527, time duration: 53.459990,
                            number of examples in current reporting: 600, step 4600
                            out of total 5000


Iteration:  19%|█▉        | 4650/23936 [1:24:11<5:34:58,  1.04s/it]

timestamp: 12/03/2020 21:05:27, average loss: 2.836208, time duration: 53.331228,
                            number of examples in current reporting: 599, step 4650
                            out of total 5000


Iteration:  20%|█▉        | 4700/23936 [1:25:05<5:37:45,  1.05s/it]

timestamp: 12/03/2020 21:06:21, average loss: 2.860971, time duration: 54.089519,
                            number of examples in current reporting: 600, step 4700
                            out of total 5000


Iteration:  20%|█▉        | 4750/23936 [1:25:59<6:29:01,  1.22s/it]

timestamp: 12/03/2020 21:07:15, average loss: 2.848305, time duration: 54.271892,
                            number of examples in current reporting: 599, step 4750
                            out of total 5000


Iteration:  20%|██        | 4800/23936 [1:26:53<5:39:29,  1.06s/it]

timestamp: 12/03/2020 21:08:09, average loss: 2.874445, time duration: 53.740144,
                            number of examples in current reporting: 600, step 4800
                            out of total 5000


Iteration:  20%|██        | 4850/23936 [1:27:46<5:37:59,  1.06s/it]

timestamp: 12/03/2020 21:09:02, average loss: 2.837490, time duration: 53.097977,
                            number of examples in current reporting: 600, step 4850
                            out of total 5000


Iteration:  20%|██        | 4900/23936 [1:28:39<5:32:21,  1.05s/it]

timestamp: 12/03/2020 21:09:55, average loss: 2.835973, time duration: 53.445394,
                            number of examples in current reporting: 600, step 4900
                            out of total 5000


Iteration:  21%|██        | 4950/23936 [1:29:32<5:27:38,  1.04s/it]

timestamp: 12/03/2020 21:10:49, average loss: 2.801602, time duration: 53.101142,
                            number of examples in current reporting: 600, step 4950
                            out of total 5000


Iteration:  21%|██        | 4999/23936 [1:30:25<5:29:50,  1.05s/it]

timestamp: 12/03/2020 21:11:42, average loss: 2.808340, time duration: 53.435037,
                            number of examples in current reporting: 600, step 5000
                            out of total 5000
saving through pytorch to ./abstemp/bert-base-uncased_step_5000.pt


Iteration:  21%|██        | 5000/23936 [1:30:28<8:28:23,  1.61s/it]


saving through pytorch to ./abstemp/fine_tuned/bertsumabs.pt


In [12]:
summarizer.save_model(MAX_STEPS, os.path.join(CACHE_PATH, "./bertsumabs.pt"))

saving through pytorch to ./abstemp/./bertsumabs.pt


## Model Evaluation

To run rouge evaluation, please refer to the section of compute_rouge_perl in [summarization_evaluation.ipynb](summarization_evaluation.ipynb) for setup.
For the settings in this notebook with QUICK_RUN=False, you should get ROUGE scores close to the following numbers: <br />
``
{'rouge-1': {'f': 0.34819639878321873,
             'p': 0.39977932634737307,
             'r': 0.34429079596863604},
 'rouge-2': {'f': 0.13919271352557894,
             'p': 0.16129965067780644,
             'r': 0.1372938054050938},
 'rouge-l': {'f': 0.2313282318854973,
             'p': 0.26664667422849747,
             'r': 0.22850294283399628}}
 ``
 
 Better performance can be achieved by increasing the MAX_STEPS.

In [19]:
# del summarizer
checkpoint = torch.load("./abstemp/./bertsumabs.pt")
summarizer = BertSumAbs(
    processor, cache_dir=CACHE_PATH, max_pos_length=MAX_POS
)
summarizer.model.load_checkpoint(checkpoint['model'])

In [16]:
# clear cache
import gc; gc.collect()
torch.cuda.empty_cache()

In [27]:
len(train_dataset)

287227

In [25]:
QUICK_RUN
# TOP_N = 32
if not QUICK_RUN:
    TOP_N = len(test_dataset)
print(TOP_N)

32


In [35]:
TEST_TOP_N = 32
if not QUICK_RUN:
    TEST_TOP_N = len(test_dataset)
    
shortened_dataset= test_dataset.shorten(top_n=TEST_TOP_N)
src = shortened_dataset.get_source()
reference_summaries = [" ".join(t).rstrip("\n") for t in shortened_dataset.get_target()]
generated_summaries = summarizer.predict(
    shortened_dataset, batch_size=32*4, num_gpus=NUM_GPUS
)
assert len(generated_summaries) == len(reference_summaries)


Generating summary:   0%|          | 0/90 [00:00<?, ?it/s]

dataset length is 11490


Generating summary: 100%|██████████| 90/90 [3:26:35<00:00, 128.17s/it]  


In [37]:
src[0]

['marseille , france ( cnn ) the french prosecutor leading an investigation into the crash of germanwings flight 9525 insisted wednesday that he was not aware of any video footage from on board the plane .',
 'marseille prosecutor brice robin told cnn that " so far no videos were used in the crash investigation . "',
 'he added , " a person who has such a video needs to immediately give it to the investigators . "',
 "robin 's comments follow claims by two magazines , german daily bild and french paris match , of a cell phone video showing the harrowing final seconds from on board germanwings flight 9525 as it crashed into the french alps .",
 'all 150 on board were killed .',
 'paris match and bild reported that the video was recovered from a phone at the wreckage site .',
 'the two publications described the supposed video , but did not post it on their websites .',
 'the publications said that they watched the video , which was found by a source close to the investigation . "',
 'on

In [21]:
generated_summaries[0]

'french prosecutor : " so far no videos were used in the crash of germanwings flight 95fk .   new : " we can hear cries of my god \' in several languages , "   french lawyer says he was not aware of any video footage from the crash , " the report says .'

In [38]:
reference_summaries[0]

' marseille prosecutor says " so far no videos were used in the crash investigation " despite media reports .    journalists at bild and paris match are " very confident " the video clip is real , an editor says .    andreas lubitz had informed his lufthansa training school of an episode of severe depression , airline says .  '

In [None]:
rouge_scores = compute_rouge_python(cand=generated_summaries, ref=reference_summaries)


In [42]:
pprint.pprint(rouge_scores)

{'rouge-1': {'f': 0.34819639878321873,
             'p': 0.39977932634737307,
             'r': 0.34429079596863604},
 'rouge-2': {'f': 0.13919271352557894,
             'p': 0.16129965067780644,
             'r': 0.1372938054050938},
 'rouge-l': {'f': 0.2313282318854973,
             'p': 0.26664667422849747,
             'r': 0.22850294283399628}}


In [43]:
# for testing
sb.glue("rouge_2_f_score", rouge_scores['rouge-2']['f'])

## Clean up temporary folders

In [None]:
if os.path.exists(DATA_PATH):
    shutil.rmtree(DATA_PATH, ignore_errors=True)
if os.path.exists(CACHE_PATH):
    shutil.rmtree(CACHE_PATH, ignore_errors=True)