# GPT-2 interpretability

<a href="https://colab.research.google.com/github/EffiSciencesResearch/ML4G/blob/main/days/w1d6/logit_lens.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Logit lens

Read : https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens

Then try to reimplement it in a minimal way.


Resources:
- Read about hooks here https://pytorch.org/tutorials/beginner/former_torchies/nnft_tutorial.html#forward-and-backward-function-hooks
- pip install transformer-utils and use the function _plot_logit_lens https://github.dev/nostalgebraist/transformer-utils/tree/main/src/transformer_utils/logit_lens


In [None]:
# !pip install transformers transformer_utils

In [None]:
import torch
import transformers

# Minimal example
model = transformers.AutoModelForCausalLM.from_pretrained('gpt2')
tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")
model.eval()

def text_to_input_ids(text):
    toks = tokenizer.encode(text)
    return torch.as_tensor(toks).view(1, -1)


input_ids = text_to_input_ids("Happy birthday to You, happy")

In [None]:
n_layers_gpt = len(model.base_model.h)
outputs = [None] * n_layers_gpt
handles = [None] * n_layers_gpt

input_ids = text_to_input_ids("Happy birthday to You, happy")

def make_memorize_output_layer(layer):
    def memorize_output_layer(self, input, output):
        global outputs
        outputs[layer] = output[0].detach()
    return memorize_output_layer

for i, gpt_block in enumerate(model.base_model.h):
    handles[i] = gpt_block.register_forward_hook(make_memorize_output_layer(i))
model(input_ids);

In [None]:
# Decode
ln_f = model.base_model.ln_f

layer_preds = []
layer_logits = []
layer_probs = []

for layer_i, output in enumerate(outputs):
    normalized_output = ln_f(output)

    word_embeddings = model.base_model.wte.weight.detach()
    word_distribution = torch.einsum("bte,we->btw", [normalized_output, word_embeddings])
    best_word = torch.argmax(word_distribution, dim=2)
    output_text = tokenizer.decode(best_word[0])
    print(output_text)
    
    layer_preds.append(best_word) 
    layer_logits.append(word_distribution) 
    layer_probs.append(torch.softmax(word_distribution, dim=2)) 


In [None]:
from typing import List

def to_tensor(list_tensor : List[torch.Tensor]) -> torch.Tensor:
    list_tensor = [t.detach() for t in list_tensor]
    return torch.concat(list_tensor, dim=0)


In [None]:
from transformer_utils.logit_lens.plotting import _plot_logit_lens
_plot_logit_lens(
    to_tensor(layer_logits),
    to_tensor(layer_preds),
    to_tensor(layer_probs),
    tokenizer,
    input_ids=input_ids,
    start_ix=0,
    layer_names=[i for i in range(n_layers_gpt)],
    probs=False,
    ranks=False,
    kl=False,
    top_down=False,
)