# Integrated Gradients Explainability

## Dataset Setup

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [2]:
import datasets

liar = datasets.load_dataset("liar")
full_liar = datasets.concatenate_datasets(
    [liar["train"], liar["test"], liar["validation"]]
)
full_liar

  table = cls._concat_blocks(blocks, axis=0)


Dataset({
    features: ['id', 'label', 'statement', 'subject', 'speaker', 'job_title', 'state_info', 'party_affiliation', 'barely_true_counts', 'false_counts', 'half_true_counts', 'mostly_true_counts', 'pants_on_fire_counts', 'context'],
    num_rows: 12836
})

## Setup

In [3]:
import captum
import captum.attr as attr
from captum.attr import visualization as viz

In [4]:
falcon = "tiiuae/falcon-7b-instruct"
llama = "meta-llama/Llama-2-7b-chat-hf"
mistral = "mistralai/Mistral-7B-Instruct-v0.2"
orca = "microsoft/Orca-2-7b"

In [5]:
model_name = mistral

In [6]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch

config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name, quantization_config=config, # device_map="auto"
)

`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

You are calling `save_pretrained` to a 4-bit converted model, but your `bitsandbytes` version doesn't support it. If you want to save 4-bit models, make sure to have `bitsandbytes>=0.41.3` installed.


## Utils

In [7]:
n_examples = 0

In [8]:
from typing import Dict

LABEL_MAP = {
    0: "E",  # 0 : False
    1: "C",  # 1 : Half True
    2: "B",  # 2 : Mostly True
    3: "A",  # 3 : True
    4: "D",  # 4 : Barely True
    5: "F",  # 5 : Pants on Fire
}


def was_correct(decoded: str, entry: Dict[str, int]) -> bool:
    return LABEL_MAP[entry["label"]] in decoded

In [9]:
import random
from typing import Dict

entries = random.choices(list(range(len(full_liar))), k=n_examples)


def to_zero_shot_prompt(entry: Dict[str, str]) -> str:
    speaker = entry["speaker"].replace("-", " ").title()
    try:
        statement = entry['statement'][entry["statement"].index("Says ")+5:]
    except:
        statement = entry['statement']

    prompt = f"""Please select the option that most closely describes the following claim by {speaker}:\n{statement}\n\nA) True\nB) Mostly True\nC) Half True\nD) Barely True\nE) False\nF) Pants on Fire (absurd lie)\n\nChoice: ("""
    return prompt


def to_n_shot_prompt(n: int, entry: Dict[str, str]) -> str:
    examples = ""
    for i in range(n):
        examples += (
            to_zero_shot_prompt(full_liar[entries[i]])
            + LABEL_MAP[full_liar[entries[i]]["label"]]
            + "\n\n"
        )
    prompt = to_zero_shot_prompt(entry)
    return examples + prompt

In [10]:
vocab = tokenizer.vocab
label_tokens = {
    "A": [],
    "B": [],
    "C": [],
    "D": [],
    "E": [],
    "F": [],
}
for char, idx in vocab.items():
    if char in label_tokens:
        label_tokens[char].append(idx)
label_tokens

{'A': [28741],
 'B': [28760],
 'C': [28743],
 'D': [28757],
 'E': [28749],
 'F': [28765]}

In [11]:
label_tokens = {char: idxs[0] for char, idxs in label_tokens.items()}
label_tokens

{'A': 28741, 'B': 28760, 'C': 28743, 'D': 28757, 'E': 28749, 'F': 28765}

In [12]:
prompt = to_n_shot_prompt(n_examples, full_liar[random.randint(0, len(full_liar))])

In [13]:
print(prompt)

Please select the option that most closely describes the following claim by Marco Rubio:
If people work and make more money, they lose more in benefits than they would earn in salary.

A) True
B) Mostly True
C) Half True
D) Barely True
E) False
F) Pants on Fire (absurd lie)

Choice: (


In [14]:
# get tokens for later
tokens = tokenizer(prompt, return_tensors="pt").input_ids
print(tokens.shape)

torch.Size([1, 81])


In [15]:
with torch.no_grad():
    label = torch.argmax(model(tokens.cuda()).logits[:, -1])
label

tensor(28760, device='cuda:0')

In [16]:
label = tokenizer.decode(label)
label

'B'

## Integrated Gradients

In [17]:
def softmax_results(inputs: torch.Tensor):
    result = model(inputs.cuda()).logits
    return torch.nn.functional.softmax(result[:, -1], dim=-1).cpu()
def softmax_results_embeds(embds: torch.Tensor):
    result = model(inputs_embeds=embds.cuda()).logits
    return torch.nn.functional.softmax(result[:, -1], dim=-1).cpu()

In [18]:
model

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralSdpaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )

In [19]:
# replace the normal pytorch embeddings (which only take ints) to interpretable embeddings
# (which are compatible with the float inputs that integratedgradients gives)
if hasattr(model, "model"):
    interpretable_emb = attr.configure_interpretable_embedding_layer(
        model.model, "embed_tokens"
    )
elif hasattr(model, "transformer"):
    interpretable_emb = attr.configure_interpretable_embedding_layer(
        model.transformer, "word_embeddings"
    )
else:
    print("What happened")
model



MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): InterpretableEmbeddingBase(
      (embedding): Embedding(32000, 4096)
    )
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralSdpaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post

In [20]:
input_embs = interpretable_emb.indices_to_embeddings(tokens).cpu()
print(input_embs.device)
input_embs.shape

cpu


torch.Size([1, 81, 4096])

In [21]:
baselines = torch.zeros_like(input_embs).cpu()
print(baselines.device)
baselines.shape

cpu


torch.Size([1, 81, 4096])

In [22]:
ig = attr.IntegratedGradients(softmax_results_embeds)

In [23]:
attributions = ig.attribute(
    input_embs,
    baselines=baselines,
    target=label_tokens[label],
    n_steps=512,
    internal_batch_size=2,
    return_convergence_delta=True,
)

## Visualize Attributions

In [24]:
attributions[0].shape

torch.Size([1, 81, 4096])

In [25]:
def summarize_attributions(attributions):
    with torch.no_grad():
        attributions = attributions.sum(dim=-1).squeeze(0)
        print((attributions / torch.norm(attributions)).sum())
        return attributions

In [26]:
summarized_attributions = summarize_attributions(attributions[0])

tensor(0.8090, dtype=torch.float64)


In [27]:
all_tokens = tokens.squeeze(0)
print(all_tokens.shape)
all_tokens = list(map(tokenizer.decode, all_tokens))
print(all_tokens)

torch.Size([81])
['<s>', 'Please', 'select', 'the', 'option', 'that', 'most', 'closely', 'describes', 'the', 'following', 'claim', 'by', 'Marco', 'Rub', 'io', ':', '\n', 'If', 'people', 'work', 'and', 'make', 'more', 'money', ',', 'they', 'lose', 'more', 'in', 'benefits', 'than', 'they', 'would', 'earn', 'in', 'salary', '.', '\n', '\n', 'A', ')', 'True', '\n', 'B', ')', 'Most', 'ly', 'True', '\n', 'C', ')', 'Half', 'True', '\n', 'D', ')', 'B', 'arely', 'True', '\n', 'E', ')', 'False', '\n', 'F', ')', 'P', 'ants', 'on', 'Fire', '(', 'abs', 'urd', 'lie', ')', '\n', '\n', 'Choice', ':', '(']


In [28]:
# remove the interpretable embedding layer so we can get regular predictions
if hasattr(model, "model"):
    attr.remove_interpretable_embedding_layer(model.model, interpretable_emb)
else:
    attr.remove_interpretable_embedding_layer(model.transformer, interpretable_emb)
with torch.no_grad():
    predictions = softmax_results(tokens)
predictions.shape

torch.Size([1, 32000])

In [29]:
MARGIN_OF_ERROR = 0.1  # off by no more than 10 percent

print("for", label)
if (
    torch.abs(
        (summarized_attributions.sum() - predictions[0, label_tokens[label]])
    )
    >= MARGIN_OF_ERROR
):
    print("we are off!!")
    print("we should be getting somewhere near", predictions[0, label_tokens[label]])
    print("instead, we get", summarized_attributions[label].sum())
else:
    print("we are pretty close!")
    print(
        "got",
        summarized_attributions.sum(),
        "instead of",
        predictions[0, label_tokens[label]],
    )

for B
we are pretty close!
got tensor(0.6054, dtype=torch.float64) instead of tensor(0.5734)


In [30]:
from dataclasses import dataclass
from typing import Any


@dataclass
class CustomDataRecord:
    word_attributions: Any
    pred_prob: torch.Tensor
    pred_class: str
    attr_class: str
    attr_prob: torch.Tensor
    attr_score: Any
    raw_input_ids: Any | list[str]
    convergence_delta: Any

# SCALE = {char:10 for char in summarized_attributions}
SCALE=2/summarized_attributions.max()

attr_vis = CustomDataRecord(
        summarized_attributions *SCALE,  # word attributions
        predictions[0].max(),  # predicted probability
        tokenizer.decode(torch.argmax(predictions[0])),  # predicted class
        label,  # attr class
        predictions[0, label_tokens[label]],  # attr probability
        summarized_attributions.sum(),  # attr score
        all_tokens,  # raw input ids
        attributions[1],  # convergence delta
    )

In [31]:
from IPython.display import HTML

def _get_color(attr):
    # clip values to prevent CSS errors (Values should be from [-1,1])
    attr = max(-2, min(2, attr))
    if attr > 0:
        hue = 120
        sat = 75
        lig = 100 - int(50 * attr)
    else:
        hue = 0
        sat = 75
        lig = 100 - int(-40 * attr)
    return "hsl({}, {}%, {}%)".format(hue, sat, lig)

def visualize_text(
    datarecords: list[CustomDataRecord], legend: bool = True
) -> "HTML":  # In quotes because this type doesn't exist in standalone mode
    dom = ["<table width: 100%>"]
    rows = [
        "<tr><th>Predicted Label</th>"
        "<th>Attribution Label</th>"
        "<th>Convergence Delta</th>"
        "<th>Attribution Score</th>"
        "<th>Word Importance</th>"
    ]
    for datarecord in datarecords:
        rows.append(
            "".join(
                [
                    "<tr>",
                    viz.format_classname(
                        "{0} ({1:.2f})".format(
                            datarecord.pred_class, datarecord.pred_prob
                        )
                    ),
                    viz.format_classname(
                        "{0} ({1:.2f})".format(
                            datarecord.attr_class, datarecord.attr_prob
                        )
                    ),
                    viz.format_classname(
                        "{0:.2f}".format(datarecord.convergence_delta.item())
                    ),
                    viz.format_classname("{0:.2f}".format(datarecord.attr_score)),
                    viz.format_word_importances(
                        datarecord.raw_input_ids, datarecord.word_attributions
                    ),
                    "<tr>",
                ]
            )
        )

    if legend:
        dom.append(
            '<div style="border-top: 1px solid; margin-top: 5px; \
            padding-top: 5px; display: inline-block">'
        )
        dom.append("<b>Legend: </b>")

        for value, label in zip([-1, 0, 1], ["Negative", "Neutral", "Positive"]):
            dom.append(
                '<span style="display: inline-block; width: 10px; height: 10px; \
                border: 1px solid; background-color: \
                {value}"></span> {label}  '.format(
                    value=_get_color(value), label=label
                )
            )
        dom.append("</div>")

    dom.append("".join(rows))
    dom.append("</table>")
    html = HTML("".join(dom))
    display(html)

    return html

In [32]:
html = visualize_text([attr_vis])
print("Results")

Predicted Label,Attribution Label,Convergence Delta,Attribution Score,Word Importance
B (0.57),B (0.57),0.03,0.61,"#s Please select the option that most closely describes the following claim by Marco Rub io : If people work and make more money , they lose more in benefits than they would earn in salary . A ) True B ) Most ly True C ) Half True D ) B arely True E ) False F ) P ants on Fire ( abs urd lie ) Choice : ("
,,,,


Results


In [33]:
#print(tokenizer.batch_decode(model.generate(tokens.cuda(), max_length=100, pad_token_id=tokenizer.eos_token_id))[0])

In [34]:
# explaining_prompt = "Please explain why the claim by Dennis Kucinich that \"We are giving almost $2 billion of taxpayer money to the junk food and fast food industries every year to make the (childhood obesity) epidemic worse.\" is mostly true.\n"
# print(tokenizer.batch_decode(model.generate(tokenizer(explaining_prompt, return_tensors='pt').input_ids.cuda(), max_length=200, pad_token_id=tokenizer.eos_token_id))[0])

In [35]:
open(f"{model_name[model_name.index('/')+1:]}_0.html", "w").write(html.data)

13927