# Clean Gemma Unembeddings

Gemma vocab includes a lot of junk. Here's a particular scheme for tossing the junk and whitening the embeddings

In [None]:
import jax
print("Devices:", jax.devices())
import jax.numpy as jnp
import json
from transformers import AutoTokenizer
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
vocab = tokenizer.get_vocab()

num_devices = len(jax.devices())
sharding = PositionalSharding(mesh_utils.create_device_mesh((num_devices,1)))
import numpy as np
out_dir = "" # where to store the unembeddings and such?

In [None]:
g = jax.device_put(np.load('path/to/gemma_unembedding_matrix/raw_unembeddings.npy'), sharding)

In [None]:
import treescope
treescope.basic_interactive_setup(autovisualize_arrays=True)

In [None]:
g = g - g.mean(axis=0)
u, s, vt = jnp.linalg.svd(g, full_matrices=False)
g = u @ vt

In [None]:
import sentencepiece as spm
vocab = spm.SentencePieceProcessor()
vocab.Load("path/to/gemma_2_tokenizer/tokenizer.model")

In [None]:
dog_idx = jax.lax.top_k(g @ g[vocab.EncodeAsIds(' dog')[0]], 20)[1]
vocab.DecodeIds(dog_idx.tolist())

In [None]:
# heuristic to filter junk words out of the vocab
norms = jnp.linalg.norm(g, axis=1)
acceptable_vocab = jnp.where((norms < 0.11558999) & (norms > 0.07008683))
g = g[acceptable_vocab]

In [None]:
# try reindexing
vocab_dict = {}
for new_idx, orig_idx in enumerate(acceptable_vocab[0].tolist()):
  vocab_dict[vocab.DecodeIds([orig_idx])] = new_idx

vocab_list = [None] * (max(vocab_dict.values()) + 1)
for word, index in vocab_dict.items():
    vocab_list[index] = word

dog_idx = jax.lax.top_k(g @ g[vocab_dict[' dog']], 20)[1]
print([vocab_list[idx] for idx in dog_idx.tolist()])

In [None]:
# whiten reindexed
g = g - g.mean(axis=0)
u, s, vt = jnp.linalg.svd(g, full_matrices=False)
g = u @ vt

# check that rewhitening doesn't break anything
dog_idx = jax.lax.top_k(g @ g[vocab_dict[' dog']], 20)[1]
print([vocab_list[idx] for idx in dog_idx.tolist()])

In [None]:
jnp.save(f'{out_dir}/clean_unembeddings.npy', g)
jnp.save(f'{out_dir}/clean_unembeddings_indices.npy', acceptable_vocab)
with open(f'{out_dir}clean_vocab_dict.json', 'w') as fout:
  json.dump(vocab_dict, fout)