In [7]:
import torch
import numpy as np
from sentence_transformers import SentenceTransformer

  from tqdm.autonotebook import tqdm, trange


In [82]:

# the most downloaded sentence transformer on HuggingFace
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')



In [83]:
sentences = ["This is an example sentence", "Here is an example sentence"]

In [85]:
embeddings = model.encode(sentences, normalize_embeddings=True)

In [86]:
# each input sentence gets a 384 dimension embedding array associated with it
embeddings.shape

(2, 384)

In [87]:
embeddings[0]

array([ 6.76568896e-02,  6.34958670e-02,  4.87130694e-02,  7.93049484e-02,
        3.74480374e-02,  2.65276735e-03,  3.93748544e-02, -7.09845824e-03,
        5.93614839e-02,  3.15370038e-02,  6.00980558e-02, -5.29051572e-02,
        4.06067856e-02, -2.59308089e-02,  2.98427846e-02,  1.12691044e-03,
        7.35149235e-02, -5.03819846e-02, -1.22386619e-01,  2.37028450e-02,
        2.97265202e-02,  4.24768701e-02,  2.56337672e-02,  1.99519377e-03,
       -5.69190904e-02, -2.71598604e-02, -3.29035521e-02,  6.60248920e-02,
        1.19007073e-01, -4.58791479e-02, -7.26214498e-02, -3.25839967e-02,
        5.23413680e-02,  4.50552627e-02,  8.25300813e-03,  3.67023535e-02,
       -1.39415022e-02,  6.53919503e-02, -2.64272951e-02,  2.06431461e-04,
       -1.36643415e-02, -3.62809561e-02, -1.95043348e-02, -2.89738495e-02,
        3.94270383e-02, -8.84090587e-02,  2.62426888e-03,  1.36713758e-02,
        4.83063124e-02, -3.11565623e-02, -1.17329180e-01, -5.11690155e-02,
       -8.85287523e-02, -

In [21]:
# Compute cosine similarity between the two sentence embeddings
# this is the foundation of all RAG stuff. basically embed your dataset
# then embed a query and find the most similar embeddings in your dataset
# then use the original text to answer the query
cosine_similarity = torch.nn.functional.cosine_similarity(torch.tensor(embeddings[0]).unsqueeze(0), torch.tensor(embeddings[1]).unsqueeze(0))
cosine_similarity


tensor([0.8809])

In [22]:
# The Sentence transformers library simplifies a lot of things for making embeddings
# let's take a look at some of the stuff that goes on under the hood
# this will be more like the LLM stuff
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F

In [23]:
# Load model from HuggingFace Hub
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')



In [55]:
# Many sentence transformers models are based on the BERT architecture
# though recently there are some based on LLMs 
model

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 384, padding_idx=0)
    (position_embeddings): Embedding(512, 384)
    (token_type_embeddings): Embedding(2, 384)
    (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-5): 6 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=384, out_features=384, bias=True)
            (key): Linear(in_features=384, out_features=384, bias=True)
            (value): Linear(in_features=384, out_features=384, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=384, out_features=384, bias=True)
            (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)


In [24]:
# we can see here that hidden_size is 384, the dimensionality of the embeddings
model.config

BertConfig {
  "_name_or_path": "sentence-transformers/all-MiniLM-L6-v2",
  "architectures": [
    "BertModel"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 384,
  "initializer_range": 0.02,
  "intermediate_size": 1536,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 6,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.45.1",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [27]:
# a key variable here is model_max_length=512, so that means the model can only handle up to 512 tokens
tokenizer

BertTokenizerFast(name_or_path='sentence-transformers/all-MiniLM-L6-v2', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

In [43]:
# lets do a single longer input to take a closer look at what happens with tokens
# uncomment the second input to see truncation happen (its 233 * 3 tokens)
# taken from this list of top 10 paragraphs https://www.jjrlore.com/post/top-10-paragraphs
inputs = [
  """
  Books bombarded his shoulder, his arms, his upturned face.  A book lit, almost obediently, like a white pigeon, in his hands, wings fluttering.  In the dim, wavering light, a page hung open and it was like a snowy feather, the words delicately painted thereon.  In all the rush and fervor, Montage had only an instant to read a line, but it blazed in his mind for the next minute as if stamped there with fiery steel.  “Time has fallen asleep in the afternoon sunshine.”  He dropped the book.  Immediately, another fell into his arms.
  """,
  # """
  # The tent he lived in stood right smack up against the wall of the shallow, dull-colored forest separating his own squadron from Dunbar’ s.  Immediately alongside was the abandoned railroad ditch that carried the pipe that carried the aviation gasoline down to the fuel trucks at the airfield.  Thanks to Orr, his roommate, it was the most luxurious tent in the squadron.  Each time Yossarian returned from one of his holidays in the hospital or rest leaves in Rome, he was surprised by some new comfort Orr had installed in his absence - running water, wood-burning fireplace, cement floor.  Yossarian had chosen the site, and he and Orr had raised the tent to get her.  Orr, who was a grinning pygmy with pilot’s wings and thick, wavy brown hair parted in the middle, furnished all the knowledge, while Yossarian, who was taller, stronger, broader, and faster, did most of the work.  Just the two of them lived there, although the tent was big enough for six.  When summer came, Orr rolled up the side flaps to allow a breeze that never blew to flush away the air baking inside.
  # The tent he lived in stood right smack up against the wall of the shallow, dull-colored forest separating his own squadron from Dunbar’ s.  Immediately alongside was the abandoned railroad ditch that carried the pipe that carried the aviation gasoline down to the fuel trucks at the airfield.  Thanks to Orr, his roommate, it was the most luxurious tent in the squadron.  Each time Yossarian returned from one of his holidays in the hospital or rest leaves in Rome, he was surprised by some new comfort Orr had installed in his absence - running water, wood-burning fireplace, cement floor.  Yossarian had chosen the site, and he and Orr had raised the tent to get her.  Orr, who was a grinning pygmy with pilot’s wings and thick, wavy brown hair parted in the middle, furnished all the knowledge, while Yossarian, who was taller, stronger, broader, and faster, did most of the work.  Just the two of them lived there, although the tent was big enough for six.  When summer came, Orr rolled up the side flaps to allow a breeze that never blew to flush away the air baking inside.
  # The tent he lived in stood right smack up against the wall of the shallow, dull-colored forest separating his own squadron from Dunbar’ s.  Immediately alongside was the abandoned railroad ditch that carried the pipe that carried the aviation gasoline down to the fuel trucks at the airfield.  Thanks to Orr, his roommate, it was the most luxurious tent in the squadron.  Each time Yossarian returned from one of his holidays in the hospital or rest leaves in Rome, he was surprised by some new comfort Orr had installed in his absence - running water, wood-burning fireplace, cement floor.  Yossarian had chosen the site, and he and Orr had raised the tent to get her.  Orr, who was a grinning pygmy with pilot’s wings and thick, wavy brown hair parted in the middle, furnished all the knowledge, while Yossarian, who was taller, stronger, broader, and faster, did most of the work.  Just the two of them lived there, although the tent was big enough for six.  When summer came, Orr rolled up the side flaps to allow a breeze that never blew to flush away the air baking inside.
  # """
]

In [44]:
# Tokenize sentences
encoded_input = tokenizer(inputs, padding=True, truncation=True, return_tensors='pt')

In [45]:
encoded_input["input_ids"].shape

torch.Size([1, 127])

In [46]:
# Compute token embeddings
with torch.no_grad():
    model_output = model(**encoded_input)

In [47]:
#Mean Pooling - Take attention mask into account for correct averaging
# basically we average the embeddings, but only for the tokens that are not padding tokens
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


In [53]:
# we have an embedding for each token
model_output[0].shape

torch.Size([1, 127, 384])

In [48]:
# Perform pooling, average the embeddings for each token into a single embedding vector
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])


In [49]:

# Normalize embeddings
# something we should explain... its often taken for granted or glossed over
# but the embeddings are normalized to unit length
normalized_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)

In [50]:
normalized_embeddings.shape

torch.Size([1, 384])

# SAE Time

In [56]:
!pip install latentsae

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting latentsae
  Using cached latentsae-0.1.0-py3-none-any.whl.metadata (5.0 kB)
Collecting wandb (from latentsae)
  Downloading wandb-0.18.3-py3-none-macosx_11_0_arm64.whl.metadata (9.7 kB)
Collecting optuna (from latentsae)
  Downloading optuna-4.0.0-py3-none-any.whl.metadata (16 kB)
Collecting datasets (from latentsae)
  Using cached datasets-3.0.1-py3-none-any.whl.metadata (20 kB)
Collecting pyarrow (from latentsae)
  Using cached pyarrow-17.0.0-cp312-cp312-macosx_11_0_arm64.whl.metadata (3.3 kB)
Collecting dataclasses (from latentsae)
  Using cached dataclasses-0.6-py3-none-any.whl.metadata (3.0 kB)
Collecting simple-parsing (from latentsae)
  Downloading simple_parsing-0.1.6-py3-none-any.whl.metadata (7.3 kB)
Collecting einops (from latentsae)
  Using cached einops-0.8.0-py3-none-any.whl.metadata (12 kB)
Collecting accelerate (from latentsae)
  Downloading accelerate-1.0.0-py3-none-any.whl.metadata (19 kB)
Collecting bitsandbytes (from latentsae)
  Using cached bitsandbytes

In [60]:
import latentsae

In [72]:
from latentsae.sae import Sae
import pandas as pd
import json

In [67]:
sae_model = Sae.load_from_hub("enjalot/sae-nomic-text-v1.5-FineWeb-edu-100BT", "64_32", device="cpu")

Fetching 2 files: 100%|██████████| 2/2 [00:00<00:00, 18040.02it/s]
Dropping extra args {'signed': False}


In [79]:
# some metadata for the SAE to help us describe features
sae_meta = json.load(open("/Users/enjalot/code/latent-taxonomy/web/public/models/NOMIC_FWEDU_25k/metadata.json"))
sae_features = pd.read_parquet("/Users/enjalot/code/latent-taxonomy/web/public/models/NOMIC_FWEDU_25k/features.parquet")

In [71]:
emb_model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True, device="cpu")

<All keys matched successfully>


In [88]:
embedding = emb_model.encode(inputs, normalize_embeddings=True)

In [89]:
embedding.shape

(1, 768)

In [91]:
embedding

array([[ 3.15869600e-02,  2.08352990e-02, -1.92254588e-01,
        -9.17355716e-02,  2.72580553e-02,  6.38516992e-02,
        -1.85686518e-02,  1.10491365e-02,  1.82607630e-03,
         2.33820500e-03, -5.76118343e-02, -1.79362893e-02,
         5.58326654e-02,  7.99097195e-02,  2.99598444e-02,
         1.22380489e-02,  5.03575206e-02, -3.67829576e-02,
        -1.23029295e-02,  2.11170670e-02, -3.35725583e-02,
        -3.16633433e-02, -6.66434243e-02, -1.86119997e-03,
         1.27506945e-02,  4.17836681e-02, -8.43166038e-02,
        -1.67221650e-02, -7.96173289e-02,  8.91472865e-03,
         4.86480780e-02, -2.93906741e-02, -7.05600679e-02,
        -1.38342381e-03, -1.24853812e-02, -5.06924242e-02,
         3.19077075e-03, -5.24868490e-04,  4.56509143e-02,
         2.04044469e-02,  2.90306266e-02, -8.74189020e-04,
         1.96447577e-02,  3.89255285e-02,  9.87223089e-02,
        -4.23938222e-03,  1.17774466e-02,  3.99901979e-02,
        -1.29936016e-04, -2.83545768e-03,  6.14176355e-0

In [97]:
sae_latents = sae_model.encode(torch.from_numpy(embedding))


In [98]:
sae_latents

EncoderOutput(top_acts=tensor([[0.1960, 0.1681, 0.1379, 0.1321, 0.1249, 0.1210, 0.1190, 0.1176, 0.0991,
         0.0820, 0.0775, 0.0743, 0.0729, 0.0710, 0.0671, 0.0658, 0.0615, 0.0602,
         0.0601, 0.0575, 0.0570, 0.0561, 0.0533, 0.0521, 0.0519, 0.0501, 0.0463,
         0.0461, 0.0454, 0.0439, 0.0436, 0.0430, 0.0418, 0.0408, 0.0404, 0.0395,
         0.0394, 0.0384, 0.0382, 0.0376, 0.0373, 0.0372, 0.0365, 0.0360, 0.0358,
         0.0358, 0.0356, 0.0349, 0.0340, 0.0337, 0.0331, 0.0330, 0.0328, 0.0325,
         0.0323, 0.0317, 0.0314, 0.0313, 0.0307, 0.0297, 0.0295, 0.0291, 0.0289,
         0.0289]], grad_fn=<TopkBackward0>), top_indices=tensor([[ 3020,  8990,   304, 14433, 12363, 10739, 17432, 13030, 21919, 17790,
          9132,  1309, 23332,  5392, 15241,  5676,  2340, 13476, 18302,  4207,
          3863, 18703, 12129, 17693,  8104,  1068, 21838, 17469,  2570, 22818,
          6981, 14476, 22688,  5527, 17265,  2454,  6674,  9934, 20069, 18553,
         19690,   928,  2595, 19736, 

In [99]:
#index of the top "feature"
sae_latents.top_indices[0][0]

tensor(3020)

In [103]:
# look at the sae metadata we have for the top feature
sae_features.iloc[sae_latents.top_indices[0][0].item()]

feature                                                   304
max_activation                                        0.33558
x                                                    0.575523
y                                                    0.231955
top10_x                                             -0.682711
top10_y                                               0.12703
label             themes of destruction suffering and despair
order                                                0.296192
Name: 304, dtype: object

In [105]:
# get the labels for the top 8 features
top8 = sae_latents.top_indices[0][:8]
top8_features = sae_features.iloc[top8]
top8_features["label"].to_list()

['publishing and writing about literature',
 'book review and analysis techniques',
 'themes of destruction suffering and despair',
 'supportive communication in addiction recovery contexts',
 'prioritizing responsibilities over personal enjoyment',
 'social change resilience and identity exploration',
 'academic and scientific excellence in personal relationships',
 'metaphorical and literal interpretations of falling']