Copyright (c) Microsoft Corporation. All rights reserved.

Licensed under the MIT License.

## Abstractive Summarization on CNN/DM Dataset using Transformers


### Summary

This notebook demonstrates how to fine tune Transformers for extractive text summarization. Utility functions and classes in the NLP Best Practices repo are used to facilitate data preprocessing, model training, model scoring, result postprocessing, and model evaluation.




### Before You Start

The running time shown in this notebook is on a Standard_NC24s_v3 Azure Ubuntu Virtual Machine with 4 NVIDIA Tesla V100 GPUs. 
> **Tip**: If you want to run through the notebook quickly, you can set the **`QUICK_RUN`** flag in the cell below to **`True`** to run the notebook on a small subset of the data and a smaller number of epochs. 

Using only 1 NVIDIA Tesla V100 GPUs, 16GB GPU memory configuration,
- for data preprocessing, it takes around 1 minutes to preprocess the data for quick run. Otherwise it takes ~20 minutes to finish the data preprocessing. This time estimation assumes that the chosen transformer model is "distilbert-base-uncased" and the sentence selection method is "greedy", which is the default. The preprocessing time can be significantly longer if the sentence selection method is "combination", which can achieve better model performance.

- for model fine tuning, it takes around 2 minutes for quick run. Otherwise, it takes around ~3 hours to finish. This estimation assumes the chosen encoder method is "transformer". The model fine tuning time can be shorter if other encoder method is chosen, which may result in worse model performance. 

### Additional Notes

* **ROUGE Evalation**: To run rouge evaluation, please refer to the section of compute_rouge_perl in [summarization_evaluation.ipynb](./summarization_evaluation.ipynb) for setup.

* **Distributed Training**:
Please note that the jupyter notebook only allows to use pytorch [DataParallel](https://pytorch.org/docs/master/nn.html#dataparallel). Faster speed and larger batch size can be achieved with pytorch [DistributedDataParallel](https://pytorch.org/docs/master/notes/ddp.html)(DDP). Script [extractive_summarization_cnndm_distributed_train.py](./extractive_summarization_cnndm_distributed_train.py) shows an example of how to use DDP.



In [47]:
%load_ext autoreload

%autoreload 2
## Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of epochs.
QUICK_RUN = False

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Configuration


In [2]:
import os
import shutil
import sys
from tempfile import TemporaryDirectory
import torch

nlp_path = os.path.abspath("../../")
if nlp_path not in sys.path:
    sys.path.insert(0, nlp_path)

from utils_nlp.dataset.cnndm import CNNDMBertSumProcessedData, CNNDMSummarizationDataset
from utils_nlp.eval import compute_rouge_python, compute_rouge_perl
from utils_nlp.models.transformers.abstractive_summarization_bartt5 import (
    AbstractiveSummarizer, SummarizationProcessor, validate)

from utils_nlp.models.transformers.datasets import SummarizationDataset
import nltk
from nltk import tokenize

import pandas as pd
import scrapbook as sb
import pprint

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])



### Configuration: choose the transformer model to be used

Several pretrained models have been made available by [Hugging Face](https://github.com/huggingface/transformers). For extractive summarization, the following pretrained models are supported. 

In [3]:
#pd.DataFrame({"model_name": ExtractiveSummarizer.list_supported_models()})

In [28]:
# Transformer model being used
#MODEL_NAME = "t5-large"
MODEL_NAME = "bart-large-cnn"
# notebook parameters
# the cache data path during find tuning
CACHE_DIR = "./bart_cache" #TemporaryDirectory().name
summarizer = AbstractiveSummarizer(MODEL_NAME, cache_dir=CACHE_DIR)

In [27]:
# bart-large
summarizer.tokenizer.tokenize("hello frenchsdfa ")

['hello', 'Ġfrench', 's', 'df', 'a']

In [35]:
summarizer.tokenizer.tokenize("hello nlp amazon china")

['hello', 'Ġn', 'lp', 'Ġam', 'azon', 'Ġch', 'ina']

In [32]:
summarizer.tokenizer

<transformers.tokenization_bart.BartTokenizer at 0x7f0e9a165a90>

In [4]:
summarizer.model

BartForConditionalGeneration(
  (model): BartModel(
    (shared): Embedding(50265, 1024, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(50265, 1024, padding_idx=1)
      (embed_positions): LearnedPositionalEmbedding(1026, 1024, padding_idx=1)
      (layers): ModuleList(
        (0): EncoderLayer(
          (self_attn): SelfAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=Tr

In [11]:
summarizer.config

BartConfig {
  "_num_labels": 3,
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_final_layer_norm": false,
  "architectures": [
    "BartModel",
    "BartForMaskedLM",
    "BartForSequenceClassification"
  ],
  "attention_dropout": 0.0,
  "bad_words_ids": null,
  "bos_token_id": 0,
  "classif_dropout": 0.0,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 12,
  "decoder_start_token_id": 2,
  "do_sample": false,
  "dropout": 0.1,
  "early_stopping": false,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 4096,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 12,
  "eos_token_id": 2,
  "finetuning_task": null,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_decoder": false,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "length_penalty": 1.0,
  "max_length": 20,
  "max_po

In [5]:
# Transformer model being used
#MODEL_NAME = "t5-large"
MODEL_NAME = "bart-large-cnn"
# notebook parameters
# the cache data path during find tuning
CACHE_DIR = "./bart_cache" #TemporaryDirectory().name
summarizer = AbstractiveSummarizer(MODEL_NAME, cache_dir=CACHE_DIR)
summarizer.model

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1300.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1625270765.0, style=ProgressStyle(descr…




BartForConditionalGeneration(
  (model): BartModel(
    (shared): Embedding(50264, 1024, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(50264, 1024, padding_idx=1)
      (embed_positions): LearnedPositionalEmbedding(1026, 1024, padding_idx=1)
      (layers): ModuleList(
        (0): EncoderLayer(
          (self_attn): SelfAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=Tr

In [9]:
summarizer.config

BartConfig {
  "_num_labels": 3,
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_final_layer_norm": false,
  "architectures": null,
  "attention_dropout": 0.0,
  "bad_words_ids": null,
  "bos_token_id": 0,
  "classif_dropout": 0.0,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 12,
  "decoder_start_token_id": 2,
  "do_sample": false,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 4096,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 12,
  "eos_token_id": 2,
  "finetuning_task": null,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_decoder": false,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "length_penalty": 2.0,
  "max_length": 142,
  "max_position_embeddings": 1024,
  "min_length": 56,
  "model_type": "bart",
  "no_r

In [19]:
task_specific_params = summarizer.config.task_specific_params

In [20]:
task_specific_params

{'summarization': {'early_stopping': True,
  'length_penalty': 2.0,
  'max_length': 142,
  'min_length': 56,
  'no_repeat_ngram_size': 3,
  'num_beams': 4}}

In [None]:
tokens = bart.encode('Hello world!')
assert tokens.tolist() == [0, 31414, 232, 328, 2]
bart.decode(tokens)  # 'Hello world!

In [44]:
summarizer.tokenizer.encode('Hello frech!')

[0, 20920, 7619, 611, 328, 2]

In [46]:
summarizer.tokenizer.decode([0, 20920, 7619, 611, 328, 2], skip_special_tokens=True, clean_up_tokenization_spaces=True)

' Hello frech!'

In [36]:
summarizer.tokenizer.decode([0, 31414, 232, 328, 2], skip_special_tokens=True, clean_up_tokenization_spaces=True)

'Hello world!'

In [18]:
summarizer.config.update(task_specific_params.get("summarization", {}))

In [23]:
summarizer.config

BartConfig {
  "_num_labels": 3,
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_final_layer_norm": false,
  "architectures": [
    "BartModel",
    "BartForMaskedLM",
    "BartForSequenceClassification"
  ],
  "attention_dropout": 0.0,
  "bad_words_ids": null,
  "bos_token_id": 0,
  "classif_dropout": 0.0,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 12,
  "decoder_start_token_id": 2,
  "do_sample": false,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 4096,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 12,
  "eos_token_id": 2,
  "finetuning_task": null,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_decoder": false,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "length_penalty": 2.0,
  "max_length": 142,
  "max_po

In [21]:
task_specific_params.get("summarization", {})

{'early_stopping': True,
 'length_penalty': 2.0,
 'max_length': 142,
 'min_length': 56,
 'no_repeat_ngram_size': 3,
 'num_beams': 4}

In [25]:
summarizer.config.update({"vocab_size": 50264})

### Data Preprocessing

The dataset we used for this notebook is CNN/DM dataset which contains the documents and accompanying questions from the news articles of CNN and Daily mail. The highlights in each article are used as summary. The dataset consits of ~289K training examples, ~11K valiation examples and ~11K test examples.   The code in following cell will download the CNN/DM dataset listed at https://github.com/harvardnlp/sent-summary/.


In [49]:
# the data path used to save the downloaded data file
DATA_PATH = "./bartt5_cnndm" #TemporaryDirectory().name
# The number of lines at the head of data file used for preprocessing. -1 means all the lines.
TOP_N = 1000
if not QUICK_RUN:
    TOP_N = -1

In [50]:
train_dataset, test_dataset = CNNDMSummarizationDataset(top_n=TOP_N, local_cache_path=DATA_PATH, raw=True)

In [14]:
test_dataset[0]['tgt_txt']

"<t> marseille prosecutor says `` so far no videos were used in the crash investigation '' despite media reports . </t> <t> journalists at bild and paris match are `` very confident '' the video clip is real , an editor says . </t> <t> andreas lubitz had informed his lufthansa training school of an episode of severe depression , airline says . </t>\n"

In [10]:
len(test_dataset)

11490

Preprocess the data.

In [11]:
%time
abs_sum_train = summarizer.processor.preprocess(train_dataset)


CPU times: user 5 µs, sys: 2 µs, total: 7 µs
Wall time: 20.3 µs


In [14]:
# torch.save(abs_sum_train,  os.path.join(DATA_PATH, "train_{0}_full.pt".format(MODEL_NAME)))

In [16]:
# torch.save(abs_sum_test,  os.path.join(DATA_PATH, "test_{0}_full.pt".format(MODEL_NAME)))

In [51]:
abs_sum_test = summarizer.processor.preprocess(test_dataset)

In [58]:
a = summarizer.processor.collate_fn(abs_sum_test[0:2], "cuda:0", True)
c = summarizer.processor.get_inputs(a, "cuda:0", MODEL_NAME, summarizer.tokenizer, True)
print(c)

{'input_ids': tensor([[   0, 4401, 1090,  ...,  167, 1081,    2],
        [   0,   36,  740,  ...,    1,    1,    1]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0]], device='cuda:0'), 'decoder_input_ids': tensor([[    0,  4401,  1090,  4061,  5644,   161,    22,    98,   444,   117,
          3424,    58,   341,    11,     5,  2058,   803,    22,  1135,   433,
           690,   479,  1437,  1437,  4225,    23,   741,  9683,     8,  2242,
           354,   914,    32,    22,   182,  3230,    22,     5,   569,  7200,
            16,   588,  2156,    41,  4474,   161,   479,  1437,  1437,     8,
           241,   281,   784,  1792,  4494,    56,  3978,    39,   784,  2951,
           212,  1253,   102,  1058,   334,     9,    41,  3238,     9,  3814,
          6943,  2156,  5195,   161,   479],
        [    0,  6332,  2029,     5, 41591,   438, 10542,    81,  1697,  3474,
          2021,    11,  8750,   990, 28307, 13560,   187,   

In [6]:
abs_sum_train = torch.load( os.path.join(DATA_PATH, "train_{0}_full.pt".format(MODEL_NAME)))

In [13]:
prediction = summarizer.predict(abs_sum_test[0:8], num_gpus=1, batch_size=8)

Generating summary:   0%|          | 0/1 [00:00<?, ?it/s]

dataset length is 8


Generating summary: 100%|██████████| 1/1 [00:09<00:00,  9.23s/it]


In [14]:
print(abs_sum_test[0])

{'src': "marseille , france -lrb- cnn -rrb- the french prosecutor leading an investigation into the crash of germanwings flight 9525 insisted wednesday that he was not aware of any video footage from on board the plane . marseille prosecutor brice robin told cnn that `` so far no videos were used in the crash investigation . '' he added , `` a person who has such a video needs to immediately give it to the investigators . '' robin 's comments follow claims by two magazines , german daily bild and french paris match , of a cell phone video showing the harrowing final seconds from on board germanwings flight 9525 as it crashed into the french alps . all 150 on board were killed . paris match and bild reported that the video was recovered from a phone at the wreckage site . the two publications described the supposed video , but did not post it on their websites . the publications said that they watched the video , which was found by a source close to the investigation . `` one can hear c

In [15]:
print(prediction[0])

nnnnennennsnnnnnnnnsnnennennennnnnnennsnnnnnnnnnnnnsnnnnnnnnnnnnnnnnennennnnnnnnnnnnnnnnennnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnennnnnnnnnnnnnnnnennnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnennnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnn


In [17]:
summarizer.model.module

BartForConditionalGeneration(
  (model): BartModel(
    (shared): Embedding(50265, 1024, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(50265, 1024, padding_idx=1)
      (embed_positions): LearnedPositionalEmbedding(1026, 1024, padding_idx=1)
      (layers): ModuleList(
        (0): EncoderLayer(
          (self_attn): SelfAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=Tr

In [8]:
"""
# save and load preprocessed data
save_path = DATA_PATH
torch.save(abs_sum_train, os.path.join(save_path, "train_full.pt"))
torch.save(abs_sum_test, os.path.join(DATA_PATH, "test_full.pt"))

"""
save_path = DATA_PATH
#abs_sum_train = torch.load(os.path.join(save_path, "train_full.pt"))
abs_sum_test = torch.load(os.path.join(save_path, "test_full.pt"))

In [9]:
print(len(abs_sum_train))
print(len(abs_sum_test))

287227
11490


In [13]:
#save_path = os.path.join(DATA_PATH, "processed")
#torch.save(abs_sum_train, os.path.join(save_path, "train_full.pt"))
#torch.save(abs_sum_test, os.path.join(DATA_PATH, "test_full.pt"))

#### Inspect Data

In [None]:
abs_sum_train[0].keys()

In [None]:
abs_sum_train[0]

### Model training
To start model training, we need to create a instance of ExtractiveSummarizer.
#### Choose the transformer model.
Currently ExtractiveSummarizer support two models:
- distilbert-base-uncase, 
- bert-base-uncase

Potentionally, roberta-based model and xlnet can be supported but needs to be tested.
#### Choose the encoder algorithm.
There are four options:
- baseline: it used a smaller transformer model to replace the bert model and with transformer summarization layer
- classifier: it uses pretrained BERT and fine-tune BERT with **simple logistic classification** summarization layer
- transformer: it uses pretrained BERT and fine-tune BERT with **transformer** summarization layer
- RNN: it uses pretrained BERT and fine-tune BERT with **LSTM** summarization layer

In [8]:
BATCH_SIZE = 8 # batch size, unit is the number of samples
MAX_POS_LENGTH = 512

    


# GPU used for training
NUM_GPUS = torch.cuda.device_count()


# Learning rate
LEARNING_RATE=3e-5

# How often the statistics reports show up in training, unit is step.
REPORT_EVERY=100

# total number of steps for training
MAX_STEPS=1e2
# number of steps for warm up
WARMUP_STEPS=5e2
    
if not QUICK_RUN:
    MAX_STEPS=5e3
    WARMUP_STEPS=5e2
 

In [9]:
#"""

summarizer.fit(
            abs_sum_train,
            num_gpus=NUM_GPUS,
            batch_size=BATCH_SIZE,
            gradient_accumulation_steps=1,
            max_steps=MAX_STEPS,
            learning_rate=LEARNING_RATE,
            warmup_steps=WARMUP_STEPS,
            verbose=True,
            report_every=REPORT_EVERY,
        )

#"""


Iteration:   0%|          | 100/35904 [02:09<11:55:36,  1.20s/it]

timestamp: 06/05/2020 20:53:58, average loss: 1.867461, time duration: 129.808588,
                            number of examples in current reporting: 800, step 100
                            out of total 5000


Iteration:   1%|          | 200/35904 [04:11<11:53:56,  1.20s/it]

timestamp: 06/05/2020 20:55:59, average loss: 1.362968, time duration: 121.602895,
                            number of examples in current reporting: 800, step 200
                            out of total 5000


Iteration:   1%|          | 300/35904 [06:12<11:45:45,  1.19s/it]

timestamp: 06/05/2020 20:58:01, average loss: 1.332468, time duration: 121.472843,
                            number of examples in current reporting: 800, step 300
                            out of total 5000


Iteration:   1%|          | 400/35904 [08:14<11:44:38,  1.19s/it]

timestamp: 06/05/2020 21:00:03, average loss: 1.358131, time duration: 121.903810,
                            number of examples in current reporting: 800, step 400
                            out of total 5000


Iteration:   1%|▏         | 500/35904 [10:16<12:11:08,  1.24s/it]

timestamp: 06/05/2020 21:02:04, average loss: 1.336675, time duration: 121.732681,
                            number of examples in current reporting: 800, step 500
                            out of total 5000


Iteration:   2%|▏         | 600/35904 [12:15<11:41:57,  1.19s/it]

timestamp: 06/05/2020 21:04:03, average loss: 1.335859, time duration: 118.979347,
                            number of examples in current reporting: 800, step 600
                            out of total 5000


Iteration:   2%|▏         | 700/35904 [14:17<11:33:50,  1.18s/it]

timestamp: 06/05/2020 21:06:05, average loss: 1.284556, time duration: 122.029723,
                            number of examples in current reporting: 800, step 700
                            out of total 5000


Iteration:   2%|▏         | 800/35904 [16:19<11:37:31,  1.19s/it]

timestamp: 06/05/2020 21:08:07, average loss: 1.316384, time duration: 121.752802,
                            number of examples in current reporting: 800, step 800
                            out of total 5000


Iteration:   3%|▎         | 900/35904 [18:20<11:32:43,  1.19s/it]

timestamp: 06/05/2020 21:10:09, average loss: 1.276707, time duration: 121.460550,
                            number of examples in current reporting: 800, step 900
                            out of total 5000


Iteration:   3%|▎         | 999/35904 [20:21<11:28:53,  1.18s/it]

timestamp: 06/05/2020 21:12:10, average loss: 1.268152, time duration: 121.720028,
                            number of examples in current reporting: 800, step 1000
                            out of total 5000
./bart_cache
saving through pytorch to ./bart_cache/bart-large_step_1000.pt


Iteration:   3%|▎         | 1100/35904 [22:28<11:48:37,  1.22s/it]

timestamp: 06/05/2020 21:14:17, average loss: 1.273300, time duration: 126.425467,
                            number of examples in current reporting: 800, step 1100
                            out of total 5000


Iteration:   3%|▎         | 1200/35904 [24:28<11:34:58,  1.20s/it]

timestamp: 06/05/2020 21:16:16, average loss: 1.306796, time duration: 119.081359,
                            number of examples in current reporting: 800, step 1200
                            out of total 5000


Iteration:   4%|▎         | 1300/35904 [26:29<11:27:53,  1.19s/it]

timestamp: 06/05/2020 21:18:18, average loss: 1.259358, time duration: 121.890670,
                            number of examples in current reporting: 800, step 1300
                            out of total 5000


Iteration:   4%|▍         | 1400/35904 [28:31<11:29:21,  1.20s/it]

timestamp: 06/05/2020 21:20:19, average loss: 1.271845, time duration: 121.798160,
                            number of examples in current reporting: 800, step 1400
                            out of total 5000


Iteration:   4%|▍         | 1500/35904 [30:33<11:24:12,  1.19s/it]

timestamp: 06/05/2020 21:22:21, average loss: 1.244293, time duration: 121.576017,
                            number of examples in current reporting: 800, step 1500
                            out of total 5000


Iteration:   4%|▍         | 1600/35904 [32:34<11:19:33,  1.19s/it]

timestamp: 06/05/2020 21:24:22, average loss: 1.244026, time duration: 121.302933,
                            number of examples in current reporting: 800, step 1600
                            out of total 5000


Iteration:   5%|▍         | 1700/35904 [34:36<11:33:25,  1.22s/it]

timestamp: 06/05/2020 21:26:24, average loss: 1.251364, time duration: 121.701675,
                            number of examples in current reporting: 800, step 1700
                            out of total 5000


Iteration:   5%|▌         | 1800/35904 [36:35<11:12:42,  1.18s/it]

timestamp: 06/05/2020 21:28:23, average loss: 1.240429, time duration: 119.028772,
                            number of examples in current reporting: 800, step 1800
                            out of total 5000


Iteration:   5%|▌         | 1900/35904 [38:37<11:14:10,  1.19s/it]

timestamp: 06/05/2020 21:30:25, average loss: 1.272301, time duration: 121.825136,
                            number of examples in current reporting: 800, step 1900
                            out of total 5000


Iteration:   6%|▌         | 1999/35904 [40:37<11:14:59,  1.19s/it]

timestamp: 06/05/2020 21:32:26, average loss: 1.250983, time duration: 121.385592,
                            number of examples in current reporting: 800, step 2000
                            out of total 5000
./bart_cache
saving through pytorch to ./bart_cache/bart-large_step_2000.pt


Iteration:   6%|▌         | 2100/35904 [42:44<11:10:54,  1.19s/it]

timestamp: 06/05/2020 21:34:32, average loss: 1.189827, time duration: 125.865933,
                            number of examples in current reporting: 800, step 2100
                            out of total 5000


Iteration:   6%|▌         | 2200/35904 [44:46<11:06:20,  1.19s/it]

timestamp: 06/05/2020 21:36:34, average loss: 1.220882, time duration: 121.613832,
                            number of examples in current reporting: 800, step 2200
                            out of total 5000


Iteration:   6%|▋         | 2300/35904 [46:44<11:11:13,  1.20s/it]

timestamp: 06/05/2020 21:38:33, average loss: 1.222556, time duration: 118.914466,
                            number of examples in current reporting: 800, step 2300
                            out of total 5000


Iteration:   7%|▋         | 2400/35904 [48:46<11:04:05,  1.19s/it]

timestamp: 06/05/2020 21:40:34, average loss: 1.236500, time duration: 121.433649,
                            number of examples in current reporting: 800, step 2400
                            out of total 5000


Iteration:   7%|▋         | 2500/35904 [50:47<11:03:49,  1.19s/it]

timestamp: 06/05/2020 21:42:35, average loss: 1.230107, time duration: 121.259210,
                            number of examples in current reporting: 800, step 2500
                            out of total 5000


Iteration:   7%|▋         | 2600/35904 [52:48<12:02:40,  1.30s/it]

timestamp: 06/05/2020 21:44:37, average loss: 1.220464, time duration: 121.267537,
                            number of examples in current reporting: 800, step 2600
                            out of total 5000


Iteration:   8%|▊         | 2700/35904 [54:47<10:53:59,  1.18s/it]

timestamp: 06/05/2020 21:46:36, average loss: 1.227780, time duration: 118.972323,
                            number of examples in current reporting: 800, step 2700
                            out of total 5000


Iteration:   8%|▊         | 2800/35904 [56:49<10:51:26,  1.18s/it]

timestamp: 06/05/2020 21:48:37, average loss: 1.206418, time duration: 121.321079,
                            number of examples in current reporting: 800, step 2800
                            out of total 5000


Iteration:   8%|▊         | 2900/35904 [58:50<10:54:47,  1.19s/it]

timestamp: 06/05/2020 21:50:38, average loss: 1.245798, time duration: 121.355105,
                            number of examples in current reporting: 800, step 2900
                            out of total 5000


Iteration:   8%|▊         | 2999/35904 [1:00:47<10:57:26,  1.20s/it]

timestamp: 06/05/2020 21:52:37, average loss: 1.191816, time duration: 118.659829,
                            number of examples in current reporting: 800, step 3000
                            out of total 5000
./bart_cache
saving through pytorch to ./bart_cache/bart-large_step_3000.pt


Iteration:   9%|▊         | 3100/35904 [1:02:54<10:53:09,  1.19s/it]

timestamp: 06/05/2020 21:54:42, average loss: 1.195659, time duration: 125.354036,
                            number of examples in current reporting: 800, step 3100
                            out of total 5000


Iteration:   9%|▉         | 3200/35904 [1:04:55<10:41:39,  1.18s/it]

timestamp: 06/05/2020 21:56:43, average loss: 1.195860, time duration: 121.057553,
                            number of examples in current reporting: 800, step 3200
                            out of total 5000


Iteration:   9%|▉         | 3300/35904 [1:06:56<10:40:21,  1.18s/it]

timestamp: 06/05/2020 21:58:44, average loss: 1.201709, time duration: 121.014937,
                            number of examples in current reporting: 800, step 3300
                            out of total 5000


Iteration:   9%|▉         | 3400/35904 [1:08:57<10:46:31,  1.19s/it]

timestamp: 06/05/2020 22:00:45, average loss: 1.208225, time duration: 121.003419,
                            number of examples in current reporting: 800, step 3400
                            out of total 5000


Iteration:  10%|▉         | 3500/35904 [1:10:56<10:41:47,  1.19s/it]

timestamp: 06/05/2020 22:02:44, average loss: 1.206781, time duration: 118.554041,
                            number of examples in current reporting: 800, step 3500
                            out of total 5000


Iteration:  10%|█         | 3600/35904 [1:12:57<10:40:29,  1.19s/it]

timestamp: 06/05/2020 22:04:45, average loss: 1.202228, time duration: 121.427421,
                            number of examples in current reporting: 800, step 3600
                            out of total 5000


Iteration:  10%|█         | 3700/35904 [1:14:58<10:37:31,  1.19s/it]

timestamp: 06/05/2020 22:06:46, average loss: 1.194690, time duration: 120.942579,
                            number of examples in current reporting: 800, step 3700
                            out of total 5000


Iteration:  11%|█         | 3800/35904 [1:16:59<10:30:54,  1.18s/it]

timestamp: 06/05/2020 22:08:48, average loss: 1.178317, time duration: 121.416029,
                            number of examples in current reporting: 800, step 3800
                            out of total 5000


Iteration:  11%|█         | 3900/35904 [1:19:01<10:57:59,  1.23s/it]

timestamp: 06/05/2020 22:10:49, average loss: 1.212299, time duration: 121.453999,
                            number of examples in current reporting: 800, step 3900
                            out of total 5000


Iteration:  11%|█         | 3999/35904 [1:21:01<10:32:58,  1.19s/it]

timestamp: 06/05/2020 22:12:50, average loss: 1.176772, time duration: 121.211658,
                            number of examples in current reporting: 800, step 4000
                            out of total 5000
./bart_cache
saving through pytorch to ./bart_cache/bart-large_step_4000.pt


Iteration:  11%|█▏        | 4100/35904 [1:23:06<10:34:54,  1.20s/it]

timestamp: 06/05/2020 22:14:54, average loss: 1.212825, time duration: 123.401030,
                            number of examples in current reporting: 800, step 4100
                            out of total 5000


Iteration:  12%|█▏        | 4200/35904 [1:25:07<10:26:55,  1.19s/it]

timestamp: 06/05/2020 22:16:55, average loss: 1.201477, time duration: 121.219083,
                            number of examples in current reporting: 800, step 4200
                            out of total 5000


Iteration:  12%|█▏        | 4300/35904 [1:27:08<10:23:00,  1.18s/it]

timestamp: 06/05/2020 22:18:56, average loss: 1.180546, time duration: 121.049531,
                            number of examples in current reporting: 800, step 4300
                            out of total 5000


Iteration:  12%|█▏        | 4335/35904 [1:27:49<10:39:36,  1.22s/it]


KeyboardInterrupt: 

In [None]:
"""
summarizer.save_model(
    os.path.join(
        CACHE_DIR,
        "extsum_modelname_{0}_usepreprocess{1}_steps_{2}.pt".format(
            MODEL_NAME, USE_PREPROCSSED_DATA, MAX_STEPS
        ),
    )
)
"""


In [None]:
# for loading a previous saved model
"""
import torch
model_path = os.path.join(
        CACHE_DIR,
        "extsum_modelname_{0}_usepreprocess{1}_steps_{2}.pt".format(
            MODEL_NAME, USE_PREPROCSSED_DATA, MAX_STEPS
        ))
summarizer = ExtractiveSummarizer(processor, MODEL_NAME, ENCODER, MAX_POS_LENGTH, CACHE_DIR)
summarizer.model.load_state_dict(torch.load(model_path, map_location="cpu"))
"""

### Model Evaluation

[ROUGE](https://en.wikipedia.org/wiki/ROUGE_(metric)), or Recall-Oriented Understudy for Gisting Evaluation has been commonly used for evaluating text summarization.

In [None]:
abs_sum_test[0].keys()

In [8]:
source = []
target = []
for i in abs_sum_test:
    source.append(i["src_txt"]) 
    target.append(i['tgt'].replace("<t>","").replace("</t>", "").replace("\n", "")) 

In [13]:
target[0]

" marseille prosecutor says `` so far no videos were used in the crash investigation '' despite media reports .   journalists at bild and paris match are `` very confident '' the video clip is real , an editor says .   andreas lubitz had informed his lufthansa training school of an episode of severe depression , airline says . "

In [15]:
%%time
prediction = summarizer.predict(abs_sum_test[0:256*4], num_gpus=1, batch_size=16) 

Generating summary:   0%|          | 0/64 [00:00<?, ?it/s]

dataset length is 1024


Generating summary: 100%|██████████| 64/64 [07:29<00:00,  6.67s/it]


CPU times: user 7min 25s, sys: 3.61 s, total: 7min 29s
Wall time: 7min 30s


In [10]:
prediction = summarizer.predict(abs_sum_train[0:8], num_gpus=1, batch_size=8) 

Generating summary:   0%|          | 0/1 [00:00<?, ?it/s]

dataset length is 8


Generating summary: 100%|██████████| 1/1 [00:09<00:00,  9.32s/it]


In [11]:
prediction[0]

' < and <e in ,e <he ,he ,he- : and <he and <he-he- : and <he-e <, herhe-e < not < and < and <, : and <, : and < and < not <, ;he and <hehe-e < not <, : and <, ands ,he- ,he-heand ,he- and < are- ,he- ,hehehehe and < are and < ins : and < and < are and < not < are < are < not < aree < not < are < are < are < are < not < and < are < are < are < are < not < not < are < are- and < are < not < and < are < are < are < are < not < not < are < are < are < are < in, and < not < not < are < -he-he-e < and'

In [16]:

%%time
prediction = summarizer.predict(abs_sum_test[0:256*4], num_gpus=2, batch_size=32) 

Generating summary:   0%|          | 0/32 [00:00<?, ?it/s]

dataset length is 1024


Generating summary: 100%|██████████| 32/32 [06:01<00:00, 10.43s/it]


CPU times: user 8min 21s, sys: 52.5 s, total: 9min 13s
Wall time: 6min 2s


In [17]:

%%time
prediction = summarizer.predict(abs_sum_test[0:256*4], num_gpus=NUM_GPUS, batch_size=64) 

Generating summary:   0%|          | 0/16 [00:00<?, ?it/s]

dataset length is 1024


Generating summary: 100%|██████████| 16/16 [04:59<00:00, 17.81s/it]


CPU times: user 9min 11s, sys: 1min 25s, total: 10min 37s
Wall time: 5min


In [None]:
torch.save(prediction, "prediction.pt")

In [3]:
prediction = torch.load("prediction.pt")

In [10]:
rouge_scores = compute_rouge_python(cand=prediction, ref=target)
pprint.pprint(rouge_scores)

Number of candidates: 11490
Number of references: 11490
{'rouge-1': {'f': 0.43366650568906373,
             'p': 0.3878160218865652,
             'r': 0.5256684651435454},
 'rouge-2': {'f': 0.2013283037622797,
             'p': 0.18073246915601657,
             'r': 0.24341426947272182},
 'rouge-l': {'f': 0.2945220878967815,
             'p': 0.26331547921506626,
             'r': 0.3573184457546521}}


In [None]:
target[0]

In [None]:
prediction[0]

In [None]:
source[0]

In [None]:
# for testing
sb.glue("rouge_2_f_score", rouge_scores['rouge-2']['f'])

## Prediction on a single input sample

In [None]:
source = """
But under the new rule, set to be announced in the next 48 hours, Border Patrol agents would immediately return anyone to Mexico — without any detainment and without any due process — who attempts to cross the southwestern border between the legal ports of entry. The person would not be held for any length of time in an American facility.

Although they advised that details could change before the announcement, administration officials said the measure was needed to avert what they fear could be a systemwide outbreak of the coronavirus inside detention facilities along the border. Such an outbreak could spread quickly through the immigrant population and could infect large numbers of Border Patrol agents, leaving the southwestern border defenses weakened, the officials argued.
The Trump administration plans to immediately turn back all asylum seekers and other foreigners attempting to enter the United States from Mexico illegally, saying the nation cannot risk allowing the coronavirus to spread through detention facilities and Border Patrol agents, four administration officials said.
The administration officials said the ports of entry would remain open to American citizens, green-card holders and foreigners with proper documentation. Some foreigners would be blocked, including Europeans currently subject to earlier travel restrictions imposed by the administration. The points of entry will also be open to commercial traffic."""

In [None]:
test_dataset = SummarizationDataset(
    None,
    source=[source],
    source_preprocessing=[tokenize.sent_tokenize],
    word_tokenize=nltk.word_tokenize,
)
processor = ExtSumProcessor(model_name=MODEL_NAME,  cache_dir=CACHE_DIR)
preprocessed_dataset = processor.preprocess(test_dataset)

In [None]:
preprocessed_dataset[0].keys()

In [None]:
prediction = summarizer.predict(preprocessed_dataset, num_gpus=0, batch_size=1, sentence_separator="\n")

In [None]:
prediction

## Clean up temporary folders

In [None]:
if os.path.exists(DATA_PATH):
    shutil.rmtree(DATA_PATH, ignore_errors=True)
if os.path.exists(CACHE_DIR):
    shutil.rmtree(CACHE_DIR, ignore_errors=True)
if USE_PREPROCSSED_DATA:
    if os.path.exists(PROCESSED_DATA_PATH):
        shutil.rmtree(PROCESSED_DATA_PATH, ignore_errors=True)