In [1]:
import numpy as np
from get_model import get_moe
from get_logits import unembed_matrix
import sentencepiece as spm

In [5]:
import zipfile

In [2]:
import numpy as np

In [11]:
def calculate_sparse_counts(expert_specific_codes, top_k_indices, M, threshold=0.1):
    """
    Calculates counts of elements in expert_specific_codes > threshold,
    aggregated according to top_k_indices, without expanding to a full sparse representation.

    Args:
        expert_specific_codes (np.ndarray): A 3D NumPy array of shape (N, K, D).
                                            N is batch size, K is number of active experts,
                                            D is feature dimension.
        top_k_indices (np.ndarray): A 2D NumPy array of shape (N, K).
                                    Contains the indices (0 to M-1) of the active experts
                                    for each item in the batch.
        M (int): The total number of possible experts (e.g., 8192).
                 This will be the size of the first dimension of the output counts array.
        threshold (float): The threshold to compare elements of expert_specific_codes against.

    Returns:
        np.ndarray: A 2D NumPy array of shape (M, D) containing the counts.
    """
    if not isinstance(expert_specific_codes, np.ndarray):
        raise TypeError("expert_specific_codes must be a NumPy array.")
    if not isinstance(top_k_indices, np.ndarray):
        raise TypeError("top_k_indices must be a NumPy array.")
    if expert_specific_codes.ndim != 3:
        raise ValueError("expert_specific_codes must be a 3D array.")
    if top_k_indices.ndim != 2:
        raise ValueError("top_k_indices must be a 2D array.")
    if expert_specific_codes.shape[0] != top_k_indices.shape[0] or \
       expert_specific_codes.shape[1] != top_k_indices.shape[1]:
        raise ValueError("Shapes of expert_specific_codes (N, K, D) and top_k_indices (N, K) are incompatible.")
    if not (np.issubdtype(top_k_indices.dtype, np.integer)):
        raise TypeError("top_k_indices must be an integer type array.")
    if top_k_indices.min() < 0 or top_k_indices.max() >= M:
        raise ValueError(f"Values in top_k_indices must be between 0 and {M-1}.")


    N, K, D = expert_specific_codes.shape

    # Initialize the output counts array with zeros.
    # Using int32 for counts, assuming counts won't exceed 2 billion per cell.
    # If N is extremely large, np.int64 might be safer.
    counts_array = np.zeros((M, D), dtype=np.int32)

    # 1. Create a boolean mask where expert_specific_codes > threshold.
    # Shape: (N, K, D)
    # This array will contain True where the condition is met, False otherwise.
    condition_met_mask = (expert_specific_codes > threshold)

    # 2. Ravel the top_k_indices to get a flat list of expert indices.
    # Shape: (N*K,)
    # These are the row indices for our counts_array.
    flat_expert_indices = top_k_indices.ravel()

    # 3. Reshape the boolean mask to align with the flat_expert_indices.
    # Shape: (N*K, D)
    # Each row corresponds to an expert activation, and columns are the D features.
    # Boolean values (True/False) will be treated as 1/0 in the addition.
    values_to_add = condition_met_mask.reshape(-1, D)

    # 4. Use np.add.at to perform the accumulation.
    # This function adds elements from `values_to_add` to `counts_array`
    # at the indices specified by `flat_expert_indices`.
    # It correctly handles cases where an expert_index appears multiple times
    # in `flat_expert_indices` by accumulating the sums.
    np.add.at(counts_array, flat_expert_indices, values_to_add)

    return counts_array

In [None]:
with np.load('data.npz') as data:
    top_level_latent_codes = data['top_level_latent_codes']
    expert_specific_codes = data['expert_specific_codes']
    top_k_indices = data['top_k_indices']
    top_k_values = data['top_k_values']
    token_ids = data['token_ids']
    token_text = data['token_text']

In [5]:
def find_subsequence_indices(array, subsequence):
    n = len(array)
    m = len(subsequence)
    indices = []

    for i in range(n - m + 1):
        if np.array_equal(array[i:i+m], subsequence):
            indices.append(i)
    return indices

In [7]:
# if you want to visualize a specific context, this might be helpful
find_subsequence_indices(token_ids,np.asarray([71822]))

[300799, 300917, 948547, 948620, 962805, 2214436]

In [20]:
# this gives activation counts of all sublatents that activate more strongly than .05
of_interest = calculate_sparse_counts(expert_specific_codes, top_k_indices, 16384, .05)

In [22]:
# this then selects the experts where there are sublatents that activated more than e.g. 5 times at the activation strength above
n = 5
idxs = np.argwhere(of_interest > n)
print(idxs)
exps, counts = np.unique(idxs[:, 0], return_counts=True)
print("Experts of Interest:")
exps_of_interest = idxs[np.isin(idxs[:, 0],exps[counts > n])]
print(exps_of_interest)

[[    2     3]
 [    8    13]
 [    9     7]
 ...
 [16379     1]
 [16379     9]
 [16379    10]]
Experts of Interest:
[[  412     0]
 [  412     1]
 [  412     7]
 [  412     8]
 [  412    10]
 [  412    11]
 [  412    14]
 [ 1347     4]
 [ 1347     5]
 [ 1347     7]
 [ 1347     9]
 [ 1347    11]
 [ 1347    14]
 [ 1646     0]
 [ 1646     2]
 [ 1646     5]
 [ 1646     7]
 [ 1646     8]
 [ 1646    12]
 [ 1825     5]
 [ 1825     8]
 [ 1825    12]
 [ 1825    13]
 [ 1825    14]
 [ 1825    15]
 [ 2202     1]
 [ 2202     4]
 [ 2202    10]
 [ 2202    12]
 [ 2202    13]
 [ 2202    15]
 [ 2923     0]
 [ 2923     1]
 [ 2923     8]
 [ 2923    11]
 [ 2923    13]
 [ 2923    14]
 [ 3181     1]
 [ 3181     2]
 [ 3181     3]
 [ 3181     7]
 [ 3181    11]
 [ 3181    14]
 [ 3181    15]
 [ 3237     1]
 [ 3237     3]
 [ 3237     7]
 [ 3237     8]
 [ 3237     9]
 [ 3237    13]
 [ 3237    15]
 [ 3565     0]
 [ 3565     8]
 [ 3565    10]
 [ 3565    11]
 [ 3565    12]
 [ 3565    15]
 [ 3902     8]
 [ 3902     9

In [14]:
model = get_moe("16k_16")



In [23]:
# note, sublatents are given index number top_level_features (e.g. 16384) + 16*sublatent_number
selected_features = exps_of_interest
selected_feats = list(16384 + selected_features[:, 1] + 16*selected_features[:, 0])
selected_feats.extend(set(selected_features[:,0]))
selected_feats.sort()

In [5]:
# or select features manually! Just make sure to use a python list
selected_feats = [16018,  1725,  6879,  3445, 12057,  9466, 14759,  9193,   150,
         9738, 11441, 12560,  4431, 14651, 13607,  6595, 13981, 13691,
         8603,  2009,  5417, 12679,  9021,  2702, 15333,  2684,  2212,
        12260, 10362, 15771,  9631,  9953]

In [16]:
from jax_sae_interface import pregenerate_sae_data

In [24]:
precomputed_data = (top_level_latent_codes, expert_specific_codes, top_k_indices, top_k_values, token_ids, token_text)
data = pregenerate_sae_data(model, precomputed_data, selected_feats, unembed_matrix)

Processing 2600960 tokens in batches of 1000
Top level latent codes shape: (2600960, 16384)
Expert specific codes shape: (2600960, 32, 16)
Top k indices shape: (2600960, 32)
Feature to column mapping created with 255 entries
Processing batch 0-1000 of 2600960 (batch 1/2601)
Batch data dimensions:
  - batch_top_level: (1000, 16384)
  - batch_expert_codes: (1000, 32, 16)
  - batch_top_k: (1000, 32)
Processing batch 1000-2000 of 2600960 (batch 2/2601)
Processing batch 2000-3000 of 2600960 (batch 3/2601)
Processing batch 3000-4000 of 2600960 (batch 4/2601)
Processing batch 4000-5000 of 2600960 (batch 5/2601)
Processing batch 5000-6000 of 2600960 (batch 6/2601)
Processing batch 6000-7000 of 2600960 (batch 7/2601)
Processing batch 7000-8000 of 2600960 (batch 8/2601)
Processing batch 8000-9000 of 2600960 (batch 9/2601)
Processing batch 9000-10000 of 2600960 (batch 10/2601)
Processing batch 10000-11000 of 2600960 (batch 11/2601)
Processing batch 11000-12000 of 2600960 (batch 12/2601)
Processin

In [48]:
data.feature_stats

{'max': [0.2189834862947464,
  0.05530748516321182,
  0.00016820360906422138,
  0.10225270688533783,
  3.562186248018406e-05,
  0.00013869069516658783,
  0.021382717415690422,
  0.0001321071176789701,
  5.516815893003013e-09,
  1.105045721594422e-09,
  0.0001916812325362116,
  1.927769499587839e-09,
  0.00017740002658683807,
  0.00014764221850782633,
  0.00016496529860887676,
  3.434777462452132e-10,
  1.550060733279679e-05],
 'frac_nonzero': [0.00018800750492125985,
  0.00010919045275590551,
  0.00010995939960629922,
  0.00011072834645669291,
  0.00010111651082677166,
  0.00010995939960629922,
  0.00011072834645669291,
  0.00011072834645669291,
  0.0,
  0.0,
  0.00010957492618110236,
  0.0,
  0.00010995939960629922,
  0.00010919045275590551,
  0.00010995939960629922,
  0.0,
  9.45804625984252e-05],
 'quantile_data': [[0.218983],
  [-0.000137,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.055307],
  [-0.000171,
   0.0,
   0

In [18]:
from sae_dashboard_adapter import save_feature_centric_vis

In [19]:
vocab_path = "" # path to the gemma tokenizer.model tokenizer, e.g. "/net/projects2/interp/gemma2/tokenizer.model"
vocab = spm.SentencePieceProcessor()
vocab.Load(vocab_path)
save_feature_centric_vis(data, "./16k_16.html", vocab, model_name="JAX SAE Example")



Original feature order in data: [1646, 3424, 7165, 13158, 14133, 14323, 42720, 42722, 42725, 42728, 42732, 71169, 71174, 71175, 71181, 71183, 131024, 131030, 131032, 131033, 131037, 226913, 226919, 226923, 226926, 226927, 242512, 242515, 242518, 242520, 242525, 245552, 245554, 245559, 245562, 245565]
Using sorted feature order: [1646, 3424, 7165, 13158, 14133, 14323, 42720, 42722, 42725, 42728, 42732, 71169, 71174, 71175, 71181, 71183, 131024, 131030, 131032, 131033, 131037, 226913, 226919, 226923, 226926, 226927, 242512, 242515, 242518, 242520, 242525, 245552, 245554, 245559, 245562, 245565]
SequencesConfig settings:
  stack_mode: stack-quantiles
  n_quantiles: 5
  top_acts_group_size: 10
  quantile_group_size: 3
  group_sizes: [10, 3, 3, 3, 3, 3]
Successfully created feature-centric layout
Building vocabulary dictionary from SentencePiece tokenizer (size: 256000)
Built vocabulary with 256000 entries
Original feature indices from jax_data: [1646, 3424, 7165, 13158, 14133, 14323, 42720

In [32]:
save_feature_centric_vis(data, "./specific_feature.html", vocab, model_name="JAX SAE Example")

Original feature order in data: [897, 8117, 1181, 4121, 1608, 3681, 2026, 4223, 3178, 5796, 7491, 7927, 588, 3900, 5576, 2884, 2174, 6928, 3855, 3371, 2673, 6266, 6543, 5310, 7434, 977, 4193, 3462, 5565, 2945, 2302, 1933]
Using sorted feature order: [588, 897, 977, 1181, 1608, 1933, 2026, 2174, 2302, 2673, 2884, 2945, 3178, 3371, 3462, 3681, 3855, 3900, 4121, 4193, 4223, 5310, 5565, 5576, 5796, 6266, 6543, 6928, 7434, 7491, 7927, 8117]
SequencesConfig settings:
  stack_mode: stack-quantiles
  n_quantiles: 5
  top_acts_group_size: 10
  quantile_group_size: 3
  group_sizes: [10, 3, 3, 3, 3, 3]
Successfully created feature-centric layout
Building vocabulary dictionary from SentencePiece tokenizer (size: 256000)
Built vocabulary with 256000 entries
Original feature indices from jax_data: [897, 8117, 1181, 4121, 1608, 3681, 2026, 4223, 3178, 5796, 7491, 7927, 588, 3900, 5576, 2884, 2174, 6928, 3855, 3371, 2673, 6266, 6543, 5310, 7434, 977, 4193, 3462, 5565, 2945, 2302, 1933]

=== DETAILED D