In [1]:
from captum.attr import IntegratedGradients, LayerIntegratedGradients
import torch
from transformers import GPT2Tokenizer, GPT2Model, GPT2LMHeadModel
from tqdm import tqdm

tokenizer = GPT2Tokenizer.from_pretrained('gpt2-xl')
model = GPT2LMHeadModel.from_pretrained('gpt2-xl', pad_token_id=tokenizer.eos_token_id)
start_text = "Megan Rapinoe plays the sport of"
start_text = "The Big Bang Theory premieres on"

encoded_input = tokenizer(start_text, return_tensors='pt')
start_length = encoded_input.input_ids.shape[-1]

generated_text_length = 1
max_length = start_length + generated_text_length

generate = model.generate(**encoded_input, max_length=max_length)

encoded_ids = generate

attributions_list = list()

def untuple(x):
    if isinstance(x, tuple):
        return x[0]
    return x
    
def model_forward_wrapper(*args, **kwargs):
    output = model(args[0].long(), **kwargs, return_dict=False)[0]
    return output

#IG = LayerIntegratedGradients(model_forward_wrapper, model.transformer.wte)
IG = LayerIntegratedGradients(model_forward_wrapper, model.transformer.h[15])

def sum_embedding_attributions(attributions):
    return torch.sum(untuple(attributions) ** 2, dim=-1).sqrt().squeeze()

for i in tqdm(range(0, generated_text_length)):
    length = start_length + i
    baseline = torch.tensor([[tokenizer.unk_token_id] * length]).long()
    target = (length - 1, encoded_ids[0, length-1].item())
    input_ids = encoded_ids[:, :length]
    attributions = IG.attribute(
        inputs = input_ids,
        baselines = baseline,
        target = target
    )
    sum_attr = sum_embedding_attributions(attributions)
    attributions_list.append(sum_attr)

for attribution in attributions_list:
    print(attribution)

  0%|                                                                         | 0/1 [00:01<?, ?it/s]


AssertionError: Input and baseline must have the same dimensions, baseline has 3 features whereas input has 2.

In [None]:
from matplotlib import pyplot as plt

plt.bar([tokenizer.decode(tok) for tok in input_ids[0]], attribution)
for tok, att in zip(input_ids[0], attribution):
    print(tokenizer.decode(tok), att.item())

In [None]:
layer_attributions = {}

for kind in ['', 'mlp', 'attn']:
    layer_attributions[kind] = []
    for layernum in tqdm(range(0, 48)):
        IG = LayerIntegratedGradients(model_forward_wrapper,
                    model.transformer.h[layernum] if not kind
                    else model.transformer.h[layernum].mlp if kind == 'mlp'
                    else model.transformer.h[layernum].mlp.c_fc if kind == 'mlp_fc'
                    else model.transformer.h[layernum].attn)

        length = start_length
        baseline = torch.tensor([[tokenizer.unk_token_id] * length]).long()
        target = (length - 1, encoded_ids[0, length-1].item())
        input_ids = encoded_ids[:, :length]
        attributions = IG.attribute(
            inputs = input_ids,
            baselines = baseline,
            target = target
        )
        sum_attr = sum_embedding_attributions(attributions)
        layer_attributions[kind].append(sum_attr)
    

In [None]:
def plot_trace_heatmap(result, savepdf=None, title=None, xlabel=None):
    differences = result["scores"]
    low_score = result["low_score"]
    kind = (
        None
        if (not result["kind"] or result["kind"] == "None")
        else str(result["kind"])
    )
    labels = list(result["input_tokens"])
    for i in range(*result["subject_range"]):
        labels[i] = labels[i] + "*"

    with plt.rc_context(rc={"font.family": "Times New Roman"}):
        fig, ax = plt.subplots(figsize=(3.5, 2), dpi=200)
        h = ax.pcolor(
            differences,
            cmap={None: "Purples", "None": "Purples", "mlp": "Greens", "mlp_fc": "Blues", "attn": "Reds"}[
                kind
            ],
            vmin=low_score,
        )
        ax.invert_yaxis()
        ax.set_yticks([0.5 + i for i in range(len(differences))])
        ax.set_xticks([0.5 + i for i in range(0, differences.shape[1] - 6, 5)])
        ax.set_xticklabels(list(range(0, differences.shape[1] - 6, 5)))
        ax.set_yticklabels(labels)
        if not kind:
            ax.set_title("Layer integrated gradient for hidden state")
            ax.set_xlabel("single layer within GPT-2-XL")
        else:
            kindname = "MLP" if kind == "mlp" else "MLP FC" if kind == "mlp_fc" else "Attn"
            ax.set_title(f"Layer integrated gradient on {kindname} output")
            ax.set_xlabel("single layer within GPT-2-XL")
        cb = plt.colorbar(h)
        if title is not None:
            ax.set_title(title)
        if xlabel is not None:
            ax.set_xlabel(xlabel)
        cb.ax.set_title(f"||IG||", y=-0.16, fontsize=10)
        if savepdf:
            os.makedirs(os.path.dirname(savepdf), exist_ok=True)
            plt.savefig(savepdf, bbox_inches="tight")
            plt.close()
        else:
            plt.show()

for kind in ['', 'mlp', 'attn']: # ['mlp_fc']: #, 
    plot_trace_heatmap(dict(
        scores=torch.stack(layer_attributions[kind]).t(),
        low_score=0,
        answer='soccer',
        kind=kind,
        input_tokens=[tokenizer.decode(tok) for tok in input_ids[0]],
        subject_range=[0,0],
    ))

In [None]:
model.transformer.h[layernum].mlp.c_fc