In [1]:
# Load dictionary
# Need model definition, right?
import torch.nn as nn
from baukit import Trace
from einops import rearrange

# import partial
from functools import partial
# from src.autoencoder import UntiedSAE
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# dictionary_path = "models/pythia-70m-deduped_gpt_neox.layers.3_Ratio-4_l1-0.003_lr-0.001_2024_02_27_17_47_39_model.pt"
dictionary_path = "models/pythia-70m-deduped_gpt_neox.layers.3_Ratio-4_l1-0.003_lr-0.001_2024_02_27_19_24_10_model.pt"
layer = 3
activation_name = f"gpt_neox.layers.{layer}"
sae = torch.load(dictionary_path).to(device)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load data
from src.utils import get_dataloader
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "EleutherAI/pythia-70m-deduped"
dataset_name = "Elriggs/openwebtext-100k"
# dataset_name = "jbrinkma/pile-500k"
batch_size = 16
context_length = 256
# model and dataset
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
train_loader, test_loader = get_dataloader(dataset_name, tokenizer, batch_size, context_length)
# Freeze the model & dictionary
model.eval()
model.requires_grad_(False)
sae.requires_grad_(False)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


UntiedSAE()

In [3]:
def zero_k_lowest_nonzero(arr, k):
    abs_arr = arr.abs()
    # Set zeros to +inf to ignore them in topk (since we are now looking for lowest values)
    mask = torch.where(abs_arr > 0, abs_arr, torch.tensor(float('inf'), device=arr.device))
    # Get top k lowest values; since torch.topk only finds the largest, we use -mask
    _, indices = torch.topk(-mask, k, dim=1)
    # Create a range for the batch dimension
    batch_indices = torch.arange(arr.size(0), device=arr.device).unsqueeze(1)
    # Use advanced indexing to set the k-lowest nonzero elements to 0
    arr[batch_indices, indices] = 0
    return arr
    
def dict_ablation_fn_k(representation, name, other_value, k=0):
    # print(f"representation {representation}")
    if(isinstance(representation, tuple)):
        second_value = representation[1]
        internal_activation = representation[0]
    else:
        internal_activation = representation
    int_val = rearrange(internal_activation, "b seq d_model -> (b seq) d_model")
    features = sae.encode(int_val)
    features = zero_k_lowest_nonzero(features, k)
    reconstruction = sae.decode(features)
        
    batch, seq_len, hidden_size = internal_activation.shape
    reconstruction = rearrange(reconstruction, '(b s) h -> b s h', b=batch, s=seq_len)

    if(isinstance(representation, tuple)):
        return_value = (reconstruction, second_value)
    else:
        return_value = reconstruction

    return return_value
       
def compute_loss(inputs_ids, logits):
    return torch.nn.CrossEntropyLoss()(
        logits[:,:-1,:].reshape(-1, logits.shape[-1]),
        inputs_ids[:,1:].reshape(-1)
    ).item()

In [7]:

# Load in N batches
N = 2
total_k = 5
avg_CE = torch.zeros(total_k)

d_model, n_features = sae.W_d.shape
tokens_per_batch = batch_size * context_length
all_features = torch.zeros(tokens_per_batch*N, n_features)
with torch.no_grad():
    for ind_batch, batch in enumerate(train_loader):
        if ind_batch >= N:
            break
        input_ids = batch["input_ids"].to(device)
        for k_inc in range(total_k):
            # k_inc = 0
            hook_function = partial(dict_ablation_fn_k, k=k_inc)
            with Trace(model, activation_name, edit_output = hook_function) as ret:
                #Calculate CE
                logits = model(input_ids).logits
                CE = compute_loss(input_ids, logits)
                avg_CE[k_inc] += CE

                # #Calculate MSE
                # representation = ret.output
                # if(isinstance(representation, tuple)):
                #     representation = representation[0]
                # activation = rearrange(representation, "b seq d_model -> (b seq) d_model")
                # reconstruction = sae(activation)
                # # features = sae.encode(activation)
                # # all_features[ind_batch*tokens_per_batch:(ind_batch+1)*tokens_per_batch] = features
                # MSE = torch.nn.MSELoss()(reconstruction, activation).item()
                # # print k, mse, ce
                # print(f"K, MSE, CE: {k_inc}, {MSE}, {CE}")
                # avg_MSE[k_inc] += MSE
        # outputs = model(input_ids)
avg_CE /= N
print(avg_CE)

tensor([8.0784, 8.0807, 8.0873, 8.1022, 8.8858])


In [9]:
# Load in N batches
N = 2
total_k = 5
avg_MSE = torch.zeros(total_k)

d_model, n_features = sae.W_d.shape
tokens_per_batch = batch_size * context_length
all_features = torch.zeros(tokens_per_batch*N, n_features)
with torch.no_grad():
    for ind_batch, batch in enumerate(train_loader):
        if ind_batch >= N:
            break
        input_ids = batch["input_ids"].to(device)
        for k_inc in range(total_k):
            # k_inc = 0
            with Trace(model, activation_name) as ret:
                #Calculate CE
                _ = model(input_ids).logits
                #Calculate MSE
                representation = ret.output
                if(isinstance(representation, tuple)):
                    representation = representation[0]
                activation = rearrange(representation, "b seq d_model -> (b seq) d_model")
                features = sae.encode(activation)
                features = zero_k_lowest_nonzero(features, k_inc)
                reconstruction = sae.decode(features)

                # features = sae.encode(activation)
                # all_features[ind_batch*tokens_per_batch:(ind_batch+1)*tokens_per_batch] = features
                MSE = torch.nn.MSELoss()(reconstruction, activation).item()
                # print k, mse, ce
                print(f"K, MSE: {k_inc}, {MSE}")
                avg_MSE[k_inc] += MSE
        # outputs = model(input_ids)
avg_MSE /= N
print(avg_MSE)

K, MSE: 0, 0.15783685445785522
K, MSE: 1, 0.15818116068840027
K, MSE: 2, 0.15910932421684265
K, MSE: 3, 0.16127437353134155
K, MSE: 4, 0.16592943668365479
K, MSE: 0, 0.15939773619174957
K, MSE: 1, 0.1597251147031784
K, MSE: 2, 0.16059614717960358
K, MSE: 3, 0.16243727505207062
K, MSE: 4, 0.16649633646011353
tensor([0.1586, 0.1590, 0.1599, 0.1619, 0.1662])


In [11]:
for batch in train_loader:
    input_ids = batch["input_ids"].to(device)
    print(input_ids[:3])
    break

tensor([[  273, 13175,   432,   247,  1077,  2969,  5673,    15, 48627,    13,
           352, 10262,   368,   342,   253,  3745,   281, 10007,   253,  1655,
         10611,   275,   247,  1180,   273,  4722,  4088,    13,  1690,  4933,
          2289,   281,   247,  1180,   273,  2969,  5657,   323, 42477,   941,
         21453,    15, 32354,  6240,   281,   253,  1618,   273,  5667,    84,
            15,   187,   187,  2598,   627,   597,   403,    15,  1310,   368,
           452,   667,   309,   943,   823,   281,   436,  1618,   273,  5667,
            84,    13,   513,  1339,   479,   871,   275,   253,  5701,   390,
          3066,  4579,    15, 12590,   187,   187,     9,  8061,   281,  1110,
           665,  6518,   479,  1973,   436,  1618,  1690, 41474, 18039,  5969,
           285, 11819,   330,  1761,    80,    10,   187,   187, 20536,   187,
           187,    60,  8339,  9044, 36304,  1040, 16447,   928,  7373,   405,
          2146,    15,   681,  5032,     0, 18412,  

In [14]:
train_loader.dataset[0], train_loader.dataset[0]

({'input_ids': tensor([ 1335,  2589,  3187, 21269,   398,   281,  3187,  3943,  5871,    13,
            533,  2550,  2430,   247,  5927, 35958,  2408,   591, 42307,   390,
           2021,   285,  2550, 29211, 16947,   281,  2060,  2758,    15,   733,
            457,    84,  1335, 10826,   281,  2186,  3187, 21269,   398,   285,
            299,  1513,   782,  8064,   323,  4712,   326,  5649,   253,  3114,
             13,   824,   347, 26864,  2579,  6493,    13,   390,   281,  7164,
           2583,   323,   247,  2173,  3943,  4096,   824,   347,   253,  7471,
            273,   247, 19150,   390,  6500,    15,  1244,  5085,   476,  1335,
           6558,   773, 17703,  8553,   668,   323,  1810,  8349,   285,  5870,
            604,   597,  5730,   281,  1978,  2583,   327, 18831,   323,  7830,
            273, 12222,    15,   187,   187,  1552,   778,   320,   247,  2201,
           8230,   281, 10824,   285,  1014, 30171,  2439,   253,  2586,   347,
            597,  2968,   3

In [10]:
for batch in train_loader:
    input_ids = batch["input_ids"].to(device)
    print(input_ids[:3])
    break

tensor([[  285,   247,  2495,  2803,   323,  8401,    15,  8969, 23787,  2296,
          2802,   423,  1417,  3038,  5262, 32412,  3390,    13, 17120,   390,
          1469,   281,   253, 17409,   812,   320,   247,  8138,   273,   673,
           323,   598,   281,   581,    14, 25512,   394,   273,   253,  3072,
         15385,   187,   187, 20576,    13,   597,   943,   513,  1029,  7133,
            13, 12217,  7467,  5763,   824,   347, 14174,   390,  2801,  3733,
            15,   187,   187,  1717,    15,  1889,  2269,  2182,  3304,    15,
          5595,  2175,   921,   326,    13,  2429,   342, 30347, 42842,    13,
         30347,   275,  3626, 12620,   310,  2330,   342,  3687,  9510,    84,
           281, 12315,    13,  6137, 10397,   285,   271,  2559, 12177,   273,
         26530,   342,   253,   789,  8349,    15,   187,   187,  1229, 38733,
           253,  3430,  5763,   323,   634,  2363,    15,  4325,  1475,   776,
          1884,    84,    13,   359,  7168,   327,  

In [67]:
# Get the k-lowest nonzero features per datapoint
# Feautures is shape (batch, features)
features.shape
nz = features != 0
# index by nz, keep as same shape
features.index_select(1, nz)

RuntimeError: Index is supposed to be an empty tensor or a vector

In [73]:
import torch

def zero_topk_nonzero(arr, k):
    # Get absolute values to consider both positive and negative values
    abs_arr = arr.abs()
    # Use torch.where to create a mask of nonzero elements, setting zeros to -inf to ignore them in topk
    mask = torch.where(abs_arr > 0, abs_arr, torch.tensor(float('-inf'), device=arr.device))
    # Get top k values and their indices along the features dimension
    _, indices = torch.topk(mask, k, dim=1, largest=False)
    # Set the top k nonzero elements to 0
    for i in range(arr.size(0)):  # Loop through batch dimension
        arr[i, indices[i]] = 0
    return arr

# Example usage
batch_size = 3
num_features = 5
k = 1
arr = torch.tensor([[1, 2, 0, 4, 5], [0, 2, 3, 0, 0], [1, 0, 0, 1, 0]], dtype=torch.float32)
result = zero_topk_nonzero(arr, k)
result

tensor([[0., 0., 0., 4., 5.],
        [0., 2., 3., 0., 0.],
        [1., 0., 0., 1., 0.]])

In [81]:
def zero_k_lowest_nonzero(arr, k):
    # Use torch.where to ignore zeros by setting them to +inf
    mask = torch.where(abs_arr > 0, abs_arr, torch.tensor(float('inf'), device=arr.device))
    # Get k smallest nonzero values and their indices along the features dimension
    _, indices = torch.topk(mask, k, largest=False, dim=1)
    # Set the k smallest nonzero elements to 0
    for i in range(arr.size(0)):
        arr[i, indices[i]] = 0
    return arr

# Example usage with the same array and parameters
k = 1
arr = torch.tensor([[1, 2, 0, 4, 5], [0, 2, 3, 0, 0], [5, 0, 0, 1, 0]], dtype=torch.float32)
result = zero_k_lowest_nonzero(arr, k)
result

tensor([[0., 2., 0., 4., 5.],
        [0., 0., 3., 0., 0.],
        [5., 0., 0., 0., 0.]])

In [104]:
import torch

def zero_k_lowest_nonzero(arr, k):
    abs_arr = arr.abs()
    # Set zeros to +inf to ignore them in topk (since we are now looking for lowest values)
    mask = torch.where(abs_arr > 0, abs_arr, torch.tensor(float('inf'), device=arr.device))
    # Get top k lowest values; since torch.topk only finds the largest, we use -mask
    _, indices = torch.topk(-mask, k, dim=1)
    # Create a range for the batch dimension
    batch_indices = torch.arange(arr.size(0), device=arr.device).unsqueeze(1)
    # Use advanced indexing to set the k-lowest nonzero elements to 0
    arr[batch_indices, indices] = 0
    return arr

# # Example usage
# batch_size = 3
# num_features = 5
k = 1
# arr = torch.tensor([[1, 2, 0, 4, 5], [0, 2, 3, 0, 0], [5, 1, 0, 3, 0]], dtype=torch.float32)
result = zero_k_lowest_nonzero(features.clone(), k)
result

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')

In [105]:
features[1].topk(12), result[1].topk(12)

(torch.return_types.topk(
 values=tensor([4.1939, 2.7761, 2.1174, 2.0487, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000], device='cuda:0'),
 indices=tensor([ 298,  721, 1571,   78,    7,    6,    4,    5,    3,    2,    0,    1],
        device='cuda:0')),
 torch.return_types.topk(
 values=tensor([4.1939, 2.7761, 2.1174, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000], device='cuda:0'),
 indices=tensor([ 298,  721, 1571,    8,    0,    7,    5,    4,    6,    3,    1,    2],
        device='cuda:0')))

In [None]:
all_features

In [41]:
features.shape, all_features.shape

(torch.Size([4096, 2048]), torch.Size([81920, 2048]))

In [30]:
representation
activation = rearrange(representation, "b seq d_model -> (b seq) d_model")
# activation = rearrange(representation[:, 1:, :], "b seq d_model -> (b seq) d_model")

# run through SAE
features = sae.encode(activation)
# get the k-lowest features per datapoint

In [None]:
num_pos = 10
scale = 1
features_by_pos = [sae.encode(representation[:, (i*scale):(i*scale)+1, :].squeeze()) for i in range(num_pos)]
# Plot each as a hist in a new plot
import matplotlib.pyplot as plt
import numpy as np
for i in range(num_pos):
    # plot a new histogram
    plt.figure()
    zero_feature = features_by_pos[i].count_nonzero(-1)
    plt.hist(zero_feature.cpu().numpy(), bins=20, alpha=0.5, label=f"pos {i*scale}")
    # print(f"topk features num {zero_feature.topk(10)}")
plt.title("L0 at different token positions")
plt.legend()
plt.show()

In [81]:
zero_feature, features_by_pos[i].shape

(tensor([ 9,  9,  9,  7,  8, 10,  8,  8,  8,  9,  8,  9,  9, 10,  8, 10],
        device='cuda:0'),
 torch.Size([16, 2048]))

In [None]:
# get the k-lowest features per datapoint
k = 1
# shape is (batch, features)
d1 = features.count_nonzero(-1).cpu().numpy()
d1.sort()
# plot d1 as sorted line graph
import matplotlib.pyplot as plt
plt.hist(d1, bins=100)
plt.show()

In [71]:
d1[-30:]

array([106, 107, 111, 114, 115, 116, 117, 117, 120, 122, 124, 127, 127,
       130, 131, 132, 134, 134, 135, 136, 136, 137, 137, 137, 137, 138,
       140, 141, 204, 234])

In [57]:
d1.sort(), d1.sorted()

AttributeError: 'numpy.ndarray' object has no attribute 'sorted'

In [61]:
torch.norm(all_features, 0, dim=-1).median()

tensor(8.)

In [None]:
with Trace(model, activation_name, edit_output=dict_ablation_fn) as ret:
    outputs = model(input_ids)
    logits_dict_reconstruction = outputs[0]