In [2]:
from itertools import islice
import pandas as pd
from sentence_transformers import SentenceTransformer
from typing import Dict

from nomic.atlas import AtlasDataset
from latentsae import Sae


  from tqdm.autonotebook import tqdm, trange


Triton not installed, using eager implementation of SAE decoder.


In [3]:
sae_model = Sae.load_from_hub("enjalot/sae-nomic-text-v1.5-FineWeb-edu-100BT", "64_32")

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Dropping extra args {'signed': False}


In [4]:
emb_model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True)

<All keys matched successfully>


In [5]:
device = "mps"
sae_model = sae_model.to(device)
emb_model = emb_model.to(device)

# Test the SAE

In [6]:
loaded_features = pd.read_parquet("features.parquet").to_dict(orient='records')

In [12]:
def aggregate_encoder_output(encoder_output, k: int = 5) -> Dict[int, float]:
    total_activations = {}
    for idx, act in zip(encoder_output.top_indices.cpu().flatten(), encoder_output.top_acts.cpu().flatten()):
        idx_int = idx.item()
        if idx_int in total_activations:
            total_activations[idx_int] += act.item()
        else:
            total_activations[idx_int] = act.item()
    sorted_activations = dict(sorted(total_activations.items(), key=lambda item: item[1], reverse=True))
    return sorted_activations

In [17]:
def summarize_encoder_output(sorted_activations, k=5):
    return [loaded_features[idx]['label'] for idx in list(islice(sorted_activations, k))]

In [18]:
test_strings = ['baseball', 'football', 'basketball', 'hockey', 'soccer']

In [19]:
summarize_encoder_output(aggregate_encoder_output(sae_model.encode(emb_model.encode(test_strings, convert_to_tensor=True))))

['sports activism and social impact',
 'Cultural narratives and storytelling techniques',
 'interdisciplinary teamwork and project management strategies',
 'quantum computing and nanomaterial advancements',
 'climate classification and temperature analysis']

# Test the SAE on the Y Combinator dataset in a mini Atlas

In [21]:
datamap = AtlasDataset('nomic/y-combinator').maps[0]
yc_df = datamap.data.df
yc_embeddings = datamap.embeddings.latent
yc_projected_embeddings = datamap.embeddings.projected
yc_projected_embeddings['year'] = yc_df.Year

selection_idx = yc_df[(yc_df.Year > 0) & (yc_df.oneliner_then_tags != "null")].index.values
yc_df = yc_df.loc[selection_idx]
yc_projected_embeddings = yc_projected_embeddings.loc[selection_idx]
years = sorted(yc_projected_embeddings['year'].unique())

[32m2024-09-20 11:43:28.436[0m | [1mINFO    [0m | [36mnomic.dataset[0m:[36m__init__[0m:[36m763[0m - [1mLoading existing dataset `nomic/y-combinator`.[0m
[32m2024-09-20 11:43:29.302[0m | [1mINFO    [0m | [36mnomic.data_operations[0m:[36m_download_data[0m:[36m902[0m - [1mDownloading data[0m
100%|██████████| 5/5 [00:00<00:00, 1756.85it/s]
100%|██████████| 5/5 [00:00<00:00, 2244.14it/s]
100%|██████████| 5/5 [00:00<00:00, 1511.03it/s]
100%|██████████| 5/5 [00:00<00:00, 2441.39it/s]
100%|██████████| 5/5 [00:00<00:00, 2179.54it/s]
100%|██████████| 5/5 [00:00<00:00, 2321.14it/s]
100%|██████████| 5/5 [00:00<00:00, 2548.49it/s]
[32m2024-09-20 11:43:29.865[0m | [1mINFO    [0m | [36mnomic.data_operations[0m:[36m_load_data[0m:[36m872[0m - [1mLoading data[0m
100%|██████████| 5/5 [00:00<00:00, 334.82it/s]
[32m2024-09-20 11:43:30.057[0m | [1mINFO    [0m | [36mnomic.data_operations[0m:[36m_download_latent[0m:[36m550[0m - [1mDownloading latent embeddings[0

In [22]:
# Precompute bar chart data for each year
bar_chart_data = {}
for year in years:
    s = yc_df[yc_df.Year == year].oneliner_then_tags.values
    text_embeddings = emb_model.encode(s, convert_to_tensor=True, normalize_embeddings=True)
    top_activated_features_sae_output = sae_model.encode(text_embeddings)
    top_sae_features_hist = aggregate_encoder_output(top_activated_features_sae_output)
    idx = list(top_sae_features_hist.keys())[:10]
    names = [f'{i}: {loaded_features[i]["label"]}' for i in idx]
    vals = [top_sae_features_hist[i] for i in idx]
    bar_chart_data[year] = {'names': names, 'vals': vals}


In [23]:
import dash
from dash import html, dcc
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from dash.dependencies import Input, Output

app = dash.Dash(__name__)

def create_figure(selected_year):
    fig = make_subplots(
        rows=1, 
        cols=2, 
        subplot_titles=(
            "Startups {selected_year}", 
            f"Vector Summary {selected_year}"
        )
    )
    
    fig.add_trace(go.Scatter(
        x=yc_projected_embeddings['x'],
        y=-yc_projected_embeddings['y'],
        mode='markers',
        marker=dict(
            size=6,
            color=['red' if year == selected_year else 'lightgrey' for year in yc_projected_embeddings['year']],
            opacity=0.3
        ),
        text=[
            f'{row.Company} {row.Batch} {row.Status}<br><br>{row.oneliner_then_tags}' 
            for _, row in yc_df.iterrows()
        ],
        hoverinfo='text'
    ), row=1, col=1)

    fig.add_trace(go.Bar(
        x=bar_chart_data[selected_year]['names'],
        y=bar_chart_data[selected_year]['vals'],
        name='SAE Features'
    ), row=1, col=2)

    fig.update_layout(
        height=400,
        showlegend=False
    )
    fig.update_xaxes(title_text="X", row=1, col=1)
    fig.update_yaxes(title_text="Y", row=1, col=1)
    fig.update_xaxes(title_text="SAE Features", tickangle=45, row=2, col=1, tickfont=3)
    fig.update_yaxes(title_text="Activation", row=2, col=1)

    return fig

# Update the app layout
app.layout = html.Div([
    html.Div([
        dcc.Graph(id='main-graph', style={'height': '400px'}),
        dcc.Slider(
            id='year-slider',
            min=min(years),
            max=max(years),
            value=min(years),
            marks={str(year): str(year) for year in years},
            step=None
        )
    ], style={'width': '100%', 'padding': '20px'})
])

@app.callback(
    Output('main-graph', 'figure'),
    Input('year-slider', 'value')
)
def update_graph(selected_year):
    return create_figure(selected_year)

if __name__ == '__main__':
    app.run_server(debug=True)

# Notes

gpt4o-mini overuses the words "interdisciplinary" and "quantum"

Manual nomencodes

5507: Apple, Inc.