Baseline: Yes L1, Yes Orthog

In [None]:
from moe_eqx import MixtureOfExperts_v2, train_step, mask_codes
import jax
import jax.numpy as jnp
import optax
import equinox as eqx
unembeds_dir = "" # where do you have the downloaded or generated vocab and unembeddings saved?

In [None]:
def data_generator(data, batch_size):
    num_samples = data.shape[0]
    key = jax.random.PRNGKey(0)

    while True:
        key, subkey = jax.random.split(key)
        indices = jax.random.permutation(subkey, num_samples)
        
        for start in range(0, num_samples, batch_size):
            end = min(start + batch_size, num_samples)
            batch_indices = indices[start:end]
            batch = data[batch_indices]
            
            # # Normalize the batch
            batch = batch - jnp.mean(batch, axis=1, keepdims=True)
            norms = jnp.linalg.norm(batch, axis=1, keepdims=True)
            batch = batch / (norms + 1e-8)  # Add small epsilon to avoid division by zero
            
            yield batch

def create_data_loader(data, batch_size):
    return data_generator(data, batch_size)

In [None]:
g = jnp.load(f'{unembeds_dir}/clean_gemma_embeddings.npy')
g = g * jnp.sqrt(g.shape[0] / g.shape[1]) # set the norms to be close to 1


# Set up the data loader
# batch_size = jax.local_device_count() * 4096
batch_size = 8192
train_loader = create_data_loader(g, batch_size)
example_batch = next(train_loader)
input_dim = example_batch[0].shape[0]

In [None]:
example_batch.shape

In [None]:
input_dim = example_batch[0].shape[0]
subspace_dim = 5
num_experts = 2**11
atoms_per_subspace = int(2**5)
k = 5

key = jax.random.PRNGKey(0)

model = MixtureOfExperts_v2(
    input_dim=input_dim,
    subspace_dim=subspace_dim,
    atoms_per_subspace=atoms_per_subspace,
    num_experts=num_experts,
    k=k,
    key=key
)

In [None]:
l1_penalty = 2.5e-3
ortho_penalty = 2.5e-2

In [None]:
import time

num_steps = 5000
warmup_steps = 250

def cosine_decay_schedule(init_value, peak_value, warmup_steps, decay_steps, alpha=0.0):
    warmup_fn = optax.linear_schedule(
        init_value=init_value,
        end_value=peak_value,
        transition_steps=warmup_steps
    )
    cosine_decay_fn = optax.cosine_decay_schedule(
        init_value=peak_value,
        decay_steps=decay_steps,
        alpha=alpha
    )
    return optax.join_schedules(
        schedules=[warmup_fn, cosine_decay_fn],
        boundaries=[warmup_steps]
    )


learning_rate_fn = cosine_decay_schedule(
    init_value=batch_size * 1e-15,
    peak_value=batch_size * 1e-6,
    warmup_steps=warmup_steps,
    decay_steps=num_steps - warmup_steps,
    alpha=0.1  # Final learning rate will be 10% of peak value
)

optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adam(learning_rate_fn, b1=0.9, b2=0.999)
)

opt_state = optimizer.init(eqx.filter(model, eqx.is_array))

start_time = time.time()
for step in range(num_steps):
    batch = next(train_loader)
    model, opt_state, loss, aux_out = train_step(model, batch, opt_state, l1_penalty, ortho_penalty, optimizer)
    if step % 50 == 0:
        print(f"Step {step}, Loss: {loss:.4f}")
end_time = time.time()
print(f"Training took {end_time - start_time:.2f} seconds.")

In [None]:
import json

# Load unembeddings
g = jnp.load(f'{unembeds_dir}/clean_gemma_embeddings.npy')
g = g * jnp.sqrt(g.shape[0] / g.shape[1])

# Load vocabulary
with open(f'{unembeds_dir}/clean_gemma_vocab_dict.json', 'r') as fout:
    vocab_dict = json.load(fout)

vocab_list = [None] * (max(vocab_dict.values()) + 1)
for word, index in vocab_dict.items():
    vocab_list[index] = word

In [None]:
x = g[vocab_dict[word]][None, :]
x.shape

In [None]:
def run_autoencoder(x):
    if x.ndim == 1:
        x = x[None, :]  # Add batch dimension

    # Wrap model.encode in vmap for consistency
    top_level_latent_codes, expert_specific_codes, top_k_indices, top_k_values = jax.vmap(model.encode)(x)
    masked_top_level_latent_codes, masked_expert_specific_codes = mask_codes(
        top_level_latent_codes, expert_specific_codes, top_k_indices, top_k_values
    )
    x_hat, _, _ = jax.vmap(model.decode)(masked_expert_specific_codes, top_k_indices, top_k_values)
    return x_hat

word=' dog'
# the encoder expects batches so we need to add a batch dimension

x = g[vocab_dict[word]]
dog_hat = run_autoencoder(x)

print(jnp.square(dog_hat - g[vocab_dict[word]]).sum())
print(jnp.square(g[vocab_dict[word]]).sum())

In [None]:
import numpy as np
import tqdm

class ActivationCache:
    def __init__(self, vocab_size, num_experts, atoms_per_subspace, max_examples_per_latent=10):
        self.vocab_size = vocab_size
        self.num_experts = num_experts
        self.atoms_per_subspace = atoms_per_subspace
        self.max_examples_per_latent = max_examples_per_latent

        self.top_activations = np.full((num_experts, max_examples_per_latent), -np.inf, dtype=np.float32)
        self.top_words = np.full((num_experts, max_examples_per_latent), -1, dtype=np.int32)

        self.low_activations = np.full((num_experts, atoms_per_subspace, max_examples_per_latent), -np.inf, dtype=np.float32)
        self.low_words = np.full((num_experts, atoms_per_subspace, max_examples_per_latent), -1, dtype=np.int32)

    def update(self, batch_words, top_latents, expert_latents):
        batch_size = top_latents.shape[0]

        for i in range(batch_size):
            word_idx = int(batch_words[i])
            top_latent = int(np.argmax(top_latents[i]))
            top_activation = float(top_latents[i, top_latent])

            # Top-level cache update
            current_acts = self.top_activations[top_latent]
            min_idx = np.argmin(current_acts)
            if top_activation > current_acts[min_idx]:
                self.top_activations[top_latent, min_idx] = top_activation
                self.top_words[top_latent, min_idx] = word_idx

            # Low-level cache update
            for expert in range(expert_latents.shape[1]):
                low_latent = int(np.argmax(expert_latents[i, expert]))
                low_activation = float(expert_latents[i, expert, low_latent])

                current_low_acts = self.low_activations[top_latent, low_latent]
                min_low_idx = np.argmin(current_low_acts)
                if low_activation > current_low_acts[min_low_idx]:
                    self.low_activations[top_latent, low_latent, min_low_idx] = low_activation
                    self.low_words[top_latent, low_latent, min_low_idx] = word_idx

def build_activation_cache(dataset, model, batch_size, vocab_embeddings, vocab_indices, num_experts, atoms_per_subspace, top_k, max_examples_per_latent=5):
    activation_cache = ActivationCache(
        vocab_size=len(vocab_list),
        num_experts=num_experts,
        atoms_per_subspace=atoms_per_subspace,
        max_examples_per_latent=max_examples_per_latent
    )

    num_batches = int(np.ceil(len(dataset) / batch_size))

    for batch_start in tqdm.tqdm(range(0, len(dataset), batch_size), total=num_batches, desc="Building Activation Cache"):
        batch_end = min(batch_start + batch_size, len(dataset))
        batch = vocab_embeddings[vocab_indices[batch_start:batch_end]]

        # Run encoding on device and pull results to CPU for fast numpy updates
        top_latents, expert_latents, _, _ = jax.vmap(model.encode)(batch)
        top_latents = np.array(top_latents)
        expert_latents = np.array(expert_latents)
        batch_words = np.array(vocab_indices[batch_start:batch_end])

        activation_cache.update(batch_words=batch_words, 
                                top_latents=top_latents, 
                                expert_latents=expert_latents)

    return activation_cache

activation_cache = build_activation_cache(
    dataset=g,
    model=model,
    batch_size=8192,
    vocab_embeddings=g,
    vocab_indices=jnp.arange(g.shape[0]),
    num_experts=num_experts,
    atoms_per_subspace=atoms_per_subspace,
    top_k=k
)

In [None]:
def explain_word(word, top_k=5, low_k=5):
    word_vec = g[vocab_dict[word]]
    top_latents, low_latents, top_k_indices, _ = model.encode(word_vec)

    # Remove batch dimension
    top_latents = np.array(top_latents).squeeze()
    low_latents = np.array(low_latents).squeeze()
    top_k_indices = np.array(top_k_indices).squeeze()

    print(f"\nExplanations for word: '{word.strip()}'\n{'='*50}")

    for idx in range(min(top_k, len(top_k_indices))):
        topic = int(top_k_indices[idx])
        activation = top_latents[topic]
        print(f"\n🔹 Top-Level Feature {topic} (Activation: {activation:.4f})")
        print("  Words that maximally activate this feature:")

        # Top-level activating words
        top_examples = np.argsort(-activation_cache.top_activations[topic])[:low_k]
        top_words = [vocab_list[int(w)] for w in activation_cache.top_words[topic, top_examples] if w >= 0]
        print(f"   {top_words}")

        # Low-level activation for this top-level feature
        low_latent = int(np.argmax(low_latents[idx, :]))
        low_activation = low_latents[idx, low_latent]
        print(f"    ↳ Low-Level Feature: {low_latent} (Activation: {low_activation:.8f})")

        print("      Words that maximally activate this low-level feature:")
        low_examples = np.argsort(-activation_cache.low_activations[topic, low_latent])[:low_k]
        low_words = [vocab_list[int(w)] for w in activation_cache.low_words[topic, low_latent, low_examples] if w >= 0]
        print(f"       {low_words}")

In [None]:
explain_word(' puppy')

In [None]:
explain_word(' Queen')

In [None]:
explain_word(' Chicago')

In [None]:
explain_word(' London')

In [None]:
explain_word(' Twitter')

In [None]:
explain_word(' python')

In [None]:
explain_word(' Bayesian')