In [6]:

import torch
from captum.attr import InputXGradient,DeepLiftShap
from transformers import PegasusTokenizer, PegasusForConditionalGeneration

model = PegasusForConditionalGeneration.from_pretrained("google/pegasus-xsum")
tokenizer = PegasusTokenizer.from_pretrained("google/pegasus-xsum")

HBox(children=(IntProgress(value=0, description='Downloading', max=1099, style=ProgressStyle(description_width…




HBox(children=(IntProgress(value=0, description='Downloading', max=2275329241, style=ProgressStyle(description…

KeyboardInterrupt: 

In [None]:
from typing import List
PGE_ARTICLE = "PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
mname = "google/pegasus-xsum"
batch = tokenizer.prepare_seq2seq_batch(src_texts=[PGE_ARTICLE])  # don't need tgt_text for inference
gen = model.generate(**batch)  # for forward pass: model(**batch)
summary: List[str] = tokenizer.batch_decode(gen, skip_special_tokens=True)

BART

In [4]:

from transformers import BartTokenizer, BartForConditionalGeneration
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-xsum')
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-xsum')

In [6]:
encoder = model.get_encoder()

TXT = "My friends are good but they eat too many ."
input_ids = tokenizer([TXT], return_tensors='pt')['input_ids']
encoder_outputs = encoder(input_ids, return_dict=True)


In [4]:
def wrap_mask_filler(inp):
    # assume batch size = 1
    print(inp)
    print(inp.size())
    scores = model(inp)[0]
    masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
    probs = scores[0, masked_index].softmax(dim=-1)
    return probs

In [5]:
from captum.attr import LayerIntegratedGradients, TokenReferenceBase, visualization

token_reference = TokenReferenceBase(reference_token_idx=tokenizer.pad_token_id)

lig = LayerIntegratedGradients(wrap_mask_filler, model.model.shared)


In [None]:

TXT = "My friends are good but they eat too many <mask> ."
input_ids = tokenizer([TXT], return_tensors='pt')['input_ids']
print(tokenizer.decode(input_ids[0]).split())
seq_length = input_ids.shape[1]
reference_indices = token_reference.generate_reference(seq_length, device='cpu').unsqueeze(0)
reference_indices[:,0] = tokenizer.bos_token_id
reference_indices[:,-1] = tokenizer.eos_token_id

masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
attributions_ig, delta = lig.attribute(input_ids, baselines=reference_indices, \
                                           n_steps=50, return_convergence_delta=True)

# probs = outputs[0, masked_index].softmax(dim=-1)
# values, predictions = probs.topk(5)
# tokenizer.decode(predictions).split()
# ['good', 'great', 'all', 'really', 'very']

['<s>My', 'friends', 'are', 'good', 'but', 'they', 'eat', 'too', 'many<mask>.</s>']
tensor([[    0,  2387,   964,    32,   205,    53,    51,  3529,   350,   171,
         50264,   479,     2]])
torch.Size([1, 13])
tensor([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2]])
torch.Size([1, 13])


In [2]:

def load_BART(mname='facebook/bart-large-xsum'):

    # Mask filling only works for bart-large
    from transformers import BartTokenizer, BartForConditionalGeneration
    tokenizer = BartTokenizer.from_pretrained(mname)
    model = BartForConditionalGeneration.from_pretrained(mname)
    return model, tokenizer
model, tokenizer = load_BART()

In [3]:
import os
import json
from typing import List, Dict

from transformers import BatchEncoding, PreTrainedTokenizer

from captum.attr import LayerIntegratedGradients, LayerGradientShap, LayerGradientXActivation, TokenReferenceBase


def read_json_data(fdir, fname):
    """Read json file from fdir/fname. Assume it contains key 'data'."""
    fp = open(os.path.join(fdir, fname), 'r')
    data_dict = json.load(fp)['data']
    return data_dict


def get_input_docs_from_json(data_dict: List[Dict], use_add_sent=True):
    outputs = []
    outputs_qa_pairs = []
    for data in data_dict:
        inp_doc, add_sent = data['input_doc'], data['added_sent']
        input_str = inp_doc if not use_add_sent else "{} {}".format(
            inp_doc, add_sent)
        outputs.append(input_str)

        mask_pairs = data['mask_pairs']
        for mask_pair in mask_pairs:
            q, a, wa = mask_pair['q'], mask_pair['a'], mask_pair['wa']
            outputs_qa_pairs.append(
                {
                    'context': input_str,
                    'q': q,
                    'a': a,
                    'wa': wa
                }
            )
    return outputs, outputs_qa_pairs


import torch
from typing import Any, List, Optional


def attr_visualization():
    pass


# layer wise: inputs, baselines, target, additional_forward_args
# for each summary, the model needs to encode the document again and again?

class SumGen(torch.nn.Module):
    def __init__(self, model, tokenizer: PreTrainedTokenizer, attribution_func, max_len=50):
        super().__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.use_cache = True
        self.attr = attribution_func(self.forward_step, self.model.model.shared)
        self.encoder = self.model.get_encoder()

    def prepare_batch_inp(self, input_articles: List[str], tgt_summaries: Optional[List[str]]) -> BatchEncoding:
        return self.tokenizer.prepare_seq2seq_batch(src_texts=input_articles, tgt_texts=tgt_summaries,
                                                    return_tensors='pt')

    def run_attribution(self, input_doc, attn_mask, tgt_sum=None):

        device = input_doc.device
        batch_size = input_doc.shape[0]
        encoder = self.model.get_encoder()
        cur_len = 1
        has_eos = [False for _ in range(batch_size)]
        bos_token_id = self.tokenizer.bos_token_id
        decoded = [[bos_token_id] for _ in range(batch_size)]
        decoder_input_ids = torch.LongTensor(decoded).to(device)
        past_key_values = None
        seq_length = input_doc.shape[1]
        token_reference = TokenReferenceBase(reference_token_idx=self.tokenizer.pad_token_id)
        reference_indice = token_reference.generate_reference(seq_length, device=device)
        reference_indices = torch.stack([reference_indice for _ in range(batch_size)], dim=0)
        reference_indices[:, 0] = self.tokenizer.bos_token_id
        reference_indices[:, -1] = self.tokenizer.eos_token_id

        while cur_len < self.max_len and (not all(has_eos)):
            additional_input = {
                "attn_mask": attn_mask,
                "past_key_values": past_key_values,
                "decoder_input_ids": decoder_input_ids, "attr_mode": False
            }
            cur_decoded, cur_past_key_values, cur_decoder_input_ids = self.forward_step(input_doc,additional_input
                                                                                        )
            # cur_decoded is just a list with token id
            for idx, cur_dec_tok in enumerate(cur_decoded):
                if cur_dec_tok == self.tokenizer.eos_token_id:
                    has_eos[idx] = True
            if tgt_sum is None:
                target = cur_decoder_input_ids[:, -1].unsqueeze(0)
                target = cur_decoded
            else:
                pass
            additional_input = {
                "attn_mask": attn_mask,
                "past_key_values": past_key_values,
                "decoder_input_ids": decoder_input_ids, "attr_mode": True
            }
            attribution, delta = self.attr.attribute(inputs=input_doc, baselines=reference_indices, target=target,
                                                     additional_forward_args=additional_input
                                                     )
            past_key_values = cur_past_key_values
            decoder_input_ids = cur_decoder_input_ids
        print("end of decoding")

    def forward_step(self, input_doc,
                     additional_input_args: dict,
                     # attn_mask, past_key_values, decoder_input_ids, attr_mode: bool
                     ):
        attn_mask, past_key_values, decoder_input_ids, attr_mode = \
            additional_input_args['attn_mask'], additional_input_args['past_key_values'], additional_input_args[
                'decoder_input_ids'], additional_input_args['attr_mode'],
        encoder_outputs = self.encoder(input_doc, attention_mask=attn_mask, return_dict=True)
        batch_size = input_doc.shape[0]
        device = input_doc.device

        expanded_batch_idxs = (
            torch.arange(batch_size)
                .view(-1, 1)
                .repeat(1, 1)
                .view(-1)
                .to(device)
        )
        encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
            0, expanded_batch_idxs
        )

        model_inputs = {"input_ids": None,
                        "past_key_values": past_key_values,
                        "attention_mask": attn_mask,
                        "encoder_outputs": encoder_outputs,
                        "decoder_input_ids": decoder_input_ids,
                        }

        outputs = self.model(**model_inputs, use_cache=True, return_dict=True)
        next_token_logits = outputs.logits[:, -1, :]
        next_token = torch.argmax(next_token_logits, dim=-1)
        if attr_mode:
            return next_token_logits.unsqueeze(0)
        next_token = next_token.unsqueeze(-1)
        cur_decoded = next_token.tolist()
        if "past_key_values" in outputs:
            past_key_values = outputs.past_key_values

        decoder_input_ids = torch.cat([decoder_input_ids, next_token], dim=-1)
        return cur_decoded, past_key_values, decoder_input_ids

    def forward_entrance(self, input_doc, attention_mask, tgt_sum):

        device = input_doc.device
        batch_size = input_doc.shape[0]
        encoder = self.model.get_encoder()
        encoder_outputs = encoder(input_doc, attention_mask=attention_mask, return_dict=True)
        cur_len = 1
        decoded = [[self.tokenizer.bos_token_id] for _ in range(batch_size)]
        decoder_input_ids = torch.LongTensor(decoded).to(device)
        next_token = decoder_input_ids
        past = None
        expanded_batch_idxs = (
            torch.arange(batch_size)
                .view(-1, 1)
                .repeat(1, 1)
                .view(-1)
                .to(device)
        )
        encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
            0, expanded_batch_idxs
        )
        all_logits = []
        while cur_len < self.max_len:
            self.forward_step()
            attributions_ig, delta = self.attr.attribute(input_ids, baselines=reference_indices, \
                                                         n_steps=50, return_convergence_delta=True)

    def forward(self, input_doc, input_doc_attn_mask, tgt_sum=None):
        device = input_doc.device
        batch_size = input_doc.shape[0]
        encoder = self.model.get_encoder()
        encoder_outputs = encoder(input_doc, attention_mask=input_doc_attn_mask, return_dict=True)
        cur_len = 1
        decoded = [[self.tokenizer.bos_token_id] for _ in range(batch_size)]
        decoder_input_ids = torch.LongTensor(decoded).to(device)
        next_token = decoder_input_ids
        past = None
        expanded_batch_idxs = (
            torch.arange(batch_size)
                .view(-1, 1)
                .repeat(1, 1)
                .view(-1)
                .to(device)
        )
        encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
            0, expanded_batch_idxs
        )
        all_logits = []
        while cur_len < self.max_len:
            model_inputs = {"input_ids": None,
                            "past_key_values": past,
                            "attention_mask": input_doc_attn_mask,
                            "encoder_outputs": encoder_outputs,
                            "decoder_input_ids": decoder_input_ids,
                            }

            outputs = self.model(**model_inputs, use_cache=True, return_dict=True)
            next_token_logits = outputs.logits[:, -1, :]
            all_logits.append(next_token_logits)
            next_token = torch.argmax(next_token_logits, dim=-1)
            next_token = next_token.unsqueeze(-1)
            cur_decoded = next_token.tolist()
            decoded = [already_dec + cur_decoded[idx] for idx, already_dec in enumerate(decoded)]
            if "past_key_values" in outputs:
                past = outputs.past_key_values

            decoder_input_ids = torch.cat([decoder_input_ids, next_token], dim=-1)
            cur_len += 1
        # while cur_len < self.max_len:
        #     model_inputs = {"input_ids": decoder_input_ids,
        #                     "past_key_values": past,
        #                     "attention_mask": input_doc_attn_mask,
        #                     "encoder_outputs": encoder_outputs,
        #                     # "decoder_input_ids": next_token,
        #                     "use_cache": self.use_cache}
        #
        #     outputs = self.model(**model_inputs, return_dict=True)
        #     next_token_logits = outputs.logits[:, -1, :]
        #     next_token = torch.argmax(next_token_logits, dim=-1)
        #     next_token = next_token.unsqueeze(-1)
        #     cur_decoded = next_token.tolist()
        #     decoded = [ already_dec+cur_decoded[idx] for idx,already_dec in enumerate(decoded)]
        #     if "past_key_values" in outputs:
        #         past = outputs.past_key_values
        #
        #     decoder_input_ids = torch.cat([decoder_input_ids, next_token],dim=-1)
        #     cur_len += 1
        for dec in decoded:
            print(self.tokenizer.decode(dec).split())

        torch.stack(all_logits)

In [1]:

sgen = SumGen(model=model,tokenizer=tokenizer,attribution_func=LayerIntegratedGradients)
example_input = ["The Pegasus model was proposed in PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu on Dec 18, 2019. According to the abstract, Pegasus’ pretraining task is intentionally similar to summarization: important sentences are removed/masked from an input document and are generated together as one output sequence from the remaining sentences, similar to an extractive summary.",
             "Transformer-based models are unable to process long sequences due to their self-attention operation, which scales quadratically with the sequence length. To address this limitation, we introduce the Longformer with an attention mechanism that scales linearly with sequence length, making it easy to process documents of thousands of tokens or longer. Longformer’s attention mechanism is a drop-in replacement for the standard self-attention and combines a local windowed attention with a task motivated global attention. Following prior work on long-sequence transformers, we evaluate Longformer on character-level language modeling and achieve state-of-the-art results on text8 and enwik8. In contrast to most prior work, we also pretrain Longformer and finetune it on a variety of downstream tasks. Our pretrained Longformer consistently outperforms RoBERTa on long document tasks and sets new state-of-the-art results on WikiHop and TriviaQA."]
example_output = ["The paper is accepted by ICML.", "The paper comes from AI2."]

example_input = example_input[:1]
example_output = example_output[:1]
data_w_label  = sgen.prepare_batch_inp(example_input, example_output)   # data_w_label['data']['input_ids']  and data_w_label['data']['labels']

# sgen.run_attribution(data_w_label.data['input_ids'], data_w_label.data['attention_mask'])
sgen.run_attribution(data_w_label.data['input_ids'], None)


NameError: name 'SumGen' is not defined