# Integrated Gradients Explainability

## Dataset Setup

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [2]:
import datasets

liar = datasets.load_dataset("liar")
full_liar = datasets.concatenate_datasets(
    [liar["train"], liar["test"], liar["validation"]]
)
full_liar

  table = cls._concat_blocks(blocks, axis=0)


Dataset({
    features: ['id', 'label', 'statement', 'subject', 'speaker', 'job_title', 'state_info', 'party_affiliation', 'barely_true_counts', 'false_counts', 'half_true_counts', 'mostly_true_counts', 'pants_on_fire_counts', 'context'],
    num_rows: 12836
})

## Setup

In [3]:
import captum
import captum.attr as attr
from captum.attr import visualization as viz

In [4]:
falcon = "tiiuae/falcon-7b-instruct"
llama = "meta-llama/Llama-2-7b-chat-hf"
mistral = "mistralai/Mistral-7B-Instruct-v0.2"
orca = "microsoft/Orca-2-7b"

In [5]:
model_name = mistral

In [6]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch

config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name, quantization_config=config, # device_map="auto"
)

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

## Utils

In [7]:
n_examples = 0

In [8]:
from typing import Dict

LABEL_MAP = {
    0: "E",  # 0 : False
    1: "C",  # 1 : Half True
    2: "B",  # 2 : Mostly True
    3: "A",  # 3 : True
    4: "D",  # 4 : Barely True
    5: "F",  # 5 : Pants on Fire
}


def was_correct(decoded: str, entry: Dict[str, int]) -> bool:
    return LABEL_MAP[entry["label"]] in decoded

In [47]:
import random
from typing import Dict

entries = random.choices(list(range(len(full_liar))), k=n_examples)


def to_zero_shot_prompt(entry: Dict[str, str]) -> str:
    speaker = entry["speaker"].replace("-", " ").title()
    try:
        statement = entry['statement'][entry["statement"].index("Says ")+5:]
    except:
        statement = entry['statement']

    prompt = f"""Please select the option that most closely describes the following claim by {speaker}:\n{statement}\n\nA) True\nB) Mostly True\nC) Half True\nD) Barely True\nE) False\nF) Pants on Fire (absurd lie)\n\nChoice: ("""
    return prompt


def to_n_shot_prompt(n: int, entry: Dict[str, str]) -> str:
    examples = ""
    for i in range(n):
        examples += (
            to_zero_shot_prompt(full_liar[entries[i]])
            + LABEL_MAP[full_liar[entries[i]]["label"]]
            + "\n\n"
        )
    prompt = to_zero_shot_prompt(entry)
    return examples + prompt

In [48]:
vocab = tokenizer.vocab
label_tokens = {
    "A": [],
    "B": [],
    "C": [],
    "D": [],
    "E": [],
    "F": [],
}
for char, idx in vocab.items():
    if char in label_tokens:
        label_tokens[char].append(idx)
label_tokens

{'A': [28741],
 'B': [28760],
 'C': [28743],
 'D': [28757],
 'E': [28749],
 'F': [28765]}

In [49]:
label_tokens = {char: idxs[0] for char, idxs in label_tokens.items()}
label_tokens

{'A': 28741, 'B': 28760, 'C': 28743, 'D': 28757, 'E': 28749, 'F': 28765}

In [146]:
prompt = to_n_shot_prompt(n_examples, full_liar[random.randint(0, len(full_liar))])

In [147]:
print(prompt)

Please select the option that most closely describes the following claim by Rick Perry:
President Barack Obama delivered $2 billion to Brazil to help with offshore drilling projects.

A) True
B) Mostly True
C) Half True
D) Barely True
E) False
F) Pants on Fire (absurd lie)

Choice: (


In [148]:
# get tokens for later
tokens = tokenizer(prompt, return_tensors="pt").input_ids
print(tokens.shape)

torch.Size([1, 80])


In [149]:
tokens

tensor([[    1,  5919,  5339,   272,  3551,   369,  1080, 11640, 13966,   272,
          2296,  3452,   486, 14613, 24150, 28747,    13, 19167,  1129,  3011,
           468, 11764, 11448,   429, 28750,  8737,   298, 13250,   298,  1316,
           395,   805, 27562,  1605,  8317,  7028, 28723,    13,    13, 28741,
         28731,  6110,    13, 28760, 28731,  4822,   346,  6110,    13, 28743,
         28731, 18994,  6110,    13, 28757, 28731,   365,  6672,  6110,    13,
         28749, 28731,  8250,    13, 28765, 28731,   367,  1549,   356,  8643,
           325,  4737, 12725,  4852, 28731,    13,    13, 28456, 28747,   325]])

## Integrated Gradients

In [150]:
def softmax_results(inputs: torch.Tensor):
    result = model(inputs.cuda()).logits
    return torch.nn.functional.softmax(result[:, -1], dim=-1).cpu()
def softmax_results_embeds(embds: torch.Tensor):
    result = model(inputs_embeds=embds.cuda()).logits
    return torch.nn.functional.softmax(result[:, -1], dim=-1).cpu()

In [151]:
model

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )

In [152]:
# replace the normal pytorch embeddings (which only take ints) to interpretable embeddings
# (which are compatible with the float inputs that integratedgradients gives)
if hasattr(model, "model"):
    interpretable_emb = attr.configure_interpretable_embedding_layer(
        model.model, "embed_tokens"
    )
elif hasattr(model, "transformer"):
    interpretable_emb = attr.configure_interpretable_embedding_layer(
        model.transformer, "word_embeddings"
    )
else:
    print("What happened")
model



MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): InterpretableEmbeddingBase(
      (embedding): Embedding(32000, 4096)
    )
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): MistralRMSNorm()
       

In [153]:
input_embs = interpretable_emb.indices_to_embeddings(tokens).cpu()
print(input_embs.device)
input_embs.shape

cpu


torch.Size([1, 80, 4096])

In [154]:
baselines = torch.zeros_like(input_embs).cpu()
print(baselines.device)
baselines.shape

cpu


torch.Size([1, 80, 4096])

In [155]:
ig = attr.IntegratedGradients(softmax_results_embeds)

In [156]:
from tqdm import tqdm

attributions = {}
for char, label_token in tqdm(label_tokens.items()):
    attributions[char] = ig.attribute(
        input_embs,
        baselines=baselines,
        target=label_token,
        n_steps=192,
        internal_batch_size=2,
        return_convergence_delta=True,
    )

100%|█████████████████████████████████████████████| 6/6 [03:19<00:00, 33.24s/it]


## Visualize Attributions

In [157]:
attributions["A"][0].shape

torch.Size([1, 80, 4096])

In [158]:
def summarize_attributions(attributions):
    with torch.no_grad():
        attributions = attributions.sum(dim=-1).squeeze(0)
        print((attributions / torch.norm(attributions)).sum())
        return attributions

In [159]:
summarized_attributions = {
    char: summarize_attributions(attribution)
    for char, (attribution, _) in attributions.items()
}
summarized_attributions["A"].shape

tensor(-0.3954, dtype=torch.float64)
tensor(0.1427, dtype=torch.float64)
tensor(0.0902, dtype=torch.float64)
tensor(0.4795, dtype=torch.float64)
tensor(2.0417, dtype=torch.float64)
tensor(1.0468, dtype=torch.float64)


torch.Size([80])

In [160]:
all_tokens = tokens.squeeze(0)
print(all_tokens.shape)
all_tokens = list(map(tokenizer.decode, all_tokens))
print(all_tokens)

torch.Size([80])
['<s>', 'Please', 'select', 'the', 'option', 'that', 'most', 'closely', 'describes', 'the', 'following', 'claim', 'by', 'Rick', 'Perry', ':', '\n', 'Pres', 'ident', 'Bar', 'ack', 'Obama', 'delivered', '$', '2', 'billion', 'to', 'Brazil', 'to', 'help', 'with', 'off', 'shore', 'dr', 'illing', 'projects', '.', '\n', '\n', 'A', ')', 'True', '\n', 'B', ')', 'Most', 'ly', 'True', '\n', 'C', ')', 'Half', 'True', '\n', 'D', ')', 'B', 'arely', 'True', '\n', 'E', ')', 'False', '\n', 'F', ')', 'P', 'ants', 'on', 'Fire', '(', 'abs', 'urd', 'lie', ')', '\n', '\n', 'Choice', ':', '(']


In [161]:
# remove the interpretable embedding layer so we can get regular predictions
if hasattr(model, "model"):
    attr.remove_interpretable_embedding_layer(model.model, interpretable_emb)
else:
    attr.remove_interpretable_embedding_layer(model.transformer, interpretable_emb)
with torch.no_grad():
    predictions = softmax_results(tokens)
predictions.shape

torch.Size([1, 32000])

In [162]:
MARGIN_OF_ERROR = 0.1  # off by no more than 10 percent
for char in summarized_attributions.keys():
    print("for", char)
    if (
        torch.abs(
            (summarized_attributions[char].sum() - predictions[0, label_tokens[char]])
        )
        >= MARGIN_OF_ERROR
    ):
        print("we are off!!")
        print("we should be getting somewhere near", predictions[0, label_tokens[char]])
        print("instead, we get", summarized_attributions[char].sum())
    else:
        print("we are pretty close!")
        print(
            "got",
            summarized_attributions[char].sum(),
            "instead of",
            predictions[0, label_tokens[char]],
        )

for A
we are pretty close!
got tensor(-0.0375, dtype=torch.float64) instead of tensor(0.0024)
for B
we are pretty close!
got tensor(0.0431, dtype=torch.float64) instead of tensor(0.0523)
for C
we are pretty close!
got tensor(0.0139, dtype=torch.float64) instead of tensor(0.0340)
for D
we are pretty close!
got tensor(0.0783, dtype=torch.float64) instead of tensor(0.0962)
for E
we are pretty close!
got tensor(0.5781, dtype=torch.float64) instead of tensor(0.5626)
for F
we are pretty close!
got tensor(0.2649, dtype=torch.float64) instead of tensor(0.2496)


In [166]:
from dataclasses import dataclass
from typing import Any


@dataclass
class CustomDataRecord:
    word_attributions: Any
    pred_prob: torch.Tensor
    pred_class: str
    attr_class: str
    attr_prob: torch.Tensor
    attr_score: Any
    raw_input_ids: Any | list[str]
    convergence_delta: Any

SCALE = {char:10 for char in summarized_attributions}
# SCALE={char:1/summarized_attributions[char].sum() for char in summarized_attributions}

attr_vis = [
    CustomDataRecord(
        summarized_attributions[char] *SCALE[char],  # word attributions
        predictions[0].max(),  # predicted probability
        tokenizer.decode(torch.argmax(predictions[0])),  # predicted class
        char,  # attr class
        predictions[0, label_tokens[char]],  # attr probability
        summarized_attributions[char].sum(),  # attr score
        all_tokens,  # raw input ids
        attributions[char][1],  # convergence delta
    )
    for char in label_tokens.keys()
]

In [167]:
from IPython.display import HTML
def visualize_text(
    datarecords: list[CustomDataRecord], legend: bool = True
) -> "HTML":  # In quotes because this type doesn't exist in standalone mode
    dom = ["<table width: 100%>"]
    rows = [
        "<tr><th>Predicted Label</th>"
        "<th>Attribution Label</th>"
        "<th>Convergence Delta</th>"
        "<th>Attribution Score</th>"
        "<th>Word Importance</th>"
    ]
    for datarecord in datarecords:
        rows.append(
            "".join(
                [
                    "<tr>",
                    viz.format_classname(
                        "{0} ({1:.2f})".format(
                            datarecord.pred_class, datarecord.pred_prob
                        )
                    ),
                    viz.format_classname(
                        "{0} ({1:.2f})".format(
                            datarecord.attr_class, datarecord.attr_prob
                        )
                    ),
                    viz.format_classname(
                        "{0:.2f}".format(datarecord.convergence_delta.item())
                    ),
                    viz.format_classname("{0:.2f}".format(datarecord.attr_score)),
                    viz.format_word_importances(
                        datarecord.raw_input_ids, datarecord.word_attributions
                    ),
                    "<tr>",
                ]
            )
        )

    if legend:
        dom.append(
            '<div style="border-top: 1px solid; margin-top: 5px; \
            padding-top: 5px; display: inline-block">'
        )
        dom.append("<b>Legend: </b>")

        for value, label in zip([-1, 0, 1], ["Negative", "Neutral", "Positive"]):
            dom.append(
                '<span style="display: inline-block; width: 10px; height: 10px; \
                border: 1px solid; background-color: \
                {value}"></span> {label}  '.format(
                    value=viz._get_color(value), label=label
                )
            )
        dom.append("</div>")

    dom.append("".join(rows))
    dom.append("</table>")
    html = HTML("".join(dom))
    display(html)

    return html

In [168]:
html = visualize_text(attr_vis)
print("Results")

Predicted Label,Attribution Label,Convergence Delta,Attribution Score,Word Importance
E (0.56),A (0.00),-0.04,-0.04,#s Please select the option that most closely describes the following claim by Rick Perry : Pres ident Bar ack Obama delivered $ 2 billion to Brazil to help with off shore dr illing projects . A ) True B ) Most ly True C ) Half True D ) B arely True E ) False F ) P ants on Fire ( abs urd lie ) Choice : (
,,,,
E (0.56),B (0.05),-0.01,0.04,#s Please select the option that most closely describes the following claim by Rick Perry : Pres ident Bar ack Obama delivered $ 2 billion to Brazil to help with off shore dr illing projects . A ) True B ) Most ly True C ) Half True D ) B arely True E ) False F ) P ants on Fire ( abs urd lie ) Choice : (
,,,,
E (0.56),C (0.03),-0.02,0.01,#s Please select the option that most closely describes the following claim by Rick Perry : Pres ident Bar ack Obama delivered $ 2 billion to Brazil to help with off shore dr illing projects . A ) True B ) Most ly True C ) Half True D ) B arely True E ) False F ) P ants on Fire ( abs urd lie ) Choice : (
,,,,
E (0.56),D (0.10),-0.02,0.08,#s Please select the option that most closely describes the following claim by Rick Perry : Pres ident Bar ack Obama delivered $ 2 billion to Brazil to help with off shore dr illing projects . A ) True B ) Most ly True C ) Half True D ) B arely True E ) False F ) P ants on Fire ( abs urd lie ) Choice : (
,,,,
E (0.56),E (0.56),0.02,0.58,#s Please select the option that most closely describes the following claim by Rick Perry : Pres ident Bar ack Obama delivered $ 2 billion to Brazil to help with off shore dr illing projects . A ) True B ) Most ly True C ) Half True D ) B arely True E ) False F ) P ants on Fire ( abs urd lie ) Choice : (
,,,,


Results


In [32]:
open(f"{model_name[model_name.index('/')+1:]}.html", "w").write(html.data)

79724