In [1]:
from anthropic import Anthropic
from dotenv import load_dotenv
import os

# Load .env file
load_dotenv()

# Get key from environment
api_key = os.getenv("ANTHROPIC_API_KEY")

# Initialize client
client = Anthropic(api_key=api_key)

# Send a simple prompt
response = client.messages.create(
    model="claude-3-7-sonnet-20250219",
    max_tokens=200,
    messages=[
        {"role": "user", "content": "Say hello"}
    ]
)

print(response.content[0].text)


Hello! How can I assist you today?


In [2]:
from pathlib import Path
import pandas as pd, torch, os, gc
from interplm.sae.inference import load_sae_from_hf
import matplotlib.pyplot as plt
import numpy as np

DEVICE="cuda"
DTYPE=torch.float16

DATA_DIR = Path("esm_sae_results"); DATA_DIR.mkdir(exist_ok=True)
SEQUENCES_DIR = Path("/home/ec2-user/SageMaker/InterPLM/data/uniprot/subset_25k.csv")
# ANNOTATIONS_DIR = Path("uniprotkb_swissprot_annotations.tsv.gz")
ANNOTATIONS_DIR = Path("/home/ec2-user/SageMaker/InterPLM/uniprotkb_swissprot_annotations.tsv.gz")


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", do_lower_case=False)
model     = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D",
                                        output_hidden_states=True).to(DEVICE).eval()

# Make sure the SAE you load matches the *plm_model* and *plm_layer* you want to use
plm_model = "esm2-650m"   # matches your checkpoint naming
plm_layer = 24            # <= MUST match esm_layer_sel
sae = load_sae_from_hf(plm_model=plm_model, plm_layer=plm_layer).to(DEVICE).eval()


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
import glob
features_all = pd.read_pickle("features_all.pkl")
features_all.shape


(40000, 6)

In [5]:
features_all.head()

Unnamed: 0,uniprot_id,length,features,max_activation,n_active_features,reconstruction_mse
0,Q9GL23,50,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.002...",1.265625,1876,45.19838
1,Q6GZU6,50,"[0.00023197175, 0.0, 0.0, 0.0, 0.0013056946, 0...",0.843262,2168,13.467114
2,P9WJG6,50,"[0.0, 0.00057144166, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.935059,1740,12.720748
3,P18924,51,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.000...",0.956543,1799,11.394856
4,Q08076,52,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.000...",1.139648,1772,24.694654


In [6]:
annotations_df = pd.read_csv(ANNOTATIONS_DIR, sep="\t", compression="gzip")

In [7]:
annotations_df.head()

Unnamed: 0,Entry,Reviewed,Protein names,Length,Sequence,EC number,Active site,Binding site,Cofactor,Disulfide bond,...,Helix,Turn,Beta strand,Coiled coil,Domain [CC],Compositional bias,Domain [FT],Motif,Region,Zinc finger
0,A0A009IHW8,reviewed,2' cyclic ADP-D-ribose synthase AbTIR (2'cADPR...,269,MSLEQKKGADIISKILQIQNSIGKTTSPSTLKTKLSEISRKEQENA...,3.2.2.-; 3.2.2.6,"ACT_SITE 208; /evidence=""ECO:0000255|PROSITE-P...","BINDING 143; /ligand=""NAD(+)""; /ligand_id=""ChE...",,,...,"HELIX 143..145; /evidence=""ECO:0007829|PDB:7UW...","TURN 146..149; /evidence=""ECO:0007829|PDB:7UWG...","STRAND 135..142; /evidence=""ECO:0007829|PDB:7U...","COILED 31..99; /evidence=""ECO:0000255""",DOMAIN: The TIR domain mediates NAD(+) hydrola...,,"DOMAIN 133..266; /note=""TIR""; /evidence=""ECO:0...",,,
1,A0A023I7E1,reviewed,"Glucan endo-1,3-beta-D-glucosidase 1 (Endo-1,3...",796,MRFQVIVAAATITMITSYIPGVASQSTSDGDDLFVPVSNFDPKSIF...,3.2.1.39,"ACT_SITE 500; /evidence=""ECO:0000255|PROSITE-P...","BINDING 504; /ligand=""(1,3-beta-D-glucosyl)n"";...",,,...,"HELIX 42..44; /evidence=""ECO:0007829|PDB:4K35""...","TURN 287..289; /evidence=""ECO:0007829|PDB:4K35...","STRAND 56..58; /evidence=""ECO:0007829|PDB:4K35...",,,,"DOMAIN 31..759; /note=""GH81""; /evidence=""ECO:0...",,"REGION 31..276; /note=""beta-sandwich subdomain...",
2,A0A024B7W1,reviewed,Genome polyprotein [Cleaved into: Capsid prote...,3423,MKNPKKKSGGFRIVNMLKRGVARVSPFGGLKRLPAGLLLGHGPIRM...,2.1.1.56; 2.1.1.57; 2.7.7.48; 3.4.21.91; 3.6.1...,"ACT_SITE 1553; /note=""Charge relay system; for...","BINDING 1696..1703; /ligand=""ATP""; /ligand_id=...",,"DISULFID 350..406; /evidence=""ECO:0000250|UniP...",...,"HELIX 222..225; /evidence=""ECO:0007829|PDB:6CO...","TURN 237..241; /evidence=""ECO:0007829|PDB:6CO8...","STRAND 234..236; /evidence=""ECO:0007829|PDB:6C...",,DOMAIN: [Small envelope protein M]: The transm...,,"DOMAIN 1503..1680; /note=""Peptidase S7""; /evid...","MOTIF 1787..1790; /note=""DEAH box""; /evidence=...","REGION 1..25; /note=""Disordered""; /evidence=""E...",
3,A0A024RXP8,reviewed,"Exoglucanase 1 (EC 3.2.1.91) (1,4-beta-cellobi...",514,MYRKLAVISAFLATARAQSACTLQSETHPPLTWQKCSSGGTCTQQT...,3.2.1.91,"ACT_SITE 229; /note=""Nucleophile""; /evidence=""...",,,"DISULFID 21..89; /evidence=""ECO:0000250|UniPro...",...,,,,,DOMAIN: The enzyme consists of two functional ...,"COMPBIAS 401..437; /note=""Polar residues""; /ev...","DOMAIN 478..514; /note=""CBM1""; /evidence=""ECO:...",,"REGION 18..453; /note=""Catalytic""; /evidence=""...",
4,A0A024SC78,reviewed,Cutinase (EC 3.1.1.74),248,MRSLAILTTLLAGHAFAYPKPAPQSVNRRDWPSINEFLSELAKVMP...,3.1.1.74,"ACT_SITE 164; /note=""Nucleophile""; /evidence=""...",,,"DISULFID 55..91; /evidence=""ECO:0000269|PubMed...",...,"HELIX 51..69; /evidence=""ECO:0007829|PDB:4PSC""...","TURN 94..100; /evidence=""ECO:0007829|PDB:4PSC""...","STRAND 48..50; /evidence=""ECO:0007829|PDB:4PSC...",,"DOMAIN: In contract to classical cutinases, po...",,,,"REGION 31..70; /note=""Lid covering the active ...",


In [8]:
# import numpy as np
# import pandas as pd
# import random

# # Parameters
# N_FEATURES = 1200
# BINS = np.arange(0, 1.1, 0.1)

# # Randomly select feature ids
# all_feature_ids = list(range(len(features_all.iloc[0].features)))
# print("num features", len(all_feature_ids))
# selected_features = random.sample(all_feature_ids, N_FEATURES)

# print(f"Selected {len(selected_features)} features out of {len(all_feature_ids)}")

# # Build dataset for each feature
# feature_datasets = {}

# # Predefine bin labels
# bin_labels = [f"{BINS[i]:.1f}-{BINS[i+1]:.1f}" for i in range(len(BINS)-1)]

# for fid in selected_features:
#     # Extract activations for this feature
#     activations = [f[fid] for f in features_all["features"]]
#     df = pd.DataFrame({
#         "uniprot_id": features_all["uniprot_id"],
#         "activation": activations
#     })

#     # Assign bins
#     df["bin"] = pd.cut(df["activation"], bins=BINS, labels=bin_labels, include_lowest=True)

#     sampled = []

#     # Sample proteins per bin
#     for b in df["bin"].dropna().unique():
#         bin_df = df[df["bin"] == b]
#         n = 10 if b == "0.9-1.0" else 2
#         sampled.extend(bin_df.sample(min(len(bin_df), n), random_state=42).to_dict(orient="records"))

#     # Add 10 random zero-activation proteins 
#     zero_df = df[df["activation"] == 0.0]
#     if len(zero_df) > 0:
#         sampled.extend(zero_df.sample(min(len(zero_df), 10), random_state=42).to_dict(orient="records"))

#     # Merge with metadata from annotations_df
#     sampled_df = pd.DataFrame(sampled)
#     merged = sampled_df.merge(annotations_df, left_on="uniprot_id", right_on="Entry", how="left")

#     feature_datasets[fid] = merged

# # Example feature dataset
# example_fid = selected_features[0]
# feature_datasets[example_fid].head()


## Normalize features

In [9]:
import numpy as np

#Stakc into [num_proteins, num_features]
X = np.vstack(features_all['features'].values) #Shape (N, F)

#Max activation per feature across all proteins
max_per_feature = X.max(axis=0) # shape: (F,)
eps = 1e-12
max_safe = np.where(max_per_feature > 0, max_per_feature, eps)
#Normalize
X_norm = X / max_safe

#Save back
features_all = features_all.copy()
features_all["features_norm"] = [row for row in X_norm]
print("Original max activation (feature 0):", X[:,0].max())
print("Normalized max activation (feature 0):", X_norm[:,0].max())

Original max activation (feature 0): 0.02611415
Normalized max activation (feature 0): 1.0


In [10]:
import numpy as np
import pandas as pd
import random

# Parameters
N_FEATURES = 10240
BINS = np.arange(0, 1.1, 0.1)

# Randomly select feature ids
all_feature_ids = list(range(len(features_all.iloc[0].features)))
print("num features", len(all_feature_ids))
selected_features = random.sample(all_feature_ids, N_FEATURES)

print(f"Selected {len(selected_features)} features out of {len(all_feature_ids)}")

# Build dataset for each feature
feature_datasets = {}

# Predefine bin labels
bin_labels = [f"{BINS[i]:.1f}-{BINS[i+1]:.1f}" for i in range(len(BINS)-1)]

for fid in selected_features:
    # Extract activations for this feature
    activations = [f[fid] for f in features_all["features_norm"]]
    df = pd.DataFrame({
        "uniprot_id": features_all["uniprot_id"],
        "activation": activations
    })

    # Assign bins
    df["bin"] = pd.cut(df["activation"], bins=BINS, labels=bin_labels, include_lowest=True)

    sampled = []

    # Sample proteins per bin
    for b in df["bin"].dropna().unique():
        bin_df = df[df["bin"] == b]
        n = 10 if b == "0.9-1.0" else 2
        sampled.extend(bin_df.sample(min(len(bin_df), n), random_state=42).to_dict(orient="records"))

    # Add 10 random zero-activation proteins 
    zero_df = df[df["activation"] == 0.0]
    if len(zero_df) > 0:
        sampled.extend(zero_df.sample(min(len(zero_df), 10), random_state=42).to_dict(orient="records"))

    # Merge with metadata from annotations_df
    sampled_df = pd.DataFrame(sampled)
    merged = sampled_df.merge(annotations_df, left_on="uniprot_id", right_on="Entry", how="left")

    feature_datasets[fid] = merged

# Example feature dataset
example_fid = selected_features[0]
feature_datasets[example_fid].head()


num features 10240
Selected 10240 features out of 10240


In [None]:

feature_datasets.to_pickle("feature_datasets.pkl")

## Add Amino Acid Indices and activations for better LLM annotations

In [None]:
from utils import extract_esm_features_batch

@torch.no_grad()
def extract_sae_features(hidden_states: torch.Tensor, sae):
    """
    Pass ESM hidden states through the Sparse Autoencoder (SAE).

    Args
    ----
    hidden_states : torch.Tensor
        Shape [B, L, d] or [L, d].
        - B = batch size (optional if unsqueezed)
        - L = sequence length
        - d = ESM embedding dimension (e.g., 1280 for esm2_t33_650M)

    Returns
    -------
    sae_features : torch.Tensor
        Shape [B, L, F]
        Sparse latent features per residue.
        F = number of SAE dictionary atoms / features.

    recon : torch.Tensor
        Shape [B, L, d]
        Reconstructed embeddings in token space.

    error : torch.Tensor
        Shape [B, L, d]
        Residual = hidden_states - recon
    """
    if hidden_states.dim() == 2:          # [L, d]
        hidden_states = hidden_states.unsqueeze(0)  # → [1, L, d]
    x = hidden_states.to(torch.float32)      # <- ensure fp32 for SAE

    # SAE should have encode() and decode() that operate on last dimension
    sae_features = sae.encode(x)     # [B, L, F]
    recon        = sae.decode(sae_features)      # [B, L, d]
    error        = hidden_states - recon         # [B, L, d]

    return sae_features, recon, error

#Config for extracting activated positions

TOP_K = 8 # How many positions to record per protein
MIN_ACT = 0.0 #Only consider positions with activation > MIN_ACT
BATCH_SIZE = 16 #For ESM/SAE Inference


def _batched(iterable, n):
    """Yield Successive n-sized chunks from iterable
    """
    it = list(iterable)
    for i in range(0, len(it), n):
        yield it[i:i+n]

import numpy as np
import torch
from typing import List, Tuple, Optional, Literal

TOP_K = 8
MIN_ACT = 0.0
BATCH_SIZE = 16

def _batched(iterable, n):
    it = list(iterable)
    for i in range(0, len(it), n):
        yield it[i:i+n]

def _normalize_1d(
    x: np.ndarray,
    mode: Literal["seq_max","feature_global_max","zscore","none"] = "seq_max",
    global_max: Optional[float] = None,
    eps: float = 1e-8
) -> np.ndarray:
    """
    Normalize a 1D activation vector x (valid positions only).
    """
    if mode == "none":
        return x.copy()

    if mode == "feature_global_max":
        if global_max is None or global_max <= eps:
            # fallback to seq_max if global not available/safe
            mode = "seq_max"
        else:
            return x / (global_max + eps)

    if mode == "seq_max":
        m = np.max(x) if x.size else 0.0
        return x / (m + eps)

    if mode == "zscore":
        mu = float(np.mean(x)) if x.size else 0.0
        sd = float(np.std(x)) if x.size else 0.0
        return (x - mu) / (sd + eps)

    # Fallback (shouldn't hit)
    return x.copy()

@torch.no_grad()
def compute_activated_positions_for_feature(
    fid: int,
    seqs: List[str],
    batch_size: int = BATCH_SIZE,
    max_per_feature: Optional[np.ndarray] = None,
    norm_mode: Literal["seq_max","feature_global_max","zscore","none"] = "seq_max",
    top_k: int = TOP_K,
    min_act: float = MIN_ACT,
    device: str = "cuda"
) -> Tuple[List[List[int]], List[List[str]], List[List[float]], List[List[float]]]:
    """
    For a list of sequences and a single SAE feature id (fid),
    return per-sequence top-K activated residue indices, AA identities,
    normalized scores (per selected position), and raw scores.

    Returns
    -------
    all_indices : List[List[int]]
        Per-sequence Top-K indices (0-based).
    all_aas : List[List[str]]
        Per-sequence amino acids at those indices.
    norm_vals : List[List[float]]
        Per-sequence normalized activations aligned with Top-K order.
    raw_vals : List[List[float]]
        Per-sequence raw activations aligned with Top-K order.
    """
    all_indices: List[List[int]] = []
    all_aas: List[List[str]] = []
    norm_vals: List[List[float]] = []
    raw_vals: List[List[float]] = []

    # Optional global max for this feature (for 'feature_global_max' mode)
    global_max = None
    if norm_mode == "feature_global_max" and max_per_feature is not None:
        if 0 <= fid < len(max_per_feature):
            gm = float(max_per_feature[fid])
            global_max = gm if np.isfinite(gm) else None

    for chunk in _batched(seqs, batch_size):
        # ESM -> SAE: token representations and mask
        token_reps, attn_mask = extract_esm_features_batch(
            chunk, layer_sel=plm_layer, device=device, model=model, tokenizer=tokenizer
        )  # token_reps: [B, L, H], attn_mask: [B, L] bool
        sae_feats, _, _ = extract_sae_features(token_reps, sae)  # [B, L, F]

        # Select feature channel -> [B, L] on CPU
        feat_act = sae_feats[..., fid].float().cpu()
        mask = attn_mask.cpu()  # [B, L] bool

        for seq, act_row, m in zip(chunk, feat_act, mask):
            L = int(m.sum().item())  # valid residues
            act_valid = act_row[:L].numpy() if L > 0 else np.array([], dtype=np.float32)

            # positions above threshold
            valid_idx = np.where(act_valid > min_act)[0]
            if valid_idx.size == 0:
                all_indices.append([])
                all_aas.append([])
                norm_vals.append([])
                raw_vals.append([])
                continue

            # sort by activation desc within valid subset and take top_k
            order = np.argsort(-act_valid[valid_idx])
            chosen_local = valid_idx[order[:top_k]].tolist()

            # raw values for the chosen positions
            chosen_raw = act_valid[chosen_local]

            # normalization (done on the full valid slice, then gather chosen)
            norm_full = _normalize_1d(
                act_valid, mode=norm_mode, global_max=global_max
            )
            chosen_norm = norm_full[chosen_local]

            # map to AAs (defensive indexing)
            aas = [seq[i] if i < len(seq) else "X" for i in chosen_local]

            all_indices.append(chosen_local)
            all_aas.append(aas)
            norm_vals.append([float(v) for v in chosen_norm])
            raw_vals.append([float(v) for v in chosen_raw])

        # free tensors early
        del token_reps, sae_feats, feat_act, attn_mask
        torch.cuda.empty_cache()

    return all_indices, all_aas, norm_vals, raw_vals


## Integrate with feature_datasets dict

In [None]:
#feature_datasets: Dict[int, pd.DataFrame] where each df includes 'Sequence'
from tqdm import tqdm
for fid, df in tqdm(list(feature_datasets.items())[:1], total=1, desc="Features"):
    #make a copy to avoid mutating original reference
    work = df.copy()

    # Ensure sequences are present and aligned
    if "Sequence" not in work.columns:
        # If your UniProt column is named differently, adjust here
        raise KeyError("Expected a 'Sequence' column in merged annotations.")
    seqs = work["Sequence"].astype(str).fillna("").tolist()

    idx_lists, aa_lists, norm_vals, raw_vals = compute_activated_positions_for_feature(fid, seqs, batch_size=BATCH_SIZE, max_per_feature = max_safe)
    print(len(idx_lists), len(aa_lists), len(norm_vals))
    #compute per-row activated positions for this feature
    work['activated_indices'] = idx_lists
    work['activated_aas'] = aa_lists
    work['seq_max_activation_norm'] = norm_vals

    # # (Optional) also add a compact string column for the prompt table
    work["activated_indices_str"] = work["activated_indices"].apply(lambda xs: ",".join(map(str, xs)) if xs else "")
    work["activated_aas_str"]     = work["activated_aas"].apply(lambda xs: ",".join(xs) if xs else "")
    work["seq_max_activation_norm_str"] = work["seq_max_activation_norm"].apply(lambda xs: ",".join(map(str, xs)) if xs else "")

    feature_datasets[fid] = work

first_fid = list(feature_datasets.keys())[0]
feature_datasets[first_fid][["uniprot_id","activation","seq_max_activation_norm", "activated_indices_str","activated_aas_str", "seq_max_activation_norm_str"]].head()

    

Features: 100%|██████████| 1/1 [00:00<00:00,  1.81it/s]

25 25 25





Unnamed: 0,uniprot_id,activation,seq_max_activation_norm,activated_indices_str,activated_aas_str,seq_max_activation_norm_str
0,Q1C0M8,0.0,[],,,
1,A5WV69,0.0,[],,,
2,Q7YRC1,0.138688,"[1.0, 0.9184075593948364, 0.31311237812042236,...",185237271278204192197205,"G,C,G,W,A,G,A,I","1.0,0.9184075593948364,0.31311237812042236,0.1..."
3,Q8TXW1,0.15686,"[0.9999998807907104, 0.5972216725349426, 0.583...",6362605961706474,"R,L,D,L,T,D,G,G","0.9999998807907104,0.5972216725349426,0.583204..."
4,P29305,0.383721,"[1.0, 0.9846875071525574, 0.8020842671394348, ...",1831311381458137136184,"S,G,A,A,I,L,Y,V","1.0,0.9846875071525574,0.8020842671394348,0.38..."


In [None]:
# --- Setup (once per notebook) ---
# pip install anthropic python-dotenv
from anthropic import Anthropic
from dotenv import load_dotenv
import os, time, json, math, textwrap
import numpy as np
import pandas as pd
from typing import Dict

# Load .env (expects ANTHROPIC_API_KEY=...)
load_dotenv()
client = Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))

# --- Config ---
MODEL_NAME = "claude-3-5-sonnet-20240620"
MAX_TOKENS = 800  # enough for description + summary
TEMPERATURE = 0.0 # deterministic
CHECKPOINT_EVERY = 50
OUTPUT_PATH = "claude_feature_annotations.csv"

# Columns to show Claude (customize as you like)
# We'll include what exists; missing columns are auto-dropped
PREFERRED_COLS = [
    # keys/ids
    "uniprot_id", "Entry", "Protein names",
    # size/sequence shape
    "Length",
    # functional annotations
    "EC number", "Active site", "Binding site", "Cofactor", "Disulfide bond",
    "Helix", "Turn", "Beta strand", "Coiled coil",
    "Domain [CC]", "Compositional bias", "Domain [FT]", "Motif", "Region", "Zinc finger",
    # your per-feature fields
    "activation", "bin",
    # optional (only used if present)
    "activated_indices", "activated_aas"
]

# Limit rows/cols so the table fits comfortably in context
MAX_ROWS = 80   # you can raise/lower if you hit token limits
TRUNCATE_STR_LEN = 120  # truncate long text fields so tables stay compact


def _coerce_and_trim_cols(df: pd.DataFrame, cols: list[str]) -> pd.DataFrame:
    """Select existing columns, stringify, and truncate long strings so the table is compact."""
    use_cols = [c for c in cols if c in df.columns]
    if not use_cols:
        # Fallback: show whatever is available
        use_cols = list(df.columns)

    out = df[use_cols].copy()

    # Coerce to string and truncate long values
    for c in use_cols:
        out[c] = out[c].astype(str).str.replace(r"\s+", " ", regex=True)
        out[c] = out[c].apply(lambda s: s[:TRUNCATE_STR_LEN] + "…" if len(s) > TRUNCATE_STR_LEN else s)

    # Keep only first MAX_ROWS to control token usage
    return out.head(MAX_ROWS)


PROMPT_TEMPLATE = """Generate description and summary
Analyze this protein dataset to determine what predicts the ’Maximum activation value’ and ‘Amino acids of
highest activated indices in protein’ columns. This description should be as concise as possible but sufficient to
predict these two columns on held-out data given only the description and the rest of the protein metadata
provided. The feature could be specific to a protein family, a structural motif, a sequence motif, a functional
role, etc. These WILL be used to predict how much unseen proteins are activated by the feature so only
highlight relevant factors for this.

Focus on:
• Properties of proteins from the metadata that are associated with high vs medium vs low activation.
• Where in the protein sequence activation occurs (in relation to the protein sequence, length, structure,
  or other properties)
• What functional annotations (binding sites, domains, etc.) and amino acids are present at or near the
  activated positions
• This description that will be used to help predict missing activation values should start with:
  “The activation patterns are characterized by:”

Then, in 1 sentence, summarize what biological feature or pattern this neural network activation is detecting.
This concise summary should start with “The feature activates on”.

Protein record:
{TABLE}
"""

def build_prompt(table_df: pd.DataFrame) -> str:
    table_md = table_df.to_markdown(index=False)
    return PROMPT_TEMPLATE.replace("{TABLE}", table_md)

def call_claude(prompt: str) -> str:
    """Call Claude, return raw text."""
    resp = client.messages.create(
        model=MODEL_NAME,
        max_tokens=MAX_TOKENS,
        temperature=TEMPERATURE,
        messages=[{"role": "user", "content": prompt}],
    )
    return resp.content[0].text

def parse_description_and_summary(text: str) -> dict:
    """
    Best-effort parse: extract the long description (must start with the required phrase)
    and the one-sentence summary (starts with 'The feature activates on').
    Falls back to raw if patterns aren’t found.
    """
    desc = ""
    summ = ""
    lines = [l.strip() for l in text.splitlines() if l.strip()]

    # Find the description block
    start_idx = None
    for i, l in enumerate(lines):
        if l.lower().startswith("the activation patterns are characterized by:"):
            start_idx = i
            break
    if start_idx is not None:
        # collect until we hit the summary or end
        buff = []
        for j in range(start_idx, len(lines)):
            if lines[j].lower().startswith("the feature activates on"):
                break
            buff.append(lines[j])
        desc = "\n".join(buff).strip()

    # Find the one-sentence summary
    for l in lines:
        if l.lower().startswith("the feature activates on"):
            # keep first sentence
            summ = l.split("\n")[0].strip()
            break

    return {
        "description": desc or "",
        "summary": summ or "",
        "raw": text.strip()
    }

# --- Main loop over feature datasets ---
# Expects: feature_datasets: Dict[int, pd.DataFrame]
results_rows = []

processed = 0
for fid, df in list(feature_datasets.items())[:3]:
    # Build a compact table for the model
    view = _coerce_and_trim_cols(df, PREFERRED_COLS)
    prompt = build_prompt(view)

    try:
        text = call_claude(prompt)
        parsed = parse_description_and_summary(text)
    except Exception as e:
        parsed = {"description": "", "summary": "", "raw": f"[ERROR] {e}"}

    results_rows.append({
        "feature_id": fid,
        "n_rows_shown": len(view),
        "description": parsed["description"],
        "summary": parsed["summary"],
        "raw_response": parsed["raw"],
    })

    processed += 1
    if processed % CHECKPOINT_EVERY == 0:
        pd.DataFrame(results_rows).to_parquet(OUTPUT_PATH, index=False)
        print(f"[checkpoint] saved {processed} → {OUTPUT_PATH}")

# Final save
df_results = pd.DataFrame(results_rows)
df_results.to_csv(OUTPUT_PATH, index=False)
print(f"[done] {len(df_results)} features → {OUTPUT_PATH}")
df_results.head()


Please migrate to a newer model. Visit https://docs.anthropic.com/en/docs/resources/model-deprecations for more information.
  resp = client.messages.create(
Please migrate to a newer model. Visit https://docs.anthropic.com/en/docs/resources/model-deprecations for more information.
  resp = client.messages.create(
Please migrate to a newer model. Visit https://docs.anthropic.com/en/docs/resources/model-deprecations for more information.
  resp = client.messages.create(


[done] 3 features → claude_feature_annotations.csv


Unnamed: 0,feature_id,n_rows_shown,description,summary,raw_response
0,2457,25,The activation patterns are characterized by:\...,The feature activates on conserved structural ...,The activation patterns are characterized by:\...
1,7950,23,The activation patterns are characterized by:\...,"The feature activates on small, single-domain ...",The activation patterns are characterized by:\...
2,780,30,The activation patterns are characterized by:\...,The feature activates on ATP synthase subunit ...,The activation patterns are characterized by:\...


In [None]:
for fid, df in list(feature_datasets.items())[:3]:
    # Build a compact table for the model
    view = _coerce_and_trim_cols(df, PREFERRED_COLS)
    prompt = build_prompt(view)

    try:
        text = call_claude(prompt)
        parsed = parse_description_and_summary(text)
    except Exception as e:
        parsed = {"description": "", "summary": "", "raw": f"[ERROR] {e}"}

    results_rows.append({
        "feature_id": fid,
        "n_rows_shown": len(view),
        "description": parsed["description"],
        "summary": parsed["summary"],
        "raw_response": parsed["raw"],
    })

    processed += 1
    if processed % CHECKPOINT_EVERY == 0:
        pd.DataFrame(results_rows).to_parquet(OUTPUT_PATH, index=False)
        print(f"[checkpoint] saved {processed} → {OUTPUT_PATH}")

# Final save
df_results = pd.DataFrame(results_rows)
df_results.to_csv(OUTPUT_PATH, index=False)
print(f"[done] {len(df_results)} features → {OUTPUT_PATH}")
df_results.head()


KeyError: 2167