# Exploring the model; sanity checks

In [1]:
from transformer_lens import HookedTransformer

model = HookedTransformer.from_pretrained('gelu-1l')
model

Loaded pretrained model gelu-1l into HookedTransformer


HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0): TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resi

In [2]:
model.cfg

HookedTransformerConfig:
{'act_fn': 'gelu',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 64,
 'd_mlp': 2048,
 'd_model': 512,
 'd_vocab': 48262,
 'd_vocab_out': 48262,
 'default_prepend_bos': True,
 'device': device(type='mps'),
 'dtype': torch.float32,
 'eps': 1e-05,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': 0.035355339059327376,
 'model_name': 'GELU_1L512W_C4_Code',
 'n_ctx': 1024,
 'n_devices': 1,
 'n_heads': 8,
 'n_layers': 1,
 'n_params': 3145728,
 'normalization_type': 'LNPre',
 'original_architecture': 'neel',
 'parallel_attn_mlp': False,
 'positional_embedding_type': 'standard',
 'post_embedding_ln': False,
 'rotary_dim': None,
 'scale_attn_by_inverse_layer_idx': False,
 'seed': None,
 'tokenizer_name': 'NeelNanda/gpt-neox-tokenizer-digits',
 'tokenizer_prepends_bo

In [3]:
_, cache = model.run_with_cache("never gonna give you up")
cache

ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'ln_final.hook_scale', 'ln_final.hook_normalized']

In [4]:
cache['resid_pre', 0].shape

torch.Size([1, 6, 512])

In [5]:
cache['attn_out', 0].shape

torch.Size([1, 6, 512])

In [6]:
cache['resid_mid', 0].shape

torch.Size([1, 6, 512])

In [7]:
cache['mlp_out', 0].shape

torch.Size([1, 6, 512])

In [8]:
cache['resid_post', 0].shape

torch.Size([1, 6, 512])

# Load in the trained SAEs

In [10]:
import sys
sys.path.append('..')
from sae_training.sparse_autoencoder import SparseAutoencoder

saes = {}
for hook in ['resid_pre', 'attn_out', 'resid_mid', 'mlp_out', 'resid_post']:
    saes[hook] = SparseAutoencoder.load_from_pretrained(f'./checkpoints32/final_sparse_autoencoder_gelu-1l_blocks.0.hook_{hook}_16384.pt')
    saes[hook].eval()
saes

ImportError: cannot import name 'compute_geometric_median' from 'sae_training.geom_median.src.geom_median.torch' (/Users/tz20913/Library/Mobile Documents/com~apple~CloudDocs/Desktop/Research/MATS_Neel/SAE_layer_computation/research/../sae_training/geom_median/src/geom_median/torch/__init__.py)

In [None]:
_, feature_acts, _, _, _, _ = saes['attn_out'](cache['attn_out', 0])
(feature_acts != 0.0).nonzero()

tensor([[    0,     0,  6883],
        [    0,     1,  5973],
        [    0,     1,  6883],
        [    0,     1,  7931],
        [    0,     1,  8569],
        [    0,     1, 10315],
        [    0,     1, 11417],
        [    0,     1, 12781],
        [    0,     2,  5027],
        [    0,     2,  6825],
        [    0,     2,  6883],
        [    0,     2,  7454],
        [    0,     2,  7694],
        [    0,     2,  7931],
        [    0,     2,  8219],
        [    0,     2,  9596],
        [    0,     2, 10017],
        [    0,     2, 10315],
        [    0,     2, 11074],
        [    0,     2, 12781],
        [    0,     2, 12792],
        [    0,     2, 13294],
        [    0,     2, 14466],
        [    0,     3,  3858],
        [    0,     3,  5027],
        [    0,     3,  5847],
        [    0,     3,  6883],
        [    0,     3,  7694],
        [    0,     3,  7931],
        [    0,     3,  9596],
        [    0,     3, 10315],
        [    0,     3, 11074],
        

In [None]:
import torch
import plotly.express as px

def logits_vis(logits):
    """
    Visualizes the top 10 predictions from the logits of a model.
    
    Parameters:
    - logits: A PyTorch tensor of logits from the model's output.
    """
    # Ensure logits are converted to a PyTorch tensor if not already
    if not isinstance(logits, torch.Tensor):
        logits = torch.tensor(logits)
    
    # Ensure logits are flattened and that we're only using the last prediction
    logits.squeeze_()  # Remove batch dim
    if logits.ndim > 1:
        logits = logits[-1]
    
    # Compute softmax to get probabilities
    probabilities = torch.softmax(logits, dim=0)
    
    # Sort and select top 10 predictions
    probs, idxs = torch.topk(probabilities, 10)
    labels = [f"'{model.to_str_tokens(i)[0]}'" for i in idxs]  # Convert indices to labels
    
    # Calculate the probability of "Other" as 1 minus the sum of top 10 probabilities
    other_prob = 1 - probs.sum().item()
    
    # Append "Other" category to the data
    labels.append("Other")
    probs = probs.tolist() + [other_prob]
    
    # Plot with Plotly Express
    fig = px.bar(x=labels, y=probs, labels={'x': 'Predictions', 'y': 'Probability'},
                 title="Top 10 Predictions of the Language Model")
    fig.show()

logits_vis(model('never gonna give you up,'))

In [None]:
def test_prompt(prompt):
    print("Tokenization:", model.to_str_tokens(prompt))
    logits_vis(model(prompt))

test_in = """
#include<iostream>
#"""
test_prompt(test_in)

Tokenization: ['<|BOS|>', '\n', '#', 'include', '<', 'i', 'ostream', '>', '\n', '#']


In [None]:
test_in = """
let userName = 'John Smith';
console.log(userName);
// 'John"""
test_prompt(test_in)

Tokenization: ['<|BOS|>', '\n', 'let', ' user', 'Name', ' =', " '", 'John', ' Smith', "';", '\n', 'console', '.', 'log', '(', 'user', 'Name', ');', '\n', '//', " '", 'John']


In [None]:
test_in = """
.button {
    color: red;
}
.button"""
test_prompt(test_in)

Tokenization: ['<|BOS|>', '\n', '.', 'button', ' {', '\n   ', ' color', ':', ' red', ';', '\n', '}', '\n', '.', 'button']


In [None]:
test_in = """
SELECT name FROM users WHERE age > 18;
"""
test_prompt(test_in)

Tokenization: ['<|BOS|>', '\n', 'SELECT', ' name', ' FROM', ' users', ' WHERE', ' age', ' >', ' 1', '8', ';', '\n']
