In [14]:
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-xsum')
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-xsum')
device = 'cuda:0'

# util

import argparse
import logging
import os
import pickle
import random
import statistics
import sys
from datetime import datetime
from typing import Dict, List
import multiprocessing
import torch
from datasets import load_dataset
from transformers import BartForConditionalGeneration, BartModel, BartTokenizer
import numpy as np
import pandas as pd

now = datetime.now()

logger = logging.getLogger('sum')
logger.setLevel(logging.DEBUG)
# create file handler which logs even debug messages
fh = logging.FileHandler(f"{now.strftime('%m')}{now.strftime('%d')}.html")
fh.setLevel(logging.DEBUG)
# create console handler with a higher log level
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
# create formatter and add it to the handlers
formatter = logging.Formatter('<br>%(levelname)s - %(message)s')
ch.setFormatter(formatter)
fh.setFormatter(formatter)
# add the handlers to logger
logger.addHandler(ch)
logger.addHandler(fh)

from util import *

from transformers import BartForConditionalGeneration, BartTokenizer

def write_pkl_to_disk(path: str, fname_prefix: str, data_obj):
    full_fname = os.path.join(path, f"{fname_prefix}.pkl")
    with open(full_fname, 'wb') as fd:
        pickle.dump(data_obj, fd)
    logging.debug(f"Done writing to {full_fname}")


def init_bart_sum_model(mname='sshleifer/distilbart-cnn-6-6', device='cuda:0'):
    model = BartForConditionalGeneration.from_pretrained(mname).to(device)
    tokenizer = BartTokenizer.from_pretrained(mname)
    return model, tokenizer

def bart_decoder_forward_embed(input_ids, embed_tokens, embed_scale):
    input_shape = input_ids.size()
    input_ids = input_ids.view(-1, input_shape[-1])
    inputs_embeds = embed_tokens(input_ids) * embed_scale
    return inputs_embeds


def summarize_attributions(attributions):
    attributions = attributions.mean(dim=-1)
    attributions = attributions / torch.norm(attributions)
    return attributions

def forward_enc_dec_step(model, encoder_outputs, decoder_inputs_embeds):
    # expanded_batch_idxs = (
    #         torch.arange(batch_size)
    #             .view(-1, 1)
    #             .repeat(1, 1)
    #             .view(-1)
    #             .to(device)
    #     )
    # encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
    #         0, expanded_batch_idxs
    #     )
    model_inputs = {"input_ids": None,
                    "past_key_values": None,
                    "encoder_outputs": encoder_outputs,
                    "decoder_inputs_embeds": decoder_inputs_embeds,
                    }
    outputs = model(**model_inputs, use_cache=False,
                    return_dict=True, output_attentions=True)
    return outputs


def init_bart_family(name_lm, name_sum, device, no_lm=False, no_ood=False):
    if not no_lm:
        lm_model, tok = init_bart_lm_model(name_lm, device)
    else:
        lm_model = None
    sum_model, tok = init_bart_sum_model(name_sum, device)
    if not no_ood:
        if name_sum == "facebook/bart-large-cnn": 
            sum_out_of_domain, _ = init_bart_sum_model(
            "facebook/bart-large-xsum", device)
        else:
            sum_out_of_domain, _ = init_bart_sum_model(
            "facebook/bart-large-cnn", device) 
    else:
        sum_out_of_domain = None
    return lm_model, sum_model, sum_out_of_domain, tok

from captum.attr._utils.visualization import format_word_importances


def simple_viz_attribution(tokenizer, input_ids, attribution_scores):
    token_in_list = input_ids.tolist()
    if isinstance(token_in_list[0], list):
        token_in_list = token_in_list[0]
    words = [tokenizer.decode(x) for x in token_in_list]
    attribution_scores_list = attribution_scores.tolist()
    # for w, ascore in zip(words, attribution_scores_list):
    #     logging.info('{:10} {:02.2f}'.format(w, ascore))

    output = format_word_importances(words, attribution_scores_list)
    return output


@torch.no_grad()
def run_full_model_slim(model, input_ids, attention_mask=None, decoder_input_ids=None, targets=None, device='cuda:0', output_dec_hid=False, output_attentions=False, T=1, special_attn=False):
    decoder_input_ids = decoder_input_ids.to(device)
    input_ids = input_ids.to(device)
    if attention_mask is not None:
        attention_mask = attention_mask.to(device)
    assert decoder_input_ids.size()[0] == input_ids.size()[0]

    model_inputs = {"input_ids": input_ids,
                    "attention_mask": attention_mask,
                    "decoder_input_ids": decoder_input_ids,
                    }

    outputs = model(**model_inputs,
                    output_hidden_states=output_dec_hid, output_attentions=output_attentions,
                    use_cache=False, return_dict=True)

    # batch, dec seq, vocab size
    next_token_logits = outputs.logits[:, -1, :]
    if targets is not None:
        targets = targets.to(device)
        loss = torch.nn.functional.cross_entropy(
            input=next_token_logits, target=targets, reduction='none')
    else:
        loss = 0
    if special_attn:
        cross_attn = outputs['cross_attentions']
        attn = cross_attn[-1][:, :, -1, :]
        # batch, nhead, enc_len
        mean_attn = torch.mean(attn, dim=1)
        # block special positions in input
        mask = (input_ids >= 5).float()
        mean_attn = mean_attn * mask
        return mean_attn[0] 
    if output_attentions:
        # use cross attention as the distribution
        # last layer.   batch=1, head, dec len, enc len
        # by default we use the last layer of attention
        output, p = get_cross_attention(
            outputs['cross_attentions'], input_ids, device=device)
        return output, p

    
    prob = torch.nn.functional.softmax(next_token_logits/T, dim=-1)
    # prob = next_token_logits.softmax(dim=-1)
    next_token = torch.argmax(next_token_logits, dim=-1)
    # next_token = next_token.unsqueeze(-1)
    next_token = next_token.tolist()    # confrim nested list?
    # print(f"Gold: {tokenizer.decode(targets[0].item())}")
    output = [tokenizer.decode(tk) for tk in next_token]
    # logging.info(f"Next token: {output}")
    # outputs['output'] = output
    return output, prob, next_token_logits, loss
from scipy.stats import entropy

def load_xsum(split='validation'):
    from datasets import load_dataset
    dataset_xsum = load_dataset('xsum',split=split)
    return dataset_xsum

In [15]:
source_data = load_xsum()

model = model.to(device=device)

Using custom data configuration default
Reusing dataset xsum (/home/jcxu/.cache/huggingface/datasets/xsum/default/1.2.0/f9abaabb5e2b2a1e765c25417264722d31877b34ec34b437c53242f6e5c30d6d)


In [16]:
# based on current prediction, predict future k steps
from typing import List

def future_simulation(model, device, input_doc_token_ids, prefix_token_ids:List, max_expand_steps=5,min_expand_prob=0.1): # return all of the possible tokens
    copy_of_prefix_token_ids = prefix_token_ids.copy()
    generated_token_ids = []
    for t in range(max_expand_steps):
        dec_prefix = torch.tensor([prefix_token_ids],dtype=torch.long).to(device)
        output, prob, next_token_logits, loss  = run_full_model_slim(model,input_doc_token_ids, decoder_input_ids=dec_prefix)
        best_id = torch.argmax(next_token_logits.squeeze())
        prefix_token_ids += [best_id]
        generated_token_ids.append(best_id)
    logger.info(f"Simulation Prefix: {tokenizer.decode(copy_of_prefix_token_ids,skip_special_tokens=True)} | {tokenizer.decode(generated_token_ids,skip_special_tokens=True)}")
    return generated_token_ids



In [17]:
max_beam_size = 10
min_expand_prob = 0.05
cnt = 0
for data in source_data:

    document = data['document']
    summary = data['summary']
    logger.info('*'*50)
    logger.info(f"Summary: {summary}")
    # init
    doc_input_ids = tokenizer(document, return_tensors='pt')['input_ids'][:,:600]
    doc_input_ids = doc_input_ids.to(device)
    # Generate system summaries with large beam
    # inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
    # Generate Summary
    summary_ids = model.generate(doc_input_ids, num_beams=1, max_length=100, early_stopping=True)
    sum_ids_list = summary_ids.squeeze().tolist()
    logger.info(f"Generated summary: {[tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids]}")
    prefix = []
    for idx, tok in enumerate(sum_ids_list):
        if idx + 2 >= len(sum_ids_list):
            logger.info(f"End of seq.")
            break
        prefix.append(tok)
        dec_prefix = torch.tensor([prefix],dtype=torch.long).to(device)
        # print(dec_prefix.size())
        # print(doc_input_ids.size())
        output, prob, next_token_logits, loss  = run_full_model_slim(model,doc_input_ids, decoder_input_ids=dec_prefix)
        top_prob, top_index = torch.topk(input=prob, k=max_beam_size)
        # show_top_k(prob, tokenizer= tokenizer)

        top_p_list = top_prob.squeeze().tolist()
        top_i_list = top_index.squeeze().tolist()
        if top_i_list[0] != sum_ids_list[idx+1]:
            logger.info(f"Result does not match.")
            continue

        # top_p_list.pop(0)
        # actual_index = top_i_list.pop(0)
        
        if len(top_p_list) < 2:
            continue
        logger.info(f"="*20)
        # logger.info(f"Tgt: {tokenizer.decode(prefix + [actual_index], skip_special_tokens=True)}")
        for p,i in zip(top_p_list, top_i_list):
            if p < min_expand_prob:
                break
            current_prefix = prefix + [i]
            # dec_prefix = torch.tensor([current_prefix],dtype=torch.long).to(device)
            future_simulation(model, device, doc_input_ids, current_prefix )

        logger.info(f"="*20)

    cnt += 1
    if cnt >= 100:
        break



<br>INFO - Simulation Prefix: More than one million people are facing hunger in Somalia, the |  UN has warned, as
<br>INFO - Simulation Prefix: More than one million people are facing hunger in Somalia, the |  UN has warned, as
<br>INFO - Simulation Prefix: More than one million people are facing hunger in Somalia, the |  UN has warned, as
<br>INFO - Simulation Prefix: More than one million people are facing hunger in Somalia, a |  UN report says, with
<br>INFO - Simulation Prefix: More than one million people are facing hunger in Somalia, a |  UN report says, with
<br>INFO - Simulation Prefix: More than one million people are facing hunger in Somalia, a |  UN report says, with
<br>INFO - Simulation Prefix: More than one million people are facing hunger in Somalia, according |  to a UN report.
<br>INFO - Simulation Prefix: More than one million people are facing hunger in Somalia, according |  to a UN report.
<br>INFO - Simulation Prefix: More than one million people are facing hunger 

In [2]:
ARTICLE_TO_SUMMARIZE = "Police in suburban Minneapolis shot and killed a man when what where who was allegedly involved in a carjacking and fired shots at pursuing officers, according to a release from the Burnsville Police Department. When, when."
inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
# Generate Summary
summary_ids = model.generate(inputs['input_ids'], num_beams=1, max_length=100, early_stopping=True)
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])

summary_ids = model.generate(inputs['input_ids'], num_beams=10, max_length=100, early_stopping=True)
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])


Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
['Police in the US have shot and killed a man who was allegedly involved in a carjacking.']
['A man has been shot and killed by police in the United States.']
