In [1]:
import torch
from utils import *
# Example usage:
from transformers import AutoModelForCausalLM, AutoTokenizer
# model_name = "gpt2"
model_name = "HuggingFaceTB/SmolLM-135M"

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.add_bos_token = True
batch_size = 256
max_length = 64
num_datapoints = 5_000


if(model_name == "gpt2"):
    target_layer = 'transformer.h.5'
    d_name = None
else: 
    target_layer = "model.layers.18"
    d_name = "cosmopedia-v2"

debug = True
if(debug):
    if(model_name == "gpt2"):
        dataset_name = "Elriggs/openwebtext-100k"
    else: 
        dataset_name = "HuggingFaceTB/smollm-corpus"
    # num_datapoints = 100_000
    num_datapoints = 5_000
    total_batches = num_datapoints // batch_size
    print(f"total amount of tokens in dataset: {num_datapoints * max_length / 1e6}M")
else:    
    if(model_name == "gpt2"):
        dataset_name = "prithivMLmods/OpenWeb888K"
        num_datapoints = None # 880_000
        total_batches = 888_000 // batch_size
    else: 
        dataset_name = "HuggingFaceTB/smollm-corpus"
        num_datapoints = 2_000_000
        total_batches = num_datapoints // batch_size
        print(f"total amount of tokens in dataset: {num_datapoints * max_length / 1e6}M")

data_generator = TokenizedDataset(dataset_name, tokenizer, d_name, batch_size=batch_size, max_length=max_length, total_batches=total_batches)

total amount of tokens in dataset: 0.32M


Resolving data files:   0%|          | 0/104 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/104 [00:00<?, ?it/s]

In [2]:
# Now we want to download all our SAEs
import json
from utils import DotDict, AutoEncoderTopK
from huggingface_hub import hf_hub_download
huggingface_name = "Elriggs/seq_concat_HuggingFaceTB_SmolLM-135M_model.layers.18"
name_prefix = f"sae_k=30_tokBias=True"
sae_name_style = name_prefix + ".pt"
cfg_name_style = name_prefix + "_cfg.json"

model_path = hf_hub_download(
    repo_id=huggingface_name,
    filename=sae_name_style
)

# Download config file
config_path = hf_hub_download(
    repo_id=huggingface_name,
    filename=cfg_name_style
)
# state_dict = torch.load(model_path)
# sae_cfg = torch.load(config_path)
cfg = DotDict(json.load(open(config_path)))

sae = AutoEncoderTopK.from_pretrained(model_path, k=cfg.k, device = None, embedding=True)
# set grad to false
sae.requires_grad_(False)
sae.to(device)

AutoEncoderTopK(
  (encoder): Linear(in_features=576, out_features=9216, bias=True)
  (decoder): Linear(in_features=9216, out_features=576, bias=False)
  (per_token_bias): EmbeddingBias()
)

In [3]:
from einops import rearrange, repeat
from tqdm import tqdm
print(f"Num Tokens: {num_datapoints * max_length/1e6}M")
total_batches = num_datapoints // batch_size + 1

# data_generator = redo_data(num_datapoints=num_datapoints, batch_size=batch_size)
all_tokens = []
# all_activations = [[] for _ in range(len(all_saes))]
all_activations = []

normalize = True
fvus = []
# Get original outputs
with torch.no_grad():
    # for batch_idx, batch in enumerate(tqdm(data_generator, total=total_batches)):
    for batch_idx in tqdm(range(total_batches)):
        batch = data_generator.next()
        with Trace(model, target_layer) as original_trace:
            _ = model(batch.to(device)).logits
            x = original_trace.output[0] if isinstance(original_trace.output, tuple) else original_trace.output
        all_biases = torch.zeros_like(x)
        if sae.per_token_bias:
            all_biases = all_biases + sae.per_token_bias(batch)
        all_tokens.append(batch)
        if normalize:
            x = (x - x.mean(dim=-1, keepdim=True)) / (x.std(dim=-1, keepdim=True) + 1e-8)
            
        features = sae.encode(x - all_biases)
        x_hat = sae.decode(features) + all_biases
        fvus.append(calculate_fvu(x, x_hat).item())
        # repeat features tokens_to_combine times in order AABBCC
        # features = repeat(features, 'b new_seq d -> b (new_seq tokens_to_combine) d', tokens_to_combine=seqC)

        all_activations.append(features.cpu())
        # for sae_idx, sae in enumera0te(all_saes):
            # seqC = all_cfgs[sae_idx].tokens_to_combine
            
            # x_input = rearrange(x, 
            #     'b (new_seq tokens_to_combine) d -> b new_seq (tokens_to_combine d)',
            #     tokens_to_combine=seqC
            #     )
            # reshaped_token_bias = einops.rearrange(
            #     all_biases,
            #     'b (new_seq tokens_to_combine) d -> b new_seq (tokens_to_combine d)',
            #     tokens_to_combine=cfg.tokens_to_combine,
            # )
            

# all_activations = [torch.cat(activations, dim=0) for activations in all_activations]
all_activations = torch.cat(all_activations, dim=0)
all_tokens = torch.cat(all_tokens, dim=0)
all_fvus = torch.tensor(fvus)
print(f"FVU: {all_fvus.mean()}")

Num Tokens: 0.32M


100%|██████████| 20/20 [00:42<00:00,  2.12s/it]


FVU: 3.12223219871521


In [4]:
# We can also plot the alive features percentage?
# for i in range(len(all_activations)):
N = 2_000
alive_features = (all_activations[:N].sum(dim=(0,1)) !=0).sum() / all_activations[:N].shape[-1]
alive_features
print(f"Percentage of alive features: {alive_features.item() * 100:.1f}%")

Percentage of alive features: 90.8%


In [5]:
import numpy as np
from IPython.display import display, HTML
from einops import rearrange

def make_colorbar(min_value, max_value, white = 255, red_blue_ness = 250, positive_threshold = 0.01, negative_threshold = 0.01):
    # Add color bar
    colorbar = ""
    num_colors = 4
    if(min_value < -negative_threshold):
        for i in range(num_colors, 0, -1):
            ratio = i / (num_colors)
            value = round((min_value*ratio),1)
            text_color = "255,255,255" if ratio > 0.5 else "0,0,0"
            colorbar += f'<span style="background-color:rgba(255, {int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},1); color:rgb({text_color})">&nbsp{value}&nbsp</span>'
    # Do zero
    colorbar += f'<span style="background-color:rgba({white},{white},{white},1);color:rgb(0,0,0)">&nbsp0.0&nbsp</span>'
    # Do positive
    if(max_value > positive_threshold):
        for i in range(1, num_colors+1):
            ratio = i / (num_colors)
            value = round((max_value*ratio),1)
            text_color = "255,255,255" if ratio > 0.5 else "0,0,0"
            colorbar += f'<span style="background-color:rgba({int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},255,1);color:rgb({text_color})">&nbsp{value}&nbsp</span>'
    return colorbar

def value_to_color(activation, max_value, min_value, white = 255, red_blue_ness = 250, positive_threshold = 0.01, negative_threshold = 0.01):
    if activation > positive_threshold:
        ratio = activation/max_value
        text_color = "0,0,0" if ratio <= 0.5 else "255,255,255"  
        background_color = f'rgba({int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},255,1)'
    elif activation < -negative_threshold:
        ratio = activation/min_value
        text_color = "0,0,0" if ratio <= 0.5 else "255,255,255"  
        background_color = f'rgba(255, {int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},1)'
    else:
        text_color = "0,0,0"
        background_color = f'rgba({white},{white},{white},1)'
    return text_color, background_color

def convert_token_array_to_list(array):
    if isinstance(array, torch.Tensor):
        if array.dim() == 1:
            array = [array.tolist()]
        elif array.dim()==2:
            array = array.tolist()
        else: 
            raise NotImplementedError("tokens must be 1 or 2 dimensional")
    elif isinstance(array, list):
        # ensure it's a list of lists
        if isinstance(array[0], int):
            array = [array]
    return array

def tokens_and_activations_to_html(toks, activations, tokenizer, logit_diffs=None, model_type="causal", text_above_each_act=None):
    # text_spacing = "0.07em"
    text_spacing = "0.00em"
    toks = convert_token_array_to_list(toks)
    activations = convert_token_array_to_list(activations)
    # toks = [[tokenizer.decode(t).replace('Ġ', '&nbsp').replace('\n', '↵') for t in tok] for tok in toks]
    toks = [[tokenizer.decode(t).replace('Ġ', '&nbsp').replace('\n', '\\n') for t in tok] for tok in toks]
    highlighted_text = []
    # Make background black
    # highlighted_text.append('<body style="background-color:black; color: white;">')
    highlighted_text.append("""
<body style="background-color: black; color: white;">
""")
    max_value = max([max(activ) for activ in activations])
    min_value = min([min(activ) for activ in activations])
    if(logit_diffs is not None and model_type != "reward_model"):
        logit_max_value = max([max(activ) for activ in logit_diffs])
        logit_min_value = min([min(activ) for activ in logit_diffs])

    # Add color bar
    highlighted_text.append("Token Activations: " + make_colorbar(min_value, max_value))
    if(logit_diffs is not None and model_type != "reward_model"):
        highlighted_text.append('<div style="margin-top: 0.1em;"></div>')
        highlighted_text.append("Logit Diff: " + make_colorbar(logit_min_value, logit_max_value))
    
    highlighted_text.append('<div style="margin-top: 0.5em;"></div>')
    for seq_ind, (act, tok) in enumerate(zip(activations, toks)):
        if(text_above_each_act is not None):
            highlighted_text.append(f'<span>{text_above_each_act[seq_ind]}</span>')
        for act_ind, (a, t) in enumerate(zip(act, tok)):
            if(logit_diffs is not None and model_type != "reward_model"):
                highlighted_text.append('<div style="display: inline-block;">')
            text_color, background_color = value_to_color(a, max_value, min_value)
            highlighted_text.append(f'<span style="background-color:{background_color};margin-right: {text_spacing}; color:rgb({text_color})">{t.replace(" ", "&nbsp")}</span>')
            if(logit_diffs is not None and model_type != "reward_model"):
                logit_diffs_act = logit_diffs[seq_ind][act_ind]
                _, logit_background_color = value_to_color(logit_diffs_act, logit_max_value, logit_min_value)
                highlighted_text.append(f'<div style="display: block; margin-right: {text_spacing}; height: 10px; background-color:{logit_background_color}; text-align: center;"></div></div>')
        if(logit_diffs is not None and model_type=="reward_model"):
            reward_change = logit_diffs[seq_ind].item()
            text_color, background_color = value_to_color(reward_change, 10, -10)
            highlighted_text.append(f'<br><span>Reward: </span><span style="background-color:{background_color};margin-right: {text_spacing}; color:rgb({text_color})">{reward_change:.2f}</span>')
        highlighted_text.append('<div style="margin-top: 0.2em;"></div>')
        # highlighted_text.append('<br><br>')
    # highlighted_text.append('</body>')
    highlighted_text = ''.join(highlighted_text)
    return highlighted_text
def save_token_display(tokens, activations, tokenizer, path, save=True, logit_diffs=None, show=False, model_type="causal"):
    html = tokens_and_activations_to_html(tokens, activations, tokenizer, logit_diffs, model_type=model_type)
    # if(save):
    #     imgkit.from_string(html, path)
    # if(show):
    return display(HTML(html))

def get_feature_indices(feature_activations, k=10, setting="max"):
    # Sort the features by activation, get the indices
    batch_size, seq_len = feature_activations.shape
    feature_activations = rearrange(feature_activations, 'b s -> (b s)')
    if setting=="max":
        found_indices = torch.argsort(feature_activations, descending=True)[:k]
    elif setting=="uniform":
        # min_value = torch.min(feature_activations)
        min_value = torch.min(feature_activations)
        max_value = torch.max(feature_activations)

        # Define the number of bins
        num_bins = k

        # Calculate the bin boundaries as linear interpolation between min and max
        bin_boundaries = torch.linspace(min_value, max_value, num_bins + 1)

        # Assign each activation to its respective bin
        bins = torch.bucketize(feature_activations, bin_boundaries)

        # Initialize a list to store the sampled indices
        sampled_indices = []

        # Sample from each bin
        for bin_idx in torch.unique(bins):
            if(bin_idx==0): # Skip the first one. This is below the median
                continue
            # Get the indices corresponding to the current bin
            bin_indices = torch.nonzero(bins == bin_idx, as_tuple=False).squeeze(dim=1)
            
            # Randomly sample from the current bin
            sampled_indices.extend(np.random.choice(bin_indices, size=1, replace=False))

        # Convert the sampled indices to a PyTorch tensor & reverse order
        found_indices = torch.tensor(sampled_indices).long().flip(dims=[0])
    else: # random
        # get nonzero indices
        nonzero_indices = torch.nonzero(feature_activations)[:, 0]
        # shuffle
        shuffled_indices = nonzero_indices[torch.randperm(nonzero_indices.shape[0])]
        found_indices = shuffled_indices[:k]
    d_indices = found_indices // seq_len
    s_indices = found_indices % seq_len
    return d_indices, s_indices

def get_feature_datapoints(d_idx, seq_pos_idx, all_activations, all_tokens, tokenizer, append=0):
    full_activations = []
    partial_activations = []
    text_list = []
    full_text = []
    token_list = []
    full_token_list = []
    for md, s_ind in zip(d_idx, seq_pos_idx):
        md = int(md)
        s_ind = int(s_ind)
        # full_tok = torch.tensor(dataset[md]["input_ids"])
        
        full_tok = all_tokens[md]
        # [tokenizer.decode(t) for t in tokens[0]]

        full_text.append(tokenizer.decode(full_tok))
        # we want to add append more tokens, but only 
        tok = full_tok[:s_ind+1+append]
        # tok = dataset[md]["input_ids"][:s_ind+1]
        full_activations.append(all_activations[md].tolist())
        partial_activations.append(all_activations[md][:s_ind+1+append].tolist())
        text = tokenizer.decode(tok)
        text_list.append(text)
        token_list.append(tok)
        full_token_list.append(full_tok)
    return text_list, full_text, token_list, full_token_list, partial_activations, full_activations


In [6]:
from IPython.display import display, HTML
num_feature_datapoints = 10 # how many examples/expert
total_features_to_print = 20
total_num_features = sae.decoder.weight.shape[-1]
features_to_print = [i for i in range(20)]
features_for_this_key = features_to_print
total_displayed = 0
# for feature_idx, feature in enumerate(features_for_this_key):
while total_displayed < total_features_to_print:
    # do random between 0 and total_num_features
    feature = np.random.randint(0, total_num_features)
    feature_activations = all_activations[..., feature]
    # if the feature has < 3 activations, skip
    if((feature_activations.sum(dim=-1) !=0).sum() < 3):
        print(f"Feature {feature} has less than 3 activations, skipping")
        continue
    print("feature:", feature)

    d_idx, seq_idx = get_feature_indices(feature_activations, k=num_feature_datapoints, setting="uniform")
    # d_idx, seq_idx = get_feature_indices(feature, feature_activations, k=num_feature_datapoints, setting="max")
    text_list, full_text, token_list, full_token_list, partial_activations, full_activations = get_feature_datapoints(d_idx, seq_idx, feature_activations, all_tokens, tokenizer, append=cfg.tokens_to_combine)
    html = tokens_and_activations_to_html(token_list, partial_activations, tokenizer)
    display(HTML(html))
    total_displayed +=1

feature: 5687


  bins = torch.bucketize(feature_activations, bin_boundaries)


feature: 5816


feature: 4902


feature: 8334


feature: 8467


feature: 3739


feature: 8422


feature: 833


feature: 2766


feature: 7805


feature: 2560


Feature 5488 has less than 3 activations, skipping
Feature 4979 has less than 3 activations, skipping
feature: 8807


feature: 7231


feature: 8037


feature: 1928


feature: 6032


feature: 240


feature: 2439


feature: 5642


Feature 6798 has less than 3 activations, skipping
feature: 2735
