# Captum

## Dataset Setup

In [1]:
import datasets

liar = datasets.load_dataset("liar")
full_liar = datasets.concatenate_datasets(
    [liar["train"], liar["test"], liar["validation"]]
)
full_liar

  table = cls._concat_blocks(blocks, axis=0)


Dataset({
    features: ['id', 'label', 'statement', 'subject', 'speaker', 'job_title', 'state_info', 'party_affiliation', 'barely_true_counts', 'false_counts', 'half_true_counts', 'mostly_true_counts', 'pants_on_fire_counts', 'context'],
    num_rows: 12836
})

## Setup

In [2]:
import captum
import captum.attr as attr
from captum.attr import visualization as viz
from captum._utils.models.linear_model import SkLearnLasso

In [3]:
falcon = "tiiuae/falcon-7b-instruct"
llama = "meta-llama/Llama-2-7b-chat-hf"
mistral = "mistralai/Mistral-7B-Instruct-v0.2"
orca = "microsoft/Orca-2-7b"

In [4]:
model_name = falcon

In [5]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
import torch.nn.functional as F

config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name, quantization_config=config, device_map="auto"
)

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

## Utils

In [6]:
n_examples = 0

In [7]:
from typing import Dict

LABEL_MAP = {
    0: "E",  # 0 : False
    1: "C",  # 1 : Half True
    2: "B",  # 2 : Mostly True
    3: "A",  # 3 : True
    4: "D",  # 4 : Barely True
    5: "F",  # 5 : Pants on Fire
}


def was_correct(decoded: str, entry: Dict[str, int]) -> bool:
    return LABEL_MAP[entry["label"]] in decoded

In [8]:
import random
from typing import Dict

random.seed(1770)
entries = random.choices(list(range(len(full_liar))), k=n_examples)


def to_zero_shot_prompt(entry: Dict[str, str]) -> str:
    speaker = entry["speaker"].replace("-", " ").title()
    statement = entry["statement"].lstrip("Says ")

    prompt = f"""Please select the option that most closely describes the following claim by {speaker}:\n{statement}\n\nA) True\nB) Mostly True\nC) Half True\nD) Barely True\nE) False\nF) Pants on Fire (absurd lie)\n\nChoice: ("""
    return prompt


def to_n_shot_prompt(n: int, entry: Dict[str, str]) -> str:
    examples = ""
    for i in range(n):
        examples += (
            to_zero_shot_prompt(full_liar[entries[i]])
            + LABEL_MAP[full_liar[entries[i]]["label"]]
            + "\n\n"
        )
    prompt = to_zero_shot_prompt(entry)
    return examples + prompt

In [9]:
print(
    tokenizer.batch_decode(
        model.generate(
            tokenizer(
                to_zero_shot_prompt(full_liar[0]), return_tensors="pt"
            ).input_ids.cuda(),
            max_new_tokens=200,
        )
    )[0]
)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:11 for open-end generation.
The current implementation of Falcon calls `torch.scaled_dot_product_attention` directly, this will be deprecated in the future in favor of the `BetterTransformer` API. Please install the latest optimum library with `pip install -U optimum` and call `model.to_bettertransformer()` to benefit from `torch.scaled_dot_product_attention` and future performance optimizations.


Please select the option that most closely describes the following claim by Dwayne Bohac:
the Annies List political group supports third-trimester abortions on demand.

A) True
B) Mostly True
C) Half True
D) Barely True
E) False
F) Pants on Fire (absurd lie)

Choice: (E) False<|endoftext|>


In [10]:
vocab = tokenizer.vocab
label_tokens = {
    "A": [],
    "B": [],
    "C": [],
    "D": [],
    "E": [],
    "F": [],
}
for char, idx in vocab.items():
    if char in label_tokens:
        label_tokens[char].append(idx)
label_tokens

{'A': [44], 'B': [45], 'C': [46], 'D': [47], 'E': [48], 'F': [49]}

In [11]:
label_tokens = {char: idxs[0] for char, idxs in label_tokens.items()}
label_tokens

{'A': 44, 'B': 45, 'C': 46, 'D': 47, 'E': 48, 'F': 49}

In [12]:
prompt = to_n_shot_prompt(n_examples, full_liar[0])

In [13]:
# get tokens for later
tokens = tokenizer(prompt, return_tensors="pt").input_ids
print(tokens.shape)

torch.Size([1, 76])


In [14]:
tokens

tensor([[ 5000,  2523,   248,  2773,   325,   758,  8285, 11117,   248,  1863,
          2472,   431,   361, 64449, 40492,   310,    37,   193,  1410,  5142,
           424,  4702,  3709,  1408,  7244,  2914,    24,  1970, 49985, 38534,
           313,  3444,    25,   193,   193,    44,    20,  9595,   193,    45,
            20, 35623,  9595,   193,    46,    20, 16865,  9595,   193,    47,
            20, 39087,   309,  9595,   193,    48,    20, 15106,   193,    49,
            20, 41292,   313,  5848,   204,    19,  8676,  5479,  5504,    20,
           193,   193, 44595,    37,   204,    19]])

## LIME / SHAP

In [15]:
LIME = "LIME"
SHAP = "SHAP"
EXPERIMENT_TYPE = SHAP

In [21]:
def softmax_results(tokens: torch.Tensor):
    with torch.no_grad():
        if tokenizer.bos_token_id is not None:  # falcon's is None
            tokens[0, 0] = tokenizer.bos_token_id
        result = model(
            torch.where(tokens != 0, tokens, tokenizer.eos_token_id).cuda(),
            attention_mask=torch.where(tokens != 0, 1, 0).cuda()
        ).logits
        ret = torch.nn.functional.softmax(result[:, -1], dim=-1).cpu()
        assert not ret.isnan().any()
    return ret


def get_embeds(tokens: torch.Tensor):
    with torch.no_grad():
        if hasattr(model, "model"):
            return model.model.embed_tokens(tokens.cuda())
        elif hasattr(model, "transformer"):
            return model.transformer.word_embeddings(tokens.cuda())
        raise Exception("Unknown model format")

In [22]:
# encode text indices into latent representations & calculate cosine similarity
def exp_embedding_cosine_distance(original_inp, perturbed_inp, _, **kwargs):
    original_emb = get_embeds(original_inp)
    perturbed_emb = get_embeds(perturbed_inp)
    distance = 1 - F.cosine_similarity(original_emb, perturbed_emb, dim=-1)
    distance[distance.isnan()] = 0
    ret = torch.exp(-1 * (distance**2) / 2).sum()
    assert not ret.isnan().any()
    return ret


# binary vector where each word is selected independently and uniformly at random
i = 0


def bernoulli_perturb(text, **kwargs):
    global i
    probs = torch.ones_like(text) * 0.5
    probs[0, 0] = 0  # don't get rid of the start token
    ret = torch.bernoulli(probs).long()
    i += 1
    return ret


# remove absent tokens based on the intepretable representation sample
def interp_to_input(interp_sample, original_input, **kwargs):
    ret = original_input.clone()
    ret[interp_sample.bool()] = 0
    return ret


if EXPERIMENT_TYPE == LIME:
    attributers = {
        char: attr.LimeBase(
            softmax_results,
            interpretable_model=SkLearnLasso(alpha=0.0005),
            similarity_func=exp_embedding_cosine_distance,
            perturb_func=bernoulli_perturb,
            perturb_interpretable_space=True,
            from_interp_rep_transform=interp_to_input,
            to_interp_rep_transform=None,
        )
        for char in label_tokens.keys()
    }
elif EXPERIMENT_TYPE == SHAP:
    attributers = {
        char : attr.KernelShap(softmax_results)
        for char in label_tokens.keys()
    }
else:
    raise Exception("Invalid Experiment Type")

In [30]:
attributions = {}
for char, label_token in label_tokens.items():
    attributions[char] = attributers[char].attribute(
        tokens, target=label_token, n_samples=10, show_progress=True
    )

Kernel Shap attribution:   0%|          | 0/10 [00:00<?, ?it/s]

Kernel Shap attribution:   0%|          | 0/10 [00:00<?, ?it/s]

Kernel Shap attribution:   0%|          | 0/10 [00:00<?, ?it/s]

Kernel Shap attribution:   0%|          | 0/10 [00:00<?, ?it/s]

Kernel Shap attribution:   0%|          | 0/10 [00:00<?, ?it/s]

Kernel Shap attribution:   0%|          | 0/10 [00:00<?, ?it/s]

## Visualize Attributions

In [31]:
print(attributions["A"].shape)
assert attributions["A"].nonzero().numel() != 0

torch.Size([1, 76])


In [32]:
all_tokens = tokens.squeeze(0)
print(all_tokens.shape)
all_tokens = list(map(tokenizer.decode, all_tokens))
print(all_tokens)

torch.Size([76])
['Please', ' select', ' the', ' option', ' that', ' most', ' closely', ' describes', ' the', ' following', ' claim', ' by', ' D', 'wayne', ' Boh', 'ac', ':', '\n', 'the', ' Ann', 'ies', ' List', ' political', ' group', ' supports', ' third', '-', 'tr', 'imester', ' abortions', ' on', ' demand', '.', '\n', '\n', 'A', ')', ' True', '\n', 'B', ')', ' Mostly', ' True', '\n', 'C', ')', ' Half', ' True', '\n', 'D', ')', ' Bare', 'ly', ' True', '\n', 'E', ')', ' False', '\n', 'F', ')', ' Pants', ' on', ' Fire', ' ', '(', 'abs', 'urd', ' lie', ')', '\n', '\n', 'Choice', ':', ' ', '(']


In [33]:
with torch.no_grad():
    predictions = softmax_results(tokens)
predictions.shape

torch.Size([1, 65024])

In [34]:
from dataclasses import dataclass
from typing import Any


@dataclass
class CustomDataRecord:
    word_attributions: Any
    pred_prob: torch.Tensor
    pred_class: str
    attr_class: str
    attr_prob: torch.Tensor
    attr_score: Any
    raw_input_ids: Any | list[str]
    convergence_delta: Any = None

SCALE = 75

attr_vis = [
    CustomDataRecord(
        attributions[char][0] * SCALE,  # word attributions
        predictions[0].max(),  # predicted probability
        tokenizer.decode(torch.argmax(predictions[0])),  # predicted class
        char,  # attr class
        predictions[0, label_tokens[char]],  # attr probability
        attributions[char].sum(),  # attr score
        all_tokens,  # raw input ids
        abs(predictions[0, label_tokens[char]] - attributions[char].sum()) if EXPERIMENT_TYPE == SHAP else None
    )
    for char in attributions.keys()
]

In [35]:
from IPython.display import HTML


def visualize_text(
    datarecords: list[CustomDataRecord], legend: bool = True
) -> "HTML":  # In quotes because this type doesn't exist in standalone mode
    dom = ["<table width: 100%>"]
    if EXPERIMENT_TYPE == LIME:
        rows = [
            "<tr><th>Predicted Label</th>"
            "<th>Attribution Label</th>"
            "<th>Word Importance</th>"
        ]
    else:
        rows = [
            "<tr><th>Predicted Label</th>"
            "<th>Attribution Label</th>"
            "<th>Convergence Delta</th>"
            "<th>Word Importance</th>"
        ]
    for datarecord in datarecords:
        if EXPERIMENT_TYPE == LIME:
            rows.append(
                "".join(
                    [
                        "<tr>",
                        viz.format_classname(
                            "{0} ({1:.2f})".format(
                                datarecord.pred_class, datarecord.pred_prob
                            )
                        ),
                        viz.format_classname(
                            "{0} ({1:.2f})".format(
                                datarecord.attr_class, datarecord.attr_prob
                            )
                        ),
                        viz.format_word_importances(
                            datarecord.raw_input_ids, datarecord.word_attributions
                        ),
                        "<tr>",
                    ]
                )
            )
        else:
            rows.append(
                "".join(
                    [
                        "<tr>",
                        viz.format_classname(
                            "{0} ({1:.2f})".format(
                                datarecord.pred_class, datarecord.pred_prob
                            )
                        ),
                        viz.format_classname(
                            "{0} ({1:.2f})".format(
                                datarecord.attr_class, datarecord.attr_prob
                            )
                        ),
                        viz.format_classname("{0:.2f}".format(datarecord.convergence_delta)),
                        viz.format_word_importances(
                            datarecord.raw_input_ids, datarecord.word_attributions
                        ),
                        "<tr>",
                    ]
                )
            )


    if legend:
        dom.append(
            '<div style="border-top: 1px solid; margin-top: 5px; \
            padding-top: 5px; display: inline-block">'
        )
        dom.append("<b>Legend: </b>")

        for value, label in zip([-1, 0, 1], ["Negative", "Neutral", "Positive"]):
            dom.append(
                '<span style="display: inline-block; width: 10px; height: 10px; \
                border: 1px solid; background-color: \
                {value}"></span> {label}  '.format(
                    value=viz._get_color(value), label=label
                )
            )
        dom.append("</div>")

    dom.append("".join(rows))
    dom.append("</table>")
    html = HTML("".join(dom))
    display(html)

    return html

In [36]:
html = visualize_text(attr_vis)
print("Results")

Predicted Label,Attribution Label,Convergence Delta,Word Importance
E (0.19),A (0.07),0.0,Please select the option that most closely describes the following claim by D wayne Boh ac : the Ann ies List political group supports third - tr imester abortions on demand . A ) True B ) Mostly True C ) Half True D ) Bare ly True E ) False F ) Pants on Fire ( abs urd lie ) Choice : (
,,,
E (0.19),B (0.18),0.0,Please select the option that most closely describes the following claim by D wayne Boh ac : the Ann ies List political group supports third - tr imester abortions on demand . A ) True B ) Mostly True C ) Half True D ) Bare ly True E ) False F ) Pants on Fire ( abs urd lie ) Choice : (
,,,
E (0.19),C (0.14),0.0,Please select the option that most closely describes the following claim by D wayne Boh ac : the Ann ies List political group supports third - tr imester abortions on demand . A ) True B ) Mostly True C ) Half True D ) Bare ly True E ) False F ) Pants on Fire ( abs urd lie ) Choice : (
,,,
E (0.19),D (0.09),0.0,Please select the option that most closely describes the following claim by D wayne Boh ac : the Ann ies List political group supports third - tr imester abortions on demand . A ) True B ) Mostly True C ) Half True D ) Bare ly True E ) False F ) Pants on Fire ( abs urd lie ) Choice : (
,,,
E (0.19),E (0.19),0.0,Please select the option that most closely describes the following claim by D wayne Boh ac : the Ann ies List political group supports third - tr imester abortions on demand . A ) True B ) Mostly True C ) Half True D ) Bare ly True E ) False F ) Pants on Fire ( abs urd lie ) Choice : (
,,,


Results


In [25]:
open(f"{model_name[model_name.index('/')+1:]}_{EXPERIMENT_TYPE.lower()}.html", "w").write(html.data)

78755

In [26]:
from collections import OrderedDict
# save pts as well
torch.save(
    OrderedDict(attributions),
    f"{model_name[model_name.index('/')+1:]}_{EXPERIMENT_TYPE.lower()}.pt",
)
