In [3]:


class Beam(object):
    """
    Beam data structure. Maintains a list of scored elements like a Counter, but only keeps the top n
    elements after every insertion operation. Insertion is O(n) (list is maintained in
    sorted order), access is O(1). Still fast enough for practical purposes for small beams.
    """
    def __init__(self, size):
        self.size = size
        self.elts = []
        self.scores = []

    def __repr__(self):
        return "Beam(" + repr(list(self.get_elts_and_scores())) + ")"

    def __str__(self):
        return self.__repr__()

    def __len__(self):
        return len(self.elts)

    def add(self, elt, score):
        """
        Adds the element to the beam with the given score if the beam has room or if the score
        is better than the score of the worst element currently on the beam

        :param elt: element to add
        :param score: score corresponding to the element
        """
        if len(self.elts) == self.size and score < self.scores[-1]:
            # Do nothing because this element is the worst
            return
        # If the list contains the item with a lower score, remove it
        i = 0
        while i < len(self.elts):
            if self.elts[i] == elt and score > self.scores[i]:
                del self.elts[i]
                del self.scores[i]
            i += 1
        # If the list is empty, just insert the item
        if len(self.elts) == 0:
            self.elts.insert(0, elt)
            self.scores.insert(0, score)
        # Find the insertion point with binary search
        else:
            lb = 0
            ub = len(self.scores) - 1
            # We're searching for the index of the first element with score less than score
            while lb < ub:
                m = (lb + ub) // 2
                # Check > because the list is sorted in descending order
                if self.scores[m] > score:
                    # Put the lower bound ahead of m because all elements before this are greater
                    lb = m + 1
                else:
                    # m could still be the insertion point
                    ub = m
            # lb and ub should be equal and indicate the index of the first element with score less than score.
            # Might be necessary to insert at the end of the list.
            if self.scores[lb] > score:
                self.elts.insert(lb + 1, elt)
                self.scores.insert(lb + 1, score)
            else:
                self.elts.insert(lb, elt)
                self.scores.insert(lb, score)
            # Drop and item from the beam if necessary
            if len(self.scores) > self.size:
                self.elts.pop()
                self.scores.pop()

    def get_elts(self):
        return self.elts

    def get_elts_and_scores(self):
        return zip(self.elts, self.scores)

    def head(self):
        return self.elts[0]


In [1]:
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-xsum')
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-xsum')
device = 'cuda:0'

# util

import argparse
import logging
import os
import pickle
import random
import statistics
import sys
from datetime import datetime
from typing import Dict, List
import multiprocessing
import torch
from datasets import load_dataset
from transformers import BartForConditionalGeneration, BartModel, BartTokenizer
import numpy as np
import pandas as pd

now = datetime.now()

logger = logging.getLogger('sum')
logger.setLevel(logging.DEBUG)
# create file handler which logs even debug messages
fh = logging.FileHandler(f"{now.strftime('%m')}{now.strftime('%d')}.html")
fh.setLevel(logging.DEBUG)
# create console handler with a higher log level
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
# create formatter and add it to the handlers
formatter = logging.Formatter('<br>%(levelname)s - %(message)s')
ch.setFormatter(formatter)
fh.setFormatter(formatter)
# add the handlers to logger
logger.addHandler(ch)
logger.addHandler(fh)

from util import *

from transformers import BartForConditionalGeneration, BartTokenizer

def write_pkl_to_disk(path: str, fname_prefix: str, data_obj):
    full_fname = os.path.join(path, f"{fname_prefix}.pkl")
    with open(full_fname, 'wb') as fd:
        pickle.dump(data_obj, fd)
    logging.debug(f"Done writing to {full_fname}")


def init_bart_sum_model(mname='sshleifer/distilbart-cnn-6-6', device='cuda:0'):
    model = BartForConditionalGeneration.from_pretrained(mname).to(device)
    tokenizer = BartTokenizer.from_pretrained(mname)
    return model, tokenizer

def bart_decoder_forward_embed(input_ids, embed_tokens, embed_scale):
    input_shape = input_ids.size()
    input_ids = input_ids.view(-1, input_shape[-1])
    inputs_embeds = embed_tokens(input_ids) * embed_scale
    return inputs_embeds


def summarize_attributions(attributions):
    attributions = attributions.mean(dim=-1)
    attributions = attributions / torch.norm(attributions)
    return attributions

def forward_enc_dec_step(model, encoder_outputs, decoder_inputs_embeds):
    # expanded_batch_idxs = (
    #         torch.arange(batch_size)
    #             .view(-1, 1)
    #             .repeat(1, 1)
    #             .view(-1)
    #             .to(device)
    #     )
    # encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
    #         0, expanded_batch_idxs
    #     )
    model_inputs = {"input_ids": None,
                    "past_key_values": None,
                    "encoder_outputs": encoder_outputs,
                    "decoder_inputs_embeds": decoder_inputs_embeds,
                    }
    outputs = model(**model_inputs, use_cache=False,
                    return_dict=True, output_attentions=True)
    return outputs


def init_bart_family(name_lm, name_sum, device, no_lm=False, no_ood=False):
    if not no_lm:
        lm_model, tok = init_bart_lm_model(name_lm, device)
    else:
        lm_model = None
    sum_model, tok = init_bart_sum_model(name_sum, device)
    if not no_ood:
        if name_sum == "facebook/bart-large-cnn": 
            sum_out_of_domain, _ = init_bart_sum_model(
            "facebook/bart-large-xsum", device)
        else:
            sum_out_of_domain, _ = init_bart_sum_model(
            "facebook/bart-large-cnn", device) 
    else:
        sum_out_of_domain = None
    return lm_model, sum_model, sum_out_of_domain, tok

from captum.attr._utils.visualization import format_word_importances


def simple_viz_attribution(tokenizer, input_ids, attribution_scores):
    token_in_list = input_ids.tolist()
    if isinstance(token_in_list[0], list):
        token_in_list = token_in_list[0]
    words = [tokenizer.decode(x) for x in token_in_list]
    attribution_scores_list = attribution_scores.tolist()
    # for w, ascore in zip(words, attribution_scores_list):
    #     logging.info('{:10} {:02.2f}'.format(w, ascore))

    output = format_word_importances(words, attribution_scores_list)
    return output


@torch.no_grad()
def run_full_model_slim(model, input_ids, attention_mask=None, decoder_input_ids=None, targets=None, device='cuda:0', output_dec_hid=False, output_attentions=False, T=1, special_attn=False):
    decoder_input_ids = decoder_input_ids.to(device)
    input_ids = input_ids.to(device)
    if attention_mask is not None:
        attention_mask = attention_mask.to(device)
    assert decoder_input_ids.size()[0] == input_ids.size()[0]

    model_inputs = {"input_ids": input_ids,
                    "attention_mask": attention_mask,
                    "decoder_input_ids": decoder_input_ids,
                    }

    outputs = model(**model_inputs,
                    output_hidden_states=output_dec_hid, output_attentions=output_attentions,
                    use_cache=False, return_dict=True)

    # batch, dec seq, vocab size
    next_token_logits = outputs.logits[:, -1, :]
    if targets is not None:
        targets = targets.to(device)
        loss = torch.nn.functional.cross_entropy(
            input=next_token_logits, target=targets, reduction='none')
    else:
        loss = 0
    if special_attn:
        cross_attn = outputs['cross_attentions']
        attn = cross_attn[-1][:, :, -1, :]
        # batch, nhead, enc_len
        mean_attn = torch.mean(attn, dim=1)
        # block special positions in input
        mask = (input_ids >= 5).float()
        mean_attn = mean_attn * mask
        return mean_attn[0] 
    if output_attentions:
        # use cross attention as the distribution
        # last layer.   batch=1, head, dec len, enc len
        # by default we use the last layer of attention
        output, p = get_cross_attention(
            outputs['cross_attentions'], input_ids, device=device)
        return output, p

    
    prob = torch.nn.functional.softmax(next_token_logits/T, dim=-1)
    # prob = next_token_logits.softmax(dim=-1)
    next_token = torch.argmax(next_token_logits, dim=-1)
    # next_token = next_token.unsqueeze(-1)
    next_token = next_token.tolist()    # confrim nested list?
    # print(f"Gold: {tokenizer.decode(targets[0].item())}")
    output = [tokenizer.decode(tk) for tk in next_token]
    # logging.info(f"Next token: {output}")
    # outputs['output'] = output
    return output, prob, next_token_logits, loss
from scipy.stats import entropy

def load_xsum(split='validation'):
    from datasets import load_dataset
    dataset_xsum = load_dataset('xsum',split=split)
    return dataset_xsum
source_data = load_xsum()
model = model.to(device=device)



In [3]:
# based on current prediction, predict future k steps
from typing import List

from transformers.generation_utils import top_k_top_p_filtering

def future_simulation(model, device, input_doc_token_ids, prefix_token_ids:List, max_expand_steps=5,min_expand_prob=0.1): # return all of the possible tokens
    copy_of_prefix_token_ids = prefix_token_ids.copy()
    generated_token_ids = []
    for t in range(max_expand_steps):
        dec_prefix = torch.tensor([prefix_token_ids],dtype=torch.long).to(device)
        output, prob, next_token_logits, loss  = run_full_model_slim(model,input_doc_token_ids, decoder_input_ids=dec_prefix)
        # print(next_token_logits)
        # print(next_token_logits.size())
        # break
        # squeezed_prob = prob
        # print(squeezed_prob[:10])
        
        filtered_logits = top_k_top_p_filtering(next_token_logits, top_p=0.9, filter_value=0)
        probabilities = torch.nn.functional.softmax(filtered_logits, dim=-1)
        # print(torch.sum(logits))
        next_tokens = torch.multinomial(probabilities, num_samples=1)
        print(next_tokens.size())
        print(next_tokens)
        break
        best_id = torch.argmax(next_token_logits.squeeze())
        prefix_token_ids += [best_id]
        generated_token_ids.append(best_id)
    logger.info(f"Simulation Prefix: {tokenizer.decode(copy_of_prefix_token_ids,skip_special_tokens=True)} | {tokenizer.decode(generated_token_ids,skip_special_tokens=True)}")
    return generated_token_ids




In [5]:
for data in source_data:

    document = data['document']
    summary = data['summary']
    break

doc_input_ids = tokenizer(document, return_tensors='pt')['input_ids'][:,:600]
doc_input_ids = doc_input_ids.to(device)
summary_ids = model.generate(doc_input_ids, num_beams=1, max_length=100, early_stopping=True)
summary_ids =summary_ids.squeeze().tolist()
prefix = summary_ids[:5]
output = future_simulation(model, device, doc_input_ids, prefix )
print(output)
def run_baseline_bs():
    pass

<br>INFO - Simulation Prefix: Apple has been accused | 
INFO:sum:Simulation Prefix: Apple has been accused | 


torch.Size([1, 1])
tensor([[9]], device='cuda:0')
[]


In [28]:
for data in source_data:

    document = data['document']
    summary = data['summary']
    break

document = "The Bart model was proposed in BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov and Luke Zettlemoyer on 29 Oct, 2019. According to the abstract, Bart uses a standard seq2seq/machine translation architecture with a bidirectional encoder (like BERT) and a left-to-right decoder (like GPT). The pretraining task involves randomly shuffling the order of the original sentences and a novel in-filling scheme, where spans of text are replaced with a single mask token. BART is particularly effective when fine tuned for text generation but also works well for comprehension tasks. It matches the performance of RoBERTa with comparable training resources on GLUE and SQuAD, achieves new state-of-the-art results on a range of abstractive dialogue, question answering, and summarization tasks, with gains of up to 6 ROUGE."


inputs = tokenizer([document], max_length=1024, return_tensors='pt')

# Generate Summary
for t in range(3, 10):
    summary_ids = model.generate(inputs['input_ids'].to(device), num_beams=5, min_length=1, max_length=t, early_stopping=True, num_return_sequences=5)
    outs = []
    for g in summary_ids:
        out = tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False)
        outs.append(out)
    outs.sort()
    print('-------',t)
    print('\t\t'.join(outs))

# print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])


------- 3
A		An		Researchers		Scientists		The
------- 4
A model		A new		A novel		A pre		A training
------- 5
A model for		A model that		A new model		A new training		A pre-
------- 6
A model that pret		A new model for		A new training model		A pre-trained		A pre-training
------- 7
A new model for training		A new training model for		A pre-trained model		A pre-training algorithm		A pre-training model
------- 8
A new training model for natural		A pre-trained model for		A pre-training algorithm for		A pre-training model for		A pre-training model that
------- 9
A new training model for natural language		A pre-trained model for natural		A pre-training algorithm for natural		A pre-training model for artificial		A pre-training model for natural
