In [1]:
!pip install duckdb



In [2]:
from urllib.request import urlretrieve
from pathlib import Path

autointerp_path = Path("~/logic-autointerp.db").expanduser()
dataset_path = Path("~/logic-0-1.ddb").expanduser()

# download the  dbs from s3
urlretrieve('https://goodfire-r1-features.s3.us-east-1.amazonaws.com/logic/autointerp.db', autointerp_path)
urlretrieve('https://goodfire-r1-features.s3.us-east-1.amazonaws.com/logic/logic-0-1.ddb', dataset_path)

(PosixPath('/mnt/polished-lake/home/maxsloef/logic-0-1.ddb'),
 <http.client.HTTPMessage at 0x7a4510598550>)

In [3]:
import duckdb

# Load the autointerp database in sqlite
conn = duckdb.connect(autointerp_path, read_only=True)

# convenience function for running queries
def run_query(query: str, conn: duckdb.DuckDBPyConnection):
    res = conn.execute(query)

    column_names = [desc[0] for desc in res.description]
    data = [dict(zip(column_names, row)) for row in res.fetchall()]
    return data

In [4]:
print('Autointerp database example row:')
example = run_query("SELECT * FROM autointerp limit 1", conn)[0]

for col_name, col_value in example.items():
    print(f'{col_name}: {col_value}')

Autointerp database example row:
feature_id: 0
label: Cycles in graph theory
quality: 0.9
interestingness: 0.6
model_name: claude-3-7-sonnet-latest
prompt_version: v0-dev


In [5]:
# print 10 feature labels
query = "SELECT feature_id, label FROM autointerp limit 10"
results = run_query(query, conn)

for row in results:
    print(f'{row["feature_id"]}: {row["label"]}')

0: Cycles in graph theory
1: Finding the k-th element in sorted arrays
2: dead feature
3: Input format specifications in competitive programming problems
4: dead feature
5: Dynamic programming solution planning
6: Algorithmic problem-solving
7: referring to locations or positions in explanations
8: Tracking current state in algorithms
9: Quantities and mathematical terms in problem-solving contexts


In [6]:
# attach the tokens db
conn.execute(f"ATTACH DATABASE '{dataset_path}' as ds")

<duckdb.duckdb.DuckDBPyConnection at 0x7a45105b0930>

In [7]:
print('Tokens database example row:')
example = run_query("SELECT * FROM ds.tokens limit 1", conn)[0]

for col_name, col_value in example.items():
    print(f'{col_name}: {col_value}')

Tokens database example row:
sequence_id: 0
token_idx: 0
token: 0
decoded_token: <｜begin▁of▁sentence｜>


In [8]:
print('SAE latent activations example row:')
example = run_query("SELECT * FROM ds.activations limit 1", conn)[0]

for col_name, col_value in example.items():
    print(f'{col_name}: {col_value}')

SAE latent activations example row:
feature_id: 3
sequence_id: 0
token_idx: 0
strength: 0.20472130179405212


In [9]:
# Here, we will join tables to get the top k activations for a given feature, then print the subsequences around those activations

TOP_K_SEQ = 10
FEATURE_ID = 20
SEQUENCE_WINDOW = 10

query = f"""
with subsequences as (
    with top_acts as (
        SELECT
            autointerp.feature_id,
            autointerp.label,
            acts.sequence_id,
            acts.token_idx,
            acts.strength
        FROM
            autointerp
        JOIN
            ds.activations acts ON autointerp.feature_id = acts.feature_id
        WHERE
            autointerp.feature_id = {FEATURE_ID}
        ORDER BY
            acts.strength DESC
        LIMIT {TOP_K_SEQ}
    )

    SELECT
        top_acts.feature_id,
        top_acts.label,
        top_acts.sequence_id,
        top_acts.token_idx as top_act_token_idx,
        top_acts.strength,
        tokens.token_idx as token_idx,
        tokens.decoded_token
    FROM
        ds.tokens as tokens
    join
        top_acts on tokens.sequence_id = top_acts.sequence_id
    where
        abs(tokens.token_idx - top_acts.token_idx) <= {SEQUENCE_WINDOW//2}
)

SELECT
    *
FROM
    subsequences
"""

results = run_query(query, conn)

subsequences = {}

for row in results:
    feat_id = row["feature_id"]
    seq_id = row["sequence_id"]

    if subsequences.get(feat_id) is None:
        subsequences[feat_id] = {}
        subsequences[feat_id]["label"] = row["label"]
        subsequences[feat_id]["subsequences"] = {}

    if subsequences[feat_id]["subsequences"].get(seq_id) is None:
        subsequences[feat_id]["subsequences"][seq_id] = {}
        subsequences[feat_id]["subsequences"][seq_id]["top_idx"] = row["top_act_token_idx"]
        subsequences[feat_id]["subsequences"][seq_id]["top_strength"] = row["strength"]
        subsequences[feat_id]["subsequences"][seq_id]["tokens"] = {}

    subsequences[feat_id]["subsequences"][seq_id]["tokens"][row["token_idx"]] = row["decoded_token"]
    

for feat_id, feat_data in subsequences.items():
    print(f'{feat_id}: {feat_data["label"]}')

    for seq_id, seq_data in feat_data["subsequences"].items():
        top_idx = seq_data["top_idx"]
        top_token = seq_data["tokens"][top_idx]
        top_strength = seq_data["top_strength"]

        ordered_tokens = sorted(seq_data["tokens"].items(), key=lambda x: x[0])

        before_top = [token[1] for token in ordered_tokens if token[0] < top_idx]
        after_top = [token[1] for token in ordered_tokens if token[0] > top_idx]

        subsequence = repr(''.join(before_top + [ "<<" + top_token + ">>" ] + after_top))

        print(f'  {top_strength} : {subsequence}')

20: Thought transition markers in reasoning
  0.6403554677963257 : '\n\nprint(max_area)\n\n<<Yes>>, this should work.\n\n'
  0.21712636947631836 : "1\nprint(result)\n\n<<That>>'s all. I think"
  0.3055551052093506 : " 'apple', s))\n\n<<This>> should handle all cases where"
  0.7282431125640869 : '()\n\nprint("YES")\n\n<<Wait>>, but the sample input'
  0.7797724008560181 : '}".format(sum_result))\n\n<<But>> wait, in Python,'
  0.5132467746734619 : '(str, post_order)))\n\n<<But>> this code is not correct'
  0.25801903009414673 : '            print("NO")\n\n<<But>> wait, let\'s test'
  0.6738934516906738 : ' global_max - n)\n\n<<But>> wait, for x='
  0.35298389196395874 : "\n    print(total)\n\nWait, that's it?2\nprint(total)\n\n<<Yes>>. That should do it"
