# Attribution Case Study: Gemma - 2B

In this notebook we show how feature attribution (as calculated by our open source library!) can be used to understand model behaviour and help hypothesise reasons behind undesirable outputs. We showcase a case study using the Gemma-2B model and our gradient-based perturbation method described here. 

Since Gemma-2B is an open-source model, we will be downloading a local copy of it before calculating attribution values. To access Gemma-2B through huggingface, a huggingface account with approval is necessary. Please click [here](https://huggingface.co/google/gemma-2b) to go through the approval process. After approval, the huggingface account needs to be logged into using the CLI command `huggingface-cli login`.

We will use our `LocalLLMAttributor` function that employs a gradient-based perturbation technique to calculate attribution values.

## Importing libraries

In [2]:
import os
import sys
from pathlib import Path

from transformers import AutoModelForCausalLM, AutoTokenizer

sys.path.append(str(Path(os.getcwd()).parent))
from attribution.local_attribution import LocalLLMAttributor

## Loading Gemma-2 model, tokenizer and embeddings

Make sure you've logged in using the `huggingface-cli login` command and have access to the Gemma-2B model hosted [here](https://huggingface.co/google/gemma-2b)

In [3]:
model_id = "google/gemma-2b-it"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto").cuda()
tokenizer = AutoTokenizer.from_pretrained(model_id)
embeddings = model.get_input_embeddings().weight.detach()

Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



## Scenario 1: Apples to Apples

We ask Gemma-2B a simple question: 'Sam has 3 apples and Tom has 5 oranges. How many apples does Sam have?'. Gemma-2B correctly answers this question

In [4]:
input_string = "Sam has 3 apples and Tom has 5 oranges. How many apples does Sam have?"
input_tokens = tokenizer(input_string, return_tensors="pt").input_ids.to(model.device)
output_tokens = model.generate(input_tokens, max_new_tokens=7).squeeze(0)
print(f"Output: {tokenizer.decode(output_tokens)}")

Output: <bos>Sam has 3 apples and Tom has 5 oranges. How many apples does Sam have?

The answer is 3.


## Scenario 2: Apples to Oranges?

We now slightly change the question to be: 'Sam has 3 apples and Tom has 5 oranges. How many oranges does Sam have?'. Gemma-2B gets this question wrong this time around, choosing to answer 5 oranges even though these oranges belong to Tom! 

In [5]:
input_string = "Sam has 3 apples and Tom has 5 oranges. How many oranges does Sam have?"
input_tokens = tokenizer(input_string, return_tensors="pt").input_ids.to(model.device)
output_tokens = model.generate(input_tokens, max_new_tokens=7).squeeze(0)
print(f"Output: {tokenizer.decode(output_tokens)}")

Output: <bos>Sam has 3 apples and Tom has 5 oranges. How many oranges does Sam have?

The answer is 5.


## Observing Attribution values

Could we have predicted this behaviour if we observed the attribution values for Scenario 1? Let's take a look!

In [22]:
attributor = LocalLLMAttributor(model=model, embeddings=embeddings, tokenizer=tokenizer)
attr_scores, token_ids = attributor.compute_attributions(
    input_string="Sam has 3 apples and John has 5 oranges. How many apples does Sam have?",
    generation_length=7,
)

attributor.print_attributions(
    word_list=tokenizer.convert_ids_to_tokens(token_ids),
    attr_scores=attr_scores,
    token_ids=token_ids,
    generation_length=7,
)

Looking at the attribution table, we see some sensible values associated with the output token `3` which is the correct answer:

|Input Token        | Output Token | Attribution |
|-------------------|--------------|-------------|
|`3`                | `3`          |  59.1844    |
|`apples` (First)   | `3`          |  48.0342    |
|`apples` (Second)  | `3`          |  51.8812    |
|`5`                | `3`          |  18.8697    |

There are high attributions for the input tokens that are strongly associated with the correct answer and low attribution for the input token `5` which would be the incorrect answer. There is something strange about some attribution values in this table for the output token `3`, notably the one associated with the`oranges` and both `Sam` input tokens.

|Input Token        | Output Token | Attribution |
|-------------------|--------------|-------------|
|`oranges`          | `3`          |  41.9764    |
|`Sam` (First)      | `3`          |  24.2310    |
|`Sam` (Second)     | `3`          |  26.5090    |

This value for the `oranges` token is surprisingly close to the input tokens associated with the correct answer input tokens and also quite a bit larger than the values associated with the incorrect answer input tokens. The values for both the `Sam` tokens however are surprisingly low and ideally should be higher since the model answers the question in Scenario 1 correctly.

Using the above analysis, we can come up with two hypotheses about model behaviour that could be potential causes for the failure in Scenario 2:

__Hypothesis 1__: Gemma-2B pays undue attention on the `oranges` tokens when asked about `apples` in Scenario 1 pointing to some confusion around which fruit is being asked about in the question.

__Hypothesis 2__: Gemma-2B doesn't associate `apples` strongly with `Sam` in Scenario 1 which causes the model to output the incorrect answer in Scenario 2

## Conclusion

In this notebook, we show a case study where attribution helps us generate hypotheses to understand model failures. There is much work to be done in this space and we encourage readers to further inspect more model failures and validate the usefulness of attribution. We welcome proposals for features and communication regarding experiments with our library to be sent to `hello@leap-labs.com`