In [2]:
import torch
from transformers import AutoModelForCausalLM,AutoTokenizer
from torch import nn
import torch.nn.functional as F
from IPython.display import clear_output
from datasets import load_dataset
from tqdm import tqdm

In [3]:
model_id = "mistralai/Mistral-7B-Instruct-v0.2"
model = AutoModelForCausalLM.from_pretrained(model_id,torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(model_id)
device = torch.device("cuda")
model.to(device)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )
    (norm): MistralRMSNorm()
  

**From Claude**: using named_parameters which is a typical pytorch Module method.

In [4]:
activations = {}
for name, param in model.named_parameters():
    # Filter for specific layers you're interested in
    if "layer" in name and "self_attn" in name:
        # This gives you access to the parameter values
        activations[name] = param.data
        # print(f"Layer: {name}, Shape: {activation.shape}")

In [5]:
len(activations.keys()),list(activations.keys())[0]

(128, 'model.layers.0.self_attn.q_proj.weight')

Okay, I can easily find all the parametric weights of the model. But perhaps what I want are the hidden states. Think about in a regular NN model: you have weights before evaluating an input, but the hidden states are dependent on the input and all the weights.

So I need to input a text into the LLM, and it'll pass through and get affected by the LLM's weights. And there should be a residual stream hidden state $h_i^{(l)}$, and attention hidden states $a_i^{(l)}$ for every token $i$ and layer $l$.

Let's try to extract hidden states for a simple model like this:

In [6]:
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

We need to modify the model slightly to also return hidden states. In RNNs, we usually return these hidden states too during the forward pass, rather than just returning the final output layer.

Note that in LLMs, the output of the forward is typically just the logits or some dictionary containing multiple things. We need to see if that already has the hidden states, else we need to modify it to return some specific hidden states.

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        hidden1 = F.relu(self.conv1(x))
        hidden2 = F.relu(self.conv2(hidden1))
        return hidden2, (hidden1, hidden2)

In [8]:
# Create and use the model
model = Model()
input_tensor = torch.randn(1, 1, 28, 28)  # Example input tensor
output, (hidden1, hidden2) = model(input_tensor)

In [9]:
input_tensor.shape,hidden1.shape,hidden2.shape # (c=1,28,28) -> (c=20,24,24) -> (c=20,20,20)

(torch.Size([1, 1, 28, 28]),
 torch.Size([1, 20, 24, 24]),
 torch.Size([1, 20, 20, 20]))

In more complex models where we don't want to modify the internal structure, we use hooks.

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

In [21]:
model = Model()
activations = {}
all_trials = []

The hook will be called every time after forward() has computed an output. It should have the following signature:

```
hook(module, input, output) -> None or modified output

In [22]:
# Define a hook function
def get_hook(name):
    def hook(model, input, output):
        activations[name] = output.detach()
    return hook

In [23]:
model.conv1.register_forward_hook(get_hook('conv1'))
model.conv2.register_forward_hook(get_hook('conv2'))
# Registers a hook onto conv1 layer, so whenever the conv1 layer forward is executed, the hook executes.
# Within the hook, outputs of the conv1 layer are stored into the activations dict.
# So if I run a forward pass through the model regularly, the output will be as expected...
# ... but now, I'll get the internal hidden states too.

input_tensor = torch.randn(1, 1, 28, 28)
output = model(input_tensor)
all_trials.append(activations.copy())
output.shape

torch.Size([1, 20, 20, 20])

In [2]:
all_trials[0]['conv1']

NameError: name 'all_trials' is not defined

In [25]:
input_tensor = torch.randn(1, 1, 28, 28)
output = model(input_tensor)
all_trials.append(activations.copy())
output.shape

torch.Size([1, 20, 20, 20])

In [26]:
len(all_trials)

2

In [29]:
all_trials[1]['conv1']

tensor([[[[ 1.1241e-01,  2.0024e-02,  3.1108e-01,  ..., -8.6740e-01,
            6.9576e-01,  7.4052e-01],
          [ 3.3542e-01, -1.8380e-01,  4.4050e-01,  ..., -8.3235e-01,
            7.0154e-01,  4.6462e-01],
          [-2.7553e-01,  2.2399e-01,  4.9891e-01,  ..., -2.7639e-01,
            3.8434e-01,  1.5652e-01],
          ...,
          [-5.5446e-03, -6.9823e-01, -1.0601e-01,  ..., -8.9663e-02,
           -7.7638e-01,  2.0911e-02],
          [ 6.3133e-01, -2.2557e-01, -3.1841e-01,  ...,  1.3433e+00,
           -1.5430e-01,  1.0233e-01],
          [ 1.7771e+00,  1.1866e+00, -4.9348e-01,  ...,  7.0751e-01,
           -6.0858e-03, -8.6562e-02]],

         [[ 1.6632e-01, -3.9560e-02, -4.2178e-01,  ..., -5.7374e-01,
            2.9578e-01,  2.4647e-01],
          [-2.6613e-01,  3.2964e-01,  1.9205e-01,  ...,  8.8782e-01,
            3.4784e-01,  1.7670e-01],
          [ 1.8674e-01,  6.2596e-01,  3.0700e-02,  ...,  3.8379e-01,
           -5.2186e-01, -1.3627e-01],
          ...,
     

It's as expected: 
```
(c=1,28,28) [input[ -> (c=20,24,24) -> (c=20,20,20) [output]