# HypotheSAEs Quickstart

This notebook demonstrates basic usage of HypotheSAEs on a sample of the Yelp review dataset.  
We use GPT-4.1 for hypothesis generation (interpreting neurons), and GPT-4.1-mini for text annotation.  
Please set your OpenAI API key in the environment variable `OPENAI_KEY_SAE` in the below notebook cell.

In [2]:
%load_ext autoreload
%autoreload 2

import os
os.environ['OPENAI_KEY_SAE'] = '...' # Replace with your OpenAI API key, or with another environment variable (e.g. os.environ['OPENAI_API_KEY'])

import numpy as np
import pandas as pd

from hypothesaes.quickstart import train_sae, interpret_sae, generate_hypotheses, evaluate_hypotheses
from hypothesaes.embedding import get_openai_embeddings, get_local_embeddings

INTERPRETER_MODEL = "gpt-4.1"
ANNOTATOR_MODEL = "gpt-4.1-mini"
N_WORKERS_ANNOTATION = 10 # Number of parallel threads to use for annotation API calls; lower if hitting OpenAI rate limits

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


**Load data**

The dataset we're using here is a subset of 20K Yelp reviews, with 2K reviews used for validation (during SAE training). 

The target variable is the `stars` column, which is a rating between 1 and 5. We treat this as a regression task.

There are also 2K reviews used for holdout evaluation, which we'll use at the end of the notebook.

In [3]:
current_dir = os.getcwd()
if current_dir.endswith("notebooks"):
    prefix = "../"
else:
    prefix = "./"

base_dir = os.path.join(prefix, "demo_data")
train_df = pd.read_json(os.path.join(base_dir, "yelp-demo-train-20K.json"), lines=True)
val_df = pd.read_json(os.path.join(base_dir, "yelp-demo-val-2K.json"), lines=True)

texts = train_df['text'].tolist()
labels = train_df['stars'].values
val_texts = val_df['text'].tolist() # These are only used for early stopping of SAE training, so we don't need labels.

**Compute text embeddings for your dataset**

We'll compute text embeddings for a training set, and optionally a validation set. The validation embeddings are used for SAE eval and early-stopping during training.

Embeddings will be stored in the `emb_cache` directory (or `os.environ["EMB_CACHE_DIR"]` if you set it) using the `cache_name` parameter, so you only need to compute embeddings once.

You can use OpenAI or a local model.

Local models will run much faster on GPU. The default local model is `nomic-ai/modernbert-embed-base`. You can use any sentence-transformers model, but please read the model's docs; you may need to edit `get_local_embeddings`.

In [3]:
EMBEDDER = "text-embedding-3-small" # OpenAI
# EMBEDDER = "nomic-ai/modernbert-embed-base" # Huggingface model, will run locally
CACHE_NAME = f"yelp_quickstart_{EMBEDDER}"

text2embedding = get_openai_embeddings(texts + val_texts, model=EMBEDDER, cache_name=CACHE_NAME)
# text2embedding = get_local_embeddings(texts + val_texts, model=EMBEDDER, batch_size=128, cache_name=CACHE_NAME)
embeddings = np.stack([text2embedding[text] for text in texts])

train_embeddings = np.stack([text2embedding[text] for text in texts])
val_embeddings = np.stack([text2embedding[text] for text in val_texts])

Loading embedding chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Loaded 22000 embeddings in 41.5s


**Train SAE** 

We will train a Matryoshka SAE with $M=256$, $k=8$, and $\text{prefix\_lengths} = [32, 256]$.  

With the Matryoshka loss, the SAE will learn to reconstruct the input from (1) just the first 32 neurons, and (2) all 256 neurons.  
This will produce 32 coarse-grained features, and 224 finer-grained features.  

See the README for more details about selecting SAE hyperparameters. 

In [4]:
checkpoint_dir = os.path.join(prefix, "checkpoints", CACHE_NAME)
sae = train_sae(embeddings=train_embeddings, val_embeddings=val_embeddings,
                M=256, K=8, matryoshka_prefix_lengths=[32, 256], 
                checkpoint_dir=checkpoint_dir)

Loaded model from ../checkpoints/yelp_quickstart_text-embedding-3-small/SAE_matryoshka_M=256_K=8_prefixes=32-256.pt onto device cpu


**Interpret neurons**  

Interpret a random subset of neurons in the SAE to sanity-check that the learned features, and their interpretations, seem reasonable. We generate and print labels for `n_random_neurons` neurons, and we also print out the top-activating texts for each neuron.

In [6]:
# This instruction will be included in the neuron interpretation prompt.
# The below instructions are specific to Yelp, but you can customize this for your task.
# If you don't pass in task-specific instructions, there is a generic instruction (see src/interpret_neurons.py);
# task-specific instructions are optional, but they help produce hypotheses at the desired level of specificity.

TASK_SPECIFIC_INSTRUCTIONS = """All of the texts are reviews of restaurants on Yelp.
Features should describe a specific aspect of the review. For example:
- "mentions long wait times to receive service"
- "praises how a dish was cooked, with phrases like 'perfect medium-rare'\""""

# Interpret random neurons
results = interpret_sae(
    texts=texts,
    embeddings=train_embeddings,
    sae=sae,
    n_random_neurons=5,
    print_examples_n=3,
    task_specific_instructions=TASK_SPECIFIC_INSTRUCTIONS,
    interpreter_model=INTERPRETER_MODEL,
)

Computing activations (batchsize=16384):   0%|          | 0/2 [00:00<?, ?it/s]

Activations shape: (20000, 256)


Generating interpretations:   0%|          | 0/5 [00:00<?, ?it/s]


Neuron 91 (1.3% active, from SAE M=256, K=8): mentions pancakes as a menu item, often specifying types or flavors such as banana, German, or sweet potato pancakes

Top activating examples:
1. Good pancakes and they should be after 50 years or so of business!  Used to also be a place to see the occasional celebrity chowing down on some awesome sweet potato pancakes, but I doubt they wait in the line these days.  They're famous for a reason - their pancakes really are good, and they have a nice vibe inside, friendly service, and the prices are fair.  If you don't mind standing in line for a while, come see why they've been around so long.
2. First time at Pancake Pantry yesterday (Tuesday)! Got to the restaurant a little after 9am and was seated right away. I was so surprised as I have read the lines can go around the block. After I left around 10:30 was when that happened.   A friend and I shared the banana bread pancakes and the Georgia peach pancakes. Both were very good! The banana 

**Generate hypotheses**

Generate hypotheses which are predictive of the target variable.

The `selection_method` parameter defines how we compute neuron predictiveness (see `src/select_neurons.py` for more details):
- "separation_score": E[target | top-activating examples] - E[target | zero-activating examples]
- "correlation": pearson(neuron activations, target variable)
- "lasso": select N nonzero features with an L1 regularized model

This cell outputs a dataframe with the following columns:
- `neuron_idx`: The index of the neuron in the SAE (if you're using multiple SAEs, this will be a global index across all of them).
- `source_sae`: The SAE that the neuron was selected from.
- `target_{selection_method}`: The predictiveness of the neuron for the target variable, using the selected `selection_method`.
- `interpretation`: The natural language interpretation of the neuron.
- `interp_fidelity_score`: The F1 fidelity score for how well the neuron's interpretation actually corresponds to its activation pattern.

In [9]:
selection_method = "correlation"
results = generate_hypotheses(
    texts=texts,
    labels=labels,
    embeddings=embeddings,
    sae=sae,
    cache_name=CACHE_NAME,
    selection_method=selection_method,
    n_selected_neurons=20,
    n_candidate_interpretations=1,
    task_specific_instructions=TASK_SPECIFIC_INSTRUCTIONS,
    interpreter_model=INTERPRETER_MODEL,
    annotator_model=ANNOTATOR_MODEL,
    n_workers_annotation=N_WORKERS_ANNOTATION, # Please lower this parameter if you are running into OpenAI API rate limits
)

print("\nMost predictive features of Yelp reviews:")
pd.set_option('display.max_colwidth', None)
display(results.sort_values(by=f"target_{selection_method}", ascending=False).round(3))
pd.reset_option('display.max_colwidth')

Embeddings shape: (20000, 1536)


Computing activations (batchsize=16384):   0%|          | 0/2 [00:00<?, ?it/s]

Activations shape: (20000, 256)

Step 1: Selecting top 20 predictive neurons

Step 2: Interpreting selected neurons


Generating interpretations:   0%|          | 0/20 [00:00<?, ?it/s]


Step 3: Scoring Interpretations
Found 100 cached items; annotating 1900 uncached items


Scoring neuron interpretation fidelity (20 neurons; 1 candidate interps per neuron; 100 examples to score each…


Most predictive features of Yelp reviews:


Unnamed: 0,neuron_idx,source_sae,target_correlation,interpretation,f1_fidelity_score
2,10,"(SAE_0, 256, 8)",0.309,"expresses strong personal enthusiasm or love for the restaurant, using emphatic language such as 'love this place', 'amazing', 'favorite', or multiple exclamation marks",0.803
4,13,"(SAE_0, 256, 8)",0.229,specifically mentions the name of a server or staff member when praising the service,0.607
14,18,"(SAE_0, 256, 8)",0.129,"describes ordering or eating high-end steak cuts (e.g., ribeye, porterhouse, filet mignon, wagyu) at a steakhouse, including comments on preparation, doneness, or aging",0.765
15,123,"(SAE_0, 256, 8)",0.112,"claims the restaurant or its offerings are the absolute best in a specific area, city, or even the world (e.g., 'best pizza south of Mason-Dixon', 'best drinks in town', 'best hot chicken in the world')",0.887
19,252,"(SAE_0, 256, 8)",0.095,"describes the restaurant as a 'hidden gem', 'hidden secret', 'hidden treasure', or emphasizes that it is hard to find or off the beaten path",0.969
18,160,"(SAE_0, 256, 8)",-0.107,"describes receiving an order with missing, incorrect, or unfulfilled customization requests (such as wrong items, missing ingredients, or ignored special instructions)",0.898
17,203,"(SAE_0, 256, 8)",-0.109,"describes being refused seating or denied a table despite visibly empty tables or due to restrictive seating policies (e.g., party size, reservation rules, or arbitrary enforcement)",0.851
16,141,"(SAE_0, 256, 8)",-0.109,"mentions that the restaurant has recently changed ownership, staff, chef, or menu, and discusses the impact of these changes on the food or experience",0.824
13,188,"(SAE_0, 256, 8)",-0.129,"describes experiencing food poisoning symptoms (such as vomiting, nausea, diarrhea, stomachache) shortly after eating at the restaurant",0.765
12,166,"(SAE_0, 256, 8)",-0.131,"describes an error or problem with the order, bill, or service and details whether or how the staff or management responded to address the issue",0.869


**Evaluate held-out generalization**

Finally, we evaluate whether these are good hypotheses by testing whether their natural language interpretations can predict the target variable.  

We compute annotations for each hypothesized concept on a holdout set (not seen during SAE training & feature selection).

After annotation, we output a dataframe with the following columns:
- `hypothesis`: The natural language hypothesis (which came from interpreting a predictive neuron in the SAE)
- `separation_score`: How much the target variable differs when the concept is present vs. absent (i.e., $E[Y\mid\text{concept} = 1] - E[Y\mid\text{concept} = 0]$).
- `separation_pvalue`: The t-test p-value of the null hypothesis that the separation score is 0 (i.e., the concept is not associated with the target variable).
- `regression_coef`: The coefficient of the concept in a multivariate linear regression of the target variable on all concepts.
- `regression_pval`: The p-value of the null hypothesis that the regression coefficient is 0.
- `feature_prevalence`: The fraction of examples that contain the concept.

Additionally, we output the evaluation metrics used in the paper:
- Significant hypotheses: the number of hypotheses that are significant in the multivariate regression at a specified significance level (default $0.1$) after Bonferroni correction. You can pass in a different significance level using the `corrected_pval_threshold` parameter.
- AUC or $R^2$: how well the hypotheses collectively predict the target variable in the multivariate regression.


In [11]:
holdout_df = pd.read_json(os.path.join(base_dir, "yelp-demo-holdout-2K.json"), lines=True)
holdout_texts = holdout_df['text'].tolist()
holdout_labels = holdout_df['stars'].values

metrics, evaluation_df = evaluate_hypotheses(
    hypotheses_df=results,
    texts=holdout_texts,
    labels=holdout_labels,
    cache_name=CACHE_NAME,
    annotator_model=ANNOTATOR_MODEL,
    n_workers_annotation=N_WORKERS_ANNOTATION, # Please lower this parameter if you are running into OpenAI API rate limits
)

pd.set_option('display.max_colwidth', None)
display(evaluation_df.sort_values(by="separation_score", ascending=False).round(3))
pd.reset_option('display.max_colwidth')

print("\nHoldout Set Metrics:")
print(f"R² Score: {metrics['r2']:.3f}")
print(f"Significant hypotheses: {metrics['Significant'][0]}/{metrics['Significant'][1]} " 
      f"(p < {metrics['Significant'][2]:.3e})")

Step 1: Annotating texts with 20 hypotheses
Found 39997 cached items; annotating 3 uncached items


Annotating:   0%|          | 0/3 [00:00<?, ?it/s]

Step 2: Computing predictiveness of hypothesis annotations


Unnamed: 0,hypothesis,separation_score,separation_pval,regression_coef,regression_pval,feature_prevalence
2,"expresses strong personal enthusiasm or love for the restaurant, using emphatic language such as 'love this place', 'amazing', 'favorite', or multiple exclamation marks",1.404,0.0,0.569,0.0,0.297
15,"claims the restaurant or its offerings are the absolute best in a specific area, city, or even the world (e.g., 'best pizza south of Mason-Dixon', 'best drinks in town', 'best hot chicken in the world')",1.051,0.0,0.272,0.0,0.064
19,"describes the restaurant as a 'hidden gem', 'hidden secret', 'hidden treasure', or emphasizes that it is hard to find or off the beaten path",0.865,0.0,0.149,0.15,0.034
4,specifically mentions the name of a server or staff member when praising the service,0.862,0.0,0.285,0.0,0.078
14,"describes ordering or eating high-end steak cuts (e.g., ribeye, porterhouse, filet mignon, wagyu) at a steakhouse, including comments on preparation, doneness, or aging",0.267,0.326,0.286,0.076,0.014
16,"mentions that the restaurant has recently changed ownership, staff, chef, or menu, and discusses the impact of these changes on the food or experience",-0.418,0.08,-0.071,0.619,0.018
5,"describes food as good or decent but with some specific criticism or disappointment about a particular aspect (e.g., texture, flavor, preparation, or unmet expectations)",-1.088,0.0,-0.168,0.001,0.29
6,"complains about long wait times to be seated, served, or receive food",-1.465,0.0,0.012,0.877,0.088
7,"describes problems or failures specifically with online or third-party delivery orders (such as Grubhub, Uber Eats, Seamless), including missed deliveries, cancellations, wrong orders, or issues with communication or refunds",-1.706,0.0,-0.098,0.577,0.012
18,"describes receiving an order with missing, incorrect, or unfulfilled customization requests (such as wrong items, missing ingredients, or ignored special instructions)",-1.713,0.0,0.089,0.229,0.123



Holdout Set Metrics:
R² Score: 0.654
Significant hypotheses: 12/20 (p < 5.000e-03)
