Copyright (c) Microsoft Corporation.  
Licensed under the MIT License.

# Abstractive Summarization using MiniLM on CNN/DailyMails

## Before you start
Set `QUICK_RUN = True` to run the notebook on a small subset of data and a smaller number of steps. If `QUICK_RUN = False`, the notebook takes about 2 hours to run on a VM with 4 16GB NVIDIA V100 GPUs. 

In [1]:
QUICK_RUN = True

## Summary
This notebook demostrates how to fine-tune the [MiniLM](https://arxiv.org/abs/2002.10957) for abstractive summarization task. Utility functions and classes in the microsoft/nlp-recipes repo are used to facilitate data preprocessing, model training, model scoring, result postprocessing, and model evaluation.

### Abstractive Summarization
Abstractive summarization is the task of taking an input text and summarizing its content in a shorter output text. In contrast to extractive summarization, abstractive summarization doesn't take sentences directly from the input text, instead, rephrases the input text.

### MiniLM
[Unified Language Model](https://arxiv.org/abs/1905.03197) (UniLM) is a state of the art model developed by Microsoft Research Asia (MSRA). The model is first pre-trained on a large unlabeled natural language corpus (English Wikipedia and BookBorpus) and can be fine-tuned on different types of labeled data for various NLP tasks like text classification and abstractive summarization. For more information, please consult the notebook [Abstractive Summarization using MiniLM on CNN/DailyMails](./abstractive_summarization_unilm_cnndm.ipynb).

Large pre-trained language models like BERT and UniLM usually consists of **hundreds** of millions of parameters and it's challleging to fine-tune such large models and also serve  real-life applications due to latency and capacity constraints.

MiniLM is a small version of UniLM, which is trained to deelply mimic UniLM with  deep self-attention knowledge distillation. It only consits of **tens** of millions of parameters (33M), which is less than one third of BERT base model and only half of the size of [DistilBERT](https://arxiv.org/abs/1910.01108). Experimental results demonstrate that MiniLM retains most of the performance of UniLM on various NLP tasks with much less computation.  


In [2]:
%load_ext autoreload
%autoreload 2
import os
import shutil
from tempfile import TemporaryDirectory
import pprint
import scrapbook as sb
import sys
import time
import torch

nlp_path = os.path.abspath("../../")
if nlp_path not in sys.path:
    sys.path.insert(0, nlp_path)

from utils_nlp.dataset.cnndm import CNNDMSummarizationDatasetOrg
from utils_nlp.models.transformers.abstractive_summarization_seq2seq import S2SAbsSumProcessor, S2SAbstractiveSummarizer
from utils_nlp.eval import compute_rouge_python

from utils_nlp.models.transformers.datasets import SummarizationDataset
from utils_nlp.dataset.cnndm import detokenize

start_time = time.time()

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [3]:
# model parameters
MODEL_NAME = "minilm-l12-h384-uncased" 
MAX_SEQ_LENGTH = 512 
MAX_SOURCE_SEQ_LENGTH = 464 
MAX_TARGET_SEQ_LENGTH = MAX_SEQ_LENGTH - MAX_SOURCE_SEQ_LENGTH 

# use 0 for CPU
NUM_GPUS =  torch.cuda.device_count()

# fine-tuning parameters
TRAIN_PER_GPU_BATCH_SIZE = 4
GRADIENT_ACCUMULATION_STEPS = 1
LEARNING_RATE = 1e-4

TOP_N = -1
WARMUP_STEPS = 500
MAX_STEPS = 5000
BEAM_SIZE = 5
if QUICK_RUN:
    TOP_N = 1000
    WARMUP_STEPS = 500
    MAX_STEPS = 1000
    BEAM_SIZE = 3
    if NUM_GPUS == 0:
        TOP_N = 5
        MAX_STEPS = 10

# inference parameters
TEST_PER_GPU_BATCH_SIZE = 12
FORBID_IGNORE_WORD = "."

# mixed precision setting. To enable mixed precision training, follow instructions in SETUP.md. 
# You will be able to increase the batch sizes with mixed precision training.
FP16 = False

CLEANUP_RESULTS = False

DATA_DIR = TemporaryDirectory().name
CACHE_DIR = TemporaryDirectory().name

MODEL_DIR = "./minilm_cnndm_model"
RESULT_DIR = "./minilm_cnndm_result"
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(RESULT_DIR, exist_ok=True)
OUTPUT_FILE = os.path.join(RESULT_DIR, 'nlp_cnndm_finetuning_results.txt')

## Load the CNN/DailyMail dataset
The [CNN/DailyMail dataset](https://cs.nyu.edu/~kcho/DMQA/) was original introduced for Q&A research. There are multiple versions of the dataset processed for summarization task available on the web. The `CNNDMSummarizationDatasetOrg` function downloads a version from the [UniLM repo](https://github.com/microsoft/unilm) with minimal processing. The function returns the training and testing dataset as `SummarizationDataset` which can be further processed for model training and testing.

In [4]:
train_ds, test_ds = CNNDMSummarizationDatasetOrg(local_path=DATA_DIR, top_n=TOP_N)
print(len(train_ds))
print(len(test_ds))

Downloading 1jiDbDbAsqy_5BM79SmX6aSu5DQVCAZq1 into /tmp/tmpnlr7g5u4/cnndm_data.zip... Done.
1000
1000


## Preprocessing
The `S2SAbsSumProcessor` has multiple methods for converting input data in `SummarizationDataset`, `IterableSummarizationDataset` or json files into the format required for model training and testing. The preprocessing steps include
- Tokenize input text
- Convert tokens into token ids

In [5]:
processor = S2SAbsSumProcessor(model_name=MODEL_NAME,  cache_dir=CACHE_DIR)

HBox(children=(IntProgress(value=0, description='Downloading', max=231478, style=ProgressStyle(description_wid…




In [6]:
train_dataset = processor.s2s_dataset_from_sum_ds(train_ds, train_mode=True)
test_dataset = processor.s2s_dataset_from_sum_ds(test_ds, train_mode=False)

100%|██████████| 1000/1000 [00:13<00:00, 72.17it/s]
100%|██████████| 1000/1000 [00:13<00:00, 72.53it/s]


In [7]:
#train_dataset = processor.s2s_dataset_from_json_or_file("/dadendev/unilm/data/xsum.train.uncased_tokenized.json", train_mode=True, top_n=TOP_N)
#test_dataset = processor.s2s_dataset_from_json_or_file("/dadendev/unilm/data/xsum.test.uncased_tokenized.json", train_mode=False, top_n=TOP_N)

In [8]:
#import torch
#torch.save(train_dataset, os.path.join(CACHE_DIR, "cnndm_train_dataset.pt"))
#torch.save(test_dataset, os.path.join(CACHE_DIR, "cnndm_test_dataset.pt"))
#train_dataset = torch.load(os.path.join(CACHE_DIR, "train_dataset.pt"))
#test_dataset = torch.load(os.path.join(CACHE_DIR, "test_dataset.pt"))

## Fine tune model

The `S2SAbstractiveSummarizer` loads a pre-trained UniLM model specified by `model_name`.  
Call `S2SAbstractiveSummarizer.list_supported_models()` to see all the supported models.  
If you want to use a model on the local disk, specify `load_model_from_dir` and `model_file_name`. This is particularly useful if you want to load a previously fine-tuned model and use it for inference directly without fine-tuning. 

In [9]:
S2SAbstractiveSummarizer.list_supported_models()

['bert-large-uncased',
 'bert-base-cased',
 'bert-large-cased',
 'roberta-base',
 'roberta-large',
 'unilm-base-cased',
 'unilm-large-cased',
 'unilm1-base-cased',
 'unilm1-large-cased',
 'unilm1.2-base-uncased',
 'minilm-l12-h384-uncased']

In [10]:
abs_summarizer = S2SAbstractiveSummarizer(
    model_name=MODEL_NAME,
    max_seq_length=MAX_SEQ_LENGTH,
    max_source_seq_length=MAX_SOURCE_SEQ_LENGTH,
    max_target_seq_length=MAX_TARGET_SEQ_LENGTH,
    cache_dir=CACHE_DIR
)


HBox(children=(IntProgress(value=0, description='Downloading', max=313, style=ProgressStyle(description_width=…




HBox(children=(IntProgress(value=0, description='Downloading', max=67702765, style=ProgressStyle(description_w…




In [11]:
abs_summarizer.model

BertForSequenceToSequence(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 384, padding_idx=0)
      (position_embeddings): Embedding(512, 384)
      (token_type_embeddings): Embedding(2, 384)
      (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=384, out_features=384, bias=True)
              (key): Linear(in_features=384, out_features=384, bias=True)
              (value): Linear(in_features=384, out_features=384, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=384, out_features=384, bias=True)
              (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise

In [12]:
"""
abs_summarizer = S2SAbstractiveSummarizer(
     model_name=MODEL_NAME,
     max_seq_length=MAX_SEQ_LENGTH,
     max_source_seq_length=MAX_SOURCE_SEQ_LENGTH,
    max_target_seq_length=MAX_TARGET_SEQ_LENGTH,
     load_model_from_dir=RESULT_DIR,
    model_file_name="model.5000.bin",
 )
"""

'\nabs_summarizer = S2SAbstractiveSummarizer(\n     model_name=MODEL_NAME,\n     max_seq_length=MAX_SEQ_LENGTH,\n     max_source_seq_length=MAX_SOURCE_SEQ_LENGTH,\n    max_target_seq_length=MAX_TARGET_SEQ_LENGTH,\n     load_model_from_dir=RESULT_DIR,\n    model_file_name="model.5000.bin",\n )\n'

In [13]:
#"""
abs_summarizer.fit(
    train_dataset=train_dataset,
    num_gpus=NUM_GPUS,
    per_gpu_batch_size=TRAIN_PER_GPU_BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    learning_rate=LEARNING_RATE,
    warmup_steps=WARMUP_STEPS,
    max_steps=MAX_STEPS,
    fp16=FP16,
    save_model_to_dir=MODEL_DIR
)
#"""


Iteration:   1%|          | 10/1000 [00:09<08:40,  1.90it/s] 

timestamp: 16/04/2020 02:17:26, average loss: 9.080113, time duration: 9.574137,
                            number of examples in current reporting: 160, step 10
                            out of total 1000


Iteration:   2%|▏         | 20/1000 [00:11<03:53,  4.19it/s]

timestamp: 16/04/2020 02:17:28, average loss: 8.981736, time duration: 2.301608,
                            number of examples in current reporting: 160, step 20
                            out of total 1000


Iteration:   3%|▎         | 30/1000 [00:14<03:42,  4.35it/s]

timestamp: 16/04/2020 02:17:30, average loss: 8.775037, time duration: 2.295033,
                            number of examples in current reporting: 160, step 30
                            out of total 1000


Iteration:   4%|▍         | 40/1000 [00:16<03:41,  4.33it/s]

timestamp: 16/04/2020 02:17:32, average loss: 8.435953, time duration: 2.311963,
                            number of examples in current reporting: 160, step 40
                            out of total 1000


Iteration:   5%|▌         | 50/1000 [00:18<03:37,  4.36it/s]

timestamp: 16/04/2020 02:17:35, average loss: 7.919435, time duration: 2.306423,
                            number of examples in current reporting: 160, step 50
                            out of total 1000


Iteration:   6%|▌         | 60/1000 [00:21<03:58,  3.94it/s]

timestamp: 16/04/2020 02:17:37, average loss: 7.356749, time duration: 2.512326,
                            number of examples in current reporting: 160, step 60
                            out of total 1000


Iteration:   7%|▋         | 70/1000 [00:23<03:41,  4.20it/s]

timestamp: 16/04/2020 02:17:40, average loss: 7.032427, time duration: 2.356884,
                            number of examples in current reporting: 160, step 70
                            out of total 1000


Iteration:   8%|▊         | 80/1000 [00:26<03:36,  4.24it/s]

timestamp: 16/04/2020 02:17:42, average loss: 6.920119, time duration: 2.321869,
                            number of examples in current reporting: 160, step 80
                            out of total 1000


Iteration:   9%|▉         | 90/1000 [00:28<03:33,  4.26it/s]

timestamp: 16/04/2020 02:17:44, average loss: 6.865249, time duration: 2.343093,
                            number of examples in current reporting: 160, step 90
                            out of total 1000


Iteration:  10%|█         | 100/1000 [00:30<03:26,  4.36it/s]

timestamp: 16/04/2020 02:17:47, average loss: 6.808245, time duration: 2.295087,
                            number of examples in current reporting: 160, step 100
                            out of total 1000


Iteration:  11%|█         | 110/1000 [00:33<03:40,  4.04it/s]

timestamp: 16/04/2020 02:17:49, average loss: 6.746358, time duration: 2.364897,
                            number of examples in current reporting: 160, step 110
                            out of total 1000


Iteration:  12%|█▏        | 120/1000 [00:35<03:37,  4.04it/s]

timestamp: 16/04/2020 02:17:51, average loss: 6.673689, time duration: 2.485214,
                            number of examples in current reporting: 160, step 120
                            out of total 1000


Iteration:  13%|█▎        | 130/1000 [00:37<03:22,  4.30it/s]

timestamp: 16/04/2020 02:17:54, average loss: 6.532504, time duration: 2.310903,
                            number of examples in current reporting: 160, step 130
                            out of total 1000


Iteration:  14%|█▍        | 140/1000 [00:40<03:18,  4.33it/s]

timestamp: 16/04/2020 02:17:56, average loss: 6.485746, time duration: 2.309193,
                            number of examples in current reporting: 160, step 140
                            out of total 1000


Iteration:  15%|█▌        | 150/1000 [00:42<03:15,  4.35it/s]

timestamp: 16/04/2020 02:17:58, average loss: 6.368846, time duration: 2.327070,
                            number of examples in current reporting: 160, step 150
                            out of total 1000


Iteration:  16%|█▌        | 160/1000 [00:44<03:17,  4.26it/s]

timestamp: 16/04/2020 02:18:01, average loss: 6.306697, time duration: 2.329306,
                            number of examples in current reporting: 160, step 160
                            out of total 1000


Iteration:  17%|█▋        | 170/1000 [00:47<03:16,  4.23it/s]

timestamp: 16/04/2020 02:18:03, average loss: 6.169260, time duration: 2.343214,
                            number of examples in current reporting: 160, step 170
                            out of total 1000


Iteration:  18%|█▊        | 180/1000 [00:49<03:25,  3.99it/s]

timestamp: 16/04/2020 02:18:06, average loss: 6.113304, time duration: 2.489889,
                            number of examples in current reporting: 160, step 180
                            out of total 1000


Iteration:  19%|█▉        | 190/1000 [00:51<03:11,  4.22it/s]

timestamp: 16/04/2020 02:18:08, average loss: 5.977268, time duration: 2.338645,
                            number of examples in current reporting: 160, step 190
                            out of total 1000


Iteration:  20%|██        | 200/1000 [00:54<03:05,  4.32it/s]

timestamp: 16/04/2020 02:18:10, average loss: 5.904318, time duration: 2.303357,
                            number of examples in current reporting: 160, step 200
                            out of total 1000


Iteration:  21%|██        | 210/1000 [00:56<03:05,  4.27it/s]

timestamp: 16/04/2020 02:18:13, average loss: 5.790859, time duration: 2.314664,
                            number of examples in current reporting: 160, step 210
                            out of total 1000


Iteration:  22%|██▏       | 220/1000 [00:58<02:58,  4.37it/s]

timestamp: 16/04/2020 02:18:15, average loss: 5.730234, time duration: 2.299737,
                            number of examples in current reporting: 160, step 220
                            out of total 1000


Iteration:  23%|██▎       | 230/1000 [01:01<02:56,  4.35it/s]

timestamp: 16/04/2020 02:18:17, average loss: 5.619011, time duration: 2.300297,
                            number of examples in current reporting: 160, step 230
                            out of total 1000


Iteration:  24%|██▍       | 240/1000 [01:03<03:09,  4.02it/s]

timestamp: 16/04/2020 02:18:20, average loss: 5.565530, time duration: 2.465627,
                            number of examples in current reporting: 160, step 240
                            out of total 1000


Iteration:  25%|██▌       | 250/1000 [01:05<02:53,  4.31it/s]

timestamp: 16/04/2020 02:18:22, average loss: 5.436758, time duration: 2.320997,
                            number of examples in current reporting: 160, step 250
                            out of total 1000


Iteration:  26%|██▌       | 260/1000 [01:08<02:58,  4.14it/s]

timestamp: 16/04/2020 02:18:24, average loss: 5.347340, time duration: 2.386022,
                            number of examples in current reporting: 160, step 260
                            out of total 1000


Iteration:  27%|██▋       | 270/1000 [01:10<02:53,  4.22it/s]

timestamp: 16/04/2020 02:18:27, average loss: 5.325142, time duration: 2.342670,
                            number of examples in current reporting: 160, step 270
                            out of total 1000


Iteration:  28%|██▊       | 280/1000 [01:13<02:49,  4.26it/s]

timestamp: 16/04/2020 02:18:29, average loss: 5.229554, time duration: 2.353546,
                            number of examples in current reporting: 160, step 280
                            out of total 1000


Iteration:  29%|██▉       | 290/1000 [01:15<02:50,  4.17it/s]

timestamp: 16/04/2020 02:18:31, average loss: 5.190810, time duration: 2.343330,
                            number of examples in current reporting: 160, step 290
                            out of total 1000


Iteration:  30%|███       | 300/1000 [01:17<02:57,  3.94it/s]

timestamp: 16/04/2020 02:18:34, average loss: 5.047404, time duration: 2.570169,
                            number of examples in current reporting: 160, step 300
                            out of total 1000


Iteration:  31%|███       | 310/1000 [01:20<02:37,  4.37it/s]

timestamp: 16/04/2020 02:18:36, average loss: 5.034877, time duration: 2.283463,
                            number of examples in current reporting: 160, step 310
                            out of total 1000


Iteration:  32%|███▏      | 320/1000 [01:22<02:36,  4.36it/s]

timestamp: 16/04/2020 02:18:39, average loss: 4.918517, time duration: 2.301758,
                            number of examples in current reporting: 160, step 320
                            out of total 1000


Iteration:  33%|███▎      | 330/1000 [01:24<02:32,  4.39it/s]

timestamp: 16/04/2020 02:18:41, average loss: 4.898799, time duration: 2.284174,
                            number of examples in current reporting: 160, step 330
                            out of total 1000


Iteration:  34%|███▍      | 340/1000 [01:27<02:31,  4.34it/s]

timestamp: 16/04/2020 02:18:43, average loss: 4.805903, time duration: 2.284619,
                            number of examples in current reporting: 160, step 340
                            out of total 1000


Iteration:  35%|███▌      | 350/1000 [01:29<02:30,  4.33it/s]

timestamp: 16/04/2020 02:18:45, average loss: 4.784902, time duration: 2.308939,
                            number of examples in current reporting: 160, step 350
                            out of total 1000


Iteration:  36%|███▌      | 360/1000 [01:31<02:37,  4.07it/s]

timestamp: 16/04/2020 02:18:48, average loss: 4.669464, time duration: 2.442239,
                            number of examples in current reporting: 160, step 360
                            out of total 1000


Iteration:  37%|███▋      | 370/1000 [01:34<02:28,  4.24it/s]

timestamp: 16/04/2020 02:18:50, average loss: 4.657346, time duration: 2.321835,
                            number of examples in current reporting: 160, step 370
                            out of total 1000


Iteration:  38%|███▊      | 380/1000 [01:36<02:22,  4.34it/s]

timestamp: 16/04/2020 02:18:52, average loss: 4.560212, time duration: 2.308345,
                            number of examples in current reporting: 160, step 380
                            out of total 1000


Iteration:  39%|███▉      | 390/1000 [01:38<02:20,  4.35it/s]

timestamp: 16/04/2020 02:18:55, average loss: 4.521819, time duration: 2.306991,
                            number of examples in current reporting: 160, step 390
                            out of total 1000


Iteration:  40%|████      | 400/1000 [01:41<02:17,  4.37it/s]

timestamp: 16/04/2020 02:18:57, average loss: 4.396229, time duration: 2.292700,
                            number of examples in current reporting: 160, step 400
                            out of total 1000


Iteration:  41%|████      | 410/1000 [01:43<02:17,  4.29it/s]

timestamp: 16/04/2020 02:18:59, average loss: 4.433820, time duration: 2.319138,
                            number of examples in current reporting: 160, step 410
                            out of total 1000


Iteration:  42%|████▏     | 420/1000 [01:45<02:24,  4.02it/s]

timestamp: 16/04/2020 02:19:02, average loss: 4.264688, time duration: 2.478348,
                            number of examples in current reporting: 160, step 420
                            out of total 1000


Iteration:  43%|████▎     | 430/1000 [01:48<02:11,  4.35it/s]

timestamp: 16/04/2020 02:19:04, average loss: 4.259535, time duration: 2.299513,
                            number of examples in current reporting: 160, step 430
                            out of total 1000


Iteration:  44%|████▍     | 440/1000 [01:50<02:09,  4.34it/s]

timestamp: 16/04/2020 02:19:06, average loss: 4.184597, time duration: 2.295861,
                            number of examples in current reporting: 160, step 440
                            out of total 1000


Iteration:  45%|████▌     | 450/1000 [01:52<02:07,  4.32it/s]

timestamp: 16/04/2020 02:19:09, average loss: 4.116971, time duration: 2.324661,
                            number of examples in current reporting: 160, step 450
                            out of total 1000


Iteration:  46%|████▌     | 460/1000 [01:55<02:04,  4.34it/s]

timestamp: 16/04/2020 02:19:11, average loss: 4.020662, time duration: 2.294218,
                            number of examples in current reporting: 160, step 460
                            out of total 1000


Iteration:  47%|████▋     | 470/1000 [01:57<02:07,  4.15it/s]

timestamp: 16/04/2020 02:19:13, average loss: 4.048894, time duration: 2.359270,
                            number of examples in current reporting: 160, step 470
                            out of total 1000


Iteration:  48%|████▊     | 480/1000 [01:59<02:09,  4.02it/s]

timestamp: 16/04/2020 02:19:16, average loss: 3.967092, time duration: 2.474524,
                            number of examples in current reporting: 160, step 480
                            out of total 1000


Iteration:  49%|████▉     | 490/1000 [02:02<01:56,  4.37it/s]

timestamp: 16/04/2020 02:19:18, average loss: 3.941359, time duration: 2.279342,
                            number of examples in current reporting: 160, step 490
                            out of total 1000


Iteration:  50%|█████     | 500/1000 [02:04<01:55,  4.32it/s]

timestamp: 16/04/2020 02:19:20, average loss: 3.823642, time duration: 2.302453,
                            number of examples in current reporting: 160, step 500
                            out of total 1000


Iteration:  51%|█████     | 510/1000 [02:06<01:53,  4.31it/s]

timestamp: 16/04/2020 02:19:23, average loss: 3.736275, time duration: 2.327158,
                            number of examples in current reporting: 160, step 510
                            out of total 1000


Iteration:  52%|█████▏    | 520/1000 [02:09<01:51,  4.31it/s]

timestamp: 16/04/2020 02:19:25, average loss: 3.755796, time duration: 2.335013,
                            number of examples in current reporting: 160, step 520
                            out of total 1000


Iteration:  53%|█████▎    | 530/1000 [02:11<01:49,  4.29it/s]

timestamp: 16/04/2020 02:19:27, average loss: 3.682912, time duration: 2.323235,
                            number of examples in current reporting: 160, step 530
                            out of total 1000


Iteration:  54%|█████▍    | 540/1000 [02:13<01:53,  4.04it/s]

timestamp: 16/04/2020 02:19:30, average loss: 3.655455, time duration: 2.467513,
                            number of examples in current reporting: 160, step 540
                            out of total 1000


Iteration:  55%|█████▌    | 550/1000 [02:16<01:44,  4.30it/s]

timestamp: 16/04/2020 02:19:32, average loss: 3.561478, time duration: 2.311885,
                            number of examples in current reporting: 160, step 550
                            out of total 1000


Iteration:  56%|█████▌    | 560/1000 [02:18<01:40,  4.38it/s]

timestamp: 16/04/2020 02:19:35, average loss: 3.581482, time duration: 2.304521,
                            number of examples in current reporting: 160, step 560
                            out of total 1000


Iteration:  57%|█████▋    | 570/1000 [02:20<01:41,  4.23it/s]

timestamp: 16/04/2020 02:19:37, average loss: 3.449773, time duration: 2.398885,
                            number of examples in current reporting: 160, step 570
                            out of total 1000


Iteration:  58%|█████▊    | 580/1000 [02:23<01:39,  4.21it/s]

timestamp: 16/04/2020 02:19:39, average loss: 3.512243, time duration: 2.346694,
                            number of examples in current reporting: 160, step 580
                            out of total 1000


Iteration:  59%|█████▉    | 590/1000 [02:25<01:35,  4.31it/s]

timestamp: 16/04/2020 02:19:42, average loss: 3.373606, time duration: 2.316331,
                            number of examples in current reporting: 160, step 590
                            out of total 1000


Iteration:  60%|██████    | 600/1000 [02:28<01:37,  4.09it/s]

timestamp: 16/04/2020 02:19:44, average loss: 3.396941, time duration: 2.452835,
                            number of examples in current reporting: 160, step 600
                            out of total 1000


Iteration:  61%|██████    | 610/1000 [02:30<01:29,  4.36it/s]

timestamp: 16/04/2020 02:19:46, average loss: 3.302350, time duration: 2.288563,
                            number of examples in current reporting: 160, step 610
                            out of total 1000


Iteration:  62%|██████▏   | 620/1000 [02:32<01:28,  4.31it/s]

timestamp: 16/04/2020 02:19:49, average loss: 3.354256, time duration: 2.316888,
                            number of examples in current reporting: 160, step 620
                            out of total 1000


Iteration:  63%|██████▎   | 630/1000 [02:35<01:24,  4.36it/s]

timestamp: 16/04/2020 02:19:51, average loss: 3.268877, time duration: 2.319733,
                            number of examples in current reporting: 160, step 630
                            out of total 1000


Iteration:  64%|██████▍   | 640/1000 [02:37<01:23,  4.31it/s]

timestamp: 16/04/2020 02:19:53, average loss: 3.230106, time duration: 2.298786,
                            number of examples in current reporting: 160, step 640
                            out of total 1000


Iteration:  65%|██████▌   | 650/1000 [02:39<01:20,  4.33it/s]

timestamp: 16/04/2020 02:19:56, average loss: 3.155308, time duration: 2.325896,
                            number of examples in current reporting: 160, step 650
                            out of total 1000


Iteration:  66%|██████▌   | 660/1000 [02:42<01:23,  4.08it/s]

timestamp: 16/04/2020 02:19:58, average loss: 3.238014, time duration: 2.462547,
                            number of examples in current reporting: 160, step 660
                            out of total 1000


Iteration:  67%|██████▋   | 670/1000 [02:44<01:16,  4.30it/s]

timestamp: 16/04/2020 02:20:00, average loss: 3.067023, time duration: 2.295462,
                            number of examples in current reporting: 160, step 670
                            out of total 1000


Iteration:  68%|██████▊   | 680/1000 [02:46<01:13,  4.38it/s]

timestamp: 16/04/2020 02:20:03, average loss: 3.128462, time duration: 2.279071,
                            number of examples in current reporting: 160, step 680
                            out of total 1000


Iteration:  69%|██████▉   | 690/1000 [02:48<01:10,  4.38it/s]

timestamp: 16/04/2020 02:20:05, average loss: 3.068446, time duration: 2.299397,
                            number of examples in current reporting: 160, step 690
                            out of total 1000


Iteration:  70%|███████   | 700/1000 [02:51<01:08,  4.38it/s]

timestamp: 16/04/2020 02:20:07, average loss: 3.030622, time duration: 2.282668,
                            number of examples in current reporting: 160, step 700
                            out of total 1000


Iteration:  71%|███████   | 710/1000 [02:53<01:08,  4.24it/s]

timestamp: 16/04/2020 02:20:10, average loss: 2.963999, time duration: 2.362757,
                            number of examples in current reporting: 160, step 710
                            out of total 1000


Iteration:  72%|███████▏  | 720/1000 [02:56<01:08,  4.10it/s]

timestamp: 16/04/2020 02:20:12, average loss: 3.040875, time duration: 2.446717,
                            number of examples in current reporting: 160, step 720
                            out of total 1000


Iteration:  73%|███████▎  | 730/1000 [02:58<01:01,  4.41it/s]

timestamp: 16/04/2020 02:20:14, average loss: 2.961722, time duration: 2.274881,
                            number of examples in current reporting: 160, step 730
                            out of total 1000


Iteration:  74%|███████▍  | 740/1000 [03:00<01:00,  4.31it/s]

timestamp: 16/04/2020 02:20:17, average loss: 2.970588, time duration: 2.301664,
                            number of examples in current reporting: 160, step 740
                            out of total 1000


Iteration:  75%|███████▌  | 750/1000 [03:02<00:57,  4.35it/s]

timestamp: 16/04/2020 02:20:19, average loss: 2.896715, time duration: 2.302271,
                            number of examples in current reporting: 160, step 750
                            out of total 1000


Iteration:  76%|███████▌  | 760/1000 [03:05<00:55,  4.35it/s]

timestamp: 16/04/2020 02:20:21, average loss: 2.843995, time duration: 2.314438,
                            number of examples in current reporting: 160, step 760
                            out of total 1000


Iteration:  77%|███████▋  | 770/1000 [03:07<00:52,  4.37it/s]

timestamp: 16/04/2020 02:20:24, average loss: 2.891821, time duration: 2.297557,
                            number of examples in current reporting: 160, step 770
                            out of total 1000


Iteration:  78%|███████▊  | 780/1000 [03:10<00:55,  3.98it/s]

timestamp: 16/04/2020 02:20:26, average loss: 2.860446, time duration: 2.498748,
                            number of examples in current reporting: 160, step 780
                            out of total 1000


Iteration:  79%|███████▉  | 790/1000 [03:12<00:49,  4.28it/s]

timestamp: 16/04/2020 02:20:28, average loss: 2.857263, time duration: 2.340558,
                            number of examples in current reporting: 160, step 790
                            out of total 1000


Iteration:  80%|████████  | 800/1000 [03:14<00:46,  4.32it/s]

timestamp: 16/04/2020 02:20:31, average loss: 2.759473, time duration: 2.307025,
                            number of examples in current reporting: 160, step 800
                            out of total 1000


Iteration:  81%|████████  | 810/1000 [03:17<00:43,  4.34it/s]

timestamp: 16/04/2020 02:20:33, average loss: 2.834507, time duration: 2.311622,
                            number of examples in current reporting: 160, step 810
                            out of total 1000


Iteration:  82%|████████▏ | 820/1000 [03:19<00:41,  4.30it/s]

timestamp: 16/04/2020 02:20:35, average loss: 2.732735, time duration: 2.315288,
                            number of examples in current reporting: 160, step 820
                            out of total 1000


Iteration:  83%|████████▎ | 830/1000 [03:21<00:41,  4.14it/s]

timestamp: 16/04/2020 02:20:38, average loss: 2.786054, time duration: 2.382516,
                            number of examples in current reporting: 160, step 830
                            out of total 1000


Iteration:  84%|████████▍ | 840/1000 [03:24<00:39,  4.05it/s]

timestamp: 16/04/2020 02:20:40, average loss: 2.707751, time duration: 2.467501,
                            number of examples in current reporting: 160, step 840
                            out of total 1000


Iteration:  85%|████████▌ | 850/1000 [03:26<00:34,  4.32it/s]

timestamp: 16/04/2020 02:20:42, average loss: 2.739814, time duration: 2.307789,
                            number of examples in current reporting: 160, step 850
                            out of total 1000


Iteration:  86%|████████▌ | 860/1000 [03:28<00:32,  4.26it/s]

timestamp: 16/04/2020 02:20:45, average loss: 2.656100, time duration: 2.339591,
                            number of examples in current reporting: 160, step 860
                            out of total 1000


Iteration:  87%|████████▋ | 870/1000 [03:31<00:30,  4.32it/s]

timestamp: 16/04/2020 02:20:47, average loss: 2.754117, time duration: 2.299520,
                            number of examples in current reporting: 160, step 870
                            out of total 1000


Iteration:  88%|████████▊ | 880/1000 [03:33<00:27,  4.29it/s]

timestamp: 16/04/2020 02:20:49, average loss: 2.675594, time duration: 2.304183,
                            number of examples in current reporting: 160, step 880
                            out of total 1000


Iteration:  89%|████████▉ | 890/1000 [03:35<00:25,  4.31it/s]

timestamp: 16/04/2020 02:20:52, average loss: 2.666570, time duration: 2.302514,
                            number of examples in current reporting: 160, step 890
                            out of total 1000


Iteration:  90%|█████████ | 900/1000 [03:38<00:25,  4.00it/s]

timestamp: 16/04/2020 02:20:54, average loss: 2.611232, time duration: 2.480318,
                            number of examples in current reporting: 160, step 900
                            out of total 1000


Iteration:  91%|█████████ | 910/1000 [03:40<00:20,  4.30it/s]

timestamp: 16/04/2020 02:20:57, average loss: 2.733101, time duration: 2.320370,
                            number of examples in current reporting: 160, step 910
                            out of total 1000


Iteration:  92%|█████████▏| 920/1000 [03:42<00:18,  4.34it/s]

timestamp: 16/04/2020 02:20:59, average loss: 2.549486, time duration: 2.286971,
                            number of examples in current reporting: 160, step 920
                            out of total 1000


Iteration:  93%|█████████▎| 930/1000 [03:45<00:16,  4.32it/s]

timestamp: 16/04/2020 02:21:01, average loss: 2.661614, time duration: 2.306303,
                            number of examples in current reporting: 160, step 930
                            out of total 1000


Iteration:  94%|█████████▍| 940/1000 [03:47<00:13,  4.36it/s]

timestamp: 16/04/2020 02:21:03, average loss: 2.611486, time duration: 2.284473,
                            number of examples in current reporting: 160, step 940
                            out of total 1000


Iteration:  95%|█████████▌| 950/1000 [03:49<00:11,  4.38it/s]

timestamp: 16/04/2020 02:21:06, average loss: 2.578787, time duration: 2.291879,
                            number of examples in current reporting: 160, step 950
                            out of total 1000


Iteration:  96%|█████████▌| 960/1000 [03:52<00:09,  4.10it/s]

timestamp: 16/04/2020 02:21:08, average loss: 2.542034, time duration: 2.435049,
                            number of examples in current reporting: 160, step 960
                            out of total 1000


Iteration:  97%|█████████▋| 970/1000 [03:54<00:06,  4.36it/s]

timestamp: 16/04/2020 02:21:10, average loss: 2.654870, time duration: 2.280848,
                            number of examples in current reporting: 160, step 970
                            out of total 1000


Iteration:  98%|█████████▊| 980/1000 [03:56<00:04,  4.25it/s]

timestamp: 16/04/2020 02:21:13, average loss: 2.530608, time duration: 2.333624,
                            number of examples in current reporting: 160, step 980
                            out of total 1000


Iteration:  99%|█████████▉| 990/1000 [03:59<00:02,  4.36it/s]

timestamp: 16/04/2020 02:21:15, average loss: 2.609326, time duration: 2.286271,
                            number of examples in current reporting: 160, step 990
                            out of total 1000


Iteration: 100%|██████████| 1000/1000 [04:01<00:00,  4.36it/s]

timestamp: 16/04/2020 02:21:17, average loss: 2.579652, time duration: 2.280927,
                            number of examples in current reporting: 160, step 1000
                            out of total 1000





1000

In [14]:
#abs_summarizer.save_model(RESULT_DIR, 5000, False)

## Generate summaries on testing dataset

In [15]:
predictions = abs_summarizer.predict(
    test_dataset=test_dataset,
    num_gpus=NUM_GPUS,
    per_gpu_batch_size=TEST_PER_GPU_BATCH_SIZE,
    beam_size=BEAM_SIZE,
    max_tgt_length=MAX_TARGET_SEQ_LENGTH,
    forbid_ignore_word=FORBID_IGNORE_WORD,
    fp16=FP16
)

Evaluating: 100%|██████████| 21/21 [01:16<00:00,  3.65s/it]


In [16]:
for r in predictions[:5]:
    print(r)

french prosecutor says he ' s not aware of video video from on board the crash . new : reports of video found in the site , official says . new : official says the reports is " completely wrong "
palestinian government agreed to join the international criminal court in january . international criminal courts agreed to be members of the world ' s first court in iraq . israel , u . s . and u . s . opposed the court .
amnesty world ' s annual report says the number of death sentenced to death last year . new : report says u . s . government use the death penalty for threat of terrorism . u . s . , china , china are using the death
new : amnesty international says death penalty is " dead , " government says . pakistan , iraq , iraq is using death penalty for violence , officials say . authorities found guilty of a range of charges linked to violence in the region .
new : experts say ann frank ' s died at least a month ago . new : report : ann frank died at than least a months ago . new : r

In [17]:
test_ds.get_source()[0]

'Marseille, France (CNN) The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that " so far no videos were used in the crash investigation . " He added, " A person who has such a video needs to immediately give it to the investigators . " Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps . All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites . The publications said that they watched the video, which was found by a source close to the investigation. " One can hear cries of\' My God\' in several lan

In [18]:
test_ds.get_target()[0]

'Marseille prosecutor says " so far no videos were used in the crash investigation " despite media reports. Journalists at Bild and Paris Match are " very confident " the video clip is real, an editor says. Andreas Lubitz had informed his Lufthansa training school of an episode of severe depression, airline says.'

In [19]:
predictions[0]

'french prosecutor says he \' s not aware of video video from on board the crash . new : reports of video found in the site , official says . new : official says the reports is " completely wrong "'

In [20]:
with open(OUTPUT_FILE, 'w', encoding="utf-8") as f:
    for line in predictions:
        f.write(line + '\n')

## Prediction on a single input sample

In [21]:
source = """
But under the new rule, set to be announced in the next 48 hours, Border Patrol agents would immediately return anyone to Mexico — without any detainment and without any due process — who attempts to cross the southwestern border between the legal ports of entry. The person would not be held for any length of time in an American facility.

Although they advised that details could change before the announcement, administration officials said the measure was needed to avert what they fear could be a systemwide outbreak of the coronavirus inside detention facilities along the border. Such an outbreak could spread quickly through the immigrant population and could infect large numbers of Border Patrol agents, leaving the southwestern border defenses weakened, the officials argued.
The Trump administration plans to immediately turn back all asylum seekers and other foreigners attempting to enter the United States from Mexico illegally, saying the nation cannot risk allowing the coronavirus to spread through detention facilities and Border Patrol agents, four administration officials said.
The administration officials said the ports of entry would remain open to American citizens, green-card holders and foreigners with proper documentation. Some foreigners would be blocked, including Europeans currently subject to earlier travel restrictions imposed by the administration. The points of entry will also be open to commercial traffic."""

In [22]:
singel_test_ds = SummarizationDataset(
    None, source=[source], source_preprocessing=[detokenize],
)
single_test_dataset = processor.s2s_dataset_from_sum_ds(singel_test_ds, train_mode=False)

100%|██████████| 1/1 [00:00<00:00, 171.04it/s]


In [23]:
single_prediction = abs_summarizer.predict(
    test_dataset=single_test_dataset,
    num_gpus=NUM_GPUS,
    per_gpu_batch_size=1,
    beam_size=BEAM_SIZE,
    forbid_ignore_word=FORBID_IGNORE_WORD,
    fp16=FP16
)

Evaluating: 100%|██████████| 1/1 [00:01<00:00,  1.19s/it]


In [24]:
single_prediction[0]

'u . s . border enforcement officials say some foreigners will be open to americans .'

## Evaluation
We provide utility functions for evaluating summarization models and details can be found in the [summarization evaluation notebook](./summarization_evaluation.ipynb).  
For the settings in this notebook with QUICK_RUN=False, you should get ROUGE scores close to the following numbers: <br />
``
{'rouge-1': {'f': 0.36208534811461,
             'p': 0.4743143496862804,
             'r': 0.30901813498597874},
 'rouge-2': {'f': 0.1620935174111968,
             'p': 0.2153396681546399,
             'r': 0.13747476622638555},
 'rouge-l': {'f': 0.2612394493528272,
             'p': 0.3426511372716949,
             'r': 0.22311445054693663}}
``


In [25]:
rouge_scores = compute_rouge_python(cand=predictions, ref=test_ds.get_target())
pprint.pprint(rouge_scores)

Number of candidates: 1000
Number of references: 1000
{'rouge-1': {'f': 0.2471059168881668,
             'p': 0.26124243347664683,
             'r': 0.24831050296409682},
 'rouge-2': {'f': 0.06458437980183909,
             'p': 0.06913302154728393,
             'r': 0.06444674243433718},
 'rouge-l': {'f': 0.17237949931564,
             'p': 0.18260208815440665,
             'r': 0.17339200584515088}}


In [26]:
# for testing
sb.glue("rouge_1_f_score", rouge_scores["rouge-1"]["f"])
sb.glue("rouge_2_f_score", rouge_scores["rouge-2"]["f"])
sb.glue("rouge_l_f_score", rouge_scores["rouge-l"]["f"])

## Distributed training with DistributedDataParallel (DDP)
Please consult the notebook [Abstractive Summarization using MiniLM on CNN/DailyMails](./abstractive_summarization_unilm_cnndm.ipynb) for distributed training.    

## Clean up 

In [27]:
if os.path.exists(DATA_DIR):
    shutil.rmtree(DATA_DIR, ignore_errors=True)
if os.path.exists(CACHE_DIR):
    shutil.rmtree(CACHE_DIR, ignore_errors=True)
    
if CLEANUP_RESULTS:
    if os.path.exists(MODEl_DIR):
        shutil.rmtree(MODEl_DIR, ignore_errors=True)
    if os.path.exists(RESULT_DIR):
        shutil.rmtree(RESULT_DIR, ignore_errors=True)

In [28]:
print("Total notebook running time {}".format(time.time() - start_time))

Total notebook running time 403.38391304016113
