In [1]:
# Import Transformer Lens, and load pythia models
from transformer_lens import HookedTransformer
import torch as th
from torch import nn
import numpy as np 
from neuron_text_simplifier import NeuronTextSimplifier
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from tqdm import tqdm
from einops import rearrange
device = "cuda:1" if th.cuda.is_available() else "cpu"

# model_name = "EleutherAI/pythia-160m-deduped"
MODEL_NAME_LIST = [
    "EleutherAI/pythia-70m-deduped", 
]
model_name = MODEL_NAME_LIST[0]
model_save_name = model_name.replace("/", "-")
layer = 1

model = HookedTransformer.from_pretrained(model_name, device=device)

# tokenizer = AutoTokenizer.from_pretrained(model_name)
Token_amount = 20

# Load the training set from pile-10k
d = load_dataset("NeelNanda/pile-10k", split="train").map(
    lambda x: model.tokenizer(x['text']),
    batched=True,
).filter(
    lambda x: len(x['input_ids']) > Token_amount
).map(
    lambda x: {'input_ids': x['input_ids'][:Token_amount]}
)
neurons = model.W_in.shape[-1]
datapoints = d.num_rows
batch_size = 64

neuron_activations = th.zeros((datapoints*Token_amount, neurons))

try:
    neuron_activations = th.load(f"Data/{model_save_name}_activations_layer_{layer}.pt")
    print("Loaded activations from file")
except:
    with th.no_grad(), d.formatted_as("pt"):
        dl = DataLoader(d["input_ids"], batch_size=batch_size)
        for i, batch in enumerate(tqdm(dl)):
            _, cache = model.run_with_cache(batch.to(device))
            neuron_activations[i*batch_size*Token_amount:(i+1)*batch_size*Token_amount,:] = rearrange(cache[f"blocks.{layer}.mlp.hook_post"], "b s n -> (b s) n" )
    th.save(neuron_activations, f"Data/{model_save_name}_activations_layer_{layer}.pt")

  from .autonotebook import tqdm as notebook_tqdm
Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m-deduped into HookedTransformer


Found cached dataset parquet (/home/mchorse/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /home/mchorse/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-c3dfa1eec06aadb9.arrow
Loading cached processed dataset at /home/mchorse/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-0251a9ae73adda5c.arrow
Loading cached processed dataset at /home/mchorse/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-1d0a8dd6aeeac743.arrow


Loaded activations from file


In [2]:
corr_coef = th.corrcoef(neuron_activations.T)

In [3]:
print(2048**2)

4194304


In [3]:
neuron = 1306
v, i = corr_coef[neuron, :].sort(descending=True)

In [5]:
for j in range(10):
    neuron_ind = i[j].item()
    print(corr_coef[neuron_ind, :].sort(descending=True).values[:10])

tensor([1306,  924,  697,  503, 1668,  852,  859, 1991,  529, 1646])
tensor([ 924, 1306,  697,  529,  503,  852, 1668,  859,  255, 1991])
tensor([ 697,  924, 1306,  529,  255, 1523,  777,  859,  503, 1449])
tensor([ 503,  859, 1306,  924, 1668,  852, 1991, 1646,  987,   35])
tensor([1668, 1306,  503,  852, 1646,  859,  924, 1200, 1991, 1896])
tensor([ 852, 1991, 1306,  503,  924, 1646, 1668,  859, 1896,  935])
tensor([ 859,  503, 1306, 1668,  924, 1173, 2019,  852, 1177, 1991])
tensor([1991,  852, 1306, 1668,  503,  924, 1646,  859,  935, 1896])
tensor([ 529,  924,  697, 1306,  255, 1449, 1523,  859,  503, 1318])
tensor([1646,  852, 1668,  935, 1991, 1896,  503, 1306,  859,  924])


In [20]:
for j in range(t):
    neuron_ind = i[j].item()
    print(corr_coef[neuron_ind, i[:t]].min().item())

0.362860769033432
0.3111647963523865
0.18492959439754486
0.2608509659767151
0.25849056243896484
0.27204516530036926
0.26944267749786377
0.2639138698577881
0.18532848358154297
0.25123652815818787
0.15140153467655182
0.24702365696430206
0.21995921432971954
0.22950829565525055
0.22643877565860748
0.26466572284698486
0.22175940871238708
0.21319076418876648
0.21186938881874084
0.23599626123905182
0.19811241328716278
0.24938030540943146
0.2179834544658661
0.21676169335842133
0.21654261648654938
0.1911548674106598
0.12188144773244858
0.1997930407524109
0.18377770483493805
0.2042209953069687
0.2179475575685501
0.2298586368560791
0.22150929272174835
0.18167485296726227
0.19090989232063293
0.2189105898141861
0.2413829118013382
0.10247824341058731
0.1726890206336975
0.19885478913784027
0.19096173346042633
0.14983387291431427
0.1877770721912384
0.19545458257198334
0.19559188187122345
0.1390303522348404
0.141532301902771
0.16652655601501465
0.17162717878818512
0.1628270149230957
0.133186474442482
0