# Experiment Reproduction

This notebook contains code to reproduce experiments in our paper, [Sparse Autoencoders for Hypothesis Generation](https://arxiv.org/abs/2502.04382).  

Currently, it includes the three real-world datasets we study: Headlines (predicting online engagement levels with news headlines), Yelp (predicting restaurant rating from review text), and Congress (predicting party affiliation from Congressional speech excerpts).  

Important notes:
1. Our processed, documented versions of the data can be found on HuggingFace: [https://huggingface.co/datasets/rmovva/HypotheSAEs](https://huggingface.co/datasets/rmovva/HypotheSAEs).
2. There is some run-to-run variance due to stochasticity in neuron interpretation, which uses an LLM (default temperature 0.7). This may explain differences from results in the paper.
3. There are multiple possible strategies when sampling texts to interpret an SAE neuron:
    - The default version of the code (i.e., if you use the quickstart function `generate_hypotheses()`) interprets neurons by using the top-10 activating examples. 
    - In contrast, for the paper we used a random sample of 10 examples from the top decile (or quintile, for Headlines) of positive activations (by passing in `sample_percentile_bins` as the sampling function). This sampling strategy produces less specific concepts, so they occur more frequently and are therefore more predictive. 
    - For most use cases, we expect that more specific concepts are more useful, which is why the default version of the code uses the absolute top-activating examples. However, if you would like less specific concepts, you can vary the sampling strategy by implementing something similar to the code in this notebook.


In [None]:
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import pandas as pd

os.environ['OPENAI_KEY_SAE'] = '...' # Replace with your OpenAI API key, or with another environment variable (e.g. os.environ['OPENAI_API_KEY'])

from hypothesaes.quickstart import train_sae, generate_hypotheses, evaluate_hypotheses

from hypothesaes.embedding import get_openai_embeddings
from hypothesaes.sae import get_multiple_sae_activations
from hypothesaes.select_neurons import select_neurons
from hypothesaes.interpret_neurons import NeuronInterpreter, InterpretConfig, ScoringConfig, SamplingConfig, sample_percentile_bins
from hypothesaes.annotate import annotate_texts_with_concepts
from hypothesaes.evaluation import score_hypotheses

base_dir = "/share/pierson/raj/hypothesis-generation-data/"

## Headlines

Note that Headlines requires some special handling because it is a pairwise dataset.  
Each example contains two headlines, A and B, and a binary label: 1 if headline A received higher clicks than B, 0 otherwise. A and B were randomly shuffled, so the label is approximately balanced.

We apply HypotheSAEs as follows:
1. Train SAE on all unique headlines in the train set.
2. Compute $Z_A$, an $N \times M$ activation matrix for all headline A's; similarly compute $Z_B$.
3. Compute $Z_{\Delta} = Z_A - Z_B$, which represents how much more (or less) each feature activates in headline A compared to B.
4. Select 20 neurons using L1 logistic regression on $Z_{\Delta}$.
5. Interpret selected neurons by sampling from the texts originally used to train the SAE.

For evaluation:
1. Compute $H_A$, an $N \times 20$ binary matrix of hypothesis annotations for all headlines A, and $H_B$ similarly.
2. Compute $H_{\Delta} = H_A - H_B$, and regress the label on this difference. Note that $H_{\Delta}$ can take on values in $\{-1, 0, 1\}$, which we account for in `src.evaluation.compute_hypothesis_separation_scores`.


In [100]:
DATASET_NAME = "headlines"
EMBEDDER = "text-embedding-3-small"
CACHE_NAME = f"{DATASET_NAME}_{EMBEDDER}"

label_col = "label_pairwise"
train_df = pd.read_json(os.path.join(base_dir, DATASET_NAME, "headline_pairs_train.json"), lines=True)
val_df = pd.read_json(os.path.join(base_dir, DATASET_NAME, "headline_pairs_validation_v2.json"), lines=True)

train_labels = train_df[label_col].values

unique_train_texts = np.unique(train_df["headline_A"].tolist() + train_df["headline_B"].tolist()).tolist()
unique_val_texts = np.unique(val_df["headline_A"].tolist() + val_df["headline_B"].tolist()).tolist()

text2embedding = get_openai_embeddings(unique_train_texts + unique_val_texts, model=EMBEDDER, cache_name=CACHE_NAME)

train_embeddings = np.stack([text2embedding[text] for text in unique_train_texts])
val_embeddings = np.stack([text2embedding[text] for text in unique_val_texts])

Loading embedding chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Loaded 16468 embeddings in 2.8s


In [101]:
checkpoint_dir = f"./checkpoints/{CACHE_NAME}"
sae_256_8 = train_sae(embeddings=train_embeddings, M=256, K=8, checkpoint_dir=checkpoint_dir, val_embeddings=val_embeddings, n_epochs=200, patience=5)
sae_32_4 = train_sae(embeddings=train_embeddings, M=32, K=4, checkpoint_dir=checkpoint_dir, val_embeddings=val_embeddings, n_epochs=200, patience=5)

sae_list = [sae_256_8, sae_32_4]

unique_train_activations, neuron_source_sae_info = get_multiple_sae_activations(sae_list, train_embeddings, return_neuron_source_info=True)

headline_A_train = np.stack([text2embedding[text] for text in train_df["headline_A"].tolist()])
headline_B_train = np.stack([text2embedding[text] for text in train_df["headline_B"].tolist()])

Z_A = get_multiple_sae_activations(sae_list, headline_A_train)
Z_B = get_multiple_sae_activations(sae_list, headline_B_train)
Z_delta = Z_A - Z_B

selected_neurons, scores = select_neurons(
    activations=Z_delta,
    target=train_labels,
    n_select=20,
    method="lasso",
    classification=True,
    verbose=True,
)

  0%|          | 0/200 [00:00<?, ?it/s]

Early stopping triggered after 79 epochs
Saved model to ./checkpoints/headlines_text-embedding-3-small/SAE_M=256_K=8.pt


  0%|          | 0/200 [00:00<?, ?it/s]

Early stopping triggered after 63 epochs
Saved model to ./checkpoints/headlines_text-embedding-3-small/SAE_M=32_K=4.pt
LASSO iteration   L1 Alpha # Features   Time (s)
----------------------------------------
       0   1.00e-01        287       1.75
       1   3.16e+01        199       0.50
       2   5.62e+02          2       0.21
       3   1.33e+02         74       0.28
       4   2.74e+02         24       0.20
       5   3.92e+02         12       0.18
       6   3.28e+02         18       0.21
       7   3.00e+02         21       0.17
       8   3.13e+02         20       0.19

Found alpha=3.13e+02 yielding exactly 20 features
Total search time: 3.70s


In [102]:
interpreter = NeuronInterpreter(
    interpreter_model="gpt-4o",
    annotator_model="gpt-4o-mini",
    n_workers_annotation=50,
    cache_name=CACHE_NAME,
)

interpretation_sampling_config = SamplingConfig(
    function=sample_percentile_bins,
    n_examples=20,
    sampling_kwargs={
        "high_percentile": (80, 100),
        "low_percentile": None, # Randomly samples from zero-activating examples
    },
)

interpret_config = InterpretConfig(
    sampling=interpretation_sampling_config,
    n_candidates=5,
    task_specific_instructions="""All of the texts are digital news headlines.
Features should describe specific characteristics in the headline texts. Example features may include:
- "explicitly mentions climate change or another environmental issue"
- "asks the reader a rhetorical question\"""",
)

interpretations = interpreter.interpret_neurons(
    texts=unique_train_texts,
    activations=unique_train_activations,
    neuron_indices=selected_neurons,
    config=interpret_config,
)

Generating 5 interpretation(s) per neuron:   0%|          | 0/100 [00:00<?, ?it/s]

In [103]:
scoring_config = ScoringConfig(
    n_examples=200,
    sampling_function=sample_percentile_bins,
    sampling_kwargs={
        "high_percentile": (80, 100),
        "low_percentile": None, # Randomly samples from zero-activating examples
    },
)

all_metrics = interpreter.score_interpretations(
    texts=unique_train_texts,
    activations=unique_train_activations,
    interpretations=interpretations,
    config=scoring_config,
)

results = []
for neuron_idx, fidelity_scores in all_metrics.items():
    best_interp = max(interpretations[neuron_idx], key=lambda interp: fidelity_scores[interp]['f1'])
    results.append({
        'neuron_idx': neuron_idx,
        'source_sae': neuron_source_sae_info[neuron_idx],
        'target_lasso_coef': scores[selected_neurons.index(neuron_idx)],
        'interpretation': best_interp,
        'f1_fidelity_score': fidelity_scores[best_interp]['f1'],
    })

hypothesis_df = pd.DataFrame(results)
pd.set_option('display.max_colwidth', None)
display(hypothesis_df.sort_values(by='target_lasso_coef', ascending=False))
pd.reset_option('display.max_colwidth')

Found 55 cached items; annotating 19835 uncached items


Scoring neuron interpretation fidelity (20 neurons; 5 candidate interps per neuron; 200 examples to score each…

Unnamed: 0,neuron_idx,source_sae,target_lasso_coef,interpretation,f1_fidelity_score
1,269,"(32, 4)",0.099923,mentions a video or visual media explicitly in the headline,0.891
4,279,"(32, 4)",0.056853,explicitly implies the revelation of a hidden truth or secret,0.84426
5,163,"(256, 8)",0.055063,"contains words or phrases that describe something as 'awful', 'horrible', 'dangerous', 'painful', or similarly negative adjectives",0.874743
6,126,"(256, 8)",0.054179,"contains phrases that challenge the reader's assumptions or expectations (e.g., 'you might not think', 'wait till you see', 'you probably don't think', 'might change the way you look')",0.6912
7,6,"(256, 8)",0.053329,mentions a personal or emotional event followed by an unexpected or dramatic twist,0.903674
8,268,"(32, 4)",0.044655,mentions interpersonal conflict or negative interactions between individuals,0.825663
11,283,"(32, 4)",0.031996,"mentions a specific male individual using 'he', 'his', or 'him'",0.826941
14,281,"(32, 4)",0.017388,"focuses on a woman or uses female pronouns (e.g., she, her)",0.989899
15,259,"(32, 4)",0.011012,"mentions body image, appearance, or how people view themselves",0.642207
16,277,"(32, 4)",0.010464,"uses emotional language to describe personal reactions or feelings (e.g., crying, shouting, feeling sick, feeling strength)",0.883616


In [104]:
holdout_df = pd.read_json(os.path.join(base_dir, DATASET_NAME, "headline_pairs_holdout_v2.json"), lines=True)

headline_A_holdout = holdout_df["headline_A"].tolist()
headline_B_holdout = holdout_df["headline_B"].tolist()
labels_holdout = holdout_df[label_col].values

hypotheses = hypothesis_df["interpretation"].tolist()

headline_A_annotations = annotate_texts_with_concepts(
    texts=headline_A_holdout,
    concepts=hypotheses,
    cache_name=CACHE_NAME,
)

headline_B_annotations = annotate_texts_with_concepts(
    texts=headline_B_holdout,
    concepts=hypotheses,
    cache_name=CACHE_NAME,
)

delta_annots = {key : headline_A_annotations[key] - headline_B_annotations[key] for key in headline_A_annotations.keys()}

metrics, evaluation_df = score_hypotheses(
    hypothesis_annotations=delta_annots,
    y_true=labels_holdout,
    classification=True
)

pd.set_option('display.max_colwidth', None)
display(evaluation_df.sort_values(by='separation_score', ascending=False).round(4))
pd.reset_option('display.max_colwidth')

print(f"AUC: {metrics['auroc']:.3f}")
print(f"Significant hypotheses: {metrics['Significant'][0]}/{metrics['Significant'][1]} " 
      f"(p < {metrics['Significant'][2]:.3e})")

Found 0 cached items; annotating 87140 uncached items


Annotating:   0%|          | 0/87140 [00:00<?, ?it/s]

Found 5040 cached items; annotating 82100 uncached items


Annotating:   0%|          | 0/82100 [00:00<?, ?it/s]

Optimization terminated successfully.
         Current function value: 0.631850
         Iterations 5


Unnamed: 0,hypothesis,separation_score,separation_pval,regression_coef,regression_pval,feature_prevalence
7,mentions a personal or emotional event followed by an unexpected or dramatic twist,0.1812,0.0,0.4397,0.0,0.2031
4,explicitly implies the revelation of a hidden truth or secret,0.1654,0.0,0.5535,0.0,0.2931
18,"mentions a specific duration of time (e.g., seconds, minutes)",0.1449,0.0,0.571,0.0001,0.0521
8,mentions interpersonal conflict or negative interactions between individuals,0.1316,0.0,0.3016,0.0002,0.1923
14,"focuses on a woman or uses female pronouns (e.g., she, her)",0.1197,0.0,0.3201,0.001,0.1187
11,"mentions a specific male individual using 'he', 'his', or 'him'",0.1047,0.0,0.1617,0.1673,0.0842
1,mentions a video or visual media explicitly in the headline,0.1034,0.0,0.4749,0.0,0.1047
5,"contains words or phrases that describe something as 'awful', 'horrible', 'dangerous', 'painful', or similarly negative adjectives",0.0922,0.0,0.1365,0.0867,0.1958
6,"contains phrases that challenge the reader's assumptions or expectations (e.g., 'you might not think', 'wait till you see', 'you probably don't think', 'might change the way you look')",0.0859,0.0,0.3351,0.0,0.4044
15,"mentions body image, appearance, or how people view themselves",0.0738,0.0,0.1708,0.2466,0.0503


AUC: 0.695
Significant hypotheses: 14/20 (p < 5.000e-03)


## Yelp

In [90]:
DATASET_NAME = "yelp"
EMBEDDER = "text-embedding-3-small"
CACHE_NAME = f"{DATASET_NAME}_{EMBEDDER}"

text_col = "text"
label_col = "stars"
train_df = pd.read_json(os.path.join(base_dir, DATASET_NAME, "train-200K.json"), lines=True)
val_df = pd.read_json(os.path.join(base_dir, DATASET_NAME, "val-10K.json"), lines=True)

train_texts = train_df[text_col].tolist()
val_texts = val_df[text_col].tolist()

train_labels = train_df[label_col].values

text2embedding = get_openai_embeddings(train_texts + val_texts, model=EMBEDDER, cache_name=CACHE_NAME)

train_embeddings = np.stack([text2embedding[text] for text in train_texts])
val_embeddings = np.stack([text2embedding[text] for text in val_texts])

Loading embedding chunks:   0%|          | 0/5 [00:00<?, ?it/s]

Loaded 209977 embeddings in 18.0s


In [91]:
checkpoint_dir = f"./checkpoints/{CACHE_NAME}"
sae_1024_32 = train_sae(embeddings=train_embeddings, M=1024, K=32, checkpoint_dir=checkpoint_dir, val_embeddings=val_embeddings)

activations = sae_1024_32.get_activations(train_embeddings)

selected_neurons, scores = select_neurons(
    activations=activations,
    target=train_labels,
    n_select=20,
    method="lasso",
    classification=False,
    verbose=True,
)

Loaded model from ./checkpoints/yelp_text-embedding-3-small/SAE_M=1024_K=32.pt
LASSO iteration   L1 Alpha # Features   Time (s)
----------------------------------------
       0   1.00e-01          5       3.81
       1   3.16e-04        913      10.00
       2   5.62e-03        197       4.94
       3   2.37e-02         45       5.64
       4   4.87e-02         12       7.13
       5   3.40e-02         26       8.33
       6   4.07e-02         19       8.10
       7   3.72e-02         21       8.11
       8   3.89e-02         19       8.11
       9   3.80e-02         19       8.15
      10   3.76e-02         19       8.14
      11   3.74e-02         21       8.17
      12   3.75e-02         20       8.21

Found alpha=3.75e-02 yielding exactly 20 features
Total search time: 96.93s


In [92]:
interpreter = NeuronInterpreter(
    interpreter_model="gpt-4o",
    annotator_model="gpt-4o-mini",
    n_workers_annotation=50,
    cache_name=CACHE_NAME,
)

interpretation_sampling_config = SamplingConfig(
    function=sample_percentile_bins,
    n_examples=20,
    sampling_kwargs={
        "high_percentile": (90, 100),
        "low_percentile": None, # Randomly samples from zero-activating examples
    },
)

interpret_config = InterpretConfig(
    sampling=interpretation_sampling_config,
    n_candidates=3,
    task_specific_instructions = """All of the texts are reviews of restaurants on Yelp.
Features should describe a specific aspect of the review. For example:
- "mentions long wait times to receive service"
- "praises how a dish was cooked, with phrases like 'perfect medium-rare'\""""
)

interpretations = interpreter.interpret_neurons(
    texts=train_texts,
    activations=activations,
    neuron_indices=selected_neurons,
    config=interpret_config,
)

scoring_config = ScoringConfig(
    n_examples=200,
    sampling_function=sample_percentile_bins,
    sampling_kwargs={
        "high_percentile": (90, 100),
        "low_percentile": None, # Randomly samples from zero-activating examples
    },
)

all_metrics = interpreter.score_interpretations(
    texts=train_texts,
    activations=activations,
    interpretations=interpretations,
    config=scoring_config,
)

Generating 3 interpretation(s) per neuron:   0%|          | 0/60 [00:00<?, ?it/s]

Found 2 cached items; annotating 11998 uncached items


Scoring neuron interpretation fidelity (20 neurons; 3 candidate interps per neuron; 200 examples to score each…

In [93]:
results = []
for neuron_idx, fidelity_scores in all_metrics.items():
    best_interp = max(interpretations[neuron_idx], key=lambda interp: fidelity_scores[interp]['f1'])
    results.append({
        'neuron_idx': neuron_idx,
        'target_lasso_coef': scores[selected_neurons.index(neuron_idx)],
        'interpretation': best_interp,
        'f1_fidelity_score': fidelity_scores[best_interp]['f1'],
    })

interpretation_df = pd.DataFrame(results)
pd.set_option('display.max_colwidth', None)
display(interpretation_df.sort_values(by='target_lasso_coef', ascending=False))
pd.reset_option('display.max_colwidth')

Unnamed: 0,neuron_idx,target_lasso_coef,interpretation,f1_fidelity_score
2,767,0.077037,"describes food as exceptionally flavorful or perfectly prepared, using superlatives like 'amazing', 'perfect', 'incredible', or 'unbelievable'",0.787975
7,313,0.023009,uses enthusiastic language to describe the taste or quality of specific dishes,0.667857
10,4,0.009572,explicitly describes something as 'the best' or 'best [specific item]' in a positive context,0.963731
14,554,0.005992,"expresses extreme enthusiasm or excitement about the food, using repeated words or exaggerated phrases (e.g., 'love love love', 'to die for', 'drool worthy')",0.684088
19,262,5.3e-05,"uses superlative language to describe the restaurant as the best in a specific category (e.g., 'best lobster roll', 'best bakery', 'best cheesesteaks')",0.929032
18,21,-0.002484,"mentions dissatisfaction with specific aspects of a venue or service, often with detailed complaints about cleanliness, maintenance, or customer service",0.572773
17,840,-0.00303,"describes safety concerns or warnings, such as issues with spicy food, unsafe bathrooms, or drink tampering",0.657534
16,763,-0.003443,"mentions negative or critical commentary about the restaurant's atmosphere, food quality, or service in a sarcastic or exaggerated tone",0.643101
15,864,-0.003551,"mentions issues or complaints about restaurant policies, services, or operational practices",0.587288
13,941,-0.00659,"mentions experiences of food causing illness, such as food poisoning, stomach pain, or vomiting",0.979592


In [94]:
np.random.seed(42)
holdout_df = pd.read_json(os.path.join(base_dir, DATASET_NAME, "holdout-50K.json"), lines=True)
holdout_sample_size = 10000
holdout_indices = np.random.choice(len(holdout_df), size=holdout_sample_size, replace=False)
holdout_sample = holdout_df.iloc[holdout_indices]

holdout_texts = holdout_sample[text_col].tolist()
holdout_labels = holdout_sample[label_col].values

metrics, hypothesis_df = evaluate_hypotheses(
    hypotheses_df=interpretation_df,
    texts=holdout_texts,
    labels=holdout_labels,
    cache_name=CACHE_NAME,
    classification=False,
    n_workers_annotation=50,
)

pd.set_option('display.max_colwidth', None)
display(hypothesis_df.round(3).sort_values(by='separation_score', ascending=False))
pd.reset_option('display.max_colwidth')

print(f"R^2: {metrics['r2']:.3f}")
print(f"Significant hypotheses: {metrics['Significant'][0]}/{metrics['Significant'][1]} " 
      f"(p < {metrics['Significant'][2]:.3e})")

Step 1: Annotating texts with 20 hypotheses
Found 0 cached items; annotating 200000 uncached items


Annotating:   0%|          | 0/200000 [00:00<?, ?it/s]

Step 2: Computing predictiveness of hypothesis annotations


Unnamed: 0,hypothesis,separation_score,separation_pval,regression_coef,regression_pval,feature_prevalence
7,uses enthusiastic language to describe the taste or quality of specific dishes,2.012,0.0,0.49,0.0,0.611
14,"expresses extreme enthusiasm or excitement about the food, using repeated words or exaggerated phrases (e.g., 'love love love', 'to die for', 'drool worthy')",1.494,0.0,0.215,0.0,0.35
2,"describes food as exceptionally flavorful or perfectly prepared, using superlatives like 'amazing', 'perfect', 'incredible', or 'unbelievable'",1.409,0.0,0.043,0.036,0.368
10,explicitly describes something as 'the best' or 'best [specific item]' in a positive context,1.114,0.0,0.106,0.0,0.177
19,"uses superlative language to describe the restaurant as the best in a specific category (e.g., 'best lobster roll', 'best bakery', 'best cheesesteaks')",1.072,0.0,0.002,0.96,0.122
3,mentions affordability or reasonable pricing of food or drinks,0.245,0.0,0.032,0.085,0.169
8,mentions high or unexpected prices for food or drinks,-0.971,0.0,-0.082,0.0,0.11
9,"mentions long wait times for seating, food, or service",-1.28,0.0,0.06,0.008,0.132
17,"describes safety concerns or warnings, such as issues with spicy food, unsafe bathrooms, or drink tampering",-1.861,0.0,-0.142,0.001,0.04
1,mentions that the food is mediocre or has flaws in preparation or freshness,-1.925,0.0,-0.167,0.0,0.284


R^2: 0.776
Significant hypotheses: 16/20 (p < 5.000e-03)


### Top sampling (instead of percentile sampling)

In [55]:
task_specific_instructions = """All of the texts are reviews of restaurants on Yelp.
Features should describe a specific aspect of the review. For example:
- "mentions long wait times to receive service"
- "praises how a dish was cooked, with phrases like 'perfect medium-rare'\""""

hypotheses_df = generate_hypotheses(
    texts=train_texts,
    labels=train_labels,
    embeddings=train_embeddings,
    sae=sae_1024_32,
    cache_name=CACHE_NAME,
    classification=False,
    selection_method="lasso",
    n_selected_neurons=20,
    interpreter_model="gpt-4o",
    annotator_model="gpt-4o-mini",
    n_examples_for_interpretation=20,
    max_words_per_example=128,
    interpret_temperature=0.7,
    max_interpretation_tokens=100,
    n_candidate_interpretations=3,
    n_scoring_examples=200,
    scoring_metric="f1",
    n_workers_interpretation=10,
    n_workers_annotation=50, 
    task_specific_instructions=task_specific_instructions,
)

pd.set_option('display.max_colwidth', None)
display(hypotheses_df.sort_values(by='target_lasso', ascending=False))
pd.reset_option('display.max_colwidth')


Embeddings shape: (200000, 1536)
Activations shape (from 1 SAEs): (200000, 1024)

Step 1: Selecting top 20 predictive neurons
LASSO iteration   L1 Alpha # Features   Time (s)
----------------------------------------
       0   1.00e-01          5       3.78
       1   3.16e-04        913       9.93
       2   5.62e-03        197       4.89
       3   2.37e-02         45       4.06
       4   4.87e-02         12       3.81
       5   3.40e-02         26       3.92
       6   4.07e-02         19       3.90
       7   3.72e-02         21       3.89
       8   3.89e-02         19       3.89
       9   3.80e-02         19       3.89
      10   3.76e-02         19       3.90
      11   3.74e-02         21       3.91
      12   3.75e-02         20       3.92

Found alpha=3.75e-02 yielding exactly 20 features
Total search time: 57.70s

Step 2: Interpreting selected neurons


Generating 3 interpretation(s) per neuron:   0%|          | 0/60 [00:00<?, ?it/s]


Step 3: Scoring Interpretations
Found 0 cached items; annotating 12000 uncached items


Scoring neuron interpretation fidelity (20 neurons; 3 candidate interps per neuron; 200 examples to score each…

API timeout, retrying in 1.0s... (1/3)


Unnamed: 0,neuron_idx,source_sae,target_lasso,interpretation,f1_fidelity_score
2,767,"(1024, 32)",0.077037,expresses awe or admiration for the chef's artistry and creativity in food preparation,0.904751
7,313,"(1024, 32)",0.023009,explicitly praises both the food and the service in a positive manner,0.722207
10,4,"(1024, 32)",0.009572,"uses the phrase 'the best' to describe food, drinks, or the overall experience",0.93369
14,554,"(1024, 32)",0.005992,"uses exaggerated enthusiasm with repeated letters, punctuation, or capitalization",0.874743
19,262,"(1024, 32)",5.3e-05,"repeatedly emphasizes that the restaurant or dish is the 'best' in a specific category or area, often using phrases like 'hands down' or 'without question'",0.984772
18,21,"(1024, 32)",-0.002484,"mentions experiences or reviews related to industries outside of food, such as car services, museums, or tours",0.496241
17,840,"(1024, 32)",-0.00303,mentions safety concerns or hazards related to the restaurant experience,0.939574
16,763,"(1024, 32)",-0.003443,"describes a place or food as being 'exactly what it is' or 'what you expect', often emphasizing mediocrity or lack of surprise",0.647887
15,864,"(1024, 32)",-0.003551,"mentions issues or inconsistencies with restaurant hours, location, or availability of service",0.59084
13,941,"(1024, 32)",-0.00659,describes experiencing food poisoning or severe illness after eating at the restaurant,0.994975


In [58]:
np.random.seed(42)
holdout_df = pd.read_json(os.path.join(base_dir, DATASET_NAME, "holdout-50K.json"), lines=True)
holdout_sample_size = 10000
holdout_indices = np.random.choice(len(holdout_df), size=holdout_sample_size, replace=False)
holdout_sample = holdout_df.iloc[holdout_indices]

holdout_texts = holdout_sample[text_col].tolist()
holdout_labels = holdout_sample[label_col].values

metrics, hypothesis_df = evaluate_hypotheses(
    hypotheses_df=hypotheses_df,
    texts=holdout_texts,
    labels=holdout_labels,
    cache_name=CACHE_NAME,
    classification=False,
    n_workers_annotation=50,
)

pd.set_option('display.max_colwidth', None)
display(hypothesis_df.round(3).sort_values(by='separation_score', ascending=False))
pd.reset_option('display.max_colwidth')

print(f"R^2: {metrics['r2']:.3f}")
print(f"Significant hypotheses: {metrics['Significant'][0]}/{metrics['Significant'][1]} " 
      f"(p < {metrics['Significant'][2]:.3e})")

Step 1: Annotating texts with 20 hypotheses
Found 200000 cached items; annotating 0 uncached items
Step 2: Computing predictiveness of hypothesis annotations


Unnamed: 0,hypothesis,separation_score,separation_pval,regression_coef,regression_pval,feature_prevalence
7,explicitly praises both the food and the service in a positive manner,1.862,0.0,0.702,0.0,0.526
2,expresses awe or admiration for the chef's artistry and creativity in food preparation,1.292,0.0,0.221,0.0,0.246
19,"repeatedly emphasizes that the restaurant or dish is the 'best' in a specific category or area, often using phrases like 'hands down' or 'without question'",1.107,0.0,0.164,0.0,0.1
10,"uses the phrase 'the best' to describe food, drinks, or the overall experience",1.056,0.0,0.055,0.066,0.166
14,"uses exaggerated enthusiasm with repeated letters, punctuation, or capitalization",0.989,0.0,0.16,0.0,0.235
3,mentions quick or convenient service/food options,0.966,0.0,0.107,0.0,0.246
18,"mentions experiences or reviews related to industries outside of food, such as car services, museums, or tours",-0.052,0.821,0.36,0.005,0.004
1,"expresses lukewarm or mixed feelings about the overall experience, often mentioning that it is 'okay', 'solid', or 'good but not great'",-0.805,0.0,0.021,0.316,0.262
16,"describes a place or food as being 'exactly what it is' or 'what you expect', often emphasizing mediocrity or lack of surprise",-1.032,0.0,-0.429,0.0,0.063
9,mentions long wait times for seating or food service,-1.339,0.0,0.077,0.021,0.119


R^2: 0.706
Significant hypotheses: 16/20 (p < 5.000e-03)


## Congress

In [67]:
DATASET_NAME = "congress"
EMBEDDER = "text-embedding-3-small"
CACHE_NAME = f"{DATASET_NAME}_{EMBEDDER}"

text_col = "speech_text"
label_col = "republican"
train_df = pd.read_json(os.path.join(base_dir, DATASET_NAME, "train-10sentence-114K.json"), lines=True)
val_df = pd.read_json(os.path.join(base_dir, DATASET_NAME, "val-10sentence-16K.json"), lines=True)

train_texts = train_df[text_col].tolist()
val_texts = val_df[text_col].tolist()

train_labels = train_df[label_col].values

text2embedding = get_openai_embeddings(train_texts + val_texts, model=EMBEDDER, cache_name=CACHE_NAME)

train_embeddings = np.stack([text2embedding[text] for text in train_texts])
val_embeddings = np.stack([text2embedding[text] for text in val_texts])

Loading embedding chunks:   0%|          | 0/3 [00:00<?, ?it/s]

Loaded 130765 embeddings in 13.9s


In [68]:
checkpoint_dir = f"./checkpoints/{CACHE_NAME}"
sae_4096_32 = train_sae(embeddings=train_embeddings, M=4096, K=32, checkpoint_dir=checkpoint_dir, val_embeddings=val_embeddings)

activations = sae_4096_32.get_activations(train_embeddings)

selected_neurons, scores = select_neurons(
    activations=activations,
    target=train_labels,
    n_select=20,
    method="lasso",
    classification=True,
    verbose=True,
)

Loaded model from ./checkpoints/congress_text-embedding-3-small/SAE_M=4096_K=32.pt
LASSO iteration   L1 Alpha # Features   Time (s)
----------------------------------------
       0   1.00e-01       4093     134.15
       1   3.16e+01       3730      65.72
       2   5.62e+02        691      37.53
       3   2.37e+03         30      28.27
       4   4.87e+03          4      23.47
       5   3.40e+03         16      27.70
       6   2.84e+03         21      28.88
       7   3.11e+03         18      28.89
       8   2.97e+03         20      28.75

Found alpha=2.97e+03 yielding exactly 20 features
Total search time: 403.37s


In [70]:
interpreter = NeuronInterpreter(
    interpreter_model="gpt-4o",
    annotator_model="gpt-4o-mini",
    n_workers_annotation=50,
    cache_name=CACHE_NAME,
)

interpretation_sampling_config = SamplingConfig(
    function=sample_percentile_bins,
    n_examples=20,
    sampling_kwargs={
        "high_percentile": (90, 100),
        "low_percentile": None, # Randomly samples from zero-activating examples
    },
)

interpret_config = InterpretConfig(
    sampling=interpretation_sampling_config,
    n_candidates=3,
    task_specific_instructions = """All of the texts are excerpts of speeches from the US Congress.
Example features may include:
- "mentions that the economy is strong"
- "discusses environmental issues or environmental policy\""""
)

interpretations = interpreter.interpret_neurons(
    texts=train_texts,
    activations=activations,
    neuron_indices=selected_neurons,
    config=interpret_config,
)

scoring_config = ScoringConfig(
    n_examples=200,
    sampling_function=sample_percentile_bins,
    sampling_kwargs={
        "high_percentile": (90, 100),
        "low_percentile": None, # Randomly samples from zero-activating examples
    },
)

all_metrics = interpreter.score_interpretations(
    texts=train_texts,
    activations=activations,
    interpretations=interpretations,
    config=scoring_config,
)

Generating 3 interpretation(s) per neuron:   0%|          | 0/60 [00:00<?, ?it/s]

Found 0 cached items; annotating 11970 uncached items


Scoring neuron interpretation fidelity (20 neurons; 3 candidate interps per neuron; 200 examples to score each…

In [71]:
results = []
for neuron_idx, fidelity_scores in all_metrics.items():
    best_interp = max(interpretations[neuron_idx], key=lambda interp: fidelity_scores[interp]['f1'])
    results.append({
        'neuron_idx': neuron_idx,
        'target_lasso_coef': scores[selected_neurons.index(neuron_idx)],
        'interpretation': best_interp,
        'f1_fidelity_score': fidelity_scores[best_interp]['f1'],
    })

interpretation_df = pd.DataFrame(results)
pd.set_option('display.max_colwidth', None)
display(interpretation_df.sort_values(by='target_lasso_coef', ascending=False))
pd.reset_option('display.max_colwidth')

Unnamed: 0,neuron_idx,target_lasso_coef,interpretation,f1_fidelity_score
3,1369,0.060833,uses the phrase 'I ask unanimous consent',0.984772
5,3380,0.04925,describes the Senate schedule or legislative agenda for the day or week,0.974359
8,3394,0.037155,mentions positive economic indicators or achievements,0.448594
12,1296,0.015483,discusses illegal immigration or related issues,0.949579
14,2102,0.011286,requests unanimous consent for a committee or subcommittee to meet on a specific date and time,0.93617
15,681,0.009446,requests unanimous consent for Senate committees or subcommittees to meet during the session of the Senate to receive testimony or conduct hearings,1.0
17,1351,0.004072,"mentions bills, resolutions, or amendments with specific section or clause numbers",0.736561
18,1268,0.000823,"discusses economic growth, including rates, impacts, or related metrics",0.878644
19,1314,0.000232,"discusses government spending, deficits, or fiscal responsibility",0.959896
16,1951,-0.009246,mentions poverty or issues related to poverty,0.959583


In [89]:
np.random.seed(42)
holdout_df = pd.read_json(os.path.join(base_dir, DATASET_NAME, "holdout-10sentence-onechunk-12K.json"), lines=True)
holdout_texts = holdout_df[text_col].tolist()
holdout_labels = holdout_df[label_col].values

metrics, hypothesis_df = evaluate_hypotheses(
    hypotheses_df=interpretation_df,
    texts=holdout_texts,
    labels=holdout_labels,
    cache_name=CACHE_NAME,
    classification=True,
    n_workers_annotation=50,
)

pd.set_option('display.max_colwidth', None)
display(hypothesis_df.round(3).sort_values(by='separation_score', ascending=False))
pd.reset_option('display.max_colwidth')

print(f"AUC: {metrics['auroc']:.3f}")
print(f"Pseudo R^2: {metrics['r2']:.3f}")
print(f"Significant hypotheses: {metrics['Significant'][0]}/{metrics['Significant'][1]} " 
      f"(p < {metrics['Significant'][2]:.3e})")

Step 1: Annotating texts with 20 hypotheses
Found 251600 cached items; annotating 0 uncached items
Step 2: Computing predictiveness of hypothesis annotations
Optimization terminated successfully.
         Current function value: 0.611921
         Iterations 7


Unnamed: 0,hypothesis,separation_score,separation_pval,regression_coef,regression_pval,feature_prevalence
14,requests unanimous consent for a committee or subcommittee to meet on a specific date and time,0.402,0.0,0.861,0.0,0.031
15,requests unanimous consent for Senate committees or subcommittees to meet during the session of the Senate to receive testimony or conduct hearings,0.335,0.0,-0.077,0.677,0.033
5,describes the Senate schedule or legislative agenda for the day or week,0.3,0.0,0.648,0.0,0.073
3,uses the phrase 'I ask unanimous consent',0.296,0.0,0.463,0.0,0.065
17,"mentions bills, resolutions, or amendments with specific section or clause numbers",0.159,0.0,0.322,0.0,0.057
8,mentions positive economic indicators or achievements,0.129,0.0,0.277,0.002,0.056
12,discusses illegal immigration or related issues,0.127,0.0,0.877,0.0,0.035
18,"discusses economic growth, including rates, impacts, or related metrics",-0.074,0.0,0.232,0.01,0.074
19,"discusses government spending, deficits, or fiscal responsibility",-0.124,0.0,0.355,0.0,0.179
6,discusses budget cuts or reductions in funding,-0.174,0.0,-0.0,0.998,0.08


AUC: 0.703
Pseudo R^2: 0.116
Significant hypotheses: 15/20 (p < 5.000e-03)


### Top sampling (instead of percentile sampling)

In [None]:
task_specific_instructions = """All of the texts are excerpts of speeches from the US Congress.
Example features may include:
- "mentions that the economy is strong"
- "discusses environmental issues or environmental policy\""""

hypotheses_df = generate_hypotheses(
    texts=train_texts,
    labels=train_labels,
    embeddings=train_embeddings,
    sae=sae_4096_32,
    cache_name=CACHE_NAME,
    classification=True,
    selection_method="lasso",
    n_selected_neurons=20,
    interpreter_model="gpt-4o",
    annotator_model="gpt-4o-mini",
    n_examples_for_interpretation=20,
    max_words_per_example=128,
    interpret_temperature=0.7,
    max_interpretation_tokens=100,
    n_candidate_interpretations=3,
    n_scoring_examples=200,
    scoring_metric="f1",
    n_workers_interpretation=10,
    n_workers_annotation=50, 
    task_specific_instructions=task_specific_instructions,
)

pd.set_option('display.max_colwidth', None)
display(hypotheses_df.sort_values(by='target_lasso', ascending=False))
pd.reset_option('display.max_colwidth')


In [15]:
np.random.seed(42)
holdout_df = pd.read_json(os.path.join(base_dir, DATASET_NAME, "holdout-10sentence-onechunk-12K.json"), lines=True)
# holdout_sample_size = len(holdout_df)
# holdout_indices = np.random.choice(len(holdout_df), size=holdout_sample_size, replace=False)
# holdout_sample = holdout_df.iloc[holdout_indices]
holdout_texts = holdout_df[text_col].tolist()
holdout_labels = holdout_df[label_col].values

metrics, hypothesis_df = evaluate_hypotheses(
    hypotheses_df=hypotheses_df,
    texts=holdout_texts,
    labels=holdout_labels,
    cache_name=CACHE_NAME,
    classification=True,
    n_workers_annotation=50,
)

pd.set_option('display.max_colwidth', None)
display(hypothesis_df.round(3).sort_values(by='separation_score', ascending=False))
pd.reset_option('display.max_colwidth')

print(f"AUC: {metrics['auroc']:.3f}")
print(f"Significant hypotheses: {metrics['Significant'][0]}/{metrics['Significant'][1]} " 
      f"(p < {metrics['Significant'][2]:.3e})")

Step 1: Annotating texts with 20 hypotheses
Found 10000 cached items; annotating 0 uncached items
Step 2: Computing predictiveness of hypothesis annotations
         Current function value: 0.575682
         Iterations: 35




Unnamed: 0,hypothesis,separation_score,separation_pval,regression_coef,regression_pval,feature_prevalence
15,mentions specific times or hours of the day,0.336,0.002,1.517,0.077,0.042
4,contains the phrase 'I ask unanimous consent',0.218,0.008,0.27,0.476,0.078
11,mentions a specific time of day or schedule for events,0.184,0.032,-0.182,0.758,0.072
12,mentions illegal immigration or illegal immigrants,0.06,0.809,-0.03,0.977,0.008
5,discusses tax cuts or reduced taxation as a means of economic growth,-0.041,0.752,0.558,0.459,0.03
6,mentions members of the Republican Study Committee or Republican politicians by name,-0.078,0.607,-0.34,0.637,0.022
8,"frequent use of numerical values, especially large dollar amounts, percentages, or time spans related to budget cuts or funding",-0.165,0.015,0.426,0.273,0.122
1,mentions issues related to health care or health insurance,-0.171,0.011,-0.667,0.06,0.124
17,"mentions African American contributions, achievements, or issues",-0.192,0.443,-1.008,0.399,0.008
13,urges opposition to an amendment,-0.201,0.001,0.368,0.542,0.162



Holdout Set Metrics:
AUC: 0.744
Significant hypotheses: 2/20 (p < 5.000e-03)
