In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

In [3]:
#! /usr/bin/python3
import argparse
import logging
import os
import sys
from collections import namedtuple

import torch
from torch.utils.data import DataLoader, SequentialSampler
from tqdm import tqdm

#
from transformers import BertTokenizer
#from transformers import BertAbs



I0108 20:09:51.090050 140687649494848 file_utils.py:35] PyTorch version 1.3.0 available.


In [4]:
sys.path.insert(0, "/dadendev/transformers/examples/summarization")
from modeling_bertabs import BertAbs, build_predictor

from utils_summarization import (
    #SummarizationDataset,
    build_mask,
    compute_token_type_ids,
    encode_for_summarization,
    fit_to_block_size,
)
from run_summarization import format_summary

In [5]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)
model = BertAbs.from_pretrained("bertabs-finetuned-cnndm")
model.to("cuda")
model.eval()

symbols = {
    "BOS": tokenizer.vocab["[unused0]"],
    "EOS": tokenizer.vocab["[unused1]"],
    "PAD": tokenizer.vocab["[PAD]"],
}

I0108 20:09:52.932581 140687649494848 tokenization_utils.py:398] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/daden/.cache/torch/transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
I0108 20:09:53.102754 140687649494848 configuration_utils.py:185] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/remi/bertabs-finetuned-cnndm-extractive-abstractive-summarization-config.json from cache at /home/daden/.cache/torch/transformers/7ebb4ac81007d10b400cb6c2968d4c8f1275a3e0cc3bab7f20f81913198b542c.df616398f4c84def6fca83d755543b01cb445db4ddd218d3efeded8ded68332f
I0108 20:09:53.103896 140687649494848 configuration_utils.py:199] Model config {
  "dec_dropout": 0.2,
  "dec_ff_size": 2048,
  "dec_heads": 8,
  "dec_hidden_size": 768,
  "dec_layers": 6,
  "enc_dropout": 0.2,
  "enc_ff_size": 512,
  "enc_heads": 8,
  "e

In [6]:
nlp_path = os.path.abspath("../../")
if nlp_path not in sys.path:
    sys.path.insert(0, nlp_path)
from utils_nlp.models.transformers.extractive_summarization import Bunch
args = Bunch({"block_trigram":True, "alpha": 0.95, "beam_size": 5, "min_length": 20, "max_length": 200})

In [7]:
predictor = build_predictor(args, tokenizer, symbols, model)

In [8]:
import os
import shutil
import sys
from tempfile import TemporaryDirectory
import torch

nlp_path = os.path.abspath("../../")
if nlp_path not in sys.path:
    sys.path.insert(0, nlp_path)

from utils_nlp.common.pytorch_utils import get_device
from utils_nlp.dataset.cnndm import CNNDMBertSumProcessedData, CNNDMSummarizationDataset
from utils_nlp.eval.evaluate_summarization import get_rouge
from utils_nlp.models.transformers.extractive_summarization import (
    ExtractiveSummarizer,
    ExtSumProcessedData,
    ExtSumProcessor,
)

import numpy as np
import pandas as pd
import scrapbook as sb

[nltk_data] Downloading package punkt to /home/daden/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [9]:
#from utils_nlp.models.transformers.datasets import SummarizationDataset
from utils_nlp.dataset.cnndm import CNNDMAbsSumDataset, CNNDMSummarizationDataset
#def build_data_iterator(args, tokenizer):


In [10]:
DATA_PATH = '/tmp/tmpsh6mbj3g'

In [11]:
QUICK_RUN = True
# the data path used to save the downloaded data file
# DATA_PATH = TemporaryDirectory().name
# The number of lines at the head of data file used for preprocessing. -1 means all the lines.
TOP_N = 4
CHUNK_SIZE=200
if not QUICK_RUN:
    TOP_N = -1
    CHUNK_SIZE = 2000

In [12]:
from torch.utils.data import Dataset
class SummarizationDataset(Dataset):
    def __init__(self, source, target=None):
        self.source = source
        self.target = target
    def __len__(self):
        return len(self.source)
    def __getitem__(self, idx):
        return self.source[idx], self.target[idx]

In [13]:
train_dataset, test_dataset = CNNDMAbsSumDataset(top_n=TOP_N, local_cache_path=DATA_PATH)

I0108 20:10:04.784991 140687649494848 utils.py:173] Opening tar file /tmp/tmpsh6mbj3g/cnndm.tar.gz.
I0108 20:10:04.786256 140687649494848 utils.py:181] /tmp/tmpsh6mbj3g/test.txt.src already extracted.
I0108 20:10:05.076648 140687649494848 utils.py:181] /tmp/tmpsh6mbj3g/test.txt.tgt.tagged already extracted.
I0108 20:10:05.103137 140687649494848 utils.py:181] /tmp/tmpsh6mbj3g/train.txt.src already extracted.
I0108 20:10:12.616743 140687649494848 utils.py:181] /tmp/tmpsh6mbj3g/train.txt.tgt.tagged already extracted.
I0108 20:10:13.232445 140687649494848 utils.py:181] /tmp/tmpsh6mbj3g/val.txt.src already extracted.
I0108 20:10:13.578413 140687649494848 utils.py:181] /tmp/tmpsh6mbj3g/val.txt.tgt.tagged already extracted.


In [14]:
data = list(test_dataset.get_source()), list(test_dataset.get_target())

In [15]:
len(data[1])

4

In [16]:
test_sum_dataset = SummarizationDataset(data[0], data[1])

In [17]:
TrainBatch = namedtuple("Batch", [ "batch_size", "src", "segs", "mask_src", "tgt", "tgt_segs", "mask_tgt", "tgt_str"])

In [18]:
def collate(data, tokenizer, block_size, device):
    """ Collate formats the data passed to the data loader.
    In particular we tokenize the data batch after batch to avoid keeping them
    all in memory. We output the data as a namedtuple to fit the original BertAbs's
    API.
    """
    data = [x for x in data if not len(x[1]) == 0]  # remove empty_files
    #print(data)
    #names = [name for name, _, _ in data]
    # summaries = [" ".join(summary_list) for _, _, summary_list in data]
    summaries = [" ".join(summary_list) for _, summary_list in data]
  

    encoded_text = [encode_for_summarization(story, summary, tokenizer) for story, summary in data]
    
    
    #""""""
    encoded_stories = torch.tensor(
        [fit_to_block_size(story, block_size, tokenizer.pad_token_id) for story, _ in encoded_text]
    )
    encoder_token_type_ids = compute_token_type_ids(encoded_stories, tokenizer.cls_token_id)
    encoder_mask = build_mask(encoded_stories, tokenizer.pad_token_id)
    #"""
    print(len(encoded_stories))

    batch = Batch(
        #document_names=None,
        batch_size=len(encoded_stories),
        src=encoded_stories.to(device),
        segs=encoder_token_type_ids.to(device),
        mask_src=encoder_mask.to(device),
        tgt_str=summaries,
    )
    return batch

In [19]:
def encode_for_summarization(story_lines, summary_lines, tokenizer, max_len=512):
    """ Encode the story and summary lines, and join them
    as specified in [1] by using `[SEP] [CLS]` tokens to separate
    sentences.
    """
    story_lines_token_ids = [tokenizer.encode(line, max_length=max_len) for line in story_lines]
    story_token_ids = [token for sentence in story_lines_token_ids for token in sentence]
    summary_lines_token_ids = [tokenizer.encode(line, max_length=max_len) for line in summary_lines]
    summary_token_ids = [token for sentence in summary_lines_token_ids for token in sentence]

    return story_token_ids, summary_token_ids

In [20]:
def build_data_iterator(dataset, tokenizer, batch_size=16, device='cuda'):
   
    sampler = SequentialSampler(dataset)

    def collate_fn(data):
        return collate(data, tokenizer, block_size=512, device=device)

    iterator = DataLoader(dataset, sampler=sampler, batch_size=batch_size, collate_fn=collate_fn,)

    return iterator

In [21]:
from utils_nlp.common.pytorch_utils import get_device
device, num_gpus = get_device(num_gpus=4, local_rank=-1)

In [22]:
data_iterator = build_data_iterator(test_sum_dataset, tokenizer, batch_size=64, device=device)

In [23]:
reference_summaries = []
generated_summaries = []

In [24]:
for batch in tqdm(data_iterator):
    
    batch_data = predictor.translate_batch(batch)
    translations = predictor.from_batch(batch_data)
    summaries = [format_summary(t) for t in translations]
    #save_summaries(summaries, args.summaries_output_dir, batch.document_names)

    if True:
        reference_summaries += batch.tgt_str
        generated_summaries += summaries

  0%|          | 0/1 [00:00<?, ?it/s]

4


100%|██████████| 1/1 [00:02<00:00,  2.37s/it]


In [25]:
generated_summaries[0]

'prosecutor brice robin : " so far no videos were used in the crash investigation ". robin \'s comments follow claims by two magazines , german daily bild and french paris match. all 150 on board germanwings flight 9525 were killed'

In [26]:
data[1][0]

[' marseille prosecutor says " so far no videos were used in the crash investigation " despite media reports . <q>  journalists at bild and paris match are " very confident " the video clip is real , an editor says . <q>  andreas lubitz had informed his lufthansa training school of an episode of severe depression , airline says . <q>\n']

In [27]:
reference_summaries = []
for i in data[1]:
    reference_summaries.append(i[0].replace('<q>', ''))

In [28]:
   def _write_list_to_file(list_items, filename):
        with open(filename, "w") as filehandle:
            # for cnt, line in enumerate(filehandle):
            for item in list_items:
                filehandle.write("%s\n" % item)

In [29]:
#_write_list_to_file(generated_summaries, "./generated_summaries")

In [30]:
generated_summaries = []
with open("./generated_summaries", "r") as filehandle:
    for cnt, line in enumerate(filehandle):
        generated_summaries.append(line)

In [31]:
len(generated_summaries)

11490

In [32]:
generated_summaries[0]

'prosecutor brice robin : " so far no videos were used in the crash investigation ". robin \'s comments follow claims by two magazines , german daily bild and french paris match. all 150 on board germanwings flight 9525 were killed\n'

In [33]:
reference_summaries[0]

' marseille prosecutor says " so far no videos were used in the crash investigation " despite media reports .   journalists at bild and paris match are " very confident " the video clip is real , an editor says .   andreas lubitz had informed his lufthansa training school of an episode of severe depression , airline says . \n'

In [34]:
import rouge
import nltk

nltk.download("punkt")
rouge_evaluator = rouge.Rouge(
    metrics=["rouge-n", "rouge-l"],
    max_n=2,
    limit_length=True,
    length_limit=args.beam_size,
    length_limit_type="words",
    apply_avg=True,
    apply_best=False,
    alpha=0.5,  # Default F1_score
    weight_factor=1.2,
    stemming=True,
)

[nltk_data] Downloading package punkt to /home/daden/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [35]:
#scores = rouge_evaluator.get_scores(generated_summaries, reference_summaries)
#str_scores = format_rouge_scores(scores)

In [36]:
#scores

In [37]:
#str_scores