In [1]:
import torch
from utils import *
# Example usage:
from transformers import AutoModelForCausalLM, AutoTokenizer
# model_name = "gpt2"
model_name = "HuggingFaceTB/SmolLM-135M"

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.add_bos_token = True
batch_size = 128
max_length = 128


if(model_name == "gpt2"):
    target_layer = 'transformer.h.5'
    d_name = None
else: 
    target_layer = "model.layers.18"
    d_name = "cosmopedia-v2"

debug = True
if(debug):
    if(model_name == "gpt2"):
        dataset_name = "Elriggs/openwebtext-100k"
    else: 
        dataset_name = "HuggingFaceTB/smollm-corpus"
    # num_datapoints = 100_000
    num_datapoints = 5_000
    total_batches = num_datapoints // batch_size
    print(f"total amount of tokens in dataset: {num_datapoints * max_length / 1e6}M")
else:    
    if(model_name == "gpt2"):
        dataset_name = "prithivMLmods/OpenWeb888K"
        num_datapoints = None # 880_000
        total_batches = 888_000 // batch_size
    else: 
        dataset_name = "HuggingFaceTB/smollm-corpus"
        num_datapoints = 2_000_000
        total_batches = num_datapoints // batch_size
        print(f"total amount of tokens in dataset: {num_datapoints * max_length / 1e6}M")

data_generator = TokenizedDataset(dataset_name, tokenizer, d_name, batch_size=batch_size, max_length=max_length, total_batches=total_batches)

total amount of tokens in dataset: 0.64M


Resolving data files:   0%|          | 0/104 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/104 [00:00<?, ?it/s]

In [2]:
# Now we want to download all our SAEs
import json
from utils import DotDict, AutoEncoderTopK
from huggingface_hub import hf_hub_download
huggingface_name = "Elriggs/seq_concat_HuggingFaceTB_SmolLM-135M_model.layers.18"
name_prefix = f"sae_k=30_tokBias=True"
sae_name_style = name_prefix + ".pt"
cfg_name_style = name_prefix + "_cfg.json"

model_path = hf_hub_download(
    repo_id=huggingface_name,
    filename=sae_name_style
)

# Download config file
config_path = hf_hub_download(
    repo_id=huggingface_name,
    filename=cfg_name_style
)
# state_dict = torch.load(model_path)
# sae_cfg = torch.load(config_path)
cfg = DotDict(json.load(open(config_path)))

sae = AutoEncoderTopK.from_pretrained(model_path, k=cfg.k, device = None, embedding=True)
# set grad to false
sae.requires_grad_(False)
sae.to(device)

AutoEncoderTopK(
  (encoder): Linear(in_features=576, out_features=9216, bias=True)
  (decoder): Linear(in_features=9216, out_features=576, bias=False)
  (per_token_bias): EmbeddingBias()
)

In [3]:
from einops import rearrange, repeat
from tqdm import tqdm
print(f"Num Tokens: {num_datapoints * max_length/1e6}M")
total_batches = num_datapoints // batch_size + 1

# data_generator = redo_data(num_datapoints=num_datapoints, batch_size=batch_size)
all_tokens = []
# all_activations = [[] for _ in range(len(all_saes))]
all_activations = []

normalize = True
fvus = []
# Get original outputs
with torch.no_grad():
    # for batch_idx, batch in enumerate(tqdm(data_generator, total=total_batches)):
    for batch_idx in tqdm(range(total_batches)):
        batch = data_generator.next()
        with Trace(model, target_layer) as original_trace:
            _ = model(batch.to(device)).logits
            x = original_trace.output[0] if isinstance(original_trace.output, tuple) else original_trace.output
        all_biases = torch.zeros_like(x)
        if sae.per_token_bias:
            all_biases = all_biases + sae.per_token_bias(batch)
        all_tokens.append(batch)
        if normalize:
            x = (x - x.mean(dim=-1, keepdim=True)) / (x.std(dim=-1, keepdim=True) + 1e-8)
            
        features = sae.encode(x - all_biases)
        x_hat = sae.decode(features) + all_biases
        fvu = calculate_fvu(x, x_hat).item()
        # print(f"fvu: {fvu:.3f}")
        fvus.append(fvu)
        all_activations.append(features.cpu())
# all_activations = [torch.cat(activations, dim=0) for activations in all_activations]
all_activations = torch.cat(all_activations, dim=0)
all_tokens = torch.cat(all_tokens, dim=0)
all_fvus = torch.tensor(fvus)
print(f"FVU: {all_fvus.mean()}")

Num Tokens: 0.64M


100%|██████████| 40/40 [01:06<00:00,  1.67s/it]


FVU: 0.11453904211521149


In [4]:
# We can also plot the alive features percentage?
# for i in range(len(all_activations)):
N = 2_000
alive_features = (all_activations[:N].sum(dim=(0,1)) !=0).sum() / all_activations[:N].shape[-1]
alive_features
print(f"Percentage of alive features: {alive_features.item() * 100:.1f}%")

Percentage of alive features: 88.6%


In [5]:
all_activations[..., 1].flatten().topk(10)

torch.return_types.topk(
values=tensor([0.2818, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000]),
indices=tensor([613710,      4,      8,      9,      7,      3,      1,      5,      0,
             2]))

In [13]:
from IPython.display import display, HTML
import numpy as np
num_feature_datapoints = 10 # how many examples/expert
total_features_to_print = 20
total_num_features = sae.decoder.weight.shape[-1]
# features_to_print = [i for i in range(20)]
# Over peak noise SAE
# features_to_print = [ 482, 8481, 1114, 7747, 6449, 8901, 4134, 5078, 2745, 8392] # high mean difference
# features_to_print = [8649, 6759, 7747, 1114, 6814,  541, 3015, 8442, 2414, 7933] # high var over all differences
# features_to_print = [59, 60, 61, 63, 65, 67, 68, 69, 71, 72, 73, 76, 79, 80, 82, 83, 85, 87,88, 89] # ~0 variance

# Over re-located(?) SAE
# features_to_print = [4239, 8481, 8604, 9187, 8494, 4305, 3551, 5336, 7991, 1464] # high mean difference
features_to_print = [8649, 8442, 7747, 6814, 6759, 7933, 3596, 7656, 6382, 4305] # high var over all differences  
features_to_print = [59, 60, 61, 64, 65, 67, 68, 69, 71, 72, 73, 76, 79, 80, 82, 83, 85, 87,88, 90] # ~0 variance

features_for_this_key = features_to_print
total_displayed = 0
curr_feature = 0
for feature_idx, feature in enumerate(features_for_this_key):
# while total_displayed < total_features_to_print:
    # do random between 0 and total_num_features
    # feature = np.random.randint(0, total_num_features)
    # feature = curr_feature
    # curr_feature += 1
    feature_activations = all_activations[..., feature]
    # if the feature has < 3 activations, skip
    if((feature_activations.sum(dim=-1) !=0).sum() < 3):
        print(f"Feature {feature} has less than 3 activations, skipping")
        continue
    print("feature:", feature)

    d_idx, seq_idx = get_feature_indices(feature_activations, k=num_feature_datapoints, setting="uniform")
    # d_idx, seq_idx = get_feature_indices(feature, feature_activations, k=num_feature_datapoints, setting="max")
    text_list, full_text, token_list, full_token_list, partial_activations, full_activations = get_feature_datapoints(d_idx, seq_idx, feature_activations, all_tokens, tokenizer, append=cfg.tokens_to_combine)
    html = tokens_and_activations_to_html(token_list, partial_activations, tokenizer)
    display(HTML(html))
    total_displayed +=1

feature: 8649


feature: 8442


feature: 7747


feature: 6814


feature: 6759


feature: 7933


feature: 3596


feature: 7656


feature: 6382


feature: 4305
