# Captum

## Dataset Setup

In [1]:
import datasets

liar = datasets.load_dataset("liar")
full_liar = datasets.concatenate_datasets(
    [liar["train"], liar["test"], liar["validation"]]
)
full_liar

  table = cls._concat_blocks(blocks, axis=0)


Dataset({
    features: ['id', 'label', 'statement', 'subject', 'speaker', 'job_title', 'state_info', 'party_affiliation', 'barely_true_counts', 'false_counts', 'half_true_counts', 'mostly_true_counts', 'pants_on_fire_counts', 'context'],
    num_rows: 12836
})

## Setup

In [2]:
import captum
import captum.attr as attr
from captum.attr import visualization as viz
from captum._utils.models.linear_model import SkLearnLasso

In [3]:
falcon = "tiiuae/falcon-7b-instruct"
llama = "meta-llama/Llama-2-7b-chat-hf"
mistral = "mistralai/Mistral-7B-Instruct-v0.2"
orca = "microsoft/Orca-2-7b"

In [4]:
model_name = orca

In [5]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
import torch.nn.functional as F

config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name, quantization_config=config, device_map="auto"
)

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

## Utils

In [6]:
n_examples = 0

In [7]:
from typing import Dict

LABEL_MAP = {
    0: "E",  # 0 : False
    1: "C",  # 1 : Half True
    2: "B",  # 2 : Mostly True
    3: "A",  # 3 : True
    4: "D",  # 4 : Barely True
    5: "F",  # 5 : Pants on Fire
}


def was_correct(decoded: str, entry: Dict[str, int]) -> bool:
    return LABEL_MAP[entry["label"]] in decoded

In [8]:
import random
from typing import Dict

random.seed(1770)
entries = random.choices(list(range(len(full_liar))), k=n_examples)


def to_zero_shot_prompt(entry: Dict[str, str]) -> str:
    speaker = entry["speaker"].replace("-", " ").title()
    statement = entry["statement"].lstrip("Says ")

    prompt = f"""Please select the option that most closely describes the following claim by {speaker}:\n{statement}\n\nA) True\nB) Mostly True\nC) Half True\nD) Barely True\nE) False\nF) Pants on Fire (absurd lie)\n\nChoice: ("""
    return prompt


def to_n_shot_prompt(n: int, entry: Dict[str, str]) -> str:
    examples = ""
    for i in range(n):
        examples += (
            to_zero_shot_prompt(full_liar[entries[i]])
            + LABEL_MAP[full_liar[entries[i]]["label"]]
            + "\n\n"
        )
    prompt = to_zero_shot_prompt(entry)
    return examples + prompt

In [9]:
print(
    tokenizer.batch_decode(
        model.generate(
            tokenizer(
                to_zero_shot_prompt(full_liar[0]), return_tensors="pt"
            ).input_ids.cuda(),
            max_new_tokens=200,
        )
    )[0]
)

<s> Please select the option that most closely describes the following claim by Dwayne Bohac:
the Annies List political group supports third-trimester abortions on demand.

A) True
B) Mostly True
C) Half True
D) Barely True
E) False
F) Pants on Fire (absurd lie)

Choice: (A) True</s>


In [10]:
vocab = tokenizer.vocab
label_tokens = {
    "A": [],
    "B": [],
    "C": [],
    "D": [],
    "E": [],
    "F": [],
}
for char, idx in vocab.items():
    if char in label_tokens:
        label_tokens[char].append(idx)
label_tokens

{'A': [29909],
 'B': [29933],
 'C': [29907],
 'D': [29928],
 'E': [29923],
 'F': [29943]}

In [11]:
label_tokens = {char: idxs[0] for char, idxs in label_tokens.items()}
label_tokens

{'A': 29909, 'B': 29933, 'C': 29907, 'D': 29928, 'E': 29923, 'F': 29943}

In [12]:
prompt = to_n_shot_prompt(n_examples, full_liar[0])

In [13]:
# get tokens for later
tokens = tokenizer(prompt, return_tensors="pt").input_ids
print(tokens.shape)

torch.Size([1, 80])


In [14]:
tokens

tensor([[    1,  3529,  1831,   278,  2984,   393,  1556, 16467, 16612,   278,
          1494,  5995,   491,   360,  1582,   484, 17966,   562, 29901,    13,
          1552,  8081,   583,  2391,  8604,  2318, 11286,  4654, 29899, 15450,
          4156, 27450,  1080,   373,  9667, 29889,    13,    13, 29909, 29897,
          5852,    13, 29933, 29897,  7849,   368,  5852,    13, 29907, 29897,
         28144,  5852,    13, 29928, 29897,   350,   598,   368,  5852,    13,
         29923, 29897,  7700,    13, 29943, 29897,   349,  1934,   373,  6438,
           313,  6897, 18245,  3804, 29897,    13,    13, 29620, 29901,   313]])

## LIME

In [31]:
def softmax_results(tokens: torch.Tensor):
    with torch.no_grad():
        result = model(
            torch.where(tokens != 0, tokens, tokenizer.eos_token_id).cuda(), attention_mask=torch.where(tokens != 0, 1, 0).cuda()
        ).logits
        ret = torch.nn.functional.softmax(result[:, -1], dim=-1).cpu()
        if ret.isnan().any():
            print(tokens)
        assert not ret.isnan().any()
    return ret


def get_embeds(tokens: torch.Tensor):
    with torch.no_grad():
        if hasattr(model, "model"):
            return model.model.embed_tokens(tokens.cuda())
        elif hasattr(model, "transformer"):
            return model.transformer.word_embeddings(tokens.cuda())
        raise Exception("Unknown model format")

In [32]:
# encode text indices into latent representations & calculate cosine similarity
def exp_embedding_cosine_distance(original_inp, perturbed_inp, _, **kwargs):
    original_emb = get_embeds(original_inp)
    perturbed_emb = get_embeds(perturbed_inp)
    distance = 1 - F.cosine_similarity(original_emb, perturbed_emb, dim=-1)
    distance[distance.isnan()] = 0
    ret = torch.exp(-1 * (distance**2) / 2).sum()
    assert not ret.isnan().any()
    return ret


# binary vector where each word is selected independently and uniformly at random
i = 0


def bernoulli_perturb(text, **kwargs):
    global i
    probs = torch.ones_like(text) * 0.5
    probs[0, 0] = 0  # don't get rid of the start token
    ret = torch.bernoulli(probs).long()
    i += 1
    return ret


# remove absent tokens based on the intepretable representation sample
def interp_to_input(interp_sample, original_input, **kwargs):
    ret = original_input.clone()
    ret[interp_sample.bool()] = 0
    return ret


limes = {
    char: attr.LimeBase(
        softmax_results,
        interpretable_model=SkLearnLasso(alpha=0.0005),
        similarity_func=exp_embedding_cosine_distance,
        perturb_func=bernoulli_perturb,
        perturb_interpretable_space=True,
        from_interp_rep_transform=interp_to_input,
        to_interp_rep_transform=None,
    )
    for char in label_tokens.keys()
}

In [33]:
attributions = {}
for char, label_token in label_tokens.items():
    attributions[char] = limes[char].attribute(
        tokens, target=label_token, n_samples=512, show_progress=True
    )

Lime Base attribution:   0%|          | 0/512 [00:00<?, ?it/s]

Lime Base attribution:   0%|          | 0/512 [00:00<?, ?it/s]

Lime Base attribution:   0%|          | 0/512 [00:00<?, ?it/s]

Lime Base attribution:   0%|          | 0/512 [00:00<?, ?it/s]

Lime Base attribution:   0%|          | 0/512 [00:00<?, ?it/s]

Lime Base attribution:   0%|          | 0/512 [00:00<?, ?it/s]

## Visualize Attributions

In [34]:
print(attributions["A"].shape)
assert attributions["A"].nonzero().numel() != 0

torch.Size([1, 80])


In [35]:
all_tokens = tokens.squeeze(0)
print(all_tokens.shape)
all_tokens = list(map(tokenizer.decode, all_tokens))
print(all_tokens)

torch.Size([80])
['<s>', 'Please', 'select', 'the', 'option', 'that', 'most', 'closely', 'describes', 'the', 'following', 'claim', 'by', 'D', 'way', 'ne', 'Boh', 'ac', ':', '\n', 'the', 'Ann', 'ies', 'List', 'political', 'group', 'supports', 'third', '-', 'trim', 'ester', 'abort', 'ions', 'on', 'demand', '.', '\n', '\n', 'A', ')', 'True', '\n', 'B', ')', 'Most', 'ly', 'True', '\n', 'C', ')', 'Half', 'True', '\n', 'D', ')', 'B', 'are', 'ly', 'True', '\n', 'E', ')', 'False', '\n', 'F', ')', 'P', 'ants', 'on', 'Fire', '(', 'abs', 'urd', 'lie', ')', '\n', '\n', 'Choice', ':', '(']


In [36]:
with torch.no_grad():
    predictions = softmax_results(tokens)
predictions.shape

torch.Size([1, 32003])

In [37]:
from dataclasses import dataclass
from typing import Any


@dataclass
class CustomDataRecord:
    word_attributions: Any
    pred_prob: torch.Tensor
    pred_class: str
    attr_class: str
    attr_prob: torch.Tensor
    attr_score: Any
    raw_input_ids: Any | list[str]

SCALE = 75

attr_vis = [
    CustomDataRecord(
        attributions[char][0] * SCALE,  # word attributions
        predictions[0].max(),  # predicted probability
        tokenizer.decode(torch.argmax(predictions[0])),  # predicted class
        char,  # attr class
        predictions[0, label_tokens[char]],  # attr probability
        attributions[char].sum(),  # attr score
        all_tokens,  # raw input ids
    )
    for char in attributions.keys()
]

In [38]:
from IPython.display import HTML


def visualize_text(
    datarecords: list[CustomDataRecord], legend: bool = True
) -> "HTML":  # In quotes because this type doesn't exist in standalone mode
    dom = ["<table width: 100%>"]
    rows = [
        "<tr><th>Predicted Label</th>"
        "<th>Attribution Label</th>"
        "<th>Word Importance</th>"
    ]
    for datarecord in datarecords:
        rows.append(
            "".join(
                [
                    "<tr>",
                    viz.format_classname(
                        "{0} ({1:.2f})".format(
                            datarecord.pred_class, datarecord.pred_prob
                        )
                    ),
                    viz.format_classname(
                        "{0} ({1:.2f})".format(
                            datarecord.attr_class, datarecord.attr_prob
                        )
                    ),
                    viz.format_word_importances(
                        datarecord.raw_input_ids, datarecord.word_attributions
                    ),
                    "<tr>",
                ]
            )
        )

    if legend:
        dom.append(
            '<div style="border-top: 1px solid; margin-top: 5px; \
            padding-top: 5px; display: inline-block">'
        )
        dom.append("<b>Legend: </b>")

        for value, label in zip([-1, 0, 1], ["Negative", "Neutral", "Positive"]):
            dom.append(
                '<span style="display: inline-block; width: 10px; height: 10px; \
                border: 1px solid; background-color: \
                {value}"></span> {label}  '.format(
                    value=viz._get_color(value), label=label
                )
            )
        dom.append("</div>")

    dom.append("".join(rows))
    dom.append("</table>")
    html = HTML("".join(dom))
    display(html)

    return html

In [39]:
html = visualize_text(attr_vis)
print("Results")

Predicted Label,Attribution Label,Word Importance
A (0.24),A (0.24),#s Please select the option that most closely describes the following claim by D way ne Boh ac : the Ann ies List political group supports third - trim ester abort ions on demand . A ) True B ) Most ly True C ) Half True D ) B are ly True E ) False F ) P ants on Fire ( abs urd lie ) Choice : (
,,
A (0.24),B (0.18),#s Please select the option that most closely describes the following claim by D way ne Boh ac : the Ann ies List political group supports third - trim ester abort ions on demand . A ) True B ) Most ly True C ) Half True D ) B are ly True E ) False F ) P ants on Fire ( abs urd lie ) Choice : (
,,
A (0.24),C (0.11),#s Please select the option that most closely describes the following claim by D way ne Boh ac : the Ann ies List political group supports third - trim ester abort ions on demand . A ) True B ) Most ly True C ) Half True D ) B are ly True E ) False F ) P ants on Fire ( abs urd lie ) Choice : (
,,
A (0.24),D (0.15),#s Please select the option that most closely describes the following claim by D way ne Boh ac : the Ann ies List political group supports third - trim ester abort ions on demand . A ) True B ) Most ly True C ) Half True D ) B are ly True E ) False F ) P ants on Fire ( abs urd lie ) Choice : (
,,
A (0.24),E (0.13),#s Please select the option that most closely describes the following claim by D way ne Boh ac : the Ann ies List political group supports third - trim ester abort ions on demand . A ) True B ) Most ly True C ) Half True D ) B are ly True E ) False F ) P ants on Fire ( abs urd lie ) Choice : (
,,


Results


In [40]:
open(f"{model_name[model_name.index('/')+1:]}_lime.html", "w").write(html.data)

77823