In [1]:
import torch
from transformer_lens import HookedTransformer
import numpy as np
from torch.utils.data import DataLoader
from datasets import Dataset, load_dataset
from tqdm.auto import tqdm
from functools import partial
from einops import rearrange

features = 1
autoencoder_path = "/mnt/ssd-cluster/longrun2408/tied_residual_l2_r6/_31/learned_dicts.pt"
autoencoder_index = 5
layer = 2
dataset_name = "NeelNanda/pile-10k"
device = "cuda:3"
model_name = "EleutherAI/pythia-70m-deduped"
setting= "residual"
max_seq_length=30

if setting == "residual":
    cache_name = f"blocks.{layer}.hook_resid_post"
elif setting == "mlp":
    cache_name = f"blocks.{layer}.mlp.hook_post"
else:
    raise NotImplementedError
all_autoencoders = torch.load(autoencoder_path)
autoencoder, hyperparams = all_autoencoders[autoencoder_index]
autoencoder.to_device(device)
print(f"Loaded autoencoder w/ {hyperparams} on {device}")

  from .autonotebook import tqdm as notebook_tqdm


Loaded autoencoder w/ {'dict_size': 3072, 'l1_alpha': 0.0013894954463467002} on cuda:3


In [2]:
model = HookedTransformer.from_pretrained_no_processing(model_name, device = device)

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m-deduped into HookedTransformer


In [3]:

def download_dataset(dataset_name, tokenizer, max_length=256, num_datapoints=None):
    dataset = load_dataset(dataset_name, split="train")
    if(num_datapoints):
        split_text = f"train[:{num_datapoints}]"
    else:
        split_text = "train"
    dataset = load_dataset(dataset_name, split=split_text).map(
        lambda x: tokenizer(x['text']),
        batched=True,
    ).filter(
        lambda x: len(x['input_ids']) > max_length
    ).map(
        lambda x: {'input_ids': x['input_ids'][:max_length]}
    )
    return dataset

print(f"Downloading {dataset_name}")
dataset = download_dataset(dataset_name, tokenizer= model.tokenizer, max_length=max_seq_length, num_datapoints=None) # num_datapoints grabs all of them if None

Downloading NeelNanda/pile-10k


Found cached dataset parquet (/home/mchorse/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Found cached dataset parquet (/home/mchorse/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /home/mchorse/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-e8c69c6c2a788d7f.arrow
Loading cached processed dataset at /home/mchorse/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-491db685689a5b37.arrow
Loading cached processed dataset at /home/mchorse/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7

In [4]:
d_model = model.cfg.d_model
assert (d_model == autoencoder.encoder.shape[-1]), f"Model and autoencoder must have same hidden size. Model: {d_model}, Autoencoder: {autoencoder.encoder.shape[-1]}"

In [5]:
# Now we can use the model to get the activations
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from einops import rearrange
def get_dictionary_activations(model, dataset, cache_name, autoencoder, batch_size=32):
    num_features, d_model = autoencoder.encoder.shape
    datapoints = dataset.num_rows
    dictionary_activations = torch.zeros((datapoints*max_seq_length, num_features))
    with torch.no_grad(), dataset.formatted_as("pt"):
        dl = DataLoader(dataset["input_ids"], batch_size=batch_size)
        for i, batch in enumerate(tqdm(dl)):
            _, cache = model.run_with_cache(batch.to(device))
            batched_neuron_activations = rearrange(cache[cache_name], "b s n -> (b s) n" )
            batched_dictionary_activations = autoencoder.encode(batched_neuron_activations)
            dictionary_activations[i*batch_size*max_seq_length:(i+1)*batch_size*max_seq_length,:] = batched_dictionary_activations.cpu()
    return dictionary_activations

print("Getting dictionary activations")
dictionary_activations = get_dictionary_activations(model, dataset, cache_name, autoencoder, batch_size=32)

Getting dictionary activations


100%|██████████| 310/310 [00:06<00:00, 51.53it/s]


In [6]:
# from interp_utils import *
if isinstance(features, int):
    features = [features]
for feature in features:
    text_list, full_text, token_list, full_token_list = get_feature_datapoints(feature, dictionary_activations, model.tokenizer, max_seq_length, dataset, setting="uniform")
    # text_list, full_text, token_list, full_token_list = get_feature_datapoints(feature, dictionary_activations, dataset, setting="max")
    # visualize_text(full_text, feature, model, autoencoder, layer)
l = visualize_text(text_list, feature, model, autoencoder, layer)

  bins = torch.bucketize(best_feature_activations, bin_boundaries)


In [56]:
def get_feature_datapoints(feature_index, dictionary_activations, tokenizer, token_amount, dataset, k=10, setting="max"):
    best_feature_activations = dictionary_activations[:, feature_index]
    # Sort the features by activation, get the indices
    if setting=="max":
        found_indices = torch.argsort(best_feature_activations, descending=True)[:k]
    elif setting=="uniform":
        # min_value = torch.min(best_feature_activations)
        min_value = torch.min(best_feature_activations)
        max_value = torch.max(best_feature_activations)

        # Define the number of bins
        num_bins = k

        # Calculate the bin boundaries as linear interpolation between min and max
        bin_boundaries = torch.linspace(min_value, max_value, num_bins + 1)

        # Assign each activation to its respective bin
        bins = torch.bucketize(best_feature_activations, bin_boundaries)

        # Initialize a list to store the sampled indices
        sampled_indices = []

        # Sample from each bin
        for bin_idx in torch.unique(bins):
            if(bin_idx==0): # Skip the first one. This is below the median
                continue
            # Get the indices corresponding to the current bin
            bin_indices = torch.nonzero(bins == bin_idx, as_tuple=False).squeeze(dim=1)
            
            # Randomly sample from the current bin
            sampled_indices.extend(np.random.choice(bin_indices, size=1, replace=False))

        # Convert the sampled indices to a PyTorch tensor & reverse order
        found_indices = torch.tensor(sampled_indices).long().flip(dims=[0])
    else: # random
        # get nonzero indices
        nonzero_indices = torch.nonzero(best_feature_activations)[:, 0]
        # shuffle
        shuffled_indices = nonzero_indices[torch.randperm(nonzero_indices.shape[0])]
        found_indices = shuffled_indices[:k]
    num_datapoints = int(dictionary_activations.shape[0]/token_amount)
    datapoint_indices =[np.unravel_index(i, (num_datapoints, token_amount)) for i in found_indices]
    all_activations = best_feature_activations.reshape(num_datapoints, token_amount).tolist()
    full_activations = []
    partial_activations = []
    text_list = []
    full_text = []
    token_list = []
    full_token_list = []
    for i, (md, s_ind) in enumerate(datapoint_indices):
        md = int(md)
        s_ind = int(s_ind)
        full_tok = torch.tensor(dataset[md]["input_ids"])
        full_text.append(tokenizer.decode(full_tok))
        tok = dataset[md]["input_ids"][:s_ind+1]
        full_activations.append(all_activations[md])
        partial_activations.append(all_activations[md][:s_ind+1])
        text = tokenizer.decode(tok)
        text_list.append(text)
        token_list.append(tok)
        full_token_list.append(full_tok)
    return text_list, full_text, token_list, full_token_list, partial_activations, full_activations

text_list, full_text, token_list, full_token_list, partial_activations, full_activations = get_feature_datapoints(feature, dictionary_activations, model.tokenizer, max_seq_length, dataset, setting="uniform")

In [75]:
# display_tokens(full_token_list, full_activations, model.tokenizer)
# display_tokens(token_list, partial_activations, model.tokenizer)
save_token_display(token_list, partial_activations, model.tokenizer, f"feature_{feature}_layer_{layer}_setting_{setting}.png")

QStandardPaths: XDG_RUNTIME_DIR not set, defaulting to '/tmp/runtime-mchorse'
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


In [74]:
from IPython.display import display, HTML
import imgkit
def tokens_and_activations_to_html(toks, activations, tokenizer):
    if isinstance(toks, torch.Tensor):
        if toks.dim() == 1:
            toks = [toks.tolist()]
        elif toks.dim()==2:
            toks = toks.tolist()
        else: 
            raise NotImplementedError("tokens must be 1 or 2 dimensional")
    elif isinstance(toks, list):
        # ensure it's a list of lists
        if isinstance(toks[0], int):
            toks = [toks]
    if isinstance(activations, torch.Tensor):
        if(activations.dim() == 1):
            activations = [activations.tolist()]
        elif(activations.dim() == 2):
            activations = activations.tolist()
        else:
            raise NotImplementedError("activations must be 1 or 2 dimensional")
    elif isinstance(activations, list):
        # ensure it's a list of lists
        if isinstance(activations[0], float) or isinstance(activations[0], int):
            activations = [activations]
    # convert tokens into strings
    toks = [[tokenizer.decode(t).replace('Ġ', '&nbsp').replace('\n', '↵') for t in tok] for tok in toks]
    highlighted_text = []
    max_value = max([max(activ) for activ in activations])
    min_value = min([min(activ) for activ in activations])
    white = 245
    for act, tok in zip(activations, toks):
        for a, t in zip(act, tok):
            if a > 0.0:
                ratio = a/max_value
                text_color = "0,0,0" if ratio <= 0.5 else "255,255,255"
                highlighted_text.append(f'<span style="background-color:rgba({int(220-(220*ratio))},{int(220-(220*ratio))},255,1); color:rgb({text_color})">{t}</span>')
            elif a < 0.0:
                ratio = a/min_value
                text_color = "0,0,0" if ratio <= 0.5 else "255,255,255"
                highlighted_text.append(f'<span style="background-color:rgba(255, {int(220-(220*ratio))},{int(220-(220*ratio))},1);color:rgb({text_color})">{t}</span>')
            else:
                # highlighted_text.append(f'<span style="background-color:rgba(255,255,255,1);color:rgb(0,0,0)">{t}</span>')
                highlighted_text.append(f'<span style="background-color:rgba({white},{white},{white},1);color:rgb(0,0,0)">{t}</span>')
        highlighted_text.append('<br><br>')
    # Add color bar
    num_colors = 4
    if(min_value < 0):
        for i in range(num_colors, 0, -1):
            ratio = i / (num_colors)
            value = round((min_value*ratio),1)
            text_color = "255,255,255" if ratio > 0.5 else "0,0,0"
            highlighted_text.append(f'<span style="background-color:rgba(255, {int(220-(220*ratio))},{int(220-(220*ratio))},1); color:rgb({text_color})">&nbsp{value}&nbsp</span>')
    # Do zero
    highlighted_text.append(f'<span style="background-color:rgba({white},{white},{white},1);color:rgb(0,0,0)">&nbsp0.0&nbsp</span>')
    # Do positive
    if(max_value > 0):
        for i in range(1, num_colors+1):
            ratio = i / (num_colors)
            value = round((max_value*ratio),1)
            text_color = "255,255,255" if ratio > 0.5 else "0,0,0"
            highlighted_text.append(f'<span style="background-color:rgba({int(220-(220*ratio))},{int(220-(220*ratio))},255,1);color:rgb({text_color})">&nbsp{value}&nbsp</span>')
    highlighted_text = ''.join(highlighted_text)
    return highlighted_text

def display_tokens(tokens, activations, tokenizer):
    return display(HTML(tokens_and_activations_to_html(tokens, activations, tokenizer)))

def save_token_display(tokens, activations, tokenizer, path):
    html = tokens_and_activations_to_html(tokens, activations, tokenizer)
    imgkit.from_string(html, path)
    # print(f"Saved to {path}")
    return
tokenizer = model.tokenizer
text = " I like eggs and \n"
activations = [0.1, 0.0, 8.0, 8.0, 0.1]
tokens = tokenizer(text, return_tensors="pt")["input_ids"].squeeze()
save_token_display(tokens, activations, tokenizer, "test.jpg")
save_token_display(tokens, activations, tokenizer, "test.jpg")
display_tokens(tokens, activations, tokenizer)

QStandardPaths: XDG_RUNTIME_DIR not set, defaulting to '/tmp/runtime-mchorse'
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               
QStandardPaths: XDG_RUNTIME_DIR not set, defaulting to '/tmp/runtime-mchorse'
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


In [37]:
HTML("<span>↵ hey ↵ ya</span>")

In [44]:
tokens

tensor([  309,   751, 11624,   285,  2490])

In [52]:
[[tokenizer.decode(t).replace('Ġ', '&nbsp').replace('\n', '↵') for t in tok] for tok in [tokens]]

[[' I', ' like', ' eggs', ' and', ' ↵']]

In [49]:
model.tokenizer.convert_ids_to_tokens([2490])

['ĠĊ']

In [50]:
model.tokenizer.decode([2490])

' \n'

In [38]:
print("\n".replace("\n", "↵"))

↵


In [12]:
token_list,

[[3848,
  18415,
  26203,
  187,
  187,
  3848,
  18415,
  8966,
  25418,
  26203,
  313,
  6448,
  3495,
  4162,
  12034,
  10,
  310,
  271,
  4383,
  31926,
  665,
  7120,
  347,
  247,
  259,
  4940,
  390],
 [18,
  15,
  7327,
  273,
  253,
  14723,
  187,
  510,
  1246,
  3688,
  7033,
  3839,
  281,
  22486,
  273,
  2144,
  534,
  403,
  20618,
  689,
  247,
  12045,
  285,
  4845,
  387,
  253,
  5024,
  390],
 [1532,
  187,
  15834,
  27,
  686,
  37,
  17799,
  29886,
  3210,
  2085,
  247,
  4217,
  7792,
  323,
  14053,
  19349,
  390],
 [7475, 19020, 310, 247, 8723, 382, 3367, 390],
 [2214, 247, 12732, 22, 19327, 253, 4394, 3551, 275, 432, 34538, 2207],
 [18968,
  16078,
  13,
  367,
  610,
  18367,
  13,
  48585,
  13,
  10388,
  478,
  13,
  2325,
  7885,
  13,
  21570,
  13,
  330,
  1595,
  1351,
  80,
  13,
  401,
  2682,
  285,
  16,
  263],
 [187, 22313, 367, 15, 20, 69, 5329, 313, 7330, 10, 187, 27018, 2207],
 [510, 17068, 3448, 273, 9550, 310, 48526, 2584, 18000,

In [208]:
tt = tokenizer(text, return_tensors="pt")["input_ids"].tolist()
[[t.replace('Ġ', '&nbsp').replace('\n', '↵') for t in tokenizer.convert_ids_to_tokens(tok)] for tok in tt]

[['&nbspI', '&nbsplike', '&nbspeggs', '&nbspand', '&nbspĊ']]

In [211]:
torch.tensor(torch.tensor(activations)).max()

  torch.tensor(torch.tensor(activations)).max()


tensor(8.)

In [46]:
from IPython.display import display, HTML

display(HTML('<h1>Hello, World!!</h1>'))

In [None]:
display(render_toks_w_weights(torch.tensor([3,4,5, 9, 10]), [0.5, 0.0, 100, 4, -4]))

In [26]:
[tokenizer.decode([tok]) for tok in [3,4,5]]

['"', '#', '$']

In [37]:
ablate_feature_direction_display(full_text, autoencoder, model, layer, features=feature)