# Captum

## Setup

In [1]:
import captum
import captum.attr as attr
from captum.attr import visualization as viz

In [2]:
from transformers import AutoTokenizer, LlamaForCausalLM, BitsAndBytesConfig
import torch

config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = LlamaForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf", quantization_config=config
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
def make_prompt(instruction: str, sys_message: str = None):
    if sys_message is None:
        sys_message = (
"""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. Please begin your response with a "Yes" or "No" before elaborating.
""")
    return (
f"""<s>[INST] <<SYS>>
{sys_message}
<</SYS>

{instruction} [/INST]

""")


# prompt = make_prompt("There's a llama in my garden 😱 What should I do?")
prompt = make_prompt("Is it true that vaccines cause autism?")

In [4]:
print(
    tokenizer.batch_decode(
        model.generate(tokenizer(prompt, return_tensors="pt").input_ids.cuda())
    )[0]
)

<s><s> [INST] <<SYS>>
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. Please begin your response with a "Yes" or "No" before elaborating.

<</SYS>

Is it true that vaccines cause autism? [/INST]

No, it is not true that vaccines cause autism. This is a common myth that has been thoroughly debunked by scientific evidence. Vaccines are safe and effective in preventing serious diseases, and they have been proven to be unrelated to the development of autism. The Centers for Disease Control and Prevention (CDC), the World Health Organizat

## Yes/No Answering

In [5]:
tokens = tokenizer(prompt, return_tensors="pt").input_ids
print(tokens.shape)

torch.Size([1, 170])


In [6]:
with torch.no_grad():
    result = model(tokens.cuda()).logits.cpu()

In [7]:
tokenizer.batch_decode(torch.argmax(result, dim=-1))

['Unterscheidung Unterscheidung![0 How<<MB> What Unterscheidung are a  and friendlyful, knowledge assistant. I. questions iffully as possible, and also mind and\n\n goal will be promote any harmful or unethical, dangerousist or sexist, toxic or or or or illegal content.\n ref that your responses are socially unbiased and positive in nature,\n\nI you user is not make sense sense, please is harm answerual correcterent, please why and of answering it that relevant.\n you are\'t know the answer to a question, say say\'t make anything information or Instead be each response with " polI" or "No" to providingating on\n\nLetainstS>>\nCan it possible that ifaccines cause autism?\nNoNo] \nNo']

In [8]:
idx = torch.argmax(result[0, -1], dim=-1)
print(f"idx {idx} corresponds to {tokenizer.decode(idx)}")

idx 3782 corresponds to No


In [9]:
softmax_results = torch.nn.functional.softmax(result[0, -1], dim=0)

In [10]:
# 8241 = Yes
# 3782 = No
yes_idx = 8241
no_idx = 3782

In [11]:
print(f"Yes probability: {softmax_results[yes_idx]:.3f} No probability {softmax_results[no_idx]:.3f}")

Yes probability: 0.054 No probability 0.799


## Integrated Gradients

In [12]:
def softmax_results(inputs: torch.Tensor):
    result = model(inputs.cuda()).logits
    return torch.nn.functional.softmax(result[:, -1], dim=-1)
    

In [13]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear4bit(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm

In [14]:
softmax_results(tokens)

tensor([[1.6596e-10, 9.2305e-11, 1.2949e-09,  ..., 2.2215e-10, 1.5060e-10,
         2.4600e-10]], device='cuda:0', grad_fn=<SoftmaxBackward0>)

In [15]:
# replace the normal pytorch embeddings (which only take ints) to interpretable embeddings
# (which are compatible with the float inputs that integratedgradients gives)
interpretable_emb = attr.configure_interpretable_embedding_layer(model.model, 'embed_tokens')
model



LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): InterpretableEmbeddingBase(
      (embedding): Embedding(32000, 4096)
    )
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear4bit(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attenti

In [16]:
input_embs = interpretable_emb.indices_to_embeddings(tokens)
input_embs.shape

torch.Size([1, 170, 4096])

In [17]:
input_embs

tensor([[[ 1.8387e-03, -3.8147e-03,  9.6130e-04,  ..., -9.0332e-03,
           2.6550e-03, -3.7537e-03],
         [ 1.8387e-03, -3.8147e-03,  9.6130e-04,  ..., -9.0332e-03,
           2.6550e-03, -3.7537e-03],
         [ 2.7710e-02, -6.0425e-03,  3.4943e-03,  ...,  5.1117e-04,
          -8.1787e-03,  6.4087e-03],
         ...,
         [-2.3315e-02, -1.8677e-02,  5.8594e-03,  ...,  1.7578e-02,
          -2.0266e-05,  1.2390e-02],
         [-1.4954e-03, -3.5095e-04, -1.7624e-03,  ...,  8.8882e-04,
           1.6937e-03, -6.9427e-04],
         [-1.4954e-03, -3.5095e-04, -1.7624e-03,  ...,  8.8882e-04,
           1.6937e-03, -6.9427e-04]]], device='cuda:0', dtype=torch.float16,
       grad_fn=<EmbeddingBackward0>)

In [18]:
ig = attr.IntegratedGradients(softmax_results)

In [19]:
no_attributions, no_conv_score = ig.attribute(input_embs, target=no_idx, n_steps=15, internal_batch_size=3, return_convergence_delta=True)

In [20]:
yes_attributions, yes_conv_score = ig.attribute(input_embs, target=yes_idx, n_steps=15, internal_batch_size=3, return_convergence_delta=True)

## Visualize Attributions

In [21]:
no_attributions.shape

torch.Size([1, 170, 4096])

In [22]:
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [23]:
no_attributions = summarize_attributions(no_attributions)
yes_attributions = summarize_attributions(yes_attributions)

no_attributions.shape

torch.Size([170])

In [24]:
all_tokens = tokens.squeeze(0)
print(all_tokens.shape)
all_tokens = list(map(tokenizer.decode, all_tokens))
print(all_tokens)

torch.Size([170])
['<s>', '<s>', '[', 'INST', ']', '<<', 'SY', 'S', '>>', '\n', 'You', 'are', 'a', 'helpful', ',', 'respect', 'ful', 'and', 'honest', 'assistant', '.', 'Always', 'answer', 'as', 'help', 'fully', 'as', 'possible', ',', 'while', 'being', 'safe', '.', '', 'Your', 'answers', 'should', 'not', 'include', 'any', 'harm', 'ful', ',', 'un', 'eth', 'ical', ',', 'rac', 'ist', ',', 'sex', 'ist', ',', 'to', 'xic', ',', 'dangerous', ',', 'or', 'illegal', 'content', '.', 'Please', 'ensure', 'that', 'your', 'responses', 'are', 'soci', 'ally', 'un', 'bi', 'ased', 'and', 'positive', 'in', 'nature', '.', '\n', '\n', 'If', 'a', 'question', 'does', 'not', 'make', 'any', 'sense', ',', 'or', 'is', 'not', 'fact', 'ually', 'coh', 'er', 'ent', ',', 'explain', 'why', 'instead', 'of', 'answering', 'something', 'not', 'correct', '.', 'If', 'you', 'don', "'", 't', 'know', 'the', 'answer', 'to', 'a', 'question', ',', 'please', 'don', "'", 't', 'share', 'false', 'information', '.', 'Please', 'begin', '

In [25]:
# remove the interpretable embedding layer so we can get regular predictions
attr.remove_interpretable_embedding_layer(model.model, interpretable_emb)
predictions = softmax_results(tokens)
predictions.shape

torch.Size([1, 32000])

In [26]:
predictions[0, no_idx]

tensor(0.7989, device='cuda:0', grad_fn=<SelectBackward0>)

In [27]:
help(viz.VisualizationDataRecord)

Help on class VisualizationDataRecord in module captum.attr._utils.visualization:

class VisualizationDataRecord(builtins.object)
 |  VisualizationDataRecord(word_attributions, pred_prob, pred_class, true_class, attr_class, attr_score, raw_input_ids, convergence_score) -> None
 |  
 |  A data record for storing attribution relevant information
 |  
 |  Methods defined here:
 |  
 |  __init__(self, word_attributions, pred_prob, pred_class, true_class, attr_class, attr_score, raw_input_ids, convergence_score) -> None
 |      Initialize self.  See help(type(self)) for accurate signature.
 |  
 |  ----------------------------------------------------------------------
 |  Data descriptors defined here:
 |  
 |  attr_class
 |  
 |  attr_score
 |  
 |  convergence_score
 |  
 |  pred_class
 |  
 |  pred_prob
 |  
 |  raw_input_ids
 |  
 |  true_class
 |  
 |  word_attributions



In [28]:
no_attributions.shape

torch.Size([170])

In [29]:
no_attr_vis = viz.VisualizationDataRecord(
    no_attributions,
    predictions[0, no_idx],
    no_idx,
    torch.argmax(predictions[0]),
    "No",  # ground truth
    no_attributions.sum(),  
    all_tokens,
    no_conv_score
)

yes_attr_vis = viz.VisualizationDataRecord(
    yes_attributions,
    predictions[0, yes_idx],
    yes_idx,
    torch.argmax(predictions[0]),
    "Yes",  # ground truth
    yes_attributions.sum(),  
    all_tokens,
    yes_conv_score
)

In [30]:
viz.visualize_text([no_attr_vis])
viz.visualize_text([yes_attr_vis])
print("Results")

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
3782.0,3782 (0.80),No,0.93,"#s #s [ INST ] << SY S >> You are a helpful , respect ful and honest assistant . Always answer as help fully as possible , while being safe . Your answers should not include any harm ful , un eth ical , rac ist , sex ist , to xic , dangerous , or illegal content . Please ensure that your responses are soci ally un bi ased and positive in nature . If a question does not make any sense , or is not fact ually coh er ent , explain why instead of answering something not correct . If you don ' t know the answer to a question , please don ' t share false information . Please begin your response with a "" Yes "" or "" No "" before elabor ating . < SY S > Is it true that v acc ines cause aut ism ? [ / INST ]"
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
3782.0,8241 (0.05),Yes,-0.87,"#s #s [ INST ] << SY S >> You are a helpful , respect ful and honest assistant . Always answer as help fully as possible , while being safe . Your answers should not include any harm ful , un eth ical , rac ist , sex ist , to xic , dangerous , or illegal content . Please ensure that your responses are soci ally un bi ased and positive in nature . If a question does not make any sense , or is not fact ually coh er ent , explain why instead of answering something not correct . If you don ' t know the answer to a question , please don ' t share false information . Please begin your response with a "" Yes "" or "" No "" before elabor ating . < SY S > Is it true that v acc ines cause aut ism ? [ / INST ]"
,,,,


Results
