In [2]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
import torch
from typing import Dict, List, NamedTuple, Tuple, Union

from nomic.atlas import AtlasDataset

from latentsae import Sae

import dash
from dash import html, dcc, callback_context
import plotly.graph_objects as go
from dash.dependencies import Input, Output, State
from plotly.subplots import make_subplots

from itertools import islice

def take(n, iterable):
    """Return the first n items of the iterable as a list."""
    return list(islice(iterable, n))

  from tqdm.autonotebook import tqdm, trange


Triton not installed, using eager implementation of SAE decoder.


In [3]:
sae_model = Sae.load_from_hub("enjalot/sae-nomic-text-v1.5-FineWeb-edu-100BT", "64_32")

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Dropping extra args {'signed': False}


In [4]:
emb_model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True)

<All keys matched successfully>


In [5]:
device = "mps"
sae_model = sae_model.to(device)
emb_model = emb_model.to(device)

In [6]:
datamap = AtlasDataset('nomic/y-combinator').maps[0]
yc_df = datamap.data.df
yc_embeddings = datamap.embeddings.latent
yc_projected_embeddings = datamap.embeddings.projected
yc_projected_embeddings['year'] = yc_df.Year

selection_idx = yc_df[(yc_df.Year > 0) & (yc_df.oneliner_then_tags != "null")].index.values
yc_df = yc_df.loc[selection_idx]
yc_projected_embeddings = yc_projected_embeddings.loc[selection_idx]
years = sorted(yc_projected_embeddings['year'].unique())

[32m2024-09-20 09:34:00.322[0m | [1mINFO    [0m | [36mnomic.dataset[0m:[36m__init__[0m:[36m763[0m - [1mLoading existing dataset `nomic/y-combinator`.[0m
[32m2024-09-20 09:34:01.266[0m | [1mINFO    [0m | [36mnomic.data_operations[0m:[36m_download_data[0m:[36m902[0m - [1mDownloading data[0m
100%|██████████| 5/5 [00:00<00:00, 3156.46it/s]
100%|██████████| 5/5 [00:00<00:00, 2714.06it/s]
100%|██████████| 5/5 [00:00<00:00, 3261.51it/s]
100%|██████████| 5/5 [00:00<00:00, 3508.12it/s]
100%|██████████| 5/5 [00:00<00:00, 4039.20it/s]
100%|██████████| 5/5 [00:00<00:00, 3075.45it/s]
100%|██████████| 5/5 [00:00<00:00, 4129.06it/s]
[32m2024-09-20 09:34:01.887[0m | [1mINFO    [0m | [36mnomic.data_operations[0m:[36m_load_data[0m:[36m872[0m - [1mLoading data[0m
100%|██████████| 5/5 [00:00<00:00, 473.31it/s]
[32m2024-09-20 09:34:02.105[0m | [1mINFO    [0m | [36mnomic.data_operations[0m:[36m_download_latent[0m:[36m550[0m - [1mDownloading latent embeddings[0

In [7]:
loaded_features = pd.read_parquet("sae/features.parquet").to_dict(orient='records')

In [8]:
class EncoderOutput(NamedTuple):
    top_acts: torch.Tensor
    top_indices: torch.Tensor

def aggregate_encoder_output(encoder_output: EncoderOutput) -> Dict[int, float]:
    """
    Aggregates an encoder output over the batch dimension and sums the total activation
    for each unique index in top_indices.

    Args:
    encoder_output (EncoderOutput): A named tuple containing top_acts and top_indices tensors.

    Returns:
    Dict[int, float]: A dictionary mapping indices to their total activation values.
    """
    # Move tensors to CPU for easier processing
    top_acts = encoder_output.top_acts.cpu()
    top_indices = encoder_output.top_indices.cpu()

    # Flatten the tensors
    flat_acts = top_acts.flatten()
    flat_indices = top_indices.flatten()

    # Create a dictionary to store the aggregated values
    aggregated = {}

    # Iterate through the flattened tensors
    for idx, act in zip(flat_indices, flat_acts):
        idx_int = idx.item()  # Convert tensor to Python int
        if idx_int in aggregated:
            aggregated[idx_int] += act.item()
        else:
            aggregated[idx_int] = act.item()

    return dict(sorted(aggregated.items(), key=lambda item: item[1], reverse=True))

In [21]:
# Precompute bar chart data for each year
bar_chart_data = {}
for year in years:
    print(year)
    s = yc_df[yc_df.Year == year].oneliner_then_tags.values
    text_embeddings = emb_model.encode(s, convert_to_tensor=True, normalize_embeddings=True)
    top_activated_features_sae_output = sae_model.encode(text_embeddings)
    top_sae_features_hist = aggregate_encoder_output(top_activated_features_sae_output)
    idx = list(top_sae_features_hist.keys())[:10]
    names = [f'{i}: {loaded_features[i]["label"]}' for i in idx]
    vals = [top_sae_features_hist[i] for i in idx]
    bar_chart_data[year] = {'names': names, 'vals': vals}


2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024


In [34]:
import dash
from dash import html, dcc
import plotly.graph_objs as go
from plotly.subplots import make_subplots

app = dash.Dash(__name__)

def create_figure(selected_year):
    fig = make_subplots(rows=1, cols=2, subplot_titles=("Y Combinator over time", f"Top SAE Features for {selected_year}"))
    
    fig.add_trace(go.Scatter(
        x=yc_projected_embeddings['x'],
        y=-yc_projected_embeddings['y'],
        mode='markers',
        marker=dict(
            size=6,
            color=['red' if year == selected_year else 'lightgrey' for year in yc_projected_embeddings['year']],
            opacity=0.3
        ),
        text=[
            f'{row.Company} {row.Batch} {row.Status}<br><br>{row.oneliner_then_tags}' 
            for _, row in yc_df.iterrows()
        ],
        hoverinfo='text'
    ), row=1, col=1)

    fig.add_trace(go.Bar(
        x=bar_chart_data[selected_year]['names'],
        y=bar_chart_data[selected_year]['vals'],
        name='SAE Features'
    ), row=1, col=2)

    fig.update_layout(
        height=400,
        showlegend=False
    )
    fig.update_xaxes(title_text="X", row=1, col=1)
    fig.update_yaxes(title_text="Y", row=1, col=1)
    fig.update_xaxes(title_text="SAE Features", tickangle=45, row=2, col=1, tickfont=3)
    fig.update_yaxes(title_text="Activation", row=2, col=1)

    return fig

# Update the app layout
app.layout = html.Div([
    html.Div([
        dcc.Graph(id='main-graph', style={'height': '400px'}),
        dcc.Slider(
            id='year-slider',
            min=min(years),
            max=max(years),
            value=min(years),
            marks={str(year): str(year) for year in years},
            step=None
        )
    ], style={'width': '100%', 'padding': '20px'})
])

@app.callback(
    Output('main-graph', 'figure'),
    Input('year-slider', 'value')
)
def update_graph(selected_year):
    return create_figure(selected_year)

if __name__ == '__main__':
    app.run_server(debug=True)

# Notes

gpt4o-mini overuses the words "interdisciplinary" and "quantum"

Manual nomencodes

5507: Apple, Inc.

# Test the SAE

In [35]:
[x['label'] for x in np.array(loaded_features)[take(5, aggregate_encoder_output(
    sae_model.encode(
        emb_model.encode(
            [
                "The Big Apple", 
                "Gotham", 
                "Empire State Building", 
                "Five Boroughs"
            ], 
            convert_to_tensor=True
        )
    )
))]]

['New York City characteristics and demographics',
 'Cultural narratives and storytelling techniques',
 'quantum computing and nanomaterial advancements',
 'advanced materials and manufacturing processes',
 'John Jacob Astor and American fur trade']