In [1]:
import torch
from transformer_lens import HookedTransformer
import numpy as np
from torch.utils.data import DataLoader
from datasets import Dataset, load_dataset
from tqdm.auto import tqdm
from functools import partial
from einops import rearrange

features = 1
autoencoder_path = "/mnt/ssd-cluster/longrun2408/tied_residual_l2_r6/_31/learned_dicts.pt"
autoencoder_index = 5
layer = 2
dataset_name = "NeelNanda/pile-10k"
device = "cuda:3"
model_name = "EleutherAI/pythia-70m-deduped"
setting= "residual"
max_seq_length=30

if setting == "residual":
    cache_name = f"blocks.{layer}.hook_resid_post"
elif setting == "mlp":
    cache_name = f"blocks.{layer}.mlp.hook_post"
else:
    raise NotImplementedError
all_autoencoders = torch.load(autoencoder_path)
autoencoder, hyperparams = all_autoencoders[autoencoder_index]
autoencoder.to_device(device)
print(f"Loaded autoencoder w/ {hyperparams} on {device}")

  from .autonotebook import tqdm as notebook_tqdm


Loaded autoencoder w/ {'dict_size': 3072, 'l1_alpha': 0.0013894954463467002} on cuda:3


In [2]:
model = HookedTransformer.from_pretrained_no_processing(model_name, device = device)

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m-deduped into HookedTransformer


In [49]:
def download_dataset(dataset_name, tokenizer, max_length=256, num_datapoints=None):
    if(num_datapoints):
        split_text = f"train[:{num_datapoints}]"
    else:
        split_text = "train"
    dataset = load_dataset(dataset_name, split=split_text).map(
        lambda x: tokenizer(x['text']),
        batched=True,
    ).filter(
        lambda x: len(x['input_ids']) > max_length
    ).map(
        lambda x: {'input_ids': x['input_ids'][:max_length]}
    )
    return dataset

print(f"Downloading {dataset_name}")
dataset = download_dataset(dataset_name, tokenizer= model.tokenizer, max_length=max_seq_length, num_datapoints=None) # num_datapoints grabs all of them if None

Downloading NeelNanda/pile-10k


Found cached dataset parquet (/home/mchorse/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /home/mchorse/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-e8c69c6c2a788d7f.arrow
Loading cached processed dataset at /home/mchorse/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-491db685689a5b37.arrow
Loading cached processed dataset at /home/mchorse/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-3ac317cfd6c1b211.arrow


In [4]:
d_model = model.cfg.d_model
assert (d_model == autoencoder.encoder.shape[-1]), f"Model and autoencoder must have same hidden size. Model: {d_model}, Autoencoder: {autoencoder.encoder.shape[-1]}"

In [50]:
# Now we can use the model to get the activations
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from einops import rearrange
def get_dictionary_activations(model, dataset, cache_name, autoencoder, batch_size=32):
    num_features, d_model = autoencoder.encoder.shape
    datapoints = dataset.num_rows
    dictionary_activations = torch.zeros((datapoints*max_seq_length, num_features))
    with torch.no_grad(), dataset.formatted_as("pt"):
        dl = DataLoader(dataset["input_ids"], batch_size=batch_size)
        for i, batch in enumerate(tqdm(dl)):
            _, cache = model.run_with_cache(batch.to(device))
            batched_neuron_activations = rearrange(cache[cache_name], "b s n -> (b s) n" )
            batched_dictionary_activations = autoencoder.encode(batched_neuron_activations)
            dictionary_activations[i*batch_size*max_seq_length:(i+1)*batch_size*max_seq_length,:] = batched_dictionary_activations.cpu()
    return dictionary_activations

print("Getting dictionary activations")
dictionary_activations = get_dictionary_activations(model, dataset, cache_name, autoencoder, batch_size=32)

Getting dictionary activations


100%|██████████| 310/310 [00:10<00:00, 29.06it/s]


In [51]:

def ablate_feature_direction(model, dataset, cache_name, autoencoder, feature, batch_size=32):
    def less_than_rank_1_ablate(value, hook):
        # Only ablate the feature direction up to the negative bias
        # ie Only subtract when it activates above that negative bias.

        # Rearrange to fit autoencoder
        int_val = rearrange(value, 'b s h -> (b s) h')
        # Run through the autoencoder
        act = autoencoder.encode(int_val)
        dictionary_for_this_autoencoder = autoencoder.get_learned_dict()
        feature_direction = torch.outer(act[:, feature].squeeze(), dictionary_for_this_autoencoder[feature].squeeze())
        batch, seq_len, hidden_size = value.shape
        feature_direction = rearrange(feature_direction, '(b s) h -> b s h', b=batch, s=seq_len)
        value -= feature_direction
        return value

    datapoints = dataset.num_rows
    logit_diffs = torch.zeros((datapoints*max_seq_length))
    with torch.no_grad(), dataset.formatted_as("pt"):
        dl = DataLoader(dataset["input_ids"], batch_size=batch_size)
        for i, batch in enumerate(tqdm(dl)):
            original_logits = model(batch.to(device)).log_softmax(dim=-1)
            ablated_logits = model.run_with_hooks(batch.to(device), fwd_hooks=[(cache_name, less_than_rank_1_ablate)]).log_softmax(dim=-1)
            diff_logits = ablated_logits  - original_logits# ablated > original -> negative diff
            gather_tokens = rearrange(batch[:,1:].to(device), "b s -> b s 1")
            gathered = diff_logits[:, :-1].gather(-1,gather_tokens)
            # append all 0's to the beggining of gathered
            gathered = torch.cat([torch.zeros((gathered.shape[0],1,1)).to(device), gathered], dim=1)
            diff = rearrange(gathered, "b s n -> (b s n)")
            # Add one to the first position of logit diff, so we're always skipping over the first token (since it's not predicted)
            logit_diffs[i*batch_size*max_seq_length:(i+1)*batch_size*max_seq_length] = diff.cpu()
    return logit_diffs
feature = 1
logit_diffs = ablate_feature_direction(model, dataset, cache_name, autoencoder, feature = feature, batch_size=32)

100%|██████████| 310/310 [00:16<00:00, 19.00it/s]


In [6]:
# # from interp_utils import *
# if isinstance(features, int):
#     features = [features]
# for feature in features:
#     text_list, full_text, token_list, full_token_list = get_feature_datapoints(feature, dictionary_activations, model.tokenizer, max_seq_length, dataset, setting="uniform")
#     # text_list, full_text, token_list, full_token_list = get_feature_datapoints(feature, dictionary_activations, dataset, setting="max")
#     # visualize_text(full_text, feature, model, autoencoder, layer)
# l = visualize_text(text_list, feature, model, autoencoder, layer)

  bins = torch.bucketize(best_feature_activations, bin_boundaries)


In [170]:
from IPython.display import display, HTML
import imgkit
def tokens_and_activations_to_html(toks, activations, tokenizer):
    if isinstance(toks, torch.Tensor):
        if toks.dim() == 1:
            toks = [toks.tolist()]
        elif toks.dim()==2:
            toks = toks.tolist()
        else: 
            raise NotImplementedError("tokens must be 1 or 2 dimensional")
    elif isinstance(toks, list):
        # ensure it's a list of lists
        if isinstance(toks[0], int):
            toks = [toks]
    if isinstance(activations, torch.Tensor):
        if(activations.dim() == 1):
            activations = [activations.tolist()]
        elif(activations.dim() == 2):
            activations = activations.tolist()
        else:
            raise NotImplementedError("activations must be 1 or 2 dimensional")
    elif isinstance(activations, list):
        # ensure it's a list of lists
        if isinstance(activations[0], float) or isinstance(activations[0], int):
            activations = [activations]
    # convert tokens into strings
    toks = [[tokenizer.decode(t).replace('Ġ', '&nbsp').replace('\n', '↵') for t in tok] for tok in toks]
    highlighted_text = []
    max_value = max([max(activ) for activ in activations])
    min_value = min([min(activ) for activ in activations])
    white = 245
    red_blue_ness = 250
    positive_threshold = 0.01
    negative_threshold = 0.01
    for act, tok in zip(activations, toks):
        for a, t in zip(act, tok):
            if a > positive_threshold:
                ratio = a/max_value
                text_color = "0,0,0" if ratio <= 0.5 else "255,255,255"
                highlighted_text.append(f'<span style="background-color:rgba({int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},255,1); color:rgb({text_color})">{t}</span>')
            elif a < negative_threshold:
                ratio = a/min_value
                text_color = "0,0,0" if ratio <= 0.5 else "255,255,255"
                highlighted_text.append(f'<span style="background-color:rgba(255, {int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},1);color:rgb({text_color})">{t}</span>')
            else:
                # highlighted_text.append(f'<span style="background-color:rgba(255,255,255,1);color:rgb(0,0,0)">{t}</span>')
                highlighted_text.append(f'<span style="background-color:rgba({white},{white},{white},1);color:rgb(0,0,0)">{t}</span>')
        highlighted_text.append('<br><br>')
    # Add color bar
    num_colors = 4
    if(min_value < -negative_threshold):
        for i in range(num_colors, 0, -1):
            ratio = i / (num_colors)
            value = round((min_value*ratio),1)
            text_color = "255,255,255" if ratio > 0.5 else "0,0,0"
            highlighted_text.append(f'<span style="background-color:rgba(255, {int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},1); color:rgb({text_color})">&nbsp{value}&nbsp</span>')
    # Do zero
    highlighted_text.append(f'<span style="background-color:rgba({white},{white},{white},1);color:rgb(0,0,0)">&nbsp0.0&nbsp</span>')
    # Do positive
    if(max_value > positive_threshold):
        for i in range(1, num_colors+1):
            ratio = i / (num_colors)
            value = round((max_value*ratio),1)
            text_color = "255,255,255" if ratio > 0.5 else "0,0,0"
            highlighted_text.append(f'<span style="background-color:rgba({int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},255,1);color:rgb({text_color})">&nbsp{value}&nbsp</span>')
    highlighted_text = ''.join(highlighted_text)
    return highlighted_text

def display_tokens(tokens, activations, tokenizer):
    return display(HTML(tokens_and_activations_to_html(tokens, activations, tokenizer)))

def save_token_display(tokens, activations, tokenizer, path):
    html = tokens_and_activations_to_html(tokens, activations, tokenizer)
    imgkit.from_string(html, path)
    # print(f"Saved to {path}")
    return

In [91]:
def get_feature_indices(feature_index, dictionary_activations, tokenizer, token_amount, dataset, k=10, setting="max"):
    best_feature_activations = dictionary_activations[:, feature_index]
    # Sort the features by activation, get the indices
    if setting=="max":
        found_indices = torch.argsort(best_feature_activations, descending=True)[:k]
    elif setting=="uniform":
        # min_value = torch.min(best_feature_activations)
        min_value = torch.min(best_feature_activations)
        max_value = torch.max(best_feature_activations)

        # Define the number of bins
        num_bins = k

        # Calculate the bin boundaries as linear interpolation between min and max
        bin_boundaries = torch.linspace(min_value, max_value, num_bins + 1)

        # Assign each activation to its respective bin
        bins = torch.bucketize(best_feature_activations, bin_boundaries)

        # Initialize a list to store the sampled indices
        sampled_indices = []

        # Sample from each bin
        for bin_idx in torch.unique(bins):
            if(bin_idx==0): # Skip the first one. This is below the median
                continue
            # Get the indices corresponding to the current bin
            bin_indices = torch.nonzero(bins == bin_idx, as_tuple=False).squeeze(dim=1)
            
            # Randomly sample from the current bin
            sampled_indices.extend(np.random.choice(bin_indices, size=1, replace=False))

        # Convert the sampled indices to a PyTorch tensor & reverse order
        found_indices = torch.tensor(sampled_indices).long().flip(dims=[0])
    else: # random
        # get nonzero indices
        nonzero_indices = torch.nonzero(best_feature_activations)[:, 0]
        # shuffle
        shuffled_indices = nonzero_indices[torch.randperm(nonzero_indices.shape[0])]
        found_indices = shuffled_indices[:k]
    return found_indices
def get_feature_datapoints(found_indices, best_feature_activations, tokenizer, token_amount, dataset):
    num_datapoints = dataset.num_rows
    datapoint_indices =[np.unravel_index(i, (num_datapoints, token_amount)) for i in found_indices]
    all_activations = best_feature_activations.reshape(num_datapoints, token_amount).tolist()
    full_activations = []
    partial_activations = []
    text_list = []
    full_text = []
    token_list = []
    full_token_list = []
    for i, (md, s_ind) in enumerate(datapoint_indices):
        md = int(md)
        s_ind = int(s_ind)
        full_tok = torch.tensor(dataset[md]["input_ids"])
        full_text.append(tokenizer.decode(full_tok))
        tok = dataset[md]["input_ids"][:s_ind+1]
        full_activations.append(all_activations[md])
        partial_activations.append(all_activations[md][:s_ind+1])
        text = tokenizer.decode(tok)
        text_list.append(text)
        token_list.append(tok)
        full_token_list.append(full_tok)
    return text_list, full_text, token_list, full_token_list, partial_activations, full_activations

uniform_indices = get_feature_indices(feature, dictionary_activations, model.tokenizer, max_seq_length, dataset, k=10, setting="uniform")
text_list, full_text, token_list, full_token_list, partial_activations, full_activations = get_feature_datapoints(uniform_indices, dictionary_activations[:, feature], model.tokenizer, max_seq_length, dataset)
display_tokens(token_list, partial_activations, model.tokenizer)

In [95]:
_, _, _, full_token_list_ablated, _, full_activations_ablated = get_feature_datapoints(uniform_indices, logit_diffs, model.tokenizer, max_seq_length, dataset)
display_tokens(full_token_list_ablated, full_activations_ablated, model.tokenizer)

In [174]:
#def ablate_context_one_token_at_a_time(model, dataset, cache_name, autoencoder, feature, batch_size=32):
all_changed_activations = []
for token_ind, token_l in enumerate(token_list):
# for token_ind, token_l in enumerate(full_token_list):
    seq_size = len(token_l)
    original_activation = partial_activations[token_ind][-1]
    # Run through the model for each seq length
    if(seq_size==1):
        continue # Size 1 sequences don't have any context to ablate
    # changed_activations = torch.zeros(seq_size).cpu() + original_activation
    changed_activations = torch.zeros(seq_size).cpu() 
    for i in range(seq_size-1):
        # ablated_tokens = token_l[:i+1] + token_l[i+1:]
        ablated_tokens = token_l
        ablated_tokens = torch.tensor(ablated_tokens).unsqueeze(0)
        with torch.no_grad():
            _, cache = model.run_with_cache(ablated_tokens.to(device))
            neuron_activations = rearrange(cache[cache_name], "b s n -> (b s) n" )
            dictionary_activations = autoencoder.encode(neuron_activations)

            changed_activations[i] += dictionary_activations[-1,feature].item()
    changed_activations -= original_activation
    all_changed_activations.append(changed_activations.tolist())

In [175]:
display_tokens(token_list, all_changed_activations, model.tokenizer)

In [104]:
all_changed_activations

[tensor([-4.2333, -4.2366, -4.2338, -4.2505, -4.2160, -4.2546, -4.2546, -4.2081,
         -4.2334, -4.2249, -4.2471, -4.2193, -4.0742, -4.1679, -4.2150, -4.1018,
         -4.2725, -4.3249, -4.1339, -4.2034, -4.1800, -4.1670, -4.5888]),
 tensor([-3.9946, -3.9989, -3.9936, -4.0064, -3.9982, -4.0143, -3.9662, -4.0079,
         -4.0061, -3.9552, -3.9683, -3.9796, -3.9800, -3.9757, -3.9958, -3.9683,
         -4.0782, -3.8525, -3.9295, -3.8120, -4.0318, -4.6483, -4.6042]),
 tensor([-3.6651, -3.6740, -3.7358, -3.7358, -3.7280, -3.7123, -3.6980, -3.7146,
         -3.7279, -3.6798, -3.7178, -3.7230, -3.7002, -3.7002, -3.7344, -3.6766,
         -3.7060, -3.7234, -3.6957, -3.7650, -3.8333, -3.7180, -3.7201, -3.6695,
         -3.4719, -3.6381, -4.1545]),
 tensor([-2.7280, -2.7347, -2.7301, -2.7301, -2.7266, -2.7236, -2.5924, -2.5106,
         -2.7239, -2.8197, -2.7921, -2.7921, -3.0452, -2.8805, -2.7906]),
 tensor([-2.0220, -2.0265, -1.9411, -2.0300, -2.0120, -2.0301, -2.0255, -1.9884,
         -2

In [None]:
display_tokens(token_list, partial_activations, model.tokenizer)

In [37]:
ablate_feature_direction_display(full_text, autoencoder, model, layer, features=feature)