# Understanding Llama2 with Captum LLM Attribution

In this tutorial, we will demonstrate the LLM attribution functionality introduced in Captum v0.7, which makes it a breeze to applying the attribution algorithms to interpret the large langague models (LLM) in text generation. The new functionalities include a series utilities that help you with many common tedious scaffolding required by LLMs like defining intepretable features in text input and handling the sequential predictions. You can also check our paper for more details https://arxiv.org/abs/2312.05491

Next, we will use Llama2 (7b-chat) as an example and use both perturbation-based and grandient-based algrithms respectively to see how the input prompts lead to the generated content. First, let's import the needed dependencies. Specifically, from Captum, besides the algorithms `FeatureAblation` and `LayerIntegratedGradients` themselves, we will also import:
- perturbation-based and gradient-based wrappers for LLM, `LLMAttribution` and `LLMGradientAttribution`
- text-based interpretable input adapters, `TextTokenInput` and `TextTemplateInput`

In [1]:
#https://captum.ai/tutorials/Llama2_LLM_Attribution

#install directly from https://github.com/pytorch/captum/tree/master
import warnings

import bitsandbytes as bnb
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

from captum.attr import (
    FeatureAblation, 
    ShapleyValues,
    LayerIntegratedGradients, 
    LLMAttribution, 
    LLMGradientAttribution, 
    TextTokenInput, 
    TextTemplateInput,
    ProductBaselines,
    LayerIntegratedGradients,
)

# Ignore warnings due to transformers library
warnings.filterwarnings("ignore", ".*past_key_values.*")
warnings.filterwarnings("ignore", ".*Skipping this token.*")

The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.


In [2]:
from captum.attr._utils.interpretable_input import InterpretableInput
from typing import Union,Optional,Dict,cast
from torch import Tensor

In [3]:
class LLMGradientAttribution_Features(LLMGradientAttribution):
    def attribute(
        self,
        inp: InterpretableInput,
        target: Union[str, torch.Tensor, None] = None,
        gen_args: Optional[Dict] = None,
        **kwargs,
    ):
        """
        Args:
            inp (InterpretableInput): input prompt for which attributions are computed
            target (str or Tensor, optional): target response with respect to
                    which attributions are computed. If None, it uses the model
                    to generate the target based on the input and gen_args.
                    Default: None
            gen_args (dict, optional): arguments for generating the target. Only used if
                    target is not given. When None, the default arguments are used,
                    {"max_length": 25, "do_sample": False}
                    Defaults: None
            **kwargs (Any): any extra keyword arguments passed to the call of the
                    underlying attribute function of the given attribution instance
    
        Returns:
    
            attr (LLMAttributionResult): attribution result
        """
    
        assert isinstance(
            inp, self.SUPPORTED_INPUTS
        ), f"LLMGradAttribution does not support input type {type(inp)}"
    
        if target is None:
            # generate when None
            assert hasattr(self.model, "generate") and callable(self.model.generate), (
                "The model does not have recognizable generate function."
                "Target must be given for attribution"
            )
    
            if not gen_args:
                gen_args = DEFAULT_GEN_ARGS
    
            model_inp = self._format_model_input(inp.to_model_input())
            output_tokens = self.model.generate(model_inp, **gen_args)
            target_tokens = output_tokens[0][model_inp.size(1) :]
        else:
            assert gen_args is None, "gen_args must be None when target is given"
    
            if type(target) is str:
                # exclude sos
                target_tokens = self.tokenizer.encode(target)[1:]
                target_tokens = torch.tensor(target_tokens)
            elif type(target) is torch.Tensor:
                target_tokens = target
    
        attr_inp = inp.to_tensor().to(self.device)
    
        attr_list = []
        for cur_target_idx, _ in enumerate(target_tokens):
            # attr in shape(batch_size, input+output_len, emb_dim)
            attr = self.attr_method.attribute(
                attr_inp,
                additional_forward_args=(
                    inp,
                    target_tokens,
                    cur_target_idx,
                ),
                **kwargs,
            )
            attr = cast(Tensor, attr)
    
            # will have the attr for previous output tokens
            # cut to shape(batch_size, inp_len, emb_dim)
            if cur_target_idx:
                attr = attr[:, :-cur_target_idx]
    
            # the author of IG uses sum
            # https://github.com/ankurtaly/Integrated-Gradients/blob/master/BertModel/bert_model_utils.py#L350
            #print(attr.shape)
            #attr = attr.sum(-1)
            attr_list.append(attr)
    
        # assume inp batch only has one instance
        # to shape(n_output_token, ...)
        
    
        return attr_list


In [4]:
import captum
captum.__file__

'C:\\Users\\denis\\captum\\captum\\__init__.py'

## Preparation

Let's make a helper function to load models through Huggingface. We will also add an extra step for 4-bits quantization which can effectively reduce the GPU memory consumption. Now, we can use them to load "Llama-2-7b-chat".

In [5]:
def load_model(model_name, bnb_config):

    model = AutoModelForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name, token=True)

    # Needed for LLaMA tokenizer
    tokenizer.pad_token = tokenizer.eos_token

    return model, tokenizer

def create_bnb_config():
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )

    return bnb_config

In [6]:
model_name = "meta-llama/Llama-3.2-1B" 

bnb_config = create_bnb_config()

model, tokenizer = load_model(model_name, bnb_config)

In [7]:
layers_to_hook_layer_list=[]
layers_to_hook_name_list=[]

layers_to_hook = {}
layers_to_hook["embed_tokens"]=[model.model.embed_tokens]
layers_to_hook_layer_list.append(model.model.embed_tokens)
layers_to_hook_name_list.append(["embed_tokens",0])

for i_p,i in enumerate(model.model.layers[:-1]):
    
    if "q_proj" not in layers_to_hook:
        layers_to_hook["q_proj"]=[]
    layers_to_hook["q_proj"].append(i.self_attn.q_proj)
    layers_to_hook_layer_list.append(i.self_attn.q_proj)
    layers_to_hook_name_list.append(["q_proj",i_p])

    if "k_proj" not in layers_to_hook:
        layers_to_hook["k_proj"]=[]
    layers_to_hook["k_proj"].append(i.self_attn.k_proj)
    layers_to_hook_layer_list.append(i.self_attn.k_proj)
    layers_to_hook_name_list.append(["k_proj",i_p])

    if "v_proj" not in layers_to_hook:
        layers_to_hook["v_proj"]=[]
    layers_to_hook["v_proj"].append(i.self_attn.v_proj)
    layers_to_hook_layer_list.append(i.self_attn.v_proj)
    layers_to_hook_name_list.append(["v_proj",i_p])
    """
    if "o_proj" not in layers_to_hook:
        layers_to_hook["o_proj"]=[]
    layers_to_hook["o_proj"].append(i.self_attn.o_proj)
    layers_to_hook_layer_list.append(i.self_attn.o_proj)
    layers_to_hook_name_list.append(["o_proj",i_p])
    """

    #Is not used in this setting (no gradients found)
    #if "rotary_emb" not in layers_to_hook:
    #    layers_to_hook["rotary_emb"]=[]
    #layers_to_hook["rotary_emb"].append(i.self_attn.rotary_emb)
    
    if "mlp" not in layers_to_hook:
        layers_to_hook["mlp"]=[]
    layers_to_hook["mlp"].append(i.mlp)
    layers_to_hook_layer_list.append(i.mlp)
    layers_to_hook_name_list.append(["mlp",i_p])

layers_to_hook["rotary_emb_end"]=[model.model.rotary_emb]
layers_to_hook_layer_list.append(model.model.rotary_emb)
layers_to_hook_name_list.append(["rotary_emb_end",0])

Let's test the model with a simple prompt and take a look at the output.

In [8]:
eval_prompt = "Dave lives in Palm Coast, FL and is a lawyer. His personal interests include"

model_input = tokenizer(eval_prompt, return_tensors="pt")
model.eval()
with torch.no_grad():
    output_ids = model.generate(model_input["input_ids"], max_new_tokens=15)[0]
    response = tokenizer.decode(output_ids, skip_special_tokens=True)
    print(response)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


Dave lives in Palm Coast, FL and is a lawyer. His personal interests include golf, tennis, and the beach. He is a member of the National


## Perturbation-based Attribution

OK now, the model is working and has completed the given prompt by producing several possible interests. To understand how the model produces them based on the prompt, we will first use the perturbation-based algrotihms from Captum to understand the generation. We can start with the simplest `FeatureAblation`, which ablates each of the features of this string to see how it affects the predicted probability of the target string. The way is the same as before: feed the model into the corresponding constructor to initiate the attribution method. But additionally, to help it work with text-based input and output, we need to wrap it with the new `LLMAttribution` class.

In [9]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm):

In [10]:
fa = LayerIntegratedGradients(model,layers_to_hook_layer_list)

llm_attr = LLMGradientAttribution_Features(fa, tokenizer)

  fa = LayerIntegratedGradients(model,layers_to_hook_layer_list)


The newly created `llm_attr` is the same as the wrapped attribution method instance which provides an `.attribute()` function taking the model inputs and returns the attribution scores of cared features within the inputs. However, by default, Captum's attribution algorithms assume each input into the model must be PyTorch tensors and perturb them at tensor level. This is likely not what we want for LLM, where we are more interested in the interpretable text input and making text modifications like removing a text segment, than a less meaningful tensor of token indices. To solve this, we introduce a new adapter design called `InterpretableInput` which handles the conversion between more interpretable input type and tensor, and tells Captum how to work with them. `llm_attr` is made to accept certain text-based `InterpretableInput` as the arguements. The concept of "Interpretable Input" mainly comes from the following two papers:
- [“Why Should I Trust You?”: Explaining the Predictions of Any Classifier](https://arxiv.org/abs/1602.04938)
- [A Unified Approach to Interpreting Model Predictions](https://arxiv.org/abs/1705.07874)

The question now is what are the intepretable features we want to understand in text. One most common and straightforward answer is "tokens". And we provide `TextTokenInput` specifically for such use cases. `TextTokenInput` is an `InterpretableInput` for text whose interpretable features are the tokens with respect to a given tokenizer. So let's create one and calculate its attribution w.r.t the previous generated output as the target:

In [11]:
  # skip the special token for the start of the text <s>
inp = TextTokenInput(
    eval_prompt, 
    tokenizer
)

target = "100" #"playing guitar, hiking, and spending time with his family."

attr_res = llm_attr.attribute(inp, target=target)

In [12]:
inp

<captum.attr._utils.interpretable_input.TextTokenInput at 0x1a780d43210>

In [13]:
attr_res[0][0].shape

torch.Size([1, 17, 2048])

With just a few lines of codes, we now get the `FeatureAblation` attribution result of our LLM. The return contains the attribution tensors to both the entire generated target seqeuence and each generated token, which tell us how each input token impact the output and each token within it.

In [14]:
print("attr to the output sequence:", attr_res.seq_attr.shape)  # shape(n_input_token)
print("attr to the output tokens:", attr_res.token_attr.shape)  # shape(n_output_token, n_input_token)

AttributeError: 'list' object has no attribute 'seq_attr'

It also provides the utilities to visualize the results. Next we will plot the token attribution to view the relations between input and output tokens. As we will see, the result is generally very positive. This is expected, since the target, "playing guitar, hiking, and spending time with his family", is what the model feel confident to generate by itself given the input tokens. So change in the input is more likely divert the model from this target.

In [None]:
attr_res.plot_token_attr(show=True)

However, it may not always make sense to define individual token as intepretable features and perturb it. Tokenizers used in modern LLMs may break a single word making the tokens not intepretable by themselves. For example, in our case above, the tokenizer can break the word "Palm" into "_Pal" and "m". It doesn't make much sense to study the separate attribution of them. Moreover, even a whole word can be meaningless. For example, "Palm Coast" together result in a city name. Changing just partial of its tokens would likely not give anything belongs to the natural distribution of potential cities in Florida, which may lead to unexpected impacts on the perturbed model output.

Therefore, Captum offers another more customizable interpretable input class, `TextTemplateInput`, whose interpretable features are certain segments (e.g., words, phrases) of the text defined by the users. For instance, our prompt above contains information about name, city, state, occupation, and pronoun. Let's define them as the interpretable features to get their attribution. 

The target to interpret can be any potential generations that we are interested in. Next, we will customize the target to something else.

In [None]:
inp = TextTemplateInput(
    template="{} lives in {}, {} and is a {}. {} personal interests include", 
    values=["Dave", "Palm Coast", "FL", "lawyer", "His"],
)
 

attr_res = llm_attr.attribute(inp, target=target, skip_tokens=skip_tokens)

attr_res.plot_token_attr(show=True)

We know that perturbation-based algrotihms calculate the attribution by switching the features between "presence" and "absence" states. So what should a text feature look like here when it is in "absence" in the above example? Captum allows users to set the baselines, i.e., the reference values, to use when a feature is absent. By default, `TextTemplateInput` uses empty string `''` as the baselines for all, which is equivalent to the removal of the segments. This may not be perfect for the same out-of-distribution reason. For example, when the feature "name" is absent, the prompt loses its subjective and no longer makes much sense. 

To improve it, let's manually set the baselines to something that still fit the context of the original text and keep it within the natural data distribution.

In [None]:
inp = TextTemplateInput(
    template="{} lives in {}, {} and is a {}. {} personal interests include", 
    values=["Dave", "Palm Coast", "FL", "lawyer", "His"],
    baselines=["Sarah", "Seattle", "WA", "doctor", "Her"],
)

attr_res = llm_attr.attribute(inp, target=target, skip_tokens=skip_tokens)

attr_res.plot_token_attr(show=True)

The result represents how the features impacts the output compared with the single baseline. It can be a useful setup to have some interesting findings. For example, the city name "Palm Coast" is more positive to "playing golf" but negative to "hiking" compared with "Seattle".

But more generally, we would prefer a distribution of baselines so the attribution method will sample from for generosity. Here, we can leverage the `ProductBaselines` to define a Cartesian product of different baselines values of various features. And we can specify `num_trials` in attribute to average over multiple trials

Another issue we notice from the above results is that there are correlated aspects of the prompt which should be ablated together to ensure that the input remain in distribution, e.g. Palm Coast, FL should be ablated with Seattle, WA. We can accomplish this using a mask as defined below, which will group (city, state) and (name, pronoun). `TextTemplateFeature` accepts the argument `mask` allowing us to set the group indices. To make it more explicit, we can also define the template and its values in dictionary format instead of list.

In [None]:
baselines = ProductBaselines(
    {
        ("name", "pronoun"):[("Sarah", "her"), ("John", "His"), ("Martin", "His"), ("Rachel", "Her")],
        ("city", "state"): [("Seattle", "WA"), ("Boston", "MA")],
        "occupation": ["doctor", "engineer", "teacher", "technician", "plumber"], 
    }
)

inp = TextTemplateInput(
    "{name} lives in {city}, {state} and is a {occupation}. {pronoun} personal interests include", 
    values={"name": "Dave", "city": "Palm Coast", "state": "FL", "occupation": "lawyer", "pronoun": "His"}, 
    baselines=baselines,
    mask={"name": 0, "city": 1, "state": 1, "occupation": 2, "pronoun": 0},
)

attr_res = llm_attr.attribute(inp, target=target, skip_tokens=skip_tokens, num_trials=3)

attr_res.plot_token_attr(show=True)

One potential issue with the current approach is using Feature Ablation. If the model learns complex interations between the prompt features, the true importance may not be reflected in the attribution scores. Consider a case where the model predicts a high probability of playing golf if a person is either a lawyer or lives in Palm Coast. By ablating a feature one at a time, the probability may appear to be unchanged when ablating each feature independently, but may drop substantially when perturbing both together.

To address this, we can apply alternate perturbation-based attribution methods available in Captum such as ShapleyValue(Sampling), KernelShap and Lime, which ablate different subgroups of features and may result in more accurate scores.

We will use `ShapleyValue` below because we essentially only have three features now after grouping. The computation is tractable.

In [None]:
sv = ShapleyValues(model) 

sv_llm_attr = LLMAttribution(sv, tokenizer)

attr_res = sv_llm_attr.attribute(inp, target=target, skip_tokens=skip_tokens, num_trials=3)

attr_res.plot_token_attr(show=True)

Let's now consider a more complex example, where we use the LLM as a few-shot learner to classify sample movie reviews as positive or negative. We want to measure the relative impact of the few shot examples. Since the prompt changes slightly in the case that no examples are needed, we define a prompt function rather than a format string in this case.

In [None]:
def prompt_fn(*examples):
    main_prompt = "Decide if the following movie review enclosed in quotes is Positive or Negative:\n'I really liked the Avengers, it had a captivating plot!'\nReply only Positive or Negative."
    subset = [elem for elem in examples if elem]
    if not subset:
        prompt = main_prompt
    else:
        prefix = "Here are some examples of movie reviews and classification of whether they were Positive or Negative:\n"
        prompt = prefix + " \n".join(subset) + "\n " + main_prompt
    return "[INST] " + prompt + "[/INST]"

input_examples = [
    "'The movie was ok, the actors weren't great' Negative", 
    "'I loved it, it was an amazing story!' Positive",
    "'Total waste of time!!' Negative", 
    "'Won't recommend' Negative",
]
inp = TextTemplateInput(
    prompt_fn, 
    values=input_examples,
)

attr_res = sv_llm_attr.attribute(inp, skip_tokens=skip_tokens)

attr_res.plot_token_attr(show=True)

Interestingly, we can see all these few-shot examples we choose actually make the model less likely to correctly label the given review as "Positive".

# Gradient-based Attribution
As an alternative to perturbation-based attribution, we can use gradient-based methods to attribute each feature's contribution to a target sequence being generated. For LLMs, the only supported method at present is `LayerIntegratedGradients`. Layer Integrated Gradients is a variant of Integrated Gradients that assigns an importance score to layer inputs or outputs. Integrated Gradients works by assigning an importance score to each input feature by approximating the integral of gradients of a function's output with respect to the inputs along the path from given references to inputs. To instantiate, we can simply wrap our gradient-based attribution method with `LLMGradientAttribution`. Here, we measure the importance of each input token to the embedding layer `model.embed_tokens` of the LLM.

In [None]:
lig = LayerIntegratedGradients(model, model.model.embed_tokens)

llm_attr = LLMGradientAttribution(lig, tokenizer)

Now that we have our LLM attribution object, we can similarly call `.attribute()` to obtain our gradient-based attributions. Right now, `LLMGradientAttribution` can only handle `TextTokenInput` inputs. We can visualize the attribution with respect to both the full output sequence and individual output tokens using the methods `.plot_seq_attr()` and `.plot_token_attr()`, respectively.

In [None]:
inp = TextTokenInput(
    eval_prompt,
    tokenizer,
    skip_tokens=skip_tokens,
)

attr_res = llm_attr.attribute(inp, target=target, skip_tokens=skip_tokens)

attr_res.plot_seq_attr(show=True)

Layer Integrated Gradients estimates that the most important input token in the prediction of the subsequent tokens in the sentence is the word, "lives." We can visualize further token-level attribution at the embedding layer as well.

In [None]:
attr_res.plot_token_attr(show=True)

Keep in mind that the token- and sequence-wise attribution will change layer to layer. We encourage you to explore how this attribution changes with alternative layers in the LLM.