# Lime And Shap Explainability

## Dataset Setup

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [2]:
import datasets

liar = datasets.load_dataset("liar")
full_liar = datasets.concatenate_datasets(
    [liar["train"], liar["test"], liar["validation"]]
)
full_liar

  table = cls._concat_blocks(blocks, axis=0)


Dataset({
    features: ['id', 'label', 'statement', 'subject', 'speaker', 'job_title', 'state_info', 'party_affiliation', 'barely_true_counts', 'false_counts', 'half_true_counts', 'mostly_true_counts', 'pants_on_fire_counts', 'context'],
    num_rows: 12836
})

## Setup

In [3]:
import captum
import captum.attr as attr
from captum.attr import visualization as viz
from captum._utils.models.linear_model import SkLearnLasso

In [4]:
falcon = "tiiuae/falcon-7b-instruct"
llama = "meta-llama/Llama-2-7b-chat-hf"
mistral = "mistralai/Mistral-7B-Instruct-v0.2"
orca = "microsoft/Orca-2-7b"

In [5]:
model_name = mistral

In [6]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
import torch.nn.functional as F

config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name, quantization_config=config, # device_map="auto"
)

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

## Utils

In [7]:
n_examples = 0

In [8]:
from typing import Dict

LABEL_MAP = {
    0: "E",  # 0 : False
    1: "C",  # 1 : Half True
    2: "B",  # 2 : Mostly True
    3: "A",  # 3 : True
    4: "D",  # 4 : Barely True
    5: "F",  # 5 : Pants on Fire
}


def was_correct(decoded: str, entry: Dict[str, int]) -> bool:
    return LABEL_MAP[entry["label"]] in decoded

In [9]:
import random
from typing import Dict

random.seed(1770)
entries = random.choices(list(range(len(full_liar))), k=n_examples)


def to_zero_shot_prompt(entry: Dict[str, str]) -> str:
    speaker = entry["speaker"].replace("-", " ").title()
    statement = entry["statement"].lstrip("Says ")

    prompt = f"""Please select the option that most closely describes the following claim by {speaker}:\n{statement}\n\nA) True\nB) Mostly True\nC) Half True\nD) Barely True\nE) False\nF) Pants on Fire (absurd lie)\n\nChoice: ("""
    return prompt


def to_n_shot_prompt(n: int, entry: Dict[str, str]) -> str:
    examples = ""
    for i in range(n):
        examples += (
            to_zero_shot_prompt(full_liar[entries[i]])
            + LABEL_MAP[full_liar[entries[i]]["label"]]
            + "\n\n"
        )
    prompt = to_zero_shot_prompt(entry)
    return examples + prompt

In [10]:
import random 
k = random.randint(0, len(full_liar))
generation = tokenizer.batch_decode(
        model.generate(
            tokenizer(
                to_zero_shot_prompt(full_liar[k]), return_tensors="pt"
            ).input_ids.cuda(),
            max_new_tokens=200,
        )
    )[0]
print(generation)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


<s> Please select the option that most closely describes the following claim by Jorge Elorza:
The reality is that we have roughly 15,000 undocumented immigrants living in the state...

A) True
B) Mostly True
C) Half True
D) Barely True
E) False
F) Pants on Fire (absurd lie)

Choice: (A) True

Explanation: Jorge Elorza's statement is based on a report from the Rhode Island Department of Health, which estimated that there were approximately 15,000 undocumented immigrants living in the state in 2013. This estimate was derived from data on the number of uninsured residents, assuming a certain percentage were undocumented. While there may be some margin of error in this estimate, it is generally accepted as a reliable approximation of the number of undocumented immigrants in Rhode Island. Therefore, Jorge Elorza's statement is true.</s>


In [11]:
continuation = f"""{generation.lstrip('<s> ')}

Explain your answer:

"""

print(
    tokenizer.batch_decode(
        model.generate(
            tokenizer(
                continuation, return_tensors="pt"
            ).input_ids.cuda(),
            max_new_tokens=200,
        )
    )[0]
)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


<s> Please select the option that most closely describes the following claim by Jorge Elorza:
The reality is that we have roughly 15,000 undocumented immigrants living in the state...

A) True
B) Mostly True
C) Half True
D) Barely True
E) False
F) Pants on Fire (absurd lie)

Choice: (A) True

Explanation: Jorge Elorza's statement is based on a report from the Rhode Island Department of Health, which estimated that there were approximately 15,000 undocumented immigrants living in the state in 2013. This estimate was derived from data on the number of uninsured residents, assuming a certain percentage were undocumented. While there may be some margin of error in this estimate, it is generally accepted as a reliable approximation of the number of undocumented immigrants in Rhode Island. Therefore, Jorge Elorza's statement is true.</s> 

Explain your answer:

Jorge Elorza's statement that "we have roughly 15,000 undocumented immigrants living in the state" is based on a report from the Rho

In [12]:
vocab = tokenizer.vocab
label_tokens = {
    "A": [],
    "B": [],
    "C": [],
    "D": [],
    "E": [],
    "F": [],
}
for char, idx in vocab.items():
    if char in label_tokens:
        label_tokens[char].append(idx)
label_tokens

{'A': [28741],
 'B': [28760],
 'C': [28743],
 'D': [28757],
 'E': [28749],
 'F': [28765]}

In [13]:
label_tokens = {char: idxs[0] for char, idxs in label_tokens.items()}
label_tokens

{'A': 28741, 'B': 28760, 'C': 28743, 'D': 28757, 'E': 28749, 'F': 28765}

In [14]:
prompt = to_n_shot_prompt(n_examples, full_liar[k])

In [15]:
# get tokens for later
tokens = tokenizer(prompt, return_tensors="pt").input_ids
print(tokens.shape)

torch.Size([1, 85])


In [16]:
tokens

tensor([[    1,  5919,  5339,   272,  3551,   369,  1080, 11640, 13966,   272,
          2296,  3452,   486, 26955,  1744,   271,  2166, 28747,    13,  1014,
          6940,   349,   369,   478,   506, 15756, 28705, 28740, 28782, 28725,
         28734, 28734, 28734,   640,  2048,   286, 22475,  3687,   297,   272,
          1665,  1101,    13,    13, 28741, 28731,  6110,    13, 28760, 28731,
          4822,   346,  6110,    13, 28743, 28731, 18994,  6110,    13, 28757,
         28731,   365,  6672,  6110,    13, 28749, 28731,  8250,    13, 28765,
         28731,   367,  1549,   356,  8643,   325,  4737, 12725,  4852, 28731,
            13,    13, 28456, 28747,   325]])

## LIME / SHAP

In [17]:
LIME = "LIME"
SHAP = "SHAP"
EXPERIMENT_TYPE = LIME

In [18]:
def softmax_results(tokens: torch.Tensor):
    with torch.no_grad():
        if tokenizer.bos_token_id is not None:  # falcon's is None
            tokens[0, 0] = tokenizer.bos_token_id
        result = model(
            torch.where(tokens != 0, tokens, tokenizer.eos_token_id).cuda(),
            attention_mask=torch.where(tokens != 0, 1, 0).cuda()
        ).logits
        ret = torch.nn.functional.softmax(result[:, -1], dim=-1).cpu()
        assert not ret.isnan().any()
    return ret


def get_embeds(tokens: torch.Tensor):
    with torch.no_grad():
        if hasattr(model, "model"):
            return model.model.embed_tokens(tokens.cuda())
        elif hasattr(model, "transformer"):
            return model.transformer.word_embeddings(tokens.cuda())
        raise Exception("Unknown model format")

In [19]:
# encode text indices into latent representations & calculate cosine similarity
def exp_embedding_cosine_distance(original_inp, perturbed_inp, _, **kwargs):
    original_emb = get_embeds(original_inp)
    perturbed_emb = get_embeds(perturbed_inp)
    distance = 1 - F.cosine_similarity(original_emb, perturbed_emb, dim=-1)
    distance[distance.isnan()] = 0
    ret = torch.exp(-1 * (distance**2) / 2).sum()
    assert not ret.isnan().any()
    return ret


# binary vector where each word is selected independently and uniformly at random
i = 0


def bernoulli_perturb(text, **kwargs):
    global i
    probs = torch.ones_like(text) * 0.5
    probs[0, 0] = 0  # don't get rid of the start token
    ret = torch.bernoulli(probs).long()
    i += 1
    return ret


# remove absent tokens based on the intepretable representation sample
def interp_to_input(interp_sample, original_input, **kwargs):
    ret = original_input.clone()
    ret[interp_sample.bool()] = 0
    return ret


if EXPERIMENT_TYPE == LIME:
    attributers = {
        char: attr.LimeBase(
            softmax_results,
            interpretable_model=SkLearnLasso(alpha=0.0005),
            similarity_func=exp_embedding_cosine_distance,
            perturb_func=bernoulli_perturb,
            perturb_interpretable_space=True,
            from_interp_rep_transform=interp_to_input,
            to_interp_rep_transform=None,
        )
        for char in label_tokens.keys()
    }
elif EXPERIMENT_TYPE == SHAP:
    attributers = {
        char : attr.KernelShap(softmax_results)
        for char in label_tokens.keys()
    }
else:
    raise Exception("Invalid Experiment Type")

In [None]:
attributions = {}
for char, label_token in label_tokens.items():
    attributions[char] = attributers[char].attribute(
        tokens, target=label_token, n_samples=512, show_progress=True
    )

## Visualize Attributions

In [21]:
print(attributions["A"].shape)
assert attributions["A"].nonzero().numel() != 0

torch.Size([1, 85])


In [22]:
all_tokens = tokens.squeeze(0)
print(all_tokens.shape)
all_tokens = list(map(tokenizer.decode, all_tokens))
print(all_tokens)

torch.Size([85])
['<s>', 'Please', 'select', 'the', 'option', 'that', 'most', 'closely', 'describes', 'the', 'following', 'claim', 'by', 'Jorge', 'El', 'or', 'za', ':', '\n', 'The', 'reality', 'is', 'that', 'we', 'have', 'roughly', '', '1', '5', ',', '0', '0', '0', 'und', 'ocument', 'ed', 'immigrants', 'living', 'in', 'the', 'state', '...', '\n', '\n', 'A', ')', 'True', '\n', 'B', ')', 'Most', 'ly', 'True', '\n', 'C', ')', 'Half', 'True', '\n', 'D', ')', 'B', 'arely', 'True', '\n', 'E', ')', 'False', '\n', 'F', ')', 'P', 'ants', 'on', 'Fire', '(', 'abs', 'urd', 'lie', ')', '\n', '\n', 'Choice', ':', '(']


In [23]:
with torch.no_grad():
    predictions = softmax_results(tokens)
predictions.shape

torch.Size([1, 32000])

In [24]:
from dataclasses import dataclass
from typing import Any


@dataclass
class CustomDataRecord:
    word_attributions: Any
    pred_prob: torch.Tensor
    pred_class: str
    attr_class: str
    attr_prob: torch.Tensor
    attr_score: Any
    raw_input_ids: Any | list[str]
    convergence_delta: Any = None

SCALE = 35

attr_vis = [
    CustomDataRecord(
        attributions[char][0] * SCALE,  # word attributions
        predictions[0].max(),  # predicted probability
        tokenizer.decode(torch.argmax(predictions[0])),  # predicted class
        char,  # attr class
        predictions[0, label_tokens[char]],  # attr probability
        attributions[char].sum(),  # attr score
        all_tokens,  # raw input ids
        abs(predictions[0, label_tokens[char]] - attributions[char].sum()) if EXPERIMENT_TYPE == SHAP else None
    )
    for char in attributions.keys()
]

In [25]:
from IPython.display import HTML


def visualize_text(
    datarecords: list[CustomDataRecord], legend: bool = True
) -> "HTML":  # In quotes because this type doesn't exist in standalone mode
    dom = ["<table width: 100%>"]
    if EXPERIMENT_TYPE == LIME:
        rows = [
            "<tr><th>Predicted Label</th>"
            "<th>Attribution Label</th>"
            "<th>Word Importance</th>"
        ]
    else:
        rows = [
            "<tr><th>Predicted Label</th>"
            "<th>Attribution Label</th>"
            "<th>Convergence Delta</th>"
            "<th>Word Importance</th>"
        ]
    for datarecord in datarecords:
        if EXPERIMENT_TYPE == LIME:
            rows.append(
                "".join(
                    [
                        "<tr>",
                        viz.format_classname(
                            "{0} ({1:.2f})".format(
                                datarecord.pred_class, datarecord.pred_prob
                            )
                        ),
                        viz.format_classname(
                            "{0} ({1:.2f})".format(
                                datarecord.attr_class, datarecord.attr_prob
                            )
                        ),
                        viz.format_word_importances(
                            datarecord.raw_input_ids, datarecord.word_attributions
                        ),
                        "<tr>",
                    ]
                )
            )
        else:
            rows.append(
                "".join(
                    [
                        "<tr>",
                        viz.format_classname(
                            "{0} ({1:.2f})".format(
                                datarecord.pred_class, datarecord.pred_prob
                            )
                        ),
                        viz.format_classname(
                            "{0} ({1:.2f})".format(
                                datarecord.attr_class, datarecord.attr_prob
                            )
                        ),
                        viz.format_classname("{0:.2f}".format(datarecord.convergence_delta)),
                        viz.format_word_importances(
                            datarecord.raw_input_ids, datarecord.word_attributions
                        ),
                        "<tr>",
                    ]
                )
            )


    if legend:
        dom.append(
            '<div style="border-top: 1px solid; margin-top: 5px; \
            padding-top: 5px; display: inline-block">'
        )
        dom.append("<b>Legend: </b>")

        for value, label in zip([-1, 0, 1], ["Negative", "Neutral", "Positive"]):
            dom.append(
                '<span style="display: inline-block; width: 10px; height: 10px; \
                border: 1px solid; background-color: \
                {value}"></span> {label}  '.format(
                    value=viz._get_color(value), label=label
                )
            )
        dom.append("</div>")

    dom.append("".join(rows))
    dom.append("</table>")
    html = HTML("".join(dom))
    display(html)

    return html

In [32]:
LABEL_MAP[full_liar[k]['label']]

'E'

In [26]:
html = visualize_text(attr_vis)
print("Results")

Predicted Label,Attribution Label,Word Importance
A (0.53),A (0.53),"#s Please select the option that most closely describes the following claim by Jorge El or za : The reality is that we have roughly 1 5 , 0 0 0 und ocument ed immigrants living in the state ... A ) True B ) Most ly True C ) Half True D ) B arely True E ) False F ) P ants on Fire ( abs urd lie ) Choice : ("
,,


Results


In [27]:
open(f"{model_name[model_name.index('/')+1:]}_{EXPERIMENT_TYPE.lower()}_side.html", "w").write(html.data)

14392

In [None]:
from collections import OrderedDict
# save pts as well
torch.save(
    OrderedDict(attributions),
    f"{model_name[model_name.index('/')+1:]}_{EXPERIMENT_TYPE.lower()}.pt",
)
