In [1]:
import torch
from torch import nn, Tensor
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
import transformers
from transformers import BertTokenizer, BertModel, BertConfig, EncoderDecoderModel, BertForMaskedLM

from tqdm import tqdm
import os

import numpy as np
import random
import pandas as pd



In [11]:
transformers.__version__

'4.2.1'

In [2]:
device = 'gpu' if torch.cuda.is_available() else 'cpu'
print(device)

cpu


In [3]:
# load model and check point
MODEL_PATH = './models/bert2bert.pt'


# load model and check point
bert2bert = torch.load(MODEL_PATH, map_location=torch.device('cpu'))
bert2bert.config.decoder_start_token_id = 1
bert2bert.config.eos_token_id = 2
bert2bert.config.sep_token_id = 3
bert2bert.config.pad_token_id = 0

# sensible parameters for beam search
bert2bert.config.vocab_size = bert2bert.config.decoder.vocab_size
bert2bert.config.max_length = 256
bert2bert.config.min_length = 32
bert2bert.config.no_repeat_ngram_size = 3
bert2bert.config.early_stopping = True
bert2bert.config.length_penalty = 2.0
bert2bert.config.num_beams = 4

print(bert2bert.config) # __dict__

EncoderDecoderConfig {
  "decoder": {
    "_name_or_path": "bert-base-uncased",
    "add_cross_attention": true,
    "architectures": [
      "BertForMaskedLM"
    ],
    "attention_probs_dropout_prob": 0.1,
    "bad_words_ids": null,
    "bos_token_id": null,
    "chunk_size_feed_forward": 0,
    "classifier_dropout": null,
    "decoder_start_token_id": null,
    "diversity_penalty": 0.0,
    "do_sample": false,
    "early_stopping": false,
    "encoder_no_repeat_ngram_size": 0,
    "eos_token_id": null,
    "finetuning_task": null,
    "forced_bos_token_id": null,
    "forced_eos_token_id": null,
    "gradient_checkpointing": false,
    "hidden_act": "gelu",
    "hidden_dropout_prob": 0.1,
    "hidden_size": 768,
    "id2label": {
      "0": "LABEL_0",
      "1": "LABEL_1"
    },
    "initializer_range": 0.02,
    "intermediate_size": 3072,
    "is_decoder": true,
    "is_encoder_decoder": false,
    "label2id": {
      "LABEL_0": 0,
      "LABEL_1": 1
    },
    "layer_norm_eps": 1e

In [4]:
from data import batch2BertPadId, createAttMask, collate_fn, SummarisationDataset
from text2token import RawDocTokenize

1.11.0
clss tensor([[   0,   27,   82,  135,  156,  186,  194,  250,  267,  287,  313,  354,
          369,  425,  446,  473,  524,  558,  602,  630,  695,  732,  750,  774,
          790,  856,  892,  937,  966,  992, 1017, 1065, 1080, 1119, 1156, 1186,
           -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
           -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
           -1,   -1,   -1,   -1,   -1],
        [   0,    9,   19,   29,   50,   84,  118,  144,  172,  186,  222,  242,
          255,  290,  319,  337,  373,  410,  426,  435,  480,  521,  529,  553,
          567,  608,  627,  648,  667,  703,  720,  746,  761,  795,  828,  844,
          866,  888,  929,  971,  986, 1024, 1067, 1085, 1121, 1149, 1171,   -1,
           -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
           -1,   -1,   -1,   -1,   -1],
        [   0,   10,   20,   73,  114,  155,  173,  197,  233,  256,  276,  308,
          332,  3

In [5]:
# random.seed(26)
# load test data file, get one random abstract and summary
pd_test = pd.read_csv('./cnn_dailymail/test.csv')
idx = random.randint(0, pd_test.shape[0]-1)

article = pd_test.iloc[idx]['article']

# gt_summary = pd_test.iloc[idx]['highlights']
print('idx = ', idx)
print(article)
# print(gt_summary)


idx =  2683
'Exploited by criminals': Gambian footballer Baboucarr Ceesay (above) was among the 900 migrants who died in the Mediterranean boat disaster, his British aunt has revealed . A British woman has revealed how her nephew was among the 900 migrants who drowned in the Mediterranean boat disaster. Baboucarr Ceesay, a talented footballer from The Gambia, is believed to have died on the fishing boat in a 'desperate' attempt to seek a new life in the UK. His aunt Jessica Sey, from Cheltenham, has spoken of her devastation after discovering that he was not among the 27 survivors and demanded the human traffickers be brought to justice. She said: 'He had his head turned and his money taken by criminals who are responsible for thousands of deaths. It's the biggest shock of my life. 'His mother will never get his body back. She'll never be able to ask him why he did it.' She said her two daughters were 'absolutely gobsmacked' by news of the 21-year-old's death. Mrs Sey and her Gambian-b

In [7]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower=True)

In [8]:
# post-process after generating
def format_summary(raw_summary: str) -> str:
    summary = (raw_summary.replace("[unused0]", "")
                          .replace("[unused3]", "")
                          .replace("[PAD]", "")
                          .replace("[unused1]", "")
                          .replace(r" +", " ")
                          .replace(" [unused2] ", ". ")
                          .replace("[unused2]", "")
                          .replace(" .", ".")
                          .replace(" ,", ",")
                          .replace(" ?", "?")
                          .replace(" !", "!")
                          .strip() )
    return summary

In [9]:
# generate one summary from a document
def generate_summary(article: str):
    bert2bert.eval()
    rdt = RawDocTokenize(article, ' ',
                         article_max_token_length = 512,
                         summary_max_token_length = 512)

    preprocessed1 = rdt.get_tokenized_output(padding = True)

    src = torch.tensor(preprocessed1['src'], dtype = torch.long).view(1, -1)
#     print(src.shape)

    # convert -1 to padding token
    src = batch2BertPadId(src)
    # get attention mask
    src_att_mask = createAttMask(src)
    src, src_att_mask = src.to(device), src_att_mask.to(device)

    # get output as token id
    outputs = bert2bert.generate(src, attention_mask = src_att_mask, decoder_start_token_id =1)
    # token id to string by tokenzier.batch_decoder
    pred_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    # print(pred_str)
    pred_str_formated = format_summary(pred_str[0])

    # replace special character
    pred_str_formated = pred_str_formated.replace('[unused0] ', '')
    pred_str_formated = pred_str_formated.replace('[unused1]', '.')
    pred_str_formated = pred_str_formated.replace('[unused2] ', '.')

#     print(pred_str_formated)
    return pred_str_formated

pred_summary = generate_summary(article)
print('PREDICTED OUTPUT:')
print(pred_summary)
# print('GROUND TRUTH ABSTRACTIVE SUMMARY:')
# print(gt_summary)

  next_indices = next_tokens // vocab_size


PREDICTED OUTPUT:
baboucarr ceesay was among 900 migrants who drowned in the mediterranean boat disaster. his aunt jessica sey has spoken of her devastation after discovering he was not among 27 survivors. she says her nephew's death was " absolutely gobsmacked " by news of his death. babou, 21, was the eldest of four siblings from the west african football players
