# Captum

## Setup

In [1]:
import captum
import captum.attr as attr
from captum.attr import visualization as viz

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch

config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)

# model_name = "tiiuae/falcon-7b-instruct"
model_name = "meta-llama/Llama-2-7b-chat-hf"

tokenizer = AutoTokenizer.from_pretrained(model_name)


model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=config, device_map="auto")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
def make_llama_prompt(instruction: str, sys_message: str = None):
    if sys_message is None:
        sys_message = (
"""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. Please begin your response with a "Yes" or "No" before elaborating.
""")
        """ <<SYS>>
{sys_message}
<</SYS>"""
    # return f"""[INST] <<SYS>>\n{sys_message}\n<</SYS>>\n{instruction} [/INST]\n"""
    return f"""[INST] {instruction} [/INST]\n"""
def make_falcon_prompt(instruction:str) -> str:
    return f"""User: {instruction}\nAssistant:"""

def make_prompt(instruction:str):
    if "llama" not in model_name:
        return make_falcon_prompt(instruction)
    return make_llama_prompt(instruction)


# prompt = make_prompt("There's a llama in my garden 😱 What should I do?")
# prompt = make_prompt("Is the statement \"Vaccines cause autism\" true or false?")
prompt = """[INST]Is the statement "The covid-19 vaccine is safe" true or false? Please answer with just "The statement is true" or "The statement is false"[/INST]\n\nThe statement is"""

In [4]:
print(
    tokenizer.batch_decode(
        model.generate(tokenizer(prompt, return_tensors="pt").input_ids.cuda(), max_new_tokens=200)
    )[0]
)

<s> [INST]Is the statement "The covid-19 vaccine is safe" true or false? Please answer with just "The statement is true" or "The statement is false"[/INST]

The statement is true. The COVID-19 vaccines have undergone rigorous testing and have been proven to be safe and effective in preventing severe illness and death from COVID-19. The vaccines have been approved by regulatory agencies around the world, such as the Food and Drug Administration (FDA) in the United States and the European Medicines Agency (EMA) in the European Union. Additionally, numerous scientific studies have shown that the vaccines are safe and well-tolerated, with the most common side effects being mild and temporary, such as soreness at the injection site.</s>


In [5]:
false_token = tokenizer("false").input_ids[1]
true_token = tokenizer("true").input_ids[1]
false_token, true_token

(2089, 1565)

In [6]:
# get tokens for later
tokens = tokenizer(prompt, return_tensors="pt").input_ids
print(tokens.shape)

torch.Size([1, 50])


In [7]:
tokens

tensor([[    1,   518, 25580, 29962,  3624,   278,  3229,   376,  1576, 18838,
           333, 29899, 29896, 29929,   325,  5753,   457,   338,  9109, 29908,
          1565,   470,  2089, 29973,  3529,  1234,   411,   925,   376,  1576,
          3229,   338,  1565, 29908,   470,   376,  1576,  3229,   338,  2089,
         29908, 29961, 29914, 25580, 29962,    13,    13,  1576,  3229,   338]])

## Integrated Gradients

In [8]:
def softmax_results(inputs: torch.Tensor):
    result = model(inputs.cuda()).logits
    return torch.nn.functional.softmax(result[:, -1], dim=-1).cpu()
    

In [9]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear4bit(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): Lla

In [10]:
# replace the normal pytorch embeddings (which only take ints) to interpretable embeddings
# (which are compatible with the float inputs that integratedgradients gives)
interpretable_emb = attr.configure_interpretable_embedding_layer(model.model, 'embed_tokens')
model



LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): InterpretableEmbeddingBase(
      (embedding): Embedding(32000, 4096)
    )
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear4bit(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_lay

In [11]:
input_embs = interpretable_emb.indices_to_embeddings(tokens).cpu()
print(input_embs.device)
input_embs.shape

cpu


torch.Size([1, 50, 4096])

In [12]:
baselines = torch.zeros_like(input_embs).cpu()
print(baselines.device)
baselines.shape

cpu


torch.Size([1, 50, 4096])

In [13]:
ig = attr.IntegratedGradients(softmax_results)

In [14]:
false_attributions, false_delta = ig.attribute(input_embs, baselines=baselines, target=false_token, n_steps=32, internal_batch_size=2, return_convergence_delta=True)

## Visualize Attributions

In [15]:
false_attributions.shape

torch.Size([1, 50, 4096])

In [16]:
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    print((attributions / torch.norm(attributions)).sum())
    return attributions

In [17]:
false_attributions = summarize_attributions(false_attributions)

false_attributions.shape

tensor(-0.9397, dtype=torch.float64, grad_fn=<SumBackward0>)


torch.Size([50])

In [18]:
all_tokens = tokens.squeeze(0)
print(all_tokens.shape)
all_tokens = list(map(tokenizer.decode, all_tokens))
print(all_tokens)

torch.Size([50])
['<s>', '[', 'INST', ']', 'Is', 'the', 'statement', '"', 'The', 'cov', 'id', '-', '1', '9', 'v', 'acc', 'ine', 'is', 'safe', '"', 'true', 'or', 'false', '?', 'Please', 'answer', 'with', 'just', '"', 'The', 'statement', 'is', 'true', '"', 'or', '"', 'The', 'statement', 'is', 'false', '"', '[', '/', 'INST', ']', '\n', '\n', 'The', 'statement', 'is']


In [19]:
# remove the interpretable embedding layer so we can get regular predictions
attr.remove_interpretable_embedding_layer(model.model, interpretable_emb)
with torch.no_grad():
    predictions = softmax_results(tokens)
predictions.shape

torch.Size([1, 32000])

In [20]:
MARGIN_OF_ERROR = 0.1 * predictions[0, false_token]  # off by no more than 10 percent
if torch.abs((false_attributions.sum() - predictions[0, false_token])) >= MARGIN_OF_ERROR:
    print("we are off!!")
    print("we should be getting somewhere near", predictions[0, false_token])
    print("instead, we get", false_attributions.sum())
else:
    print("we are pretty close!")
    print("got", false_attributions.sum(), "instead of", predictions[0, false_token])

we are off!!
we should be getting somewhere near tensor(0.0024)
instead, we get tensor(-1.6513, dtype=torch.float64, grad_fn=<SumBackward0>)


In [21]:
attr_vis = viz.VisualizationDataRecord(
    false_attributions*20,  # word attributions
    predictions[0].max(),  # predicted probability
    tokenizer.decode(torch.argmax(predictions[0])),  # predicted class
    torch.argmax(predictions[0]),  # true class
    "false",  # attr class
    false_attributions.sum(),  # attr score  
    all_tokens,  # raw input ids
    false_delta  # convergence delta
)

In [22]:
viz.visualize_text([attr_vis])
print("Results")

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1565.0,true (1.00),False,-1.65,"#s [ INST ] Is the statement "" The cov id - 1 9 v acc ine is safe "" true or false ? Please answer with just "" The statement is true "" or "" The statement is false "" [ / INST ] The statement is"
,,,,


Results
