Copyright (c) Microsoft Corporation.
Licensed under the MIT License.

# Abstractive Summarization using BertSumAbs on CNN/DailyMails Dataset

## Summary

This notebook demonstrates how to fine tune BERT for abstractive text summarization. Utility functions and classes in the NLP Best Practices repo are used to facilitate data preprocessing, model training, model scoring, result postprocessing, and model evaluation.

### Abstractive Summarization
Abstractive summarization is the task of taking an input text and summarizing its content in a shorter output text. In contrast to extractive summarization, abstractive summarization doesn't take sentences directly from the input text, instead, rephrases the input text.

### BertSumAbs

BertSumAbs refers to an BERT-based abstractive summarization algorithm  in [Text Summarization with Pretrained Encoders](https://arxiv.org/abs/1908.08345) with [published examples](https://github.com/nlpyang/PreSumm). It uses the pretrained BERT model as encoder and finetune both encoder and decoder on a specific labeled summarization dataset like [CNN/DM dataset](https://github.com/harvardnlp/sent-summary). 

The figure below shows the comparison of architecture of the original BERT model (left) and BERTSUM (right), which BertSumAbs is built upon. For BERTSUM, a input document is split into sentences, and [CLS] and [SEP] tokens are inserted before and after each sentence. This resulting sequence is followed by the summation of three kinds of embeddings for each token before feeding into the transformer layers. The positional embedding used in BertSumAbs enables input length of more than 512, which is the  maximum input length for BERT model. 

It should be noted that the architecture only shows the encoder part. For decoder, BertSumAbs also uses a transformer with multiple layers and random initialization. As pretrained weights are used in the encoder, there is a mismatch in encoder and decoder which may result in unstable finetuning. Therefore, in fine tuning, BertSumAbs uses seperate optimizers for encoder and decoder, each uses its own scheduling. In text generation, techniques like trigram blocking and beam search can be used to improve model accuarcy.
<img src="https://nlpbp.blob.core.windows.net/images/BertForSummarization.PNG">


## Before you start
Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of steps. If QUICK_RUN = False, the notebook takes about 5 hours to run on a VM with 4 16GB NVIDIA V100 GPUs.

To run rouge evaluation, please refer to the section of compute_rouge_perl in [summarization_evaluation.ipynb](./summarization_evaluation.ipynb).

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
QUICK_RUN = True

In [3]:
import os
import shutil
import sys
from tempfile import TemporaryDirectory
import torch

nlp_path = os.path.abspath("../../")
if nlp_path not in sys.path:
    sys.path.insert(0, nlp_path)

from utils_nlp.models.transformers.abstractive_summarization_bertsum import BertSumAbs, BertSumAbsProcessor, shorten_dataset

from utils_nlp.dataset.cnndm import CNNDMSummarizationDataset
from utils_nlp.models.transformers.datasets import SummarizationNonIterableDataset
from utils_nlp.eval.evaluate_summarization import get_rouge

import pandas as pd
import scrapbook as sb

## Data Preprocessing

The dataset we used for this notebook is CNN/DM dataset which contains the documents and accompanying questions from the news articles of CNN and Daily mail. The highlights in each article are used as summary. The dataset consits of ~289K training examples, ~11K valiation examples and ~11K test examples. The length of the news articles is 781 tokens on average and the summaries are of 3.75 sentences and 56 tokens on average.

The significant part of data preprocessing only involve splitting the input document into sentences.

In [4]:
#DATA_PATH

In [5]:
# the data path used to save the downloaded data file
DATA_PATH =  '/tmp/tmpf17u9ovc' #TemporaryDirectory().name
# The number of lines at the head of data file used for preprocessing. -1 means all the lines.
TOP_N = 1000
if not QUICK_RUN:
    TOP_N = -1

In [6]:
train_dataset, test_dataset = CNNDMSummarizationDataset(
            top_n=TOP_N, local_cache_path=DATA_PATH, prepare_extractive=False
        )
source = [x[0] for x in list(test_dataset.get_source())]
target = [x[0] for x in list(test_dataset.get_target())]
test_sum_dataset = SummarizationNonIterableDataset(source, target)

source = [x[0] for x in list(train_dataset.get_source())]
target = [x[0] for x in list(train_dataset.get_target())]
train_sum_dataset = SummarizationNonIterableDataset(source, target)

In [7]:
train_sum_dataset.source[0]

["editor 's note : in our behind the scenes series , cnn correspondents share their experiences in covering news and analyze the stories behind the events .",
 "here , soledad o'brien takes users inside a jail where many of the inmates are mentally ill .",
 'an inmate housed on the " forgotten floor , " where many mentally ill inmates are housed in miami before trial .',
 'miami , florida ( cnn ) -- the ninth floor of the miami-dade pretrial detention facility is dubbed the " forgotten floor . "',
 "here , inmates with the most severe mental illnesses are incarcerated until they 're ready to appear in court .",
 'most often , they face drug charges or charges of assaulting an officer -- charges that judge steven leifman says are usually " avoidable felonies . "',
 'he says the arrests often result from confrontations with police .',
 "mentally ill people often wo n't do what they 're told when police arrive on the scene -- confrontation seems to exacerbate their illness and they become

In [8]:
train_sum_dataset.target[0]

[' mentally ill inmates in miami are housed on the " forgotten floor " ',
 '  judge steven leifman says most are there as a result of " avoidable felonies " ',
 '  while cnn tours facility , patient shouts : " i am the son of the president " ',
 "  leifman says the system is unjust and he 's fighting for change . ",
 '\n']

## Model Finetuning

In [9]:
# notebook parameters
# the cache path
CACHE_PATH = TemporaryDirectory().name

# model parameters
MODEL_NAME = "bert-base-uncased"
MAX_POS = 768
MAX_SOURCE_SEQ_LENGTH = 640
MAX_TARGET_SEQ_LENGTH = 140

# mixed precision setting. To enable mixed precision training, follow instructions in SETUP.md. 
FP16 = False
if FP16:
    FP16_OPT_LEVEL="O2"
    
# fine-tuning parameters
# batch size, unit is the number of tokens
BATCH_SIZE_PER_GPU = 3


# GPU used for training
NUM_GPUS = 2

# Learning rate
LEARNING_RATE_BERT=2e-3
LEARNING_RATE_DEC=0.2


# How often the statistics reports show up in training, unit is step.
REPORT_EVERY=10
SAVE_EVERY=500

# total number of steps for training
MAX_STEPS=5e4
WARMUP_STEPS_BERT=20000
WARMUP_STEPS_DEC=10000
    
if QUICK_RUN:
    MAX_STEPS=1e2

   


In [10]:
# processor which contains the colloate function to load the preprocessed data
processor = BertSumAbsProcessor(cache_dir=CACHE_PATH, max_src_len=MAX_SOURCE_SEQ_LENGTH, max_target_len=MAX_TARGET_SEQ_LENGTH)
# summarizer
summarizer = BertSumAbs(
    processor, cache_dir=CACHE_PATH, max_pos_length=MAX_POS
)

In [11]:
BATCH_SIZE_PER_GPU*NUM_GPUS

6

In [12]:

summarizer.fit(
        train_sum_dataset,
        num_gpus=NUM_GPUS,
        batch_size=BATCH_SIZE_PER_GPU*NUM_GPUS,
        max_steps=MAX_STEPS,
        learning_rate_bert=LEARNING_RATE_BERT,
        learning_rate_dec=LEARNING_RATE_DEC,
        warmup_steps_bert=WARMUP_STEPS_BERT,
        warmup_steps_dec=WARMUP_STEPS_DEC,
        save_every=SAVE_EVERY,
        report_every=REPORT_EVERY,
        fp16=FP16,
)


'\nsummarizer.fit(\n        train_sum_dataset,\n        num_gpus=NUM_GPUS,\n        batch_size=BATCH_SIZE_PER_GPU*NUM_GPUS,\n        max_steps=MAX_STEPS,\n        learning_rate_bert=LEARNING_RATE_BERT,\n        learning_rate_dec=LEARNING_RATE_DEC,\n        warmup_steps_bert=WARMUP_STEPS_BERT,\n        warmup_steps_dec=WARMUP_STEPS_DEC,\n        save_every=SAVE_EVERY,\n        report_every=REPORT_EVERY,\n        fp16=FP16,\n)\n'

In [13]:
summarizer.save_model(MAX_STEPS, os.path.join(CACHE_PATH, "./bertsumabs.pt"))

## Model Evaluation

In [14]:
# del summarizer
# checkpoint = torch.load("./abstemp/bert-base-uncased_step_1000.pt")
# summarizer.model.load_checkpoint(checkpoint['model'])

In [15]:
# clear cache
import gc; gc.collect()
torch.cuda.empty_cache()

In [16]:
TOP_N=128
src = test_sum_dataset.source[0:TOP_N]
reference_summaries = [" ".join(t).rstrip("\n") for t in test_sum_dataset.target[0:TOP_N]]
generated_summaries = summarizer.predict(
    shorten_dataset(test_sum_dataset, top_n=TOP_N), batch_size=32, num_gpus=NUM_GPUS, max_seq_length=MAX_POS
)
assert len(generated_summaries) == len(reference_summaries)


Generating summary:   0%|          | 0/4 [00:00<?, ?it/s]

dataset length is 128
device is cuda:0


Generating summary:  25%|██▌       | 1/4 [00:44<02:13, 44.34s/it]

device is cuda:0


Generating summary:  50%|█████     | 2/4 [01:27<01:27, 43.97s/it]

device is cuda:0


Generating summary:  75%|███████▌  | 3/4 [02:09<00:43, 43.50s/it]

device is cuda:0


Generating summary: 100%|██████████| 4/4 [02:52<00:00, 43.20s/it]


In [17]:
src[0]

['marseille , france ( cnn ) the french prosecutor leading an investigation into the crash of germanwings flight 9525 insisted wednesday that he was not aware of any video footage from on board the plane .',
 'marseille prosecutor brice robin told cnn that " so far no videos were used in the crash investigation . "',
 'he added , " a person who has such a video needs to immediately give it to the investigators . "',
 "robin 's comments follow claims by two magazines , german daily bild and french paris match , of a cell phone video showing the harrowing final seconds from on board germanwings flight 9525 as it crashed into the french alps .",
 'all 150 on board were killed .',
 'paris match and bild reported that the video was recovered from a phone at the wreckage site .',
 'the two publications described the supposed video , but did not post it on their websites .',
 'the publications said that they watched the video , which was found by a source close to the investigation . "',
 'on

In [18]:
generated_summaries[1]

"new : u . s . justice minister says lebanon ' s .   u . n ' t have been used to be used to step down the country ' t want to step step down ."

In [19]:
reference_summaries[1]

' membership gives the icc jurisdiction over alleged crimes committed in palestinian territories since last june .    israel and the united states opposed the move , which could open the door to war crimes investigations against israelis .  '

In [20]:
RESULT_DIR = TemporaryDirectory().name
rouge_score = get_rouge(generated_summaries, reference_summaries, RESULT_DIR)
print(rouge_score)

2020-03-04 21:54:28,482 [MainThread  ] [INFO ]  Writing summaries.
INFO:global:Writing summaries.
2020-03-04 21:54:28,484 [MainThread  ] [INFO ]  Processing summaries. Saving system files to /tmp/tmpmgnbbu7h/tmp54vpunvl/system and model files to /tmp/tmpmgnbbu7h/tmp54vpunvl/model.
INFO:global:Processing summaries. Saving system files to /tmp/tmpmgnbbu7h/tmp54vpunvl/system and model files to /tmp/tmpmgnbbu7h/tmp54vpunvl/model.
2020-03-04 21:54:28,485 [MainThread  ] [INFO ]  Processing files in /tmp/tmpmgnbbu7h/rouge-tmp-2020-03-04-21-54-28/candidate/.
INFO:global:Processing files in /tmp/tmpmgnbbu7h/rouge-tmp-2020-03-04-21-54-28/candidate/.
2020-03-04 21:54:28,498 [MainThread  ] [INFO ]  Saved processed files to /tmp/tmpmgnbbu7h/tmp54vpunvl/system.
INFO:global:Saved processed files to /tmp/tmpmgnbbu7h/tmp54vpunvl/system.
2020-03-04 21:54:28,499 [MainThread  ] [INFO ]  Processing files in /tmp/tmpmgnbbu7h/rouge-tmp-2020-03-04-21-54-28/reference/.
INFO:global:Processing files in /tmp/tmpm

128
128
---------------------------------------------
1 ROUGE-1 Average_R: 0.14106 (95%-conf.int. 0.12835 - 0.15425)
1 ROUGE-1 Average_P: 0.16039 (95%-conf.int. 0.14419 - 0.17604)
1 ROUGE-1 Average_F: 0.14274 (95%-conf.int. 0.13074 - 0.15444)
---------------------------------------------
1 ROUGE-2 Average_R: 0.01723 (95%-conf.int. 0.01253 - 0.02174)
1 ROUGE-2 Average_P: 0.02043 (95%-conf.int. 0.01519 - 0.02617)
1 ROUGE-2 Average_F: 0.01780 (95%-conf.int. 0.01340 - 0.02242)
---------------------------------------------
1 ROUGE-L Average_R: 0.10427 (95%-conf.int. 0.09507 - 0.11371)
1 ROUGE-L Average_P: 0.11909 (95%-conf.int. 0.10773 - 0.12984)
1 ROUGE-L Average_F: 0.10584 (95%-conf.int. 0.09711 - 0.11484)

{'rouge_1_recall': 0.14106, 'rouge_1_recall_cb': 0.12835, 'rouge_1_recall_ce': 0.15425, 'rouge_1_precision': 0.16039, 'rouge_1_precision_cb': 0.14419, 'rouge_1_precision_ce': 0.17604, 'rouge_1_f_score': 0.14274, 'rouge_1_f_score_cb': 0.13074, 'rouge_1_f_score_ce': 0.15444, 'rouge_2_rec