In [1]:
from itertools import islice
import pandas as pd
from sentence_transformers import SentenceTransformer
from typing import Dict

from nomic.atlas import AtlasDataset
from latentsae import Sae


  from tqdm.autonotebook import tqdm, trange


Triton not installed, using eager implementation of SAE decoder.


In [2]:
sae_model = Sae.load_from_hub("enjalot/sae-nomic-text-v1.5-FineWeb-edu-100BT", "64_32")

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Dropping extra args {'signed': False}


In [3]:
emb_model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True)

<All keys matched successfully>


In [4]:
device = "mps"
sae_model = sae_model.to(device)
emb_model = emb_model.to(device)

# Test the SAE

In [5]:
loaded_features = pd.read_parquet("features.parquet").to_dict(orient='records')

In [7]:
def aggregate_encoder_output(encoder_output, k: int = 5) -> Dict[int, float]:
    total_activations = {}
    for idx, act in zip(encoder_output.top_indices.cpu().flatten(), encoder_output.top_acts.cpu().flatten()):
        idx_int = idx.item()
        if idx_int in total_activations:
            total_activations[idx_int] += act.item()
        else:
            total_activations[idx_int] = act.item()
    sorted_activations = dict(sorted(total_activations.items(), key=lambda item: item[1], reverse=True))
    return sorted_activations

In [8]:
sae_model.encode(emb_model.encode(['ben'], convert_to_tensor=True))

EncoderOutput(top_acts=tensor([[10.2676,  9.1370,  6.9692,  6.5758,  6.5321,  6.4380,  6.0169,  5.9566,
          5.9451,  5.8986,  5.8974,  5.8676,  5.8553,  5.8119,  5.7983,  5.7916,
          5.7621,  5.7501,  5.7412,  5.7339,  5.7236,  5.6925,  5.6840,  5.6580,
          5.6326,  5.6273,  5.6240,  5.5846,  5.5812,  5.5757,  5.5744,  5.5702,
          5.5594,  5.4869,  5.4468,  5.4023,  5.3949,  5.3904,  5.3837,  5.3410,
          5.3271,  5.3213,  5.3161,  5.3132,  5.3039,  5.2932,  5.2801,  5.2531,
          5.2474,  5.2139,  5.2120,  5.2048,  5.1769,  5.1763,  5.1713,  5.1688,
          5.1683,  5.1613,  5.1479,  5.1476,  5.1438,  5.1357,  5.1347,  5.1287]],
       device='mps:0', grad_fn=<TopkBackward0>), top_indices=tensor([[19159, 16718,  6328, 20182,  1239,  6939, 23114, 11704, 11465, 23945,
          6625,  4997,  1741,  2884, 16833, 22050, 17685, 11466,  9254,  7775,
         16983,  2910,  4080,   433, 13437,  1865, 21547,  3104,  3123,  7919,
          9995, 18136,   193,

In [9]:
def summarize_encoder_output(sorted_activations, k=5):
    return [loaded_features[idx]['label'] for idx in list(islice(sorted_activations, k))]

In [13]:
test_strings = ['UMAP', 't-SNE', 'PCA', 'SVD']

In [14]:
summarize_encoder_output(aggregate_encoder_output(sae_model.encode(emb_model.encode(test_strings, convert_to_tensor=True))))

['Cultural narratives and storytelling techniques',
 'advanced materials and manufacturing processes',
 'quantum computing and nanomaterial advancements',
 'taser usage and technology in law enforcement',
 'astrophysics and materials science advancements']

# Notes

gpt4o-mini overuses the words "interdisciplinary" and "quantum"

Manual nomencodes

5507: Apple, Inc.