In [1]:
%load_ext autoreload
%autoreload 2

import os
import joblib
import numpy as np
import pandas as pd
from neuro import config

from neuro.baselines.hypothesaes.quickstart import train_sae, interpret_sae, generate_hypotheses, evaluate_hypotheses
from neuro.baselines.hypothesaes.embedding import get_local_embeddings

INTERPRETER_MODEL = "gpt-4.1"
# ANNOTATOR_MODEL = 'meta-llama/Meta-Llama-3-8B-Instruct' # not actually used
EMBEDDER = "nomic-ai/modernbert-embed-base" # Huggingface model, will run locally
TASK_SPECIFIC_INSTRUCTIONS = """All of the texts are reviews of restaurants on Yelp.
Features should describe a specific aspect of the review. For example:
- "mentions long wait times to receive service"
- "praises how a dish was cooked, with phrases like 'perfect medium-rare'\""""

# paths
OUT_DIR = os.path.expanduser('~/mntv1/deep-fMRI/qa/hypothesaes')
os.makedirs(OUT_DIR, exist_ok=True)
os.environ['EMB_CACHE_DIR'] = os.path.join(OUT_DIR, os.path.expanduser('~/emb_cache'))
CACHE_NAME = f"yelp_quickstart_{EMBEDDER.replace('/', '___')}"
checkpoint_dir = os.path.join(OUT_DIR, "checkpoints", CACHE_NAME)

**Load data**

The dataset we're using here is a subset of 20K Yelp reviews, with 2K reviews used for validation (during SAE training). 

The target variable is the `stars` column, which is a rating between 1 and 5. We treat this as a regression task.

There are also 2K reviews used for holdout evaluation, which we'll use at the end of the notebook.

In [2]:
base_dir = os.path.join(os.path.join(config.REPO_DIR, 'data', "demo_data"))
train_df = pd.read_json(os.path.join(base_dir, "yelp-demo-train-20K.json"), lines=True)
val_df = pd.read_json(os.path.join(base_dir, "yelp-demo-val-2K.json"), lines=True)

texts = train_df['text'].tolist()
labels = train_df['stars'].values
val_texts = val_df['text'].tolist() # These are only used for early stopping of SAE training, so we don't need labels.

**Compute text embeddings for your dataset**

We'll compute text embeddings for a training set, and optionally a validation set. The validation embeddings are used for SAE eval and early-stopping during training.

Embeddings will be stored in the `emb_cache` directory (or `os.environ["EMB_CACHE_DIR"]` if you set it) using the `cache_name` parameter, so you only need to compute embeddings once.

You can use OpenAI or a local model.

Local models will run much faster on GPU. The default local model is `nomic-ai/modernbert-embed-base`. You can use any sentence-transformers model, but please read the model's docs; you may need to edit `get_local_embeddings`.

In [None]:
# text2embedding = get_local_embeddings(texts + val_texts, model=EMBEDDER, batch_size=128, cache_name=CACHE_NAME)
# embeddings = np.stack([text2embedding[text] for text in texts])

# train_embeddings = np.stack([text2embedding[text] for text in texts])
# val_embeddings = np.stack([text2embedding[text] for text in val_texts])
# joblib.dump((train_embeddings, val_embeddings), os.path.join(OUT_DIR, f"{CACHE_NAME}_embeddings.pkl"))

In [3]:
train_embeddings, val_embeddings = joblib.load(os.path.join(OUT_DIR, f"{CACHE_NAME}_embeddings.pkl"))

**Train SAE** 

We will train a Matryoshka SAE with $M=256$, $k=8$, and $\text{prefix\_lengths} = [32, 256]$.  

With the Matryoshka loss, the SAE will learn to reconstruct the input from (1) just the first 32 neurons, and (2) all 256 neurons.  
This will produce 32 coarse-grained features, and 224 finer-grained features.  

See the README for more details about selecting SAE hyperparameters. 

In [None]:
# sae = train_sae(embeddings=train_embeddings, val_embeddings=val_embeddings,
#                 M=512, K=16, matryoshka_prefix_lengths=[64, 512], 
#                 checkpoint_dir=checkpoint_dir)
# joblib.dump(sae, os.path.join(checkpoint_dir, "sae.pkl"))

In [4]:
sae = joblib.load(os.path.join(checkpoint_dir, "sae.pkl"))

**Interpret neurons**  

Interpret a random subset of neurons in the SAE to sanity-check that the learned features, and their interpretations, seem reasonable. We generate and print labels for `n_random_neurons` neurons, and we also print out the top-activating texts for each neuron.

In [5]:
# This instruction will be included in the neuron interpretation prompt.
# The below instructions are specific to Yelp, but you can customize this for your task.
# If you don't pass in task-specific instructions, there is a generic instruction (see src/interpret_neurons.py);
# task-specific instructions are optional, but they help produce hypotheses at the desired level of specificity.

# Interpret random neurons
results_examples = interpret_sae(
    texts=texts,
    embeddings=train_embeddings,
    sae=sae,
    n_random_neurons=5,
    print_examples_n=3,
    task_specific_instructions=TASK_SPECIFIC_INSTRUCTIONS,
    interpreter_model=INTERPRETER_MODEL,
)

Computing activations (batchsize=16384):   0%|          | 0/2 [00:00<?, ?it/s]

Activations shape: (20000, 512)


Generating interpretations:   0%|          | 0/5 [00:00<?, ?it/s]


Neuron 153 (0.0% active, from SAE M=512, K=16): The input describes the seating arrangement, table setup, or physical layout of the restaurant, such as mentioning table spacing, table assignment, seating comfort, lounge areas, or crowdedness.

Top activating examples:
1. Firstly, keep in mind that it takes forever to get a reservation here, so plan early (like weeks ahead of time) to get a reservation at a decent time. I really liked the atmosphere - lights are dimmed, lots of straight lines with a casual upscale type of feel. The food was really good - my favorites were the UOVO pizza...mmmm the truffled farm egg was soo good...and the wild mushroom arancini. Their drink selection is also very good with seasonal craft beers rotated through out the year. The only reason I'm giving Barbuzzo 4 stars is based on the service. Firstly, we were told we had to order everything at once since the restaurant was full which made us feel very rushed. Secondly, even though we ordered our drinks at

**Generate hypotheses**

Generate hypotheses which are predictive of the target variable.

The `selection_method` parameter defines how we compute neuron predictiveness (see `src/select_neurons.py` for more details):
- "separation_score": E[target | top-activating examples] - E[target | zero-activating examples]
- "correlation": pearson(neuron activations, target variable)
- "lasso": select N nonzero features with an L1 regularized model

This cell outputs a dataframe with the following columns:
- `neuron_idx`: The index of the neuron in the SAE (if you're using multiple SAEs, this will be a global index across all of them).
- `source_sae`: The SAE that the neuron was selected from.
- `target_{selection_method}`: The predictiveness of the neuron for the target variable, using the selected `selection_method`.
- `interpretation`: The natural language interpretation of the neuron.
- `interp_fidelity_score`: The F1 fidelity score for how well the neuron's interpretation actually corresponds to its activation pattern.

In [6]:
labels_multitarget = np.vstack([labels, labels, labels]).T  # Duplicate labels for multi-target selection

results = generate_hypotheses(
    texts=texts,
    labels=labels_multitarget,
    embeddings=train_embeddings,
    sae=sae,
    cache_name=CACHE_NAME,
    classification=False,

    # hyperparams
    selection_method = "correlation_multi_target",
    n_selected_neurons=20,
    n_candidate_interpretations=1,
    task_specific_instructions=TASK_SPECIFIC_INSTRUCTIONS,
    interpreter_model=INTERPRETER_MODEL,

    # provide some scoring of the hypotheses using this many examples
    n_scoring_examples=0,
    # n_scoring_examples=100,
    # annotator_model=ANNOTATOR_MODEL,
    # n_workers_annotation=N_WORKERS_ANNOTATION, # Please lower this parameter if you are running into OpenAI API rate limits
)

Embeddings shape: (20000, 768)


Computing activations (batchsize=16384):   0%|          | 0/2 [00:00<?, ?it/s]

Activations shape: (20000, 512)

Step 1: Selecting top 20 predictive neurons

Step 2: Interpreting selected neurons


Generating interpretations:   0%|          | 0/20 [00:00<?, ?it/s]

In [7]:
print("\nMost predictive features of Yelp reviews:")
pd.set_option('display.max_colwidth', None)
display(results.sort_values(by=f"target_correlation_multi_target", ascending=False))
pd.reset_option('display.max_colwidth')


Most predictive features of Yelp reviews:


Unnamed: 0,neuron_idx,source_sae,target_correlation_multi_target,interpretation
4,54,"(SAE_0, 512, 16)",0.351364,"The input describes the restaurant as the reviewer's personal favorite or go-to place, often using phrases like 'my favorite', 'my go to', 'absolutely love', or expressing frequent repeat visits."
5,0,"(SAE_0, 512, 16)",0.241403,"The input describes the restaurant experience as exceptionally memorable or impressive, often using superlative language (e.g., 'amazing', 'incredible', 'out of this world', 'top ten', 'mind blowingly good'), and emphasizes"
8,43,"(SAE_0, 512, 16)",0.215558,"The input describes the experience of being able to sample or try multiple items (such as flavors, foods, or drinks) before purchasing or making a choice."
9,63,"(SAE_0, 512, 16)",0.188179,"The input describes an intention or plan to return to the restaurant in the future (e.g., 'will definitely be back', 'I can't wait to come back here', 'will be returning', 'definitely will return again soon', '"
11,8,"(SAE_0, 512, 16)",0.172736,"The input expresses an intention to return to the restaurant, using phrases like 'will be back', 'will return', or 'coming back again'"
14,59,"(SAE_0, 512, 16)",0.159896,"The input describes details about a restaurant's wine selection, such as a large or rotating wine list, wines on tap, or wine flights."
15,40,"(SAE_0, 512, 16)",0.155774,"The input describes or recommends specific menu items, often mentioning unique ingredients, flavors, or combinations, and sometimes provides tips on what to order or how to customize the dish."
16,10,"(SAE_0, 512, 16)",0.151753,"The input describes a standout, positive customer service experience, where staff or owners go above and beyond to accommodate needs, such as resolving reservation issues, providing personalized recommendations, adapting to challenges, or giving special attention (e.g., accommodating dietary restrictions"
19,15,"(SAE_0, 512, 16)",-0.141028,"The input describes the restaurant being closed (either temporarily or permanently), or not open during the advertised or expected hours."
18,61,"(SAE_0, 512, 16)",-0.141743,"The input describes a decline in quality or experience at a restaurant compared to previous visits, often attributing the change to new ownership, management, chef, menu changes, or remodels"


In [9]:
results['interpretation']

0     The input describes food as inedible, rotten, ...
1     The input describes negative experiences with ...
2     The input describes food as bland, lacking fla...
3     The input describes receiving an incorrect res...
4     The input describes the restaurant as the revi...
5     The input describes the restaurant experience ...
6     The input describes negative experiences with ...
7     The input describes experiencing or anticipati...
8     The input describes the experience of being ab...
9     The input describes an intention or plan to re...
10    The input describes extremely negative experie...
11    The input expresses an intention to return to ...
12    The input describes restaurant staff behaving ...
13    The input contains sarcasm or irony, often usi...
14    The input describes details about a restaurant...
15    The input describes or recommends specific men...
16    The input describes a standout, positive custo...
17    The input describes the restaurant being o

In [21]:
input_list = results['interpretation'].values.tolist()

In [42]:
input_list

['The input describes food as inedible, rotten, spoiled, or causing illness (e.g., mentions of food poisoning, nausea, vomiting, or food tasting/smelling rotten or rancid)', 'The input describes negative experiences with restaurant staff, specifically mentioning being ignored, receiving rude or unprofessional service, or staff failing to respond to requests or correct mistakes.', 'The input describes food as bland, lacking flavor, underseasoned, or dull.', "The input describes receiving an incorrect restaurant order, such as missing, wrong, or unmodified items, and often details the restaurant's inadequate response to the mistake.", "The input describes the restaurant as the reviewer's personal favorite or go-to place, often using phrases like 'my favorite', 'my go to', 'absolutely love', or expressing frequent repeat visits.", "The input describes the restaurant experience as exceptionally memorable or impressive, often using superlative language (e.g., 'amazing', 'incredible', 'out o

In [53]:
input_list = ['The input mentions long wait times to receive service?', 'Does the input mention rude staff behavior?'] + input_list

In [46]:
lm = imodelsx.llm.get_llm('gpt-4.1')

In [57]:
qs = neuro.agent.revise_invalid_questions_by_rewording(input_list, lm)

cached!


In [58]:
print('\n'.join(qs))

Does the input describe an intention or plan to return to the restaurant in the future (e.g., 'will definitely be back', 'I can't wait to come back here', 'will be returning', 'definitely will return again soon')?
Does the input describe food as inedible, rotten, spoiled, or causing illness (e.g., mentions of food poisoning, nausea, vomiting, or food tasting/smelling rotten or rancid)?
Does the input describe experiencing or anticipating a long wait time to receive food or service at a restaurant, often specifying the duration (e.g., '30 minutes', '45-60 minutes', '51 mins'), and expressing frustration or inconvenience caused by?
Does the input mention rude staff behavior?
Does the input describe the experience of being able to sample or try multiple items (such as flavors, foods, or drinks) before purchasing or making a choice?
Does the input express an intention to return to the restaurant, using phrases like 'will be back', 'will return', or 'coming back again'?
Does the input descr

In [18]:
import imodelsx.llm

In [None]:
PROMPT_REWORD = f"""
Rewrite every element in this list to be a question starting with "Does the input" and ending with a question mark.

- {'\n- '.join(input_list)}

Make as few changes as possible to the original text.
Return only a python list and nothing else.""".strip()


In [39]:
ans_list = agent(PROMPT_REWORD, max_completion_tokens=None)

not cached


In [41]:
print(PROMPT_REWORD)

Rewrite every element in this list to be a question starting with "Does the input" and ending with a question mark.

- The input describes food as inedible, rotten, spoiled, or causing illness (e.g., mentions of food poisoning, nausea, vomiting, or food tasting/smelling rotten or rancid)
- The input describes negative experiences with restaurant staff, specifically mentioning being ignored, receiving rude or unprofessional service, or staff failing to respond to requests or correct mistakes.
- The input describes food as bland, lacking flavor, underseasoned, or dull.
- The input describes receiving an incorrect restaurant order, such as missing, wrong, or unmodified items, and often details the restaurant's inadequate response to the mistake.
- The input describes the restaurant as the reviewer's personal favorite or go-to place, often using phrases like 'my favorite', 'my go to', 'absolutely love', or expressing frequent repeat visits.
- The input describes the restaurant experience a

In [40]:
print(ans_list)

[
    "Does the input describe food as inedible, rotten, spoiled, or causing illness (e.g., mentions of food poisoning, nausea, vomiting, or food tasting/smelling rotten or rancid)?",
    "Does the input describe negative experiences with restaurant staff, specifically mentioning being ignored, receiving rude or unprofessional service, or staff failing to respond to requests or correct mistakes?",
    "Does the input describe food as bland, lacking flavor, underseasoned, or dull?",
    "Does the input describe receiving an incorrect restaurant order, such as missing, wrong, or unmodified items, and often detail the restaurant's inadequate response to the mistake?",
    "Does the input describe the restaurant as the reviewer's personal favorite or go-to place, often using phrases like 'my favorite', 'my go to', 'absolutely love', or expressing frequent repeat visits?",
    "Does the input describe the restaurant experience as exceptionally memorable or impressive, often using superlativ

In [34]:
print(PROMPT_REWORD)

Rewrite every element in this list to be a question starting with "Does the input" and ending with a question mark.

- The input describes food as inedible, rotten, spoiled, or causing illness (e.g., mentions of food poisoning, nausea, vomiting, or food tasting/smelling rotten or rancid)
- The input describes negative experiences with restaurant staff, specifically mentioning being ignored, receiving rude or unprofessional service, or staff failing to respond to requests or correct mistakes.
- The input describes food as bland, lacking flavor, underseasoned, or dull.
- The input describes receiving an incorrect restaurant order, such as missing, wrong, or unmodified items, and often details the restaurant's inadequate response to the mistake.
- The input describes the restaurant as the reviewer's personal favorite or go-to place, often using phrases like 'my favorite', 'my go to', 'absolutely love', or expressing frequent repeat visits.
- The input describes the restaurant experience a