In [28]:
%load_ext autoreload
%autoreload 2

import os
import joblib
import numpy as np
import pandas as pd
from neuro import config

from neuro.baselines.hypothesaes.quickstart import train_sae, interpret_sae, generate_hypotheses, evaluate_hypotheses
from neuro.baselines.hypothesaes.embedding import get_local_embeddings

INTERPRETER_MODEL = "gpt-4.1"
# ANNOTATOR_MODEL = 'meta-llama/Meta-Llama-3-8B-Instruct' # not actually used
EMBEDDER = "nomic-ai/modernbert-embed-base" # Huggingface model, will run locally
TASK_SPECIFIC_INSTRUCTIONS = """All of the texts are reviews of restaurants on Yelp.
Features should describe a specific aspect of the review. For example:
- "mentions long wait times to receive service"
- "praises how a dish was cooked, with phrases like 'perfect medium-rare'\""""

# paths
OUT_DIR = os.path.expanduser('~/mntv1/deep-fMRI/qa/hypothesaes')
os.makedirs(OUT_DIR, exist_ok=True)
os.environ['EMB_CACHE_DIR'] = os.path.join(OUT_DIR, os.path.expanduser('~/emb_cache'))
CACHE_NAME = f"yelp_quickstart_{EMBEDDER.replace('/', '___')}"
checkpoint_dir = os.path.join(OUT_DIR, "checkpoints", CACHE_NAME)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


**Load data**

The dataset we're using here is a subset of 20K Yelp reviews, with 2K reviews used for validation (during SAE training). 

The target variable is the `stars` column, which is a rating between 1 and 5. We treat this as a regression task.

There are also 2K reviews used for holdout evaluation, which we'll use at the end of the notebook.

In [13]:
base_dir = os.path.join(os.path.join(config.REPO_DIR, 'data', "demo_data"))
train_df = pd.read_json(os.path.join(base_dir, "yelp-demo-train-20K.json"), lines=True)
val_df = pd.read_json(os.path.join(base_dir, "yelp-demo-val-2K.json"), lines=True)

texts = train_df['text'].tolist()
labels = train_df['stars'].values
val_texts = val_df['text'].tolist() # These are only used for early stopping of SAE training, so we don't need labels.

**Compute text embeddings for your dataset**

We'll compute text embeddings for a training set, and optionally a validation set. The validation embeddings are used for SAE eval and early-stopping during training.

Embeddings will be stored in the `emb_cache` directory (or `os.environ["EMB_CACHE_DIR"]` if you set it) using the `cache_name` parameter, so you only need to compute embeddings once.

You can use OpenAI or a local model.

Local models will run much faster on GPU. The default local model is `nomic-ai/modernbert-embed-base`. You can use any sentence-transformers model, but please read the model's docs; you may need to edit `get_local_embeddings`.

In [None]:
# text2embedding = get_openai_embeddings(texts + val_texts, model=EMBEDDER, cache_name=CACHE_NAME)
text2embedding = get_local_embeddings(texts + val_texts, model=EMBEDDER, batch_size=128, cache_name=CACHE_NAME)
embeddings = np.stack([text2embedding[text] for text in texts])

train_embeddings = np.stack([text2embedding[text] for text in texts])
val_embeddings = np.stack([text2embedding[text] for text in val_texts])

In [None]:
joblib.dump((train_embeddings, val_embeddings), os.path.join(OUT_DIR, f"{CACHE_NAME}_embeddings.pkl"))

In [24]:
train_embeddings, val_embeddings = joblib.load(os.path.join(OUT_DIR, f"{CACHE_NAME}_embeddings.pkl"))

**Train SAE** 

We will train a Matryoshka SAE with $M=256$, $k=8$, and $\text{prefix\_lengths} = [32, 256]$.  

With the Matryoshka loss, the SAE will learn to reconstruct the input from (1) just the first 32 neurons, and (2) all 256 neurons.  
This will produce 32 coarse-grained features, and 224 finer-grained features.  

See the README for more details about selecting SAE hyperparameters. 

In [None]:

sae = train_sae(embeddings=train_embeddings, val_embeddings=val_embeddings,
                M=512, K=16, matryoshka_prefix_lengths=[64, 512], 
                checkpoint_dir=checkpoint_dir)

In [None]:
joblib.dump(sae, os.path.join(checkpoint_dir, "sae.pkl"))

In [25]:
sae = joblib.load(os.path.join(checkpoint_dir, "sae.pkl"))

**Interpret neurons**  

Interpret a random subset of neurons in the SAE to sanity-check that the learned features, and their interpretations, seem reasonable. We generate and print labels for `n_random_neurons` neurons, and we also print out the top-activating texts for each neuron.

In [None]:
# This instruction will be included in the neuron interpretation prompt.
# The below instructions are specific to Yelp, but you can customize this for your task.
# If you don't pass in task-specific instructions, there is a generic instruction (see src/interpret_neurons.py);
# task-specific instructions are optional, but they help produce hypotheses at the desired level of specificity.

# Interpret random neurons
results_examples = interpret_sae(
    texts=texts,
    embeddings=train_embeddings,
    sae=sae,
    n_random_neurons=5,
    print_examples_n=3,
    task_specific_instructions=TASK_SPECIFIC_INSTRUCTIONS,
    interpreter_model=INTERPRETER_MODEL,
)

**Generate hypotheses**

Generate hypotheses which are predictive of the target variable.

The `selection_method` parameter defines how we compute neuron predictiveness (see `src/select_neurons.py` for more details):
- "separation_score": E[target | top-activating examples] - E[target | zero-activating examples]
- "correlation": pearson(neuron activations, target variable)
- "lasso": select N nonzero features with an L1 regularized model

This cell outputs a dataframe with the following columns:
- `neuron_idx`: The index of the neuron in the SAE (if you're using multiple SAEs, this will be a global index across all of them).
- `source_sae`: The SAE that the neuron was selected from.
- `target_{selection_method}`: The predictiveness of the neuron for the target variable, using the selected `selection_method`.
- `interpretation`: The natural language interpretation of the neuron.
- `interp_fidelity_score`: The F1 fidelity score for how well the neuron's interpretation actually corresponds to its activation pattern.

In [None]:
labels_multitarget = np.vstack([labels, labels, labels]).T  # Duplicate labels for multi-target selection

results = generate_hypotheses(
    texts=texts,
    labels=labels_multitarget,
    embeddings=train_embeddings,
    sae=sae,
    cache_name=CACHE_NAME,
    classification=False,

    # hyperparams
    selection_method = "correlation_multi_target",
    n_selected_neurons=20,
    n_candidate_interpretations=1,
    task_specific_instructions=TASK_SPECIFIC_INSTRUCTIONS,
    interpreter_model=INTERPRETER_MODEL,

    # provide some scoring of the hypotheses using this many examples
    n_scoring_examples=0,
    # n_scoring_examples=100,
    # annotator_model=ANNOTATOR_MODEL,
    # n_workers_annotation=N_WORKERS_ANNOTATION, # Please lower this parameter if you are running into OpenAI API rate limits
)

Embeddings shape: (20000, 768)


Computing activations (batchsize=16384):   0%|          | 0/2 [00:00<?, ?it/s]

Activations shape: (20000, 512)

Step 1: Selecting top 20 predictive neurons

Step 2: Interpreting selected neurons


Generating interpretations:   0%|          | 0/20 [00:00<?, ?it/s]


Most predictive features of Yelp reviews:


NameError: name 'selection_method' is not defined

In [32]:
print("\nMost predictive features of Yelp reviews:")
pd.set_option('display.max_colwidth', None)
display(results.sort_values(by=f"target_correlation_multi_target", ascending=False))
pd.reset_option('display.max_colwidth')


Most predictive features of Yelp reviews:


Unnamed: 0,neuron_idx,source_sae,target_correlation_multi_target,interpretation
4,54,"(SAE_0, 512, 16)",0.351364,"describes the restaurant as a personal or family favorite, using phrases like 'my favorite', 'go to', 'all time favorite', or mentioning frequent repeat visits"
5,0,"(SAE_0, 512, 16)",0.241403,"describes the overall restaurant experience as exceptionally memorable or impressive, using superlative language and emphasizing a strong emotional reaction (e.g., 'mind blowingly good', 'in love', 'memorable experience', 'can't wait to turn more"
8,43,"(SAE_0, 512, 16)",0.215558,"mentions the ability to sample or try different items (such as flavors, drinks, or foods) before purchasing"
9,63,"(SAE_0, 512, 16)",0.188179,mentions intention to return or try more items at the restaurant
11,8,"(SAE_0, 512, 16)",0.172736,"mentions intention or desire to return to the restaurant, with phrases like 'will definitely be back', 'will return', or 'coming back again'"
14,59,"(SAE_0, 512, 16)",0.159896,"mentions a wine list, wine selection, or wines by the glass"
15,40,"(SAE_0, 512, 16)",0.155774,"describes specific dishes or items in detail, often naming ingredients, preparation methods, or flavors (e.g., 'cardamom gelato', 'vegan scramble', 'crispy tofu over soba noodles', 'chorizo breakfast"
16,10,"(SAE_0, 512, 16)",0.151753,"describes a positive experience with staff or owners going above and beyond in terms of helpfulness, hospitality, or accommodating special requests"
19,15,"(SAE_0, 512, 16)",-0.141028,mentions that the restaurant was closed or not open during expected business hours
18,61,"(SAE_0, 512, 16)",-0.141743,"describes a decline in quality or experience at a restaurant that was previously a favorite, often referencing changes such as new ownership, new chef, remodeled space, reduced menu, or increased prices"
