In [1]:
# Import Transformer Lens, and load pythia models
from transformer_lens import HookedTransformer
import torch as th
from torch import nn
import numpy as np 
from neuron_text_simplifier import NeuronTextSimplifier
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from tqdm import tqdm
from einops import rearrange
device = "cuda:1" if th.cuda.is_available() else "cpu"

# model_name = "EleutherAI/pythia-160m-deduped"
MODEL_NAME_LIST = [
    "EleutherAI/pythia-70m-deduped", 
]
model_name = MODEL_NAME_LIST[0]
model_save_name = model_name.replace("/", "-")
layer = 1

model = HookedTransformer.from_pretrained(model_name, device=device)

# tokenizer = AutoTokenizer.from_pretrained(model_name)
Token_amount = 20

# Load the training set from pile-10k
d = load_dataset("NeelNanda/pile-10k", split="train").map(
    lambda x: model.tokenizer(x['text']),
    batched=True,
).filter(
    lambda x: len(x['input_ids']) > Token_amount
).map(
    lambda x: {'input_ids': x['input_ids'][:Token_amount]}
)
neurons = model.W_in.shape[-1]
datapoints = d.num_rows
batch_size = 64

neuron_activations = th.zeros((datapoints*Token_amount, neurons))

try:
    neuron_activations = th.load(f"Data/{model_save_name}_activations_layer_{layer}.pt")
    print("Loaded activations from file")
except:
    with th.no_grad(), d.formatted_as("pt"):
        dl = DataLoader(d["input_ids"], batch_size=batch_size)
        for i, batch in enumerate(tqdm(dl)):
            _, cache = model.run_with_cache(batch.to(device))
            neuron_activations[i*batch_size*Token_amount:(i+1)*batch_size*Token_amount,:] = rearrange(cache[f"blocks.{layer}.mlp.hook_post"], "b s n -> (b s) n" )
    th.save(neuron_activations, f"Data/{model_save_name}_activations_layer_{layer}.pt")

  from .autonotebook import tqdm as notebook_tqdm
Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m-deduped into HookedTransformer


Found cached dataset parquet (/home/mchorse/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /home/mchorse/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-9b5f9fe88d7bc2e3.arrow
Loading cached processed dataset at /home/mchorse/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-81ea955e26446615.arrow
Loading cached processed dataset at /home/mchorse/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-000a30dc29ae5636.arrow


Loaded activations from file


In [32]:
from sklearn.cluster import DBSCAN

neuron = 1306
# First remove all activations that are <, say 0.1
mask = neuron_activations[:, neuron].where(neuron_activations[:, neuron] > 0.4, th.tensor(0.0))

# Then cluster the remaining activations
x = th.stack((neuron_activations[mask != 0, neuron],neuron_activations[mask != 0, 924]), dim=1) 
db = DBSCAN(eps=0.1, min_samples=10).fit(x)
labels = db.labels_
mask1 = (mask != 0)
# Remove the clusters that are SE-docs
mask2 = (labels != 1) & (labels !=2)

def combined_masks_of_diff_shape(mask1, mask2):
    i = 0
    for x in range(len(mask1)):
        if mask1[x]:
            mask1[x] = th.tensor(mask2[i])
            i += 1
    return mask1
final_mask = combined_masks_of_diff_shape(mask1, mask2)


In future, it will be an error for 'np.bool_' scalars to be interpreted as an index



In [33]:
m2 = neuron_activations[final_mask, :].where(neuron_activations[final_mask, :] > 0.4, th.tensor(0.0)) != 0

In [34]:
m2[m2 != 0] = 1

In [38]:
import plotly.express as px
v, i = m2.sum(dim=0).sort(descending=True)
px.scatter(v)
print(i)

tensor([1306,  512,  982,  ..., 1199, 1525, 1264])


In [58]:
def neuron_plot(n1, n2, mask_n1, neuron_activations=neuron_activations, threshold = 0.4):
    m2 = neuron_activations[:, n2].where(neuron_activations[:, n2] > threshold, th.tensor(0.0))
    mask = mask_n1 * m2
    x = neuron_activations[mask != 0,n1]
    y = neuron_activations[mask != 0,n2]
    print(x.shape)
    # counts, xedges, yedges = np.histogram2d(x, y, bins=30)
    bins = 30
    return px.density_heatmap(
        x = x,
        y = y,
        title=f"Correlation between neuron {n1} and neuron {n2} in layer {layer} of {model_name}",
        nbinsx = bins, 
        nbinsy = bins,
        text_auto=True,
        labels=dict(x=f"Neuron {n1}", y=f"Neuron {n2}", color="Density")
    )
def cluster_points(n1,n2, mask, neuron_activations=neuron_activations, threshold = 0.4):
    x = th.stack((neuron_activations[mask, n1],neuron_activations[mask, n2]), dim=1) 
    db = DBSCAN(eps=0.1, min_samples=10).fit(x)
    return db.labels_
n1 = 1306
n2 = 512
l = cluster_points(n1, n2, final_mask)
neuron_plot(n1, n2, final_mask)

torch.Size([908])


In [56]:
clustered_datapoints = th.tensor(l).where(th.tensor(l) == 1, th.tensor(0.0)) != 0

In [59]:
neuron_activations[clustered_datapoints, n2]

IndexError: The shape of the mask [2366] at index 0 does not match the shape of the indexed tensor [199180] at index 0

In [60]:
clustered_datapoints.shape

torch.Size([2366])

In [26]:
mask4 = combined_masks_of_diff_shape(mask1, mask2)
neuron_plot(1306, 445, mask4) 


In future, it will be an error for 'np.bool_' scalars to be interpreted as an index



torch.Size([281])


In [2]:
corr_coef = th.corrcoef(neuron_activations.T)

In [3]:
# Find the neurons that most correlate with neuron
neuron = 1306
v, i = corr_coef[neuron, :].sort(descending=True)

In [4]:
i

tensor([1306,  924,  697,  ..., 1245,   26,   43])

In [105]:
mask2

array([ True,  True,  True,  True,  True, False, False,  True,  True,
       False, False, False, False,  True,  True, False, False,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
       False, False,  True,  True,  True,  True,  True,  True,  True,
        True,  True, False, False,  True,  True,  True, False, False,
        True, False, False,  True,  True, False, False,  True, False,
       False,  True,  True,  True,  True, False, False,  True,  True,
       False, False,  True,  True,  True,  True, False, False,  True,
       False, False, False, False,  True, False, False,  True,  True,
        True,  True, False, False,  True,  True,  True, False, False,
       False, False,  True,  True,  True,  True, False, False,  True,
        True])

In [111]:
mask1[mask1 != 0].shape

torch.Size([5141])

In [95]:
mask2.shape

(7944,)

In [96]:
mask2[mask2 != 0].shape

(5141,)

In [17]:
import plotly.express as px
def neuron_plot(n1, n2, mask_n1, neuron_activations=neuron_activations, threshold = 0.1):
    m2 = neuron_activations[:, n2].where(neuron_activations[:, n2] > 0.1, th.tensor(0.0))
    mask = mask_n1 * m2
    x = neuron_activations[mask != 0,n1]
    y = neuron_activations[mask != 0,n2]
    print(x.shape)
    # counts, xedges, yedges = np.histogram2d(x, y, bins=30)
    bins = 30
    return px.density_heatmap(
        x = x,
        y = y,
        title=f"Correlation between neuron {n1} and neuron {n2} in layer {layer} of {model_name}",
        nbinsx = bins, 
        nbinsy = bins,
        text_auto=True,
        labels=dict(x=f"Neuron {n1}", y=f"Neuron {n2}", color="Density")
    )
n1 = 1306
neuron_plot(n1, 924, mask1)

NameError: name 'mask1' is not defined

In [53]:
# Now repeat correlation code
corr_coef2 = th.corrcoef(neuron_activations[mask1, :][mask2].T)
v2, i2 = corr_coef2[neuron, :].sort(descending=True)

In [138]:
    def pearsonr(x, y):
        #Copied from https://gist.github.com/ncullen93/58e71c4303b89e420bd8e0b0aa54bf48
        """
        Mimics `scipy.stats.pearsonr`
        Arguments
        ---------
        x : 1D th.Tensor
        y : 1D th.Tensor
        Returns
        -------
        r_val : float
            pearsonr correlation coefficient between x and y
        
        Scipy docs ref:
            https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.pearsonr.html
        
        Scipy code ref:
            https://github.com/scipy/scipy/blob/v0.19.0/scipy/stats/stats.py#L2975-L3033
        Example:
            >>> x = np.random.randn(100)
            >>> y = np.random.randn(100)
            >>> sp_corr = scipy.stats.pearsonr(x, y)[0]
            >>> th_corr = pearsonr(th.from_numpy(x), torch.from_numpy(y))
            >>> np.allclose(sp_corr, th_corr)
        """
        mean_x = th.mean(x)
        mean_y = th.mean(y)
        xm = x.sub(mean_x)
        ym = y.sub(mean_y)
        r_num = xm.dot(ym)
        r_den = th.norm(xm, 2) * th.norm(ym, 2)
        r_val = r_num / r_den
        return r_val
    n2 = i2[1]
    m2 = neuron_activations[:, n2].where(neuron_activations[:, n2] > 0.1, th.tensor(0.0))
    mask = mask1 * m2
    x = neuron_activations[mask != 0,n1]
    y = neuron_activations[mask != 0,n2]
    pearsonr(x, x)

tensor(1.)

In [192]:
current_neuron = 1306
n1 = current_neuron
mask1 = neuron_activations[:, current_neuron].where(neuron_activations[:, current_neuron] > 0.4, th.tensor(0.0))
# mask1 = combined_masks_of_diff_shape(mask1, mask2)
pearsonr_values = th.zeros(neurons)
for n in range(neurons):
    m2 = neuron_activations[:, n].where(neuron_activations[:, n] > 0.4, th.tensor(0.0))
    mask = mask1 * m2
    if(mask.sum() < 10):
        continue
    x = neuron_activations[mask != 0,current_neuron]
    y = neuron_activations[mask != 0,n]    
    pearsonr_values[n] = pearsonr(x, y)
vp, ip = pearsonr_values.sort(descending=True)

In [193]:
ip

tensor([1306,  924,  697,  ...,  733,  673,  280])

In [194]:
vp

tensor([ 1.0000,  0.9864,  0.9758,  ..., -0.8395, -0.8901, -0.8911])

In [199]:
i = 1
# mask1 = neuron_activations[:, current_neuron].where(neuron_activations[:, current_neuron] > 0.1, th.tensor(0.0))
mask4 = combined_masks_of_diff_shape(mask1, mask2)
neuron_plot(current_neuron, ip[i], mask4) 


In future, it will be an error for 'np.bool_' scalars to be interpreted as an index



torch.Size([41])
