In [None]:
#@title foo
!pip install transformers==4.1.1 captum==0.3.0 plotnine

In [None]:
import re
import itertools

import numpy as np
import pandas as pd

from IPython.display import HTML
import plotnine
from plotnine import *

import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer

from captum.attr import visualization as viz
from captum.attr import DeepLiftShap, ShapleyValueSampling, LayerIntegratedGradients, IntegratedGradients, Occlusion
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer

plotnine.options.figure_size = (12, 12)

In [None]:
#transformer = "distilbert-base-cased"
transformer = "roberta-base"
#transformer = "twmkn9/bert-base-uncased-squad2"
tokenizer = AutoTokenizer.from_pretrained(transformer)
model = AutoModelForMaskedLM.from_pretrained(transformer)
#model.eval()
#model.zero_grad()

In [None]:
def prepare_data_for_sentence(sent, N):
    input_seq = tokenizer.encode(sent)

    mask_index = input_seq.index(tokenizer.mask_token_id)
#    print(mask_index)

    input_tensor = torch.tensor([input_seq])

    input_result = model.forward(input_tensor, return_dict=True)

    input_result.logits.shape

    token_logits = input_result.logits
    mask_token_logits = token_logits[0, mask_index, :]
    mask_token_probs = torch.nn.functional.softmax(mask_token_logits, dim=0)

    # get the top predictions for the non-occluded sentence
    top_N = torch.topk(mask_token_probs, N, dim=0)
#    print(top_N)
    probs = top_N.values.tolist()
    top_N_tokens = top_N.indices.tolist()
#    print(probs, sum(probs), top_N_tokens)
    return input_seq, top_N_tokens, mask_index

In [None]:
def dls_run_attribution_model(input_seq, ref_token_id, top_N_tokens):
    def custom_forward(inputs, attention_mask=None):
        result = model.forward(inputs, attention_mask=attention_mask, return_dict=True)
        preds = result.logits
        N_token = preds[:, 0, :]
        return N_token

    ablator = DeepLiftShap(custom_forward)

    input_tensor = torch.tensor([input_seq] * len(top_N_tokens))
    attention_mask = torch.ones_like(input_tensor)
#    print(input_tensor.shape, input_tensor.dtype)
    ref_tensor = torch.tensor([ref_token_id]).expand((len(top_N_tokens), len(input_seq)))
#    print(ref_tensor.shape, input_tensor.dtype)

    attributions = ablator.attribute(
            inputs=input_tensor,
            baselines=ref_tensor,
            additional_forward_args=(attention_mask,),
            target=top_N_tokens,
    )

    return attributions.T

def svs_run_attribution_model(input_seq, ref_token_id, top_N_tokens):
    def custom_forward(inputs, attention_mask=None):
        result = model.forward(inputs, attention_mask=attention_mask, return_dict=True)
        preds = result.logits
        N_token = preds[:, 0, :]
        return N_token

    ablator = ShapleyValueSampling(custom_forward)

    input_tensor = torch.tensor([input_seq] * len(top_N_tokens))
    attention_mask = torch.ones_like(input_tensor)
#    print(input_tensor.shape, input_tensor.dtype)
    ref_tensor = torch.tensor([ref_token_id]).expand((len(top_N_tokens), len(input_seq)))
#    print(ref_tensor.shape, input_tensor.dtype)

    attributions = ablator.attribute(
            inputs=input_tensor,
            baselines=ref_tensor,
            additional_forward_args=(attention_mask,),
            target=top_N_tokens,
    )

    return attributions.T

def occlusion_run_attribution_model(input_seq, ref_token_id, top_N_tokens):
    def custom_forward(inputs, attention_mask=None):
        #result = model.forward(inputs.long(), return_dict=True)
        result = model.forward(inputs, return_dict=True)
        preds = result.logits
        N_token = preds[:, 0, :]
        return N_token

    ablator = Occlusion(custom_forward)

    input_tensor = torch.tensor([input_seq] * len(top_N_tokens))

    ref_tensor = torch.tensor([ref_token_id])


    attributions = ablator.attribute(
            inputs=input_tensor,
            baselines=ref_token_id,
            sliding_window_shapes=(1,),
            target=top_N_tokens,
    )
#    print("ATTRIBUTIONS", attributions.T)
    return attributions.T

def lig_run_attribution_model(input_seq, ref_token_id, top_N_tokens):
    def custom_forward(inputs, attention_mask=None):
        result = model.forward(inputs, attention_mask=attention_mask, return_dict=True)
        preds = result.logits
        N_token = preds[:, 0, :]
        return N_token

    def summarize_attributions(attributions):
        attributions = attributions.sum(dim=-1).squeeze(0)
        attributions = attributions / torch.norm(attributions)
        return attributions
    ablator = LayerIntegratedGradients(custom_forward, model.roberta.embeddings)

    input_tensor = torch.tensor([input_seq] * len(top_N_tokens))
#    attention_mask = torch.ones_like(input_tensor)

    ref_tensor = torch.tensor([ref_token_id])
#    print(top_N_tokens)


    attributions = ablator.attribute(
            inputs=input_tensor,
            baselines=ref_token_id,
#            additional_forward_args=(attention_mask,),
            target=top_N_tokens,
    )
    attributions = summarize_attributions(attributions)
    return attributions.T

def ig_run_attribution_model(input_seq, ref_token_id, top_N_tokens):
    def custom_forward(inputs, attention_mask=None):
        result = model.forward(inputs, return_dict=True, attention_mask=attention_mask)
        preds = result.logits
        N_token = preds[:, 0, :]
        return N_token

    def summarize_attributions(attributions):
        attributions = attributions.sum(dim=-1).squeeze(0)
        attributions = attributions / torch.norm(attributions)
        return attributions

    def construct_whole_bert_embeddings(input_ids, ref_input_ids):
    
        input_embeddings = interpretable_embedding.indices_to_embeddings(input_ids)
        #, token_type_ids=token_type_ids, position_ids=position_ids)
        ref_input_embeddings = interpretable_embedding.indices_to_embeddings(ref_input_ids)
        #, token_type_ids=token_type_ids, position_ids=position_ids)

        return input_embeddings, ref_input_embeddings


    interpretable_embedding = configure_interpretable_embedding_layer(model, 'roberta.embeddings')
    try:
        ablator = IntegratedGradients(custom_forward)

        input_tensor = torch.tensor([input_seq] * len(top_N_tokens))
#        print(input_tensor.shape)
        ref_tensor = torch.tensor([ref_token_id]).expand((1,len(input_seq)))
#        print(ref_tensor.shape)
        interpretable_input_tensor = interpretable_embedding.indices_to_embeddings(input_tensor)
#        print(interpretable_input_tensor.shape)
        ref_tensor = interpretable_embedding.indices_to_embeddings(ref_tensor)
#        print(ref_tensor.shape)

        attention_mask = torch.ones_like(input_tensor)
        attributions = ablator.attribute(
                inputs=interpretable_input_tensor,
                baselines=ref_tensor,
                additional_forward_args=(attention_mask,),
                target=top_N_tokens,
        )
        attributions = summarize_attributions(attributions)
    finally:
        remove_interpretable_embedding_layer(model, interpretable_embedding)

    return attributions.T
attr_models = {
    'Occlusion': occlusion_run_attribution_model,
    'LayerIntegratedGradients': lig_run_attribution_model,
    'IntegratedGradients': ig_run_attribution_model,
    'ShapleyValueSampling': svs_run_attribution_model,
#    'DeepLiftShap': dls_run_attribution_model,
}

In [None]:
def build_dataframe(model, attributions, input_seq, top_N_tokens):

    ix = pd.CategoricalIndex(
                tokenizer.convert_ids_to_tokens(input_seq),
                categories=reversed(tokenizer.convert_ids_to_tokens(input_seq)),
                ordered=True
            )
#    print(ix)
    attr_df = (
        pd.DataFrame(
                attributions.detach().numpy(), 
                columns=tokenizer.convert_ids_to_tokens(top_N_tokens),
                index=ix,
            )
        .reset_index()
    )
    attr_df = attr_df.melt(id_vars=["index"])
    #attr_df = attr_df[~(attr_df['index'] == "<mask>")]
    attr_df['variable'] = pd.Categorical(
            attr_df['variable'], 
            categories = tokenizer.convert_ids_to_tokens(top_N_tokens), 
            ordered=True
        )
    attr_df['display_value'] = attr_df['value'].apply(lambda f: f"{f:.2f}")
    attr_df['model'] = model
    return attr_df
    

In [None]:
def create_plot(attr_df, mask_index, N):
#    ncol = np.ceil(np.sqrt(N))
    return (ggplot(attr_df, aes(x="index", y="value")) + 
        geom_col(aes(fill="index", colour="index"))  +
        geom_text(aes(y="value/2", label="display_value"), size=10) +
#        geom_label(aes(x = len(input_seq) - mask_index, y= 0, label="variable"), size=13, boxstyle="darrow") +
        scale_x_discrete(drop=False) +
#        facet_grid("model~variable") +
        facet_wrap("~model+variable", scales="free_x", ncol=N) +
        coord_flip() +
        labs(
                x="Token in sentence",
                y="Captum contribution scores",
                title="Exploring the contribution of each token to the prediction."
            ) +
        theme(legend_position="none", subplots_adjust={'hspace': 0.25})
    )

In [None]:
def show_in_sentence(df, mask_index):
    output = ""
    for model, mg in df.groupby('model'):
        output += "<div style='font-size: 1.5em; padding: 1em; background-color: #CCC'>"
        output += f"<h3>{model}</h3>"
        for v, g in mg.groupby('variable'):
            output += f"<div style='margin: 0.5em; padding: 0.5em; background-color: white; border: 3px solid #CCC; border-radius: 0.3em'>"
            for i, row in g.iterrows():
                if i % len(g) == mask_index:
                    output += f"<span style='padding: 0.1em; text-decoration:underline;  background-color: rgba({'0,255,0' if row['value'] > 0 else '255,0,0'},{row['rel_value']});'>{row['variable']}</span>"
                else:
                    word = row['index'].strip()
                    word = re.sub(r'<', '&lt;', re.sub(r'>', '&gt;', word))
                    output += f"<span style='padding: 0.1em; background-color: rgba({'0,255,0' if row['value'] > 0 else '255,0,0'},{row['rel_value']});'>{word}</span>"
            output += "</div>"
        output += "</div>"
    display(HTML(output))

In [None]:
def run_and_show(sentence, mask_tokens, reference=tokenizer.unk_token_id):
    if isinstance(mask_tokens, int):
        N = mask_tokens
        input_seq, mask_tokens, mask_index = prepare_data_for_sentence(sentence, N)
    else:
        N = len(mask_tokens)
        if isinstance(mask_tokens[0], str):
            mask_tokens = tokenizer.convert_tokens_to_ids(mask_tokens)
        input_seq, _, mask_index = prepare_data_for_sentence(sentence, N)

    result_df = None
    for model, model_func in attr_models.items():
        attributions = model_func(input_seq, reference, mask_tokens)
        df = build_dataframe(model, attributions, input_seq, mask_tokens)
        df['rel_value'] = (df['value'].abs() / df['value'].abs().max()).round(2)
        if result_df is None:
            result_df = df
        else:
            result_df = pd.concat([result_df, df])

    plot = create_plot(result_df, mask_index, N)
    display(plot)
    show_in_sentence(result_df, mask_index)


In [None]:
tokenizer.convert_tokens_to_ids(["horse", "Ġhorse", "Ġbicycle"])

In [None]:
tokenizer.convert_tokens_to_ids(["The", "author", "talked", "to", "Sarah", "about", "book"])

In [None]:
tokenizer.convert_tokens_to_ids(["the", "Ġthe", "ĠThe", "Ġauthor", "Ġtalked", "Ġto", "ĠSarah", "Ġabout", "Ġbook"])

In [None]:
tokenizer.convert_ids_to_tokens([19471, 5253, 14678])

In [None]:
run_and_show("John rode his <mask>.", ["horse", "Ġhorse", "Ġbicycle"])

In [None]:
run_and_show("The author talked to Sara about <mask> book.", 4)

In [None]:
run_and_show("The author talked to Sara about <mask> experience.", 4)

In [None]:
ids = tokenizer.encode("The author talked to Sarah about her book.")
tokens = tokenizer.convert_ids_to_tokens(ids)
for id, token in zip(ids, tokens):
    print(id, token)
    

In [None]:
tokens = "The author talked to Sarah about her book".split()
ids = tokenizer.convert_tokens_to_ids(tokens)
actual_tokens = tokenizer.convert_ids_to_tokens(ids)
for id, token, actual in zip(ids, tokens, actual_tokens):
    print(id, f"'{token}'", f"'{actual}'")

    
    

In [None]:
run_and_show("The author talked to Sarah about <mask> book.", 4)

In [None]:
run_and_show("Sarah talked to the author about <mask> book.", ["his", "Ġhis", "Ġher", "her"])