In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

In [3]:
#! /usr/bin/python3
import argparse
import logging
import os
import sys
from collections import namedtuple

import torch
from torch.utils.data import DataLoader, SequentialSampler
from tqdm import tqdm

#
from transformers import BertTokenizer
#from transformers import BertAbs



I0108 04:39:52.395967 139994149881664 file_utils.py:35] PyTorch version 1.2.0 available.


In [4]:
sys.path.insert(0, "/dadendev/transformers/examples/summarization")
from modeling_bertabs import BertAbs, build_predictor

from utils_summarization import (
    #SummarizationDataset,
    build_mask,
    compute_token_type_ids,
    encode_for_summarization,
    fit_to_block_size,
)
from run_summarization import format_summary

In [5]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)
model = BertAbs.from_pretrained("bertabs-finetuned-cnndm")
model.to("cuda")
model.eval()

symbols = {
    "BOS": tokenizer.vocab["[unused0]"],
    "EOS": tokenizer.vocab["[unused1]"],
    "PAD": tokenizer.vocab["[PAD]"],
}

I0108 04:39:54.246211 139994149881664 tokenization_utils.py:398] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/daden/.cache/torch/transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
I0108 04:39:54.417192 139994149881664 configuration_utils.py:185] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/remi/bertabs-finetuned-cnndm-extractive-abstractive-summarization-config.json from cache at /home/daden/.cache/torch/transformers/7ebb4ac81007d10b400cb6c2968d4c8f1275a3e0cc3bab7f20f81913198b542c.df616398f4c84def6fca83d755543b01cb445db4ddd218d3efeded8ded68332f
I0108 04:39:54.418124 139994149881664 configuration_utils.py:199] Model config {
  "dec_dropout": 0.2,
  "dec_ff_size": 2048,
  "dec_heads": 8,
  "dec_hidden_size": 768,
  "dec_layers": 6,
  "enc_dropout": 0.2,
  "enc_ff_size": 512,
  "enc_heads": 8,
  "e

In [6]:
nlp_path = os.path.abspath("../../")
if nlp_path not in sys.path:
    sys.path.insert(0, nlp_path)
from utils_nlp.models.transformers.extractive_summarization import Bunch
args = Bunch({"block_trigram":True, "alpha": 0.95, "beam_size": 5, "min_length": 20, "max_length": 200})

In [7]:
predictor = build_predictor(args, tokenizer, symbols, model)

In [8]:
import os
import shutil
import sys
from tempfile import TemporaryDirectory
import torch

nlp_path = os.path.abspath("../../")
if nlp_path not in sys.path:
    sys.path.insert(0, nlp_path)

from utils_nlp.common.pytorch_utils import get_device
from utils_nlp.dataset.cnndm import CNNDMBertSumProcessedData, CNNDMSummarizationDataset
from utils_nlp.eval.evaluate_summarization import get_rouge
from utils_nlp.models.transformers.extractive_summarization import (
    ExtractiveSummarizer,
    ExtSumProcessedData,
    ExtSumProcessor,
)

import numpy as np
import pandas as pd
import scrapbook as sb

[nltk_data] Downloading package punkt to /home/daden/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [9]:
#from utils_nlp.models.transformers.datasets import SummarizationDataset
from utils_nlp.dataset.cnndm import CNNDMAbsSumDataset, CNNDMSummarizationDataset
#def build_data_iterator(args, tokenizer):


In [10]:
DATA_PATH = '/tmp/tmpbvxzmv1v'

In [11]:
QUICK_RUN = False
# the data path used to save the downloaded data file
DATA_PATH = TemporaryDirectory().name
# The number of lines at the head of data file used for preprocessing. -1 means all the lines.
TOP_N = 4
CHUNK_SIZE=200
if not QUICK_RUN:
    TOP_N = -1
    CHUNK_SIZE = 2000

In [12]:
from torch.utils.data import Dataset
class SummarizationDataset(Dataset):
    def __init__(self, source, target=None):
        self.source = source
        self.target = target
    def __len__(self):
        return len(self.source)
    def __getitem__(self, idx):
        return self.source[idx], self.target[idx]

In [13]:
train_dataset, test_dataset = CNNDMAbsSumDataset(top_n=TOP_N, local_cache_path=DATA_PATH)

100%|██████████| 489k/489k [00:07<00:00, 68.9kKB/s] 
I0108 04:40:12.525922 139994149881664 utils.py:173] Opening tar file /tmp/tmpbbgje18p/cnndm.tar.gz.


In [14]:
data = list(test_dataset.get_source()), list(test_dataset.get_target())

In [15]:
len(data[1])

11490

In [23]:
test_sum_dataset = SummarizationDataset(data[0], data[1])

In [24]:
Batch = namedtuple("Batch", [ "batch_size", "src", "segs", "mask_src", "tgt_str"])

In [25]:
def collate(data, tokenizer, block_size, device):
    """ Collate formats the data passed to the data loader.
    In particular we tokenize the data batch after batch to avoid keeping them
    all in memory. We output the data as a namedtuple to fit the original BertAbs's
    API.
    """
    data = [x for x in data if not len(x[1]) == 0]  # remove empty_files
    #print(data)
    #names = [name for name, _, _ in data]
    # summaries = [" ".join(summary_list) for _, _, summary_list in data]
    summaries = [" ".join(summary_list) for _, summary_list in data]
  

    encoded_text = [encode_for_summarization(story, summary, tokenizer) for story, summary in data]
    
    
    #""""""
    encoded_stories = torch.tensor(
        [fit_to_block_size(story, block_size, tokenizer.pad_token_id) for story, _ in encoded_text]
    )
    encoder_token_type_ids = compute_token_type_ids(encoded_stories, tokenizer.cls_token_id)
    encoder_mask = build_mask(encoded_stories, tokenizer.pad_token_id)
    #"""
    print(len(encoded_stories))

    batch = Batch(
        #document_names=None,
        batch_size=len(encoded_stories),
        src=encoded_stories.to(device),
        segs=encoder_token_type_ids.to(device),
        mask_src=encoder_mask.to(device),
        tgt_str=summaries,
    )
    return batch

In [26]:
def encode_for_summarization(story_lines, summary_lines, tokenizer, max_len=512):
    """ Encode the story and summary lines, and join them
    as specified in [1] by using `[SEP] [CLS]` tokens to separate
    sentences.
    """
    story_lines_token_ids = [tokenizer.encode(line, max_length=max_len) for line in story_lines]
    story_token_ids = [token for sentence in story_lines_token_ids for token in sentence]
    summary_lines_token_ids = [tokenizer.encode(line, max_length=max_len) for line in summary_lines]
    summary_token_ids = [token for sentence in summary_lines_token_ids for token in sentence]

    return story_token_ids, summary_token_ids

In [27]:
def build_data_iterator(dataset, tokenizer, batch_size=16, device='cuda'):
   
    sampler = SequentialSampler(dataset)

    def collate_fn(data):
        return collate(data, tokenizer, block_size=512, device=device)

    iterator = DataLoader(dataset, sampler=sampler, batch_size=batch_size, collate_fn=collate_fn,)

    return iterator

In [28]:
from utils_nlp.common.pytorch_utils import get_device
device, num_gpus = get_device(num_gpus=4, local_rank=-1)

In [29]:
data_iterator = build_data_iterator(train_sum_dataset, tokenizer, batch_size=64, device=device)

In [30]:
reference_summaries = []
generated_summaries = []

In [None]:
for batch in tqdm(data_iterator):
    
    batch_data = predictor.translate_batch(batch)
    translations = predictor.from_batch(batch_data)
    summaries = [format_summary(t) for t in translations]
    #save_summaries(summaries, args.summaries_output_dir, batch.document_names)

    if True:
        reference_summaries += batch.tgt_str
        generated_summaries += summaries

  0%|          | 0/180 [00:00<?, ?it/s]

64


  1%|          | 1/180 [00:23<1:08:39, 23.01s/it]

64


  1%|          | 2/180 [00:41<1:03:56, 21.55s/it]

64


  2%|▏         | 3/180 [01:01<1:02:13, 21.09s/it]

64


  2%|▏         | 4/180 [01:20<59:58, 20.45s/it]  

64


  3%|▎         | 5/180 [01:39<58:56, 20.21s/it]

64


  3%|▎         | 6/180 [01:57<56:05, 19.34s/it]

64


  4%|▍         | 7/180 [02:16<55:39, 19.31s/it]

64


  4%|▍         | 8/180 [02:37<57:22, 20.02s/it]

64


  5%|▌         | 9/180 [02:58<57:06, 20.04s/it]

64


  6%|▌         | 10/180 [03:17<56:26, 19.92s/it]

64


  6%|▌         | 11/180 [03:38<56:26, 20.04s/it]

64


  7%|▋         | 12/180 [03:59<57:18, 20.47s/it]

64


  7%|▋         | 13/180 [04:20<57:12, 20.56s/it]

64


  8%|▊         | 14/180 [04:40<56:14, 20.33s/it]

64


  8%|▊         | 15/180 [05:01<56:39, 20.60s/it]

64


  9%|▉         | 16/180 [05:20<55:06, 20.16s/it]

64


  9%|▉         | 17/180 [05:40<55:04, 20.27s/it]

64


 10%|█         | 18/180 [06:12<1:03:38, 23.57s/it]

64


 11%|█         | 19/180 [06:44<1:10:07, 26.13s/it]

64


 11%|█         | 20/180 [07:14<1:13:00, 27.38s/it]

64


 12%|█▏        | 21/180 [07:43<1:13:54, 27.89s/it]

64


 12%|█▏        | 22/180 [08:15<1:16:35, 29.08s/it]

64


 13%|█▎        | 23/180 [08:43<1:15:04, 28.69s/it]

64


 13%|█▎        | 24/180 [09:20<1:20:53, 31.11s/it]

64


 14%|█▍        | 25/180 [09:50<1:20:03, 30.99s/it]

64


 14%|█▍        | 26/180 [10:26<1:22:53, 32.30s/it]

64


 15%|█▌        | 27/180 [10:55<1:20:17, 31.49s/it]

64


 16%|█▌        | 28/180 [11:25<1:18:29, 30.98s/it]

64


In [None]:
reference_summaries[0]

In [None]:
generated_summaries[0]

In [None]:
   def _write_list_to_file(list_items, filename):
        with open(filename, "w") as filehandle:
            # for cnt, line in enumerate(filehandle):
            for item in list_items:
                filehandle.write("%s\n" % item)

In [None]:
_write_list_to_file(generated_summaries, "./generated_summaries")