In [1]:
import duckdb
from pathlib import Path

# Load the autointerp database in sqlite
autointerp_path = Path("~/autointerp/r1-logic/autointerp.db").expanduser()
conn = duckdb.connect(autointerp_path, read_only=True)
conn.execute("SET sqlite_all_varchar=true")

# convenience function for running queries
def run_query(query: str, conn: duckdb.DuckDBPyConnection):
    res = conn.execute(query)

    column_names = [desc[0] for desc in res.description]
    data = [dict(zip(column_names, row)) for row in res.fetchall()]
    return data

In [2]:
print('Autointerp database example row:')
example = run_query("SELECT * FROM autointerp limit 1", conn)[0]

for col_name, col_value in example.items():
    print(f'{col_name}: {col_value}')

Autointerp database example row:
feature_id: 0
label: Cycles in graphs and algorithms
seqs: []
indices: []
quality: 0.9
interestingness: 0.6
model_name: claude-3-7-sonnet-latest
prompt_version: v0-dev


In [3]:
# print 10 feature labels
query = "SELECT feature_id, label FROM autointerp limit 10"
results = run_query(query, conn)

for row in results:
    print(f'{row["feature_id"]}: {row["label"]}')

0: Cycles in graphs and algorithms
1: Sorting collections and processing sorted data
3: Competitive programming input specification
5: Algorithm optimization
6: searching for a specific instance or solution
7: Prepositions indicating origin or source
8: Initializing variables to zero in programming
9: Mathematical calculation and numerical reasoning
10: Recognizing computational complexity limitations
11: Once a solution is found, transition to the next step


In [4]:
# attach the tokens db
dataset_path = Path("~/logic-1.ddb").expanduser()
conn.execute(f"ATTACH DATABASE '{dataset_path}' as ds")

<duckdb.duckdb.DuckDBPyConnection at 0x765194779130>

In [5]:
print('Tokens database example row:')
example = run_query("SELECT * FROM ds.tokens limit 1", conn)[0]

for col_name, col_value in example.items():
    print(f'{col_name}: {col_value}')

Tokens database example row:
sequence_id: 0
token_idx: 0
token: 0
decoded_token: <｜begin▁of▁sentence｜>


In [6]:
print('SAE latent activations example row:')
example = run_query("SELECT * FROM ds.activations limit 1", conn)[0]

for col_name, col_value in example.items():
    print(f'{col_name}: {col_value}')

SAE latent activations example row:
feature_id: 3
sequence_id: 0
token_idx: 0
strength: 0.20472130179405212


In [7]:
# Here, we will join tables to get the top k activations for a given feature, then print the subsequences around those activations

TOP_K_SEQ = 10
FEATURE_ID = 764
SEQUENCE_WINDOW = 10

query = f"""
with subsequences as (
    with top_acts as (
        SELECT
            autointerp.feature_id,
            autointerp.label,
            acts.sequence_id,
            acts.token_idx,
            acts.strength
        FROM
            autointerp
        JOIN
            ds.activations acts ON autointerp.feature_id = acts.feature_id
        WHERE
            autointerp.feature_id = {FEATURE_ID}
        ORDER BY
            acts.strength DESC
        LIMIT {TOP_K_SEQ}
    )

    SELECT
        top_acts.feature_id,
        top_acts.label,
        top_acts.sequence_id,
        top_acts.token_idx as top_act_token_idx,
        top_acts.strength,
        tokens.token_idx as token_idx,
        tokens.decoded_token
    FROM
        ds.tokens as tokens
    join
        top_acts on tokens.sequence_id = top_acts.sequence_id
    where
        abs(tokens.token_idx - top_acts.token_idx) <= {SEQUENCE_WINDOW//2}
)

SELECT
    *
FROM
    subsequences
"""

results = run_query(query, conn)

subsequences = {}

for row in results:
    feat_id = row["feature_id"]
    seq_id = row["sequence_id"]

    if subsequences.get(feat_id) is None:
        subsequences[feat_id] = {}
        subsequences[feat_id]["label"] = row["label"]
        subsequences[feat_id]["subsequences"] = {}

    if subsequences[feat_id]["subsequences"].get(seq_id) is None:
        subsequences[feat_id]["subsequences"][seq_id] = {}
        subsequences[feat_id]["subsequences"][seq_id]["top_idx"] = row["top_act_token_idx"]
        subsequences[feat_id]["subsequences"][seq_id]["top_strength"] = row["strength"]
        subsequences[feat_id]["subsequences"][seq_id]["tokens"] = {}

    subsequences[feat_id]["subsequences"][seq_id]["tokens"][row["token_idx"]] = row["decoded_token"]
    

for feat_id, feat_data in subsequences.items():
    print(f'{feat_id}: {feat_data["label"]}')

    for seq_id, seq_data in feat_data["subsequences"].items():
        top_idx = seq_data["top_idx"]
        top_token = seq_data["tokens"][top_idx]
        top_strength = seq_data["top_strength"]

        ordered_tokens = sorted(seq_data["tokens"].items(), key=lambda x: x[0])

        before_top = [token[1] for token in ordered_tokens if token[0] < top_idx]
        after_top = [token[1] for token in ordered_tokens if token[0] > top_idx]

        subsequence = repr(''.join(before_top + [ "<<" + top_token + ">>" ] + after_top))

        print(f'  {top_strength} : {subsequence}')

764: Problem-solving breakthrough confirmation
  1.016119360923767 : ' surviving edges ].\n\nYes<<!>> Because for any forest,'
  1.120764970779419 : ' what we need.\n\nYes<<!>> So this model captures the'
  1.0652724504470825 : ' k-th one.\n\nYes<<!>> This seems promising.\n\nSo'
  1.0131556987762451 : "1) time.\n\nYes<<!>> Let's compute prefix_s"
  0.9246755838394165 : ' to the total.\n\nYes<<.>> So, the algorithm would'
  0.9901113510131836 : ' pass without overlapping.\n\nYes<<,>> this approach would work.'
  1.0250762701034546 : '1] accordingly.\n\nYes<<,>> this way, for each'
  1.1328017711639404 : ', then proceed.\n\nYes<<!>> Because once the first i'
  0.945244550704956 : ' the original lines.\n\nYes<<,>> so if two lines intersect'
  1.1280906200408936 : " floating point error.\n\nYes<<,>> that's a better approach"
