In [4]:
import torch
import pandas as pd
from typing import List, Dict
from nnsight import LanguageModel

In [2]:
MODEL_NAME = "meta-llama/Meta-Llama-3-8B"
DTYPE = torch.float16
DEVICE_MAP = "auto"

In [136]:
from src.models import ModelandTokenizer


mt = ModelandTokenizer(
    model_key=MODEL_NAME,
    device_map=DEVICE_MAP
)
mt

meta-llama/Meta-Llama-3-8B not found in models/
If not found in cache, model will be downloaded from HuggingFace to cache directory


config.json:   0%|          | 0.00/654 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/177 [00:00<?, ?B/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)
    (rotary_

In [153]:
%load_ext autoreload
%autoreload 2

In [160]:
from src.analysis import logit_attribution
from src.plotting import plot_layer_attributions

prompt = "An electric guitar and an acoustic guitar are both types of"
logit_attribution_df, logit_total = logit_attribution(mt, prompt, "guitars")
logit_attribution_df.nlargest(10, 'contribution')
plot_layer_attributions(logit_attribution_df)

In [141]:
#model = LanguageModel(
#    MODEL_NAME,
#    device_map=DEVICE_MAP,
#    torch_dtype=DTYPE,
#    dispatch=True # stream layers if multiple GPUs, else ignore
#)
#tok = model.tokenizer

def token_id(mt, word: str) -> int:
    """Return ID of *first* token of 'word' (good enough for single-token words)."""
    return mt.tokenizer.encode(word, add_special_tokens=False)[0]

token_id(mt, "king")

10789

In [8]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)
    (rotary_

In [148]:
from typing import Tuple
from src.models import ModelandTokenizer

@torch.no_grad()
def logit_attribution(
    mt: ModelandTokenizer,
    prompt: str,
    target: str,
    max_layers: int | None = None,
) -> Tuple[pd.DataFrame, float]:
    """
    Run prompt once, capture every attn-head & mlp output.
    Compute each's direct logit contribution to the target token.

    Returns:
        A tuple containing:
        - A pandas dataframe with columns:
            layer, kind ('head'|'mlp'), index (head-idx or None), contribution (float), pct
        - The total logit of the target token
    """
    tgt_id = token_id(mt, target)

    config = mt.config
    num_heads = config.num_attention_heads
    head_dim = config.hidden_size // num_heads
    
    saved_proxies = []

    with mt.trace(prompt) as t:
        # save embedding contribution
        embed_proxy = mt.model.embed_tokens.output.save()
        saved_proxies.append({'layer': -1, 'kind': 'embed', 'proxy': embed_proxy})
        
        # save every attention head & mlp output's contribution
        for layer, block in enumerate(mt.model.layers):
            if max_layers and layer >= max_layers: break

            # to get per-head attribution, we save the input to o_proj 
            o_proj_input_proxy = block.self_attn.o_proj.input.save()
            saved_proxies.append({
                'layer': layer,
                'kind': 'head',
                'proxy': o_proj_input_proxy,
                'block': block
            })

            # save mlp output
            mlp_proxy = block.mlp.output.save()
            saved_proxies.append({
                'layer': layer,
                'kind': 'mlp',
                'proxy': mlp_proxy
            })

        final_logits_proxy = mt.lm_head.output.save()

    # The direct logit attribution of a component is its output vector
    ## projected by the unembedding matrix
    unembed_matrix = mt.lm_head.weight

    final_logits = final_logits_proxy[0, -1, :]
    total_logit = final_logits[tgt_id].item()

    rows = []
    for p in saved_proxies:
        kind = p['kind']
        
        # We're interested in the last token's logit, so we take the state at
        ## sequence position -1. 
        # The proxy value is a tuple, so we access the first element.
        hidden_state = p['proxy'].value[0, -1, :]

        if kind == 'head':
            # This is the concatenated output of all heads.
            # We need to split it and apply the corresponding part of the
            ## o_proj weight matrix.
            # h_per_head has shape [num_heads, head_dim]
            h_per_head = hidden_state.view(num_heads, head_dim)

            # W_O_heads is a list of (hidden_size, head_dim) tensors
            W_O = p['block'].self_attn.o_proj.weight
            W_O_heads = W_O.chunk(num_heads, dim=1)

            for i in range(num_heads):
                # head_contribution is shape [hidden_size]
                head_contribution = h_per_head[i] @ W_O_heads[i].T
                head_contribution = head_contribution.to(unembed_matrix.device)
                logit_contribution = head_contribution @ unembed_matrix.T
                contribution = logit_contribution[tgt_id].item()
                pct_contribution = (contribution / total_logit) * 100
                rows.append({
                    'layer': p['layer'],
                    'kind': 'head',
                    'index': i,
                    'contribution': logit_contribution[tgt_id].item(),
                    'pct': f"{pct_contribution:.6f}%"
                })

        else: #mlp or embed
            hidden_state = hidden_state.to(unembed_matrix.device)
            logit_contribution = hidden_state @ unembed_matrix.T
            contribution = logit_contribution[tgt_id].item()
            pct_contribution = (contribution / total_logit) * 100
            rows.append({
                'layer': p['layer'],
                'kind': kind,
                'index': None,
                'contribution': contribution,
                'pct': f"{pct_contribution:.6f}%"
            })

    df = pd.DataFrame(rows)
    return df, total_logit

In [149]:
prompt = "An electric guitar and an acoustic guitar are both types of"
target = "guitars"
attributions = logit_attribution(mt, prompt, target)

In [150]:
attributions

(      layer   kind  index  contribution         pct
 0        -1  embed    NaN     -0.004072  -0.043161%
 1         0   head    0.0      0.000595   0.006303%
 2         0   head    1.0      0.002392   0.025354%
 3         0   head    2.0      0.000451   0.004778%
 4         0   head    3.0      0.000890   0.009439%
 ...     ...    ...    ...           ...         ...
 1052     31   head   28.0      0.020070   0.212751%
 1053     31   head   29.0      0.003477   0.036858%
 1054     31   head   30.0     -0.001784  -0.018909%
 1055     31   head   31.0     -0.006523  -0.069150%
 1056     31    mlp    NaN     -0.038374  -0.406777%
 
 [1057 rows x 5 columns],
 9.433680534362793)

In [151]:
import plotly.express as px

# Aggregate contributions by layer
# We use .copy() to avoid SettingWithCopyWarning
layer_contribs = attributions[0].groupby('layer')['contribution'].sum().reset_index()

px.bar(
    layer_contribs,
    x='layer',
    y='contribution',
    title='Per-Layer Logit Contribution for "guitars"'
)

In [152]:
# Show top 20 contributing heads/MLPs
attributions[0].nlargest(20, 'contribution')

Unnamed: 0,layer,kind,index,contribution,pct
891,26,mlp,,0.842214,8.927731%
1038,31,head,14.0,0.446686,4.735017%
957,28,mlp,,0.393107,4.167058%
1006,30,head,15.0,0.276085,2.926585%
726,21,mlp,,0.245309,2.600355%
1023,30,mlp,,0.225743,2.392951%
924,27,mlp,,0.147254,1.560938%
1025,31,head,1.0,0.139997,1.484016%
532,16,head,3.0,0.121504,1.287982%
693,20,mlp,,0.115645,1.225869%
