In [1]:
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
from jaxtyping import Float, Int, Bool
from torch import Tensor
from tqdm.auto import tqdm
import plotly.io as pio
import pandas as pd
import numpy as np
import plotly.express as px 
from collections import defaultdict
import matplotlib.pyplot as plt
import re
from IPython.display import display, HTML
from datasets import load_dataset
from collections import Counter
import pickle
import os
import plotly.graph_objects as go
from scipy.stats import norm

import sklearn
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression

pio.renderers.default = "notebook_connected+notebook"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

from haystack_utils import get_mlp_activations
from hook_utils import get_ablate_neuron_hook
import haystack_utils
from pythia_160m_utils import get_neuron_accuracy, ablation_effect
from plotting_utils import plot_neuron_acts, color_binned_histogram


%reload_ext autoreload
%autoreload 2

In [25]:
with open('data/tiny_stories_chatgpt.json', 'r') as f:
    prompts = json.load(f)

In [3]:
model = HookedTransformer.from_pretrained("roneneldan/TinyStories-1M",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device="cuda")
import plotting_utils
large_acts_df = plotting_utils.get_neuron_moments(model, prompts,
                                                  [[i, j] for i in range(8) for j in range(256)], hook_pre=True)

# plotting_utils.plot_neuron_acts(model, prompts, [[0, 109], [1, 213], [2, 233], [3, 149], [4, 4], [5, 197], [6, 48], [7, 88],])

Using pad_token, but it is not set yet.


Loaded pretrained model roneneldan/TinyStories-1M into HookedTransformer


In [None]:
from hook_utils import save_activation

with model.hooks([('blocks.2.mlp.hook_post', save_activation)]):
    model(prompts[0])
acts = model.hook_dict['blocks.2.mlp.hook_post'].ctx['activation']
haystack_utils.clean_print_strings_as_html(model.to_str_tokens(prompts[0]), acts[0, :, 233], max_value=0.7)

In [6]:
layer_neuron_tuples = large_acts_df.sort_values('kurtosis', ascending=False).iloc[:2][['layer', 'neuron']].values.tolist()
plotting_utils.plot_neuron_acts(model, prompts, layer_neuron_tuples, range_x=[-1, 2], hook_pre=True)

In [20]:
layer_neuron_tuples = large_acts_df.sort_values('kurtosis', ascending=False).iloc[:20][['layer', 'neuron']].values.tolist()

In [21]:
layer_neuron_tuples

[[1, 205],
 [1, 168],
 [2, 244],
 [1, 94],
 [0, 60],
 [0, 215],
 [2, 23],
 [2, 113],
 [3, 189],
 [4, 78],
 [3, 202],
 [0, 219],
 [0, 95],
 [4, 22],
 [2, 115],
 [1, 97],
 [0, 98],
 [2, 222],
 [7, 225],
 [4, 11]]

In [16]:
layer_neuron_tuples = large_acts_df.sort_values('kurtosis', ascending=False).iloc[:20][['layer', 'neuron']].values.tolist()
layer_neurons_dict = haystack_utils.get_neurons_by_layer(layer_neuron_tuples)
for layer in layer_neurons_dict.keys():
    acts = []
    str_tokens = []
    for prompt in prompts[:2]:
        acts.append(haystack_utils.get_mlp_activations([prompt], layer, model, hook_pre=False, mean=False, context_crop_start=0, context_crop_end=1000))
        str_tokens.append(model.to_str_tokens(prompt))
    
    for neuron in layer_neurons_dict[layer]:
        print(f'L{layer}N{neuron}')
        for i in range(2):
            haystack_utils.clean_print_strings_as_html(str_tokens[i], acts[i][:, neuron], max_value=1)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

L1N205


L1N168


L1N94


L1N97


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

L2N244


L2N23


L2N113


L2N115


L2N222


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

L0N60


L0N215


L0N219


L0N95


L0N98


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

L3N189


L3N202


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

L4N78


L4N22


L4N11


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

L7N225


In [None]:
from hook_utils import save_activation

with model.hooks([('blocks.2.mlp.hook_post', save_activation)]):
    model(prompts[0])
acts = model.hook_dict['blocks.2.mlp.hook_post'].ctx['activation']
haystack_utils.clean_print_strings_as_html(model.to_str_tokens(prompts[0]), acts[0, :, 233], max_value=0.7)

### Skew

In [17]:
layer_neuron_tuples = large_acts_df.sort_values('skew', ascending=False).iloc[:2][['layer', 'neuron']].values.tolist()
plotting_utils.plot_neuron_acts(model, prompts, layer_neuron_tuples, range_x=[-1, 2], hook_pre=True)

In [19]:
N_EXAMPLES = 4

layer_neuron_tuples = large_acts_df.sort_values('skew', ascending=False).iloc[:20][['layer', 'neuron']].values.tolist()
layer_neurons_dict = haystack_utils.get_neurons_by_layer(layer_neuron_tuples)
for layer in layer_neurons_dict.keys():
    acts = []
    str_tokens = []
    for prompt in longer_stories[:N_EXAMPLES]:
        acts.append(haystack_utils.get_mlp_activations([prompt], layer, model, hook_pre=False, mean=False, context_crop_start=0, context_crop_end=1000))
        str_tokens.append(model.to_str_tokens(prompt))
    
    for neuron in layer_neurons_dict[layer]:
        print(f'L{layer}N{neuron}')
        for i in range(N_EXAMPLES):
            haystack_utils.clean_print_strings_as_html(str_tokens[i], acts[i][:, neuron], max_value=1)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

L0N60


L0N58


L0N230


L0N68


L0N172


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

L1N97


L1N94


L1N30


L1N40


L1N219


L1N182


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

L6N131


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

L3N11


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

L7N41


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

L2N23


L2N17


L2N206


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

L4N221


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

L5N194


L5N160
