# Captum

## Setup

In [1]:
import captum
import captum.attr as attr
from captum.attr import visualization as viz

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch

config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)

# model_name = "tiiuae/falcon-7b-instruct"
model_name = "meta-llama/Llama-2-7b-chat-hf"

tokenizer = AutoTokenizer.from_pretrained(model_name)


model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=config, device_map="auto")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
def make_llama_prompt(instruction: str, sys_message: str = None):
    if sys_message is None:
        sys_message = (
"""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. Please begin your response with a "Yes" or "No" before elaborating.
""")
        """ <<SYS>>
{sys_message}
<</SYS>"""
    return f"""[INST] <<SYS>>\n{sys_message}\n<</SYS>>\n{instruction} [/INST]\n"""
def make_falcon_prompt(instruction:str) -> str:
    return f"""User: {instruction}\nAssistant:"""

def make_prompt(instruction:str):
    if "llama" not in model_name:
        return make_falcon_prompt(instruction)
    return make_llama_prompt(instruction)


# prompt = make_prompt("There's a llama in my garden ðŸ˜± What should I do?")
prompt = make_prompt("Is it true that vaccines don't cause autism? Please begin your response with \"Yes\" or \"No\" before elaborating.")

In [4]:
# print(
#     tokenizer.batch_decode(
#         model.generate(tokenizer(prompt, return_tensors="pt").input_ids.cuda(), max_new_tokens=200)
#     )[0]
# )

In [5]:
# get tokens for later
tokens = tokenizer(prompt, return_tensors="pt").input_ids
print(tokens.shape)

torch.Size([1, 186])


In [6]:
tokens

tensor([[    1,   518, 25580, 29962,  3532, 14816, 29903,  6778,    13,  3492,
           526,   263,  8444, 29892,  3390,  1319,   322, 15993, 20255, 29889,
         29849,  1234,   408,  1371,  3730,   408,  1950, 29892,  1550,  1641,
          9109, 29889, 29871,  3575,  6089,   881,   451,  3160,   738, 10311,
          1319, 29892,   443,   621,   936, 29892, 11021,   391, 29892,  7916,
           391, 29892,   304, 27375, 29892, 18215, 29892,   470, 27302,  2793,
         29889,  3529,  9801,   393,   596, 20890,   526,  5374,   635,   443,
          5365,  1463,   322,  6374,   297,  5469, 29889,    13,    13,  3644,
           263,  1139,   947,   451,  1207,   738,  4060, 29892,   470,   338,
           451,  2114,  1474, 16165,   261,   296, 29892,  5649,  2020,  2012,
           310, 22862,  1554,   451,  1959, 29889,   960,   366,  1016, 29915,
         29873,  1073,   278,  1234,   304,   263,  1139, 29892,  3113,  1016,
         29915, 29873,  6232,  2089,  2472, 29889,  

In [7]:
# Llama indexes
# 8241 = Yes
# 3782 = No
yes_idx = 8241
no_idx = 3782

## Integrated Gradients

In [8]:
def softmax_results(inputs: torch.Tensor):
    result = model(inputs.cuda()).logits
    return torch.nn.functional.softmax(result[:, -1], dim=-1).cpu()
    

In [9]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear4bit(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm

In [10]:
# replace the normal pytorch embeddings (which only take ints) to interpretable embeddings
# (which are compatible with the float inputs that integratedgradients gives)
interpretable_emb = attr.configure_interpretable_embedding_layer(model.model, 'embed_tokens')
model



LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): InterpretableEmbeddingBase(
      (embedding): Embedding(32000, 4096)
    )
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear4bit(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attenti

In [11]:
input_embs = interpretable_emb.indices_to_embeddings(tokens).cpu()
print(input_embs.device)
input_embs.shape

cpu


torch.Size([1, 186, 4096])

In [12]:
baselines = torch.zeros_like(input_embs).cpu()
print(baselines.device)
baselines.shape

cpu


torch.Size([1, 186, 4096])

In [13]:
ig = attr.IntegratedGradients(softmax_results)

In [14]:
yes_attributions, yes_delta = ig.attribute(input_embs, baselines=baselines, target=yes_idx, n_steps=243*3, internal_batch_size=3, return_convergence_delta=True)

## Visualize Attributions

In [15]:
yes_attributions.shape

torch.Size([1, 186, 4096])

In [16]:
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    print((attributions / torch.norm(attributions)).sum())
    return attributions

In [17]:
yes_attributions = summarize_attributions(yes_attributions)

yes_attributions.shape

tensor(0.3143, dtype=torch.float64, grad_fn=<SumBackward0>)


torch.Size([186])

In [18]:
all_tokens = tokens.squeeze(0)
print(all_tokens.shape)
all_tokens = list(map(tokenizer.decode, all_tokens))
print(all_tokens)

torch.Size([186])
['<s>', '[', 'INST', ']', '<<', 'SY', 'S', '>>', '\n', 'You', 'are', 'a', 'helpful', ',', 'respect', 'ful', 'and', 'honest', 'assistant', '.', 'Always', 'answer', 'as', 'help', 'fully', 'as', 'possible', ',', 'while', 'being', 'safe', '.', '', 'Your', 'answers', 'should', 'not', 'include', 'any', 'harm', 'ful', ',', 'un', 'eth', 'ical', ',', 'rac', 'ist', ',', 'sex', 'ist', ',', 'to', 'xic', ',', 'dangerous', ',', 'or', 'illegal', 'content', '.', 'Please', 'ensure', 'that', 'your', 'responses', 'are', 'soci', 'ally', 'un', 'bi', 'ased', 'and', 'positive', 'in', 'nature', '.', '\n', '\n', 'If', 'a', 'question', 'does', 'not', 'make', 'any', 'sense', ',', 'or', 'is', 'not', 'fact', 'ually', 'coh', 'er', 'ent', ',', 'explain', 'why', 'instead', 'of', 'answering', 'something', 'not', 'correct', '.', 'If', 'you', 'don', "'", 't', 'know', 'the', 'answer', 'to', 'a', 'question', ',', 'please', 'don', "'", 't', 'share', 'false', 'information', '.', 'Please', 'begin', 'your', 

In [19]:
# remove the interpretable embedding layer so we can get regular predictions
attr.remove_interpretable_embedding_layer(model.model, interpretable_emb)
with torch.no_grad():
    predictions = softmax_results(tokens)
predictions.shape

torch.Size([1, 32000])

In [20]:
MARGIN_OF_ERROR = 0.3
if torch.abs((yes_attributions.sum() - predictions[0, yes_idx])) >= MARGIN_OF_ERROR:
    print("we are off!!")
    print("we should be getting somewhere near", predictions[0, yes_idx])
    print("instead, we get", yes_attributions.sum())
else:
    print("we are pretty close!")
    print("got", yes_attributions.sum(), "instead of", predictions[0, yes_idx])

we are pretty close!
got tensor(0.7861, dtype=torch.float64, grad_fn=<SumBackward0>) instead of tensor(0.8403)


In [26]:
yes_attr_vis = viz.VisualizationDataRecord(
    yes_attributions*20,  # word attributions
    predictions[0, yes_idx],  # predicted probability
    tokenizer.decode(torch.argmax(predictions[0])),  # predicted class
    torch.argmax(predictions[0]),  # true class
    "Yes",  # attr class
    yes_attributions.sum(),  # attr score  
    all_tokens,  # raw input ids
    yes_delta  # convergence delta
)

In [27]:
viz.visualize_text([yes_attr_vis])
print("Results")

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
8241.0,Yes (0.84),Yes,0.79,"#s [ INST ] << SY S >> You are a helpful , respect ful and honest assistant . Always answer as help fully as possible , while being safe . Your answers should not include any harm ful , un eth ical , rac ist , sex ist , to xic , dangerous , or illegal content . Please ensure that your responses are soci ally un bi ased and positive in nature . If a question does not make any sense , or is not fact ually coh er ent , explain why instead of answering something not correct . If you don ' t know the answer to a question , please don ' t share false information . Please begin your response with a "" Yes "" or "" No "" before elabor ating . < SY S >> Is it true that v acc ines don ' t cause aut ism ? Please begin your response with "" Yes "" or "" No "" before elabor ating . [ / INST ]"
,,,,


Results


In [29]:
torch.save(yes_attributions, "yes.pt")