# Captum

## Dataset Setup

In [1]:
import datasets

liar = datasets.load_dataset("liar")
full_liar = datasets.concatenate_datasets(
    [liar["train"], liar["test"], liar["validation"]]
)
full_liar

  table = cls._concat_blocks(blocks, axis=0)


Dataset({
    features: ['id', 'label', 'statement', 'subject', 'speaker', 'job_title', 'state_info', 'party_affiliation', 'barely_true_counts', 'false_counts', 'half_true_counts', 'mostly_true_counts', 'pants_on_fire_counts', 'context'],
    num_rows: 12836
})

## Setup

In [2]:
import captum
import captum.attr as attr
from captum.attr import visualization as viz

In [3]:
falcon = "tiiuae/falcon-7b-instruct"
llama = "meta-llama/Llama-2-7b-chat-hf"
mistral = "mistralai/Mistral-7B-Instruct-v0.2"
orca = "microsoft/Orca-2-7b"

In [4]:
model_name = orca

In [5]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch

config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name, quantization_config=config, device_map="auto"
)

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

## Utils

In [6]:
n_examples = 0

In [7]:
from typing import Dict

LABEL_MAP = {
    0: "E",  # 0 : False
    1: "C",  # 1 : Half True
    2: "B",  # 2 : Mostly True
    3: "A",  # 3 : True
    4: "D",  # 4 : Barely True
    5: "F",  # 5 : Pants on Fire
}


def was_correct(decoded: str, entry: Dict[str, int]) -> bool:
    return LABEL_MAP[entry["label"]] in decoded

In [8]:
import random
from typing import Dict

random.seed(1770)
entries = random.choices(list(range(len(full_liar))), k=n_examples)


def to_zero_shot_prompt(entry: Dict[str, str]) -> str:
    speaker = entry["speaker"].replace("-", " ").title()
    statement = entry["statement"].lstrip("Says ")

    prompt = f"""Please select the option that most closely describes the following claim by {speaker}:\n{statement}\n\nA) True\nB) Mostly True\nC) Half True\nD) Barely True\nE) False\nF) Pants on Fire (absurd lie)\n\nChoice: ("""
    return prompt


def to_n_shot_prompt(n: int, entry: Dict[str, str]) -> str:
    examples = ""
    for i in range(n):
        examples += (
            to_zero_shot_prompt(full_liar[entries[i]])
            + LABEL_MAP[full_liar[entries[i]]["label"]]
            + "\n\n"
        )
    prompt = to_zero_shot_prompt(entry)
    return examples + prompt

In [9]:
print(
    tokenizer.batch_decode(
        model.generate(
            tokenizer(
                to_zero_shot_prompt(full_liar[0]), return_tensors="pt"
            ).input_ids.cuda(),
            max_new_tokens=200,
        )
    )[0]
)

<s> Please select the option that most closely describes the following claim by Dwayne Bohac:
the Annies List political group supports third-trimester abortions on demand.

A) True
B) Mostly True
C) Half True
D) Barely True
E) False
F) Pants on Fire (absurd lie)

Choice: (A) True</s>


In [10]:
vocab = tokenizer.vocab
label_tokens = {
    "A": [],
    "B": [],
    "C": [],
    "D": [],
    "E": [],
    "F": [],
}
for char, idx in vocab.items():
    if char in label_tokens:
        label_tokens[char].append(idx)
label_tokens

{'A': [29909],
 'B': [29933],
 'C': [29907],
 'D': [29928],
 'E': [29923],
 'F': [29943]}

In [11]:
label_tokens = {char: idxs[0] for char, idxs in label_tokens.items()}
label_tokens

{'A': 29909, 'B': 29933, 'C': 29907, 'D': 29928, 'E': 29923, 'F': 29943}

In [12]:
prompt = to_n_shot_prompt(n_examples, full_liar[0])

In [13]:
# get tokens for later
tokens = tokenizer(prompt, return_tensors="pt").input_ids
print(tokens.shape)

torch.Size([1, 80])


In [14]:
tokens

tensor([[    1,  3529,  1831,   278,  2984,   393,  1556, 16467, 16612,   278,
          1494,  5995,   491,   360,  1582,   484, 17966,   562, 29901,    13,
          1552,  8081,   583,  2391,  8604,  2318, 11286,  4654, 29899, 15450,
          4156, 27450,  1080,   373,  9667, 29889,    13,    13, 29909, 29897,
          5852,    13, 29933, 29897,  7849,   368,  5852,    13, 29907, 29897,
         28144,  5852,    13, 29928, 29897,   350,   598,   368,  5852,    13,
         29923, 29897,  7700,    13, 29943, 29897,   349,  1934,   373,  6438,
           313,  6897, 18245,  3804, 29897,    13,    13, 29620, 29901,   313]])

## Integrated Gradients

In [15]:
def softmax_results(inputs: torch.Tensor):
    result = model(inputs.cuda()).logits
    return torch.nn.functional.softmax(result[:, -1], dim=-1).cpu()
def softmax_results_embeds(embds: torch.Tensor):
    result = model(inputs_embeds=embds.cuda()).logits
    return torch.nn.functional.softmax(result[:, -1], dim=-1).cpu()

In [16]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32003, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear4bit(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm

In [17]:
# replace the normal pytorch embeddings (which only take ints) to interpretable embeddings
# (which are compatible with the float inputs that integratedgradients gives)
if hasattr(model, "model"):
    interpretable_emb = attr.configure_interpretable_embedding_layer(
        model.model, "embed_tokens"
    )
elif hasattr(model, "transformer"):
    interpretable_emb = attr.configure_interpretable_embedding_layer(
        model.transformer, "word_embeddings"
    )
else:
    print("What happened")
model



LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): InterpretableEmbeddingBase(
      (embedding): Embedding(32003, 4096)
    )
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear4bit(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attenti

In [18]:
input_embs = interpretable_emb.indices_to_embeddings(tokens).cpu()
print(input_embs.device)
input_embs.shape

cpu


torch.Size([1, 80, 4096])

In [19]:
baselines = torch.zeros_like(input_embs).cpu()
print(baselines.device)
baselines.shape

cpu


torch.Size([1, 80, 4096])

In [20]:
ig = attr.IntegratedGradients(softmax_results_embeds)

In [21]:
from tqdm import tqdm

attributions = {}
for char, label_token in tqdm(label_tokens.items()):
    attributions[char] = ig.attribute(
        input_embs,
        baselines=baselines,
        target=label_token,
        n_steps=192,
        internal_batch_size=2,
        return_convergence_delta=True,
    )

100%|█████████████████████████████████████████████| 6/6 [03:20<00:00, 33.45s/it]


## Visualize Attributions

In [22]:
attributions["A"][0].shape

torch.Size([1, 80, 4096])

In [23]:
def summarize_attributions(attributions):
    with torch.no_grad():
        attributions = attributions.sum(dim=-1).squeeze(0)
        print((attributions / torch.norm(attributions)).sum())
        return attributions

In [24]:
summarized_attributions = {
    char: summarize_attributions(attribution)
    for char, (attribution, _) in attributions.items()
}
summarized_attributions["A"].shape

tensor(0.6762, dtype=torch.float64)
tensor(0.1900, dtype=torch.float64)
tensor(-0.1262, dtype=torch.float64)
tensor(0.0530, dtype=torch.float64)
tensor(0.6956, dtype=torch.float64)
tensor(0.3716, dtype=torch.float64)


torch.Size([80])

In [25]:
all_tokens = tokens.squeeze(0)
print(all_tokens.shape)
all_tokens = list(map(tokenizer.decode, all_tokens))
print(all_tokens)

torch.Size([80])
['<s>', 'Please', 'select', 'the', 'option', 'that', 'most', 'closely', 'describes', 'the', 'following', 'claim', 'by', 'D', 'way', 'ne', 'Boh', 'ac', ':', '\n', 'the', 'Ann', 'ies', 'List', 'political', 'group', 'supports', 'third', '-', 'trim', 'ester', 'abort', 'ions', 'on', 'demand', '.', '\n', '\n', 'A', ')', 'True', '\n', 'B', ')', 'Most', 'ly', 'True', '\n', 'C', ')', 'Half', 'True', '\n', 'D', ')', 'B', 'are', 'ly', 'True', '\n', 'E', ')', 'False', '\n', 'F', ')', 'P', 'ants', 'on', 'Fire', '(', 'abs', 'urd', 'lie', ')', '\n', '\n', 'Choice', ':', '(']


In [26]:
# remove the interpretable embedding layer so we can get regular predictions
if hasattr(model, "model"):
    attr.remove_interpretable_embedding_layer(model.model, interpretable_emb)
else:
    attr.remove_interpretable_embedding_layer(model.transformer, interpretable_emb)
with torch.no_grad():
    predictions = softmax_results(tokens)
predictions.shape

torch.Size([1, 32003])

In [27]:
MARGIN_OF_ERROR = 0.1  # off by no more than 10 percent
for char in summarized_attributions.keys():
    print("for", char)
    if (
        torch.abs(
            (summarized_attributions[char].sum() - predictions[0, label_tokens[char]])
        )
        >= MARGIN_OF_ERROR
    ):
        print("we are off!!")
        print("we should be getting somewhere near", predictions[0, label_tokens[char]])
        print("instead, we get", summarized_attributions[char].sum())
    else:
        print("we are pretty close!")
        print(
            "got",
            summarized_attributions[char].sum(),
            "instead of",
            predictions[0, label_tokens[char]],
        )

for A
we are pretty close!
got tensor(0.2962, dtype=torch.float64) instead of tensor(0.2403)
for B
we are pretty close!
got tensor(0.1822, dtype=torch.float64) instead of tensor(0.1758)
for C
we are off!!
we should be getting somewhere near tensor(0.1109)
instead, we get tensor(-0.0368, dtype=torch.float64)
for D
we are off!!
we should be getting somewhere near tensor(0.1539)
instead, we get tensor(0.0402, dtype=torch.float64)
for E
we are pretty close!
got tensor(0.1340, dtype=torch.float64) instead of tensor(0.1348)
for F
we are pretty close!
got tensor(0.1986, dtype=torch.float64) instead of tensor(0.1306)


In [28]:
from dataclasses import dataclass
from typing import Any


@dataclass
class CustomDataRecord:
    word_attributions: Any
    pred_prob: torch.Tensor
    pred_class: str
    attr_class: str
    attr_prob: torch.Tensor
    attr_score: Any
    raw_input_ids: Any | list[str]
    convergence_delta: Any


attr_vis = [
    CustomDataRecord(
        summarized_attributions[char] * 20,  # word attributions
        predictions[0].max(),  # predicted probability
        tokenizer.decode(torch.argmax(predictions[0])),  # predicted class
        char,  # attr class
        predictions[0, label_tokens[char]],  # attr probability
        summarized_attributions[char].sum(),  # attr score
        all_tokens,  # raw input ids
        attributions[char][1],  # convergence delta
    )
    for char in label_tokens.keys()
]

In [29]:
from IPython.display import HTML
def visualize_text(
    datarecords: list[CustomDataRecord], legend: bool = True
) -> "HTML":  # In quotes because this type doesn't exist in standalone mode
    dom = ["<table width: 100%>"]
    rows = [
        "<tr><th>Predicted Label</th>"
        "<th>Attribution Label</th>"
        "<th>Convergence Delta</th>"
        "<th>Attribution Score</th>"
        "<th>Word Importance</th>"
    ]
    for datarecord in datarecords:
        rows.append(
            "".join(
                [
                    "<tr>",
                    viz.format_classname(
                        "{0} ({1:.2f})".format(
                            datarecord.pred_class, datarecord.pred_prob
                        )
                    ),
                    viz.format_classname(
                        "{0} ({1:.2f})".format(
                            datarecord.attr_class, datarecord.attr_prob
                        )
                    ),
                    viz.format_classname(
                        "{0:.2f}".format(datarecord.convergence_delta.item())
                    ),
                    viz.format_classname("{0:.2f}".format(datarecord.attr_score)),
                    viz.format_word_importances(
                        datarecord.raw_input_ids, datarecord.word_attributions
                    ),
                    "<tr>",
                ]
            )
        )

    if legend:
        dom.append(
            '<div style="border-top: 1px solid; margin-top: 5px; \
            padding-top: 5px; display: inline-block">'
        )
        dom.append("<b>Legend: </b>")

        for value, label in zip([-1, 0, 1], ["Negative", "Neutral", "Positive"]):
            dom.append(
                '<span style="display: inline-block; width: 10px; height: 10px; \
                border: 1px solid; background-color: \
                {value}"></span> {label}  '.format(
                    value=viz._get_color(value), label=label
                )
            )
        dom.append("</div>")

    dom.append("".join(rows))
    dom.append("</table>")
    html = HTML("".join(dom))
    display(html)

    return html

In [30]:
html = visualize_text(attr_vis)
print("Results")

Predicted Label,Attribution Label,Convergence Delta,Attribution Score,Word Importance
A (0.24),A (0.24),0.06,0.3,#s Please select the option that most closely describes the following claim by D way ne Boh ac : the Ann ies List political group supports third - trim ester abort ions on demand . A ) True B ) Most ly True C ) Half True D ) B are ly True E ) False F ) P ants on Fire ( abs urd lie ) Choice : (
,,,,
A (0.24),B (0.18),0.01,0.18,#s Please select the option that most closely describes the following claim by D way ne Boh ac : the Ann ies List political group supports third - trim ester abort ions on demand . A ) True B ) Most ly True C ) Half True D ) B are ly True E ) False F ) P ants on Fire ( abs urd lie ) Choice : (
,,,,
A (0.24),C (0.11),-0.15,-0.04,#s Please select the option that most closely describes the following claim by D way ne Boh ac : the Ann ies List political group supports third - trim ester abort ions on demand . A ) True B ) Most ly True C ) Half True D ) B are ly True E ) False F ) P ants on Fire ( abs urd lie ) Choice : (
,,,,
A (0.24),D (0.15),-0.11,0.04,#s Please select the option that most closely describes the following claim by D way ne Boh ac : the Ann ies List political group supports third - trim ester abort ions on demand . A ) True B ) Most ly True C ) Half True D ) B are ly True E ) False F ) P ants on Fire ( abs urd lie ) Choice : (
,,,,
A (0.24),E (0.13),-0.0,0.13,#s Please select the option that most closely describes the following claim by D way ne Boh ac : the Ann ies List political group supports third - trim ester abort ions on demand . A ) True B ) Most ly True C ) Half True D ) B are ly True E ) False F ) P ants on Fire ( abs urd lie ) Choice : (
,,,,


Results


In [31]:
open(f"{model_name[model_name.index('/')+1:]}.html", "w").write(html.data)

78683