In [9]:
import itertools
import numpy as np
import pandas as pd
import matplotlib
import seaborn as sns
import plotnine
from plotnine import *
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer

from captum.attr import visualization as viz
from captum.attr import LayerIntegratedGradients
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer

plotnine.options.figure_size = (12, 12)
import warnings
warnings.filterwarnings("ignore")

In [2]:
transformer = "distilbert-base-cased"
#transformer = "roberta-base"
#transformer = "twmkn9/bert-base-uncased-squad2"
tokenizer = AutoTokenizer.from_pretrained(transformer)
model = AutoModelForMaskedLM.from_pretrained(transformer)
model.eval()
model.zero_grad()

In [3]:
def prepare_data_for_sentence(sent, N):
    input_seq = tokenizer.encode(sent)

    mask_index = input_seq.index(tokenizer.mask_token_id)
#    print(mask_index)

    input_tensor = torch.tensor([input_seq])

    input_result = model.forward(input_tensor, return_dict=True)

    input_result.logits.shape

    token_logits = input_result.logits
    mask_token_logits = token_logits[0, mask_index, :]
    mask_token_probs = torch.nn.functional.softmax(mask_token_logits, dim=0)

    # get the top predictions for the non-occluded sentence
    top_N = torch.topk(mask_token_probs, N, dim=0)
#    print(top_N)
    probs = top_N.values.tolist()
    top_N_tokens = top_N.indices.tolist()
#    print(probs, sum(probs), top_N_tokens)
    return input_seq, top_N_tokens, mask_index

In [4]:
def custom_forward(inputs, attention_mask=None, pos=0):
#    result = model.forward(inputs.double(), return_dict=True, attention_mask=attention_mask)
#    print("POS", pos)
    result = model.forward(inputs, return_dict=True, attention_mask=attention_mask)
    preds = result.logits
#    print("PREDS SHAPE:", preds.shape)
    N_token = preds[:, pos, :]
#()    print("SHAPE", N_token.shape)
    return N_token

def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

def construct_whole_bert_embeddings(input_ids, ref_input_ids):
    
    input_embeddings = interpretable_embedding.indices_to_embeddings(input_ids)
    #, token_type_ids=token_type_ids, position_ids=position_ids)
    ref_input_embeddings = interpretable_embedding.indices_to_embeddings(ref_input_ids)
    #, token_type_ids=token_type_ids, position_ids=position_ids)

    return input_embeddings, ref_input_embeddings


def run_attribution_model(input_seq, ref_token_id, top_N_tokens, mask_index, layer):
    interpretable_embedding = configure_interpretable_embedding_layer(model, 'distilbert.embeddings')
    try:
        ablator = LayerIntegratedGradients(custom_forward, layer)

        input_tensor = torch.tensor([input_seq] * len(top_N_tokens))
        ref_tensor = torch.tensor([ref_token_id]).expand((1,len(input_seq)))
        interpretable_input_tensor = interpretable_embedding.indices_to_embeddings(input_tensor)
        ref_tensor = interpretable_embedding.indices_to_embeddings(ref_tensor)

        attention_mask = torch.ones_like(input_tensor)
        attributions = ablator.attribute(
                inputs=interpretable_input_tensor,
                baselines=ref_tensor,
                additional_forward_args=(attention_mask,mask_index),
                target=top_N_tokens,
        )
        attributions = summarize_attributions(attributions)
    finally:
        remove_interpretable_embedding_layer(model, interpretable_embedding)

    return attributions.T

In [5]:
def build_dataframe(attributions, input_seq, top_N_tokens):

    ix = pd.CategoricalIndex(
                [tokenizer.decode([token]) for token in input_seq],
                categories=reversed([tokenizer.decode([token]) for token in input_seq]),
                ordered=True
            )
#    print(ix)
    attr_df = (
        pd.DataFrame(
                attributions.detach().numpy(), 
                columns=[tokenizer.decode([token]) for token in top_N_tokens],
                index=ix,
            )
        .reset_index()
    )
    attr_df = attr_df.melt(id_vars=["index"])
    #attr_df = attr_df[~(attr_df['index'] == tokenizer.mask_token)]
    attr_df['variable'] = pd.Categorical(
            attr_df['variable'], 
            categories = [tokenizer.decode([token]) for token in top_N_tokens], 
            ordered=True
        )
    attr_df['display_value'] = attr_df['value'].apply(lambda f: f"{f:.2f}")
    return attr_df
    

In [6]:
def create_plot(attr_df, mask_index, N):
    ncol = np.ceil(np.sqrt(N))
    return (ggplot(attr_df, aes(x="index", y="value")) + 
        geom_col(aes(fill="index", colour="index"))  +
        geom_text(aes(y="value/2", label="display_value"), size=10) +
#        geom_label(aes(x = len(input_seq) - mask_index, y= 0, label="variable"), size=13, boxstyle="darrow") +
        scale_x_discrete(drop=False) +
        facet_wrap("~variable", ncol=ncol) +
        coord_flip() +
        labs(
                x="target token",
                y="Captum contribution scores",
#                title="Exploring the contribution of each token to the prediction."
            ) +
        theme(legend_position="none")
    )

In [7]:
def run_and_show(sentence, mask_tokens, reference=[tokenizer.unk_token_id], layer=None):
    if isinstance(mask_tokens, int):
        N = mask_tokens
        input_seq, mask_tokens, mask_index = prepare_data_for_sentence(sentence, N)
    else:
        N = len(mask_tokens)
        if isinstance(mask_tokens[0], str):
            mask_tokens = tokenizer.convert_tokens_to_ids(mask_tokens)
        input_seq, _, mask_index = prepare_data_for_sentence(sentence, N)

    attributions = []
    for ref in reference:
        attributions.append(run_attribution_model(input_seq, ref, mask_tokens, mask_index, layer))
    attributions = torch.stack(attributions).mean(axis=0)
    df = build_dataframe(attributions, input_seq, mask_tokens)
    plot = create_plot(df, mask_index, N)
    display(plot)
    return df, mask_index

In [8]:
def run_no_show(sentence, mask_tokens, reference=[tokenizer.unk_token_id], layer=None):
    if isinstance(mask_tokens, int):
        N = mask_tokens
        input_seq, mask_tokens, mask_index = prepare_data_for_sentence(sentence, N)
    else:
        N = len(mask_tokens)
        if isinstance(mask_tokens[0], str):
            mask_tokens = tokenizer.convert_tokens_to_ids(mask_tokens)
        input_seq, _, mask_index = prepare_data_for_sentence(sentence, N)

    attributions = []
    for ref in reference:
        attributions.append(run_attribution_model(input_seq, ref, mask_tokens, mask_index, layer))
    attributions = torch.stack(attributions).mean(axis=0)
    df = build_dataframe(attributions, input_seq, mask_tokens)
    return df, mask_index

In [170]:
from IPython.display import HTML
import re

def html_for_sentences(df, mask_index, groupings):
    if len(groupings) > 1:
        output = "<div>"
        group = groupings.pop(0)
        for g, gdf in df.groupby(group):
            output += f"<div>{group}: {g}</div><div style='padding-left:2em;'>"
            output += html_for_sentences(gdf, mask_index, groupings)
            output += "</div>"
        output += "</div>"
    else:
        output = "<table>"
        for label, lvl in df.groupby(groupings[0]):
            output += f"<tr><th>{label}</th>"
            for v, grp in lvl.groupby('variable'):
                for i, row in grp.iterrows():
                    pct = int(row['rel_value'] * 100)
                    g = 10 + pct
                    b = 150 + pct
                    if row['value'] >= 0:
                        r = g
                    else:
                        r = b
                    a = (row['rel_value'] / 2) + 0.3
                    color = ",".join(list(map(str, [r,g,b,a])))
                    size = (0.5 + (row['rel_value'] / 2) + 0.3) * 1.5

                    if i % len(grp) == mask_index:
                        output += f"<td><span style='font-size: {size}em; text-decoration:underline;  color: rgba({color}'>{row['variable']}</span></td>"
                    else:
                        word = row['index'].strip()
                        word = re.sub(r'<', '&lt;', re.sub(r'>', '&gt;', word))
                        output += f"<td><span style='font-size: {size}em; color: rgba({color});'>{word}</span></td>"
            output += "</tr><tr><th/>"
            for v, grp in lvl.groupby('variable'):
                for i, row in grp.iterrows():
                    output += f"<td style='font-size: 0.7em'>{round(row['value'],3)}</td>"
            output += "</tr>"
        output += "</table>"
    return output


In [156]:
from numpy.random import default_rng
rng = default_rng()
sample = rng.integers(model.config.vocab_size, size=20)
print(sample)

[18593 25119 18503  4291  7840 27513  8231  3386 23858 14233 20925  2194
 10027 14065  2072 16659 16280 19265 27845 12211]


In [157]:
print([tokenizer.decode([t]) for t in sample])

['artery', 'Telecom', '##bey', 'sector', 'fourteen', 'coordinating', '##EC', 'era', 'balancing', 'Waterloo', 'terminology', 'provide', 'offense', 'Audio', 'entire', '##claim', 'auditorium', 'Oracle', 'Evaluation', 'exam']


In [161]:
%%time
full_df = None
for sample in [[tokenizer.unk_token_id], [tokenizer.pad_token_id], sample]:
    sample_df = None
    for i in range(6):
        layer = getattr(model.distilbert.transformer.layer, f"{i}")
        df, mask_index = run_no_show(f"The cat from the neighbours chases a {tokenizer.mask_token}.", ["mouse"], reference=sample, layer=layer)
        df['layer'] = f"layer {i+1}"
        df['rel_value'] = (df['value'].abs() / df['value'].abs().max()).round(2)

        if sample_df is None:
            sample_df = df
        else:
            sample_df = pd.concat([sample_df, df])
    sample_df['sample'] = "; ".join([tokenizer.decode([t]) for t in sample])
    if full_df is None:
        full_df = sample_df
    else:
        full_df = pd.concat([full_df, sample_df])
        

CPU times: user 5min 55s, sys: 14.6 s, total: 6min 9s
Wall time: 1min 32s


In [172]:
html = html_for_sentences(full_df, mask_index, ["layer", "sample"])
display(HTML(html))

0,1,2,3,4,5,6,7,8,9,10,11,12
[PAD],[CLS],The,cat,from,the,neighbours,chase,##s,a,mouse,.,[SEP]
,0.215,0.046,0.485,0.063,-0.006,0.039,0.142,0.109,0.208,0.221,0.283,0.712
[UNK],[CLS],The,cat,from,the,neighbours,chase,##s,a,mouse,.,[SEP]
,0.169,0.031,0.269,0.031,-0.002,0.062,0.169,0.083,0.21,0.401,0.16,0.792
artery; Telecom; ##bey; sector; fourteen; coordinating; ##EC; era; balancing; Waterloo; terminology; provide; offense; Audio; entire; ##claim; auditorium; Oracle; Evaluation; exam,[CLS],The,cat,from,the,neighbours,chase,##s,a,mouse,.,[SEP]
,0.276,0.115,0.315,0.031,-0.028,0.075,0.352,0.117,0.187,0.264,0.15,0.599

0,1,2,3,4,5,6,7,8,9,10,11,12
[PAD],[CLS],The,cat,from,the,neighbours,chase,##s,a,mouse,.,[SEP]
,0.108,0.069,0.268,0.071,-0.014,0.001,0.105,0.041,0.147,0.411,0.179,0.819
[UNK],[CLS],The,cat,from,the,neighbours,chase,##s,a,mouse,.,[SEP]
,0.157,-0.008,0.235,0.023,-0.004,0.01,0.168,0.066,0.209,0.48,0.166,0.765
artery; Telecom; ##bey; sector; fourteen; coordinating; ##EC; era; balancing; Waterloo; terminology; provide; offense; Audio; entire; ##claim; auditorium; Oracle; Evaluation; exam,[CLS],The,cat,from,the,neighbours,chase,##s,a,mouse,.,[SEP]
,0.117,0.056,0.267,0.015,-0.029,0.002,0.176,0.047,0.19,0.366,0.143,0.737

0,1,2,3,4,5,6,7,8,9,10,11,12
[PAD],[CLS],The,cat,from,the,neighbours,chase,##s,a,mouse,.,[SEP]
,0.039,0.026,0.285,0.079,0.052,0.024,0.401,0.075,0.279,0.541,0.037,0.607
[UNK],[CLS],The,cat,from,the,neighbours,chase,##s,a,mouse,.,[SEP]
,0.086,0.089,0.334,0.063,0.056,0.007,0.378,0.066,0.311,0.668,0.085,0.411
artery; Telecom; ##bey; sector; fourteen; coordinating; ##EC; era; balancing; Waterloo; terminology; provide; offense; Audio; entire; ##claim; auditorium; Oracle; Evaluation; exam,[CLS],The,cat,from,the,neighbours,chase,##s,a,mouse,.,[SEP]
,0.044,0.031,0.343,0.028,0.018,-0.019,0.341,0.034,0.238,0.406,0.016,0.655

0,1,2,3,4,5,6,7,8,9,10,11,12
[PAD],[CLS],The,cat,from,the,neighbours,chase,##s,a,mouse,.,[SEP]
,0.086,0.063,0.437,-0.006,0.034,0.032,0.266,0.077,0.091,0.785,0.012,0.305
[UNK],[CLS],The,cat,from,the,neighbours,chase,##s,a,mouse,.,[SEP]
,0.09,0.049,0.439,-0.029,0.008,-0.001,0.239,0.058,0.091,0.842,0.023,0.133
artery; Telecom; ##bey; sector; fourteen; coordinating; ##EC; era; balancing; Waterloo; terminology; provide; offense; Audio; entire; ##claim; auditorium; Oracle; Evaluation; exam,[CLS],The,cat,from,the,neighbours,chase,##s,a,mouse,.,[SEP]
,0.064,0.043,0.573,-0.048,0.013,0.023,0.251,0.066,0.109,0.662,0.002,0.237

0,1,2,3,4,5,6,7,8,9,10,11,12
[PAD],[CLS],The,cat,from,the,neighbours,chase,##s,a,mouse,.,[SEP]
,0.094,0.073,0.498,0.024,0.013,0.009,0.081,0.035,0.054,0.835,0.002,-0.169
[UNK],[CLS],The,cat,from,the,neighbours,chase,##s,a,mouse,.,[SEP]
,0.074,0.055,0.403,0.016,0.013,0.007,0.084,0.042,0.059,0.895,-0.0,-0.127
artery; Telecom; ##bey; sector; fourteen; coordinating; ##EC; era; balancing; Waterloo; terminology; provide; offense; Audio; entire; ##claim; auditorium; Oracle; Evaluation; exam,[CLS],The,cat,from,the,neighbours,chase,##s,a,mouse,.,[SEP]
,0.069,0.054,0.486,0.013,0.003,-0.004,0.09,0.042,0.061,0.828,0.003,-0.136

0,1,2,3,4,5,6,7,8,9,10,11,12
[PAD],[CLS],The,cat,from,the,neighbours,chase,##s,a,mouse,.,[SEP]
,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
[UNK],[CLS],The,cat,from,the,neighbours,chase,##s,a,mouse,.,[SEP]
,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
artery; Telecom; ##bey; sector; fourteen; coordinating; ##EC; era; balancing; Waterloo; terminology; provide; offense; Audio; entire; ##claim; auditorium; Oracle; Evaluation; exam,[CLS],The,cat,from,the,neighbours,chase,##s,a,mouse,.,[SEP]
,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0


In [171]:
html = html_for_sentences(full_df, mask_index, ["sample", "layer"])
display(HTML(html))

0,1,2,3,4,5,6,7,8,9,10,11,12
layer 1,[CLS],The,cat,from,the,neighbours,chase,##s,a,mouse,.,[SEP]
,0.215,0.046,0.485,0.063,-0.006,0.039,0.142,0.109,0.208,0.221,0.283,0.712
layer 2,[CLS],The,cat,from,the,neighbours,chase,##s,a,mouse,.,[SEP]
,0.108,0.069,0.268,0.071,-0.014,0.001,0.105,0.041,0.147,0.411,0.179,0.819
layer 3,[CLS],The,cat,from,the,neighbours,chase,##s,a,mouse,.,[SEP]
,0.039,0.026,0.285,0.079,0.052,0.024,0.401,0.075,0.279,0.541,0.037,0.607
layer 4,[CLS],The,cat,from,the,neighbours,chase,##s,a,mouse,.,[SEP]
,0.086,0.063,0.437,-0.006,0.034,0.032,0.266,0.077,0.091,0.785,0.012,0.305
layer 5,[CLS],The,cat,from,the,neighbours,chase,##s,a,mouse,.,[SEP]
,0.094,0.073,0.498,0.024,0.013,0.009,0.081,0.035,0.054,0.835,0.002,-0.169

0,1,2,3,4,5,6,7,8,9,10,11,12
layer 1,[CLS],The,cat,from,the,neighbours,chase,##s,a,mouse,.,[SEP]
,0.169,0.031,0.269,0.031,-0.002,0.062,0.169,0.083,0.21,0.401,0.16,0.792
layer 2,[CLS],The,cat,from,the,neighbours,chase,##s,a,mouse,.,[SEP]
,0.157,-0.008,0.235,0.023,-0.004,0.01,0.168,0.066,0.209,0.48,0.166,0.765
layer 3,[CLS],The,cat,from,the,neighbours,chase,##s,a,mouse,.,[SEP]
,0.086,0.089,0.334,0.063,0.056,0.007,0.378,0.066,0.311,0.668,0.085,0.411
layer 4,[CLS],The,cat,from,the,neighbours,chase,##s,a,mouse,.,[SEP]
,0.09,0.049,0.439,-0.029,0.008,-0.001,0.239,0.058,0.091,0.842,0.023,0.133
layer 5,[CLS],The,cat,from,the,neighbours,chase,##s,a,mouse,.,[SEP]
,0.074,0.055,0.403,0.016,0.013,0.007,0.084,0.042,0.059,0.895,-0.0,-0.127

0,1,2,3,4,5,6,7,8,9,10,11,12
layer 1,[CLS],The,cat,from,the,neighbours,chase,##s,a,mouse,.,[SEP]
,0.276,0.115,0.315,0.031,-0.028,0.075,0.352,0.117,0.187,0.264,0.15,0.599
layer 2,[CLS],The,cat,from,the,neighbours,chase,##s,a,mouse,.,[SEP]
,0.117,0.056,0.267,0.015,-0.029,0.002,0.176,0.047,0.19,0.366,0.143,0.737
layer 3,[CLS],The,cat,from,the,neighbours,chase,##s,a,mouse,.,[SEP]
,0.044,0.031,0.343,0.028,0.018,-0.019,0.341,0.034,0.238,0.406,0.016,0.655
layer 4,[CLS],The,cat,from,the,neighbours,chase,##s,a,mouse,.,[SEP]
,0.064,0.043,0.573,-0.048,0.013,0.023,0.251,0.066,0.109,0.662,0.002,0.237
layer 5,[CLS],The,cat,from,the,neighbours,chase,##s,a,mouse,.,[SEP]
,0.069,0.054,0.486,0.013,0.003,-0.004,0.09,0.042,0.061,0.828,0.003,-0.136
