In [2]:
from anthropic import Anthropic
from dotenv import load_dotenv
import os

# Load .env file
load_dotenv()

# Get key from environment
api_key = os.getenv("ANTHROPIC_API_KEY")

# Initialize client
client = Anthropic(api_key=api_key)

# Send a simple prompt
response = client.messages.create(
    model="claude-3-7-sonnet-20250219",
    max_tokens=200,
    messages=[
        {"role": "user", "content": "Say hello"}
    ]
)

print(response.content[0].text)


Hello! How can I assist you today?


In [3]:
from pathlib import Path
import pandas as pd, torch, os, gc
from interplm.sae.inference import load_sae_from_hf
import matplotlib.pyplot as plt
import numpy as np
DEVICE="cuda"

DATA_DIR = Path("esm_sae_results"); DATA_DIR.mkdir(exist_ok=True)
SEQUENCES_DIR = Path("/home/ec2-user/InterPLM/data/uniprot/subset_25k.csv")
# ANNOTATIONS_DIR = Path("uniprotkb_swissprot_annotations.tsv.gz")
ANNOTATIONS_DIR = Path("/home/ec2-user/InterPLM/subset_annotations.tsv.gz")


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
import glob
parts = [pd.read_pickle(p) for p in sorted(glob.glob(str(DATA_DIR / "sae_features_rank*.final.pkl")))]
features_all = pd.concat(parts, ignore_index=True).drop_duplicates(subset=["uniprot_id"])
features_all.to_pickle(DATA_DIR / "sae_features_all.pkl")
features_all.shape


(40000, 6)

In [5]:
features_all.head()

Unnamed: 0,uniprot_id,length,features,max_activation,n_active_features,reconstruction_mse
0,Q9GL23,50,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.002...",1.265625,1876,45.19838
1,Q6GZU6,50,"[0.00023197175, 0.0, 0.0, 0.0, 0.0013056946, 0...",0.843262,2168,13.467114
2,P9WJG6,50,"[0.0, 0.00057144166, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.935059,1740,12.720748
3,P18924,51,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.000...",0.956543,1799,11.394856
4,Q08076,52,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.000...",1.139648,1772,24.694654


In [6]:
annotations_df = pd.read_csv(ANNOTATIONS_DIR, sep="\t", compression="gzip")

In [7]:
annotations_df.head()

Unnamed: 0,Entry,Reviewed,Protein names,Length,Sequence,EC number,Active site,Binding site,Cofactor,Disulfide bond,...,Helix,Turn,Beta strand,Coiled coil,Domain [CC],Compositional bias,Domain [FT],Motif,Region,Zinc finger
0,A0A009IHW8,reviewed,2' cyclic ADP-D-ribose synthase AbTIR (2'cADPR...,269,MSLEQKKGADIISKILQIQNSIGKTTSPSTLKTKLSEISRKEQENA...,3.2.2.-; 3.2.2.6,"ACT_SITE 208; /evidence=""ECO:0000255|PROSITE-P...","BINDING 143; /ligand=""NAD(+)""; /ligand_id=""ChE...",,,...,"HELIX 143..145; /evidence=""ECO:0007829|PDB:7UW...","TURN 146..149; /evidence=""ECO:0007829|PDB:7UWG...","STRAND 135..142; /evidence=""ECO:0007829|PDB:7U...","COILED 31..99; /evidence=""ECO:0000255""",DOMAIN: The TIR domain mediates NAD(+) hydrola...,,"DOMAIN 133..266; /note=""TIR""; /evidence=""ECO:0...",,,
1,A0A059WI14,reviewed,Trivalent organoarsenical cleaving enzyme (EC ...,161,MKYAHVGLNVTNLEKSIEFYSKLFGAEPVKVKPDYAKFLLESPGLN...,1.13.11.-,,"BINDING 5; /ligand=""Fe(2+)""; /ligand_id=""ChEBI...",COFACTOR: Name=Fe(2+); Xref=ChEBI:CHEBI:29033;...,,...,,,,,DOMAIN: The thiolates of the vicinal cysteine ...,,"DOMAIN 2..119; /note=""VOC""; /evidence=""ECO:000...",,,
2,A0A067XGX8,reviewed,"Phospho-2-dehydro-3-deoxyheptonate aldolase 2,...",512,MALTATATTRGGSALPNSCLQTPKFQSLQKPTFISSFPTNKKTKPR...,2.5.1.54,,"BINDING 126; /ligand=""Mn(2+)""; /ligand_id=""ChE...",COFACTOR: Name=Mn(2+); Xref=ChEBI:CHEBI:29035;...,,...,,,,,,,,,"REGION 37..57; /note=""Disordered""; /evidence=""...",
3,A0A067XH53,reviewed,"Phospho-2-dehydro-3-deoxyheptonate aldolase 1,...",533,MALSTNSTTSSLLPKTPLVQQPLLKNASLPTTTKAIRFIQPISAIH...,2.5.1.54,,"BINDING 145; /ligand=""Mn(2+)""; /ligand_id=""ChE...",COFACTOR: Name=Mn(2+); Xref=ChEBI:CHEBI:29035;...,,...,,,,,,"COMPBIAS 47..56; /note=""Polar residues""; /evid...",,,"REGION 47..70; /note=""Disordered""; /evidence=""...",
4,A0A0A1H8I4,reviewed,Aconitate isomerase (AI) (EC 5.3.3.7),262,MFPRLPTLALGALLLASTPLLAAQPVTTLTVLSSGGIMGTIREVAP...,5.3.3.7,,,,,...,,,,,,,,,,


In [8]:
import numpy as np
import pandas as pd
import random

# Parameters
N_FEATURES = 1200
BINS = np.arange(0, 1.1, 0.1)

# Randomly select feature ids
all_feature_ids = list(range(len(features_all.iloc[0].features)))
print("num features", len(all_feature_ids))
selected_features = random.sample(all_feature_ids, N_FEATURES)

print(f"Selected {len(selected_features)} features out of {len(all_feature_ids)}")

# Build dataset for each feature
feature_datasets = {}

# Predefine bin labels
bin_labels = [f"{BINS[i]:.1f}-{BINS[i+1]:.1f}" for i in range(len(BINS)-1)]

for fid in selected_features:
    # Extract activations for this feature
    activations = [f[fid] for f in features_all["features"]]
    df = pd.DataFrame({
        "uniprot_id": features_all["uniprot_id"],
        "activation": activations
    })

    # Assign bins
    df["bin"] = pd.cut(df["activation"], bins=BINS, labels=bin_labels, include_lowest=True)

    sampled = []

    # Sample proteins per bin
    for b in df["bin"].dropna().unique():
        bin_df = df[df["bin"] == b]
        n = 10 if b == "0.9-1.0" else 2
        sampled.extend(bin_df.sample(min(len(bin_df), n), random_state=42).to_dict(orient="records"))

    # Add 10 random zero-activation proteins 
    zero_df = df[df["activation"] == 0.0]
    if len(zero_df) > 0:
        sampled.extend(zero_df.sample(min(len(zero_df), 10), random_state=42).to_dict(orient="records"))

    # Merge with metadata from annotations_df
    sampled_df = pd.DataFrame(sampled)
    merged = sampled_df.merge(annotations_df, left_on="uniprot_id", right_on="Entry", how="left")

    feature_datasets[fid] = merged

# Example feature dataset
example_fid = selected_features[0]
feature_datasets[example_fid].head()


num features 10240
Selected 1200 features out of 10240


## Normalize features

In [None]:
import numpy as np

#Stakc into [num_proteins, num_features]
X = np.vstack(features_all['features'].values) #Shape (N, F)

#Max activation per feature across all proteins
max_per_feature = X.max(axis=0) # shape: (F,)
eps = 1e-12
max_safe = np.where(max_per_feature > 0, max_per_feature, eps)
#Normalize
X_norm = X / max_safe

#Save back
features_all = features_all.copy()
features_all["features_norm"] = [row for row in X_norm]
print("Original max activation (feature 0):", X[:,0].max())
print("Normalized max activation (feature 0):", X_norm[:,0].max())

NameError: name 'features_all' is not defined

In [None]:
import numpy as np
import pandas as pd
import random

# Parameters
N_FEATURES = 1200
BINS = np.arange(0, 1.1, 0.1)

# Randomly select feature ids
all_feature_ids = list(range(len(features_all.iloc[0].features)))
print("num features", len(all_feature_ids))
selected_features = random.sample(all_feature_ids, N_FEATURES)

print(f"Selected {len(selected_features)} features out of {len(all_feature_ids)}")

# Build dataset for each feature
feature_datasets = {}

# Predefine bin labels
bin_labels = [f"{BINS[i]:.1f}-{BINS[i+1]:.1f}" for i in range(len(BINS)-1)]

for fid in selected_features:
    # Extract activations for this feature
    activations = [f[fid] for f in features_all["features_norm"]]
    df = pd.DataFrame({
        "uniprot_id": features_all["uniprot_id"],
        "activation": activations
    })

    # Assign bins
    df["bin"] = pd.cut(df["activation"], bins=BINS, labels=bin_labels, include_lowest=True)

    sampled = []

    # Sample proteins per bin
    for b in df["bin"].dropna().unique():
        bin_df = df[df["bin"] == b]
        n = 10 if b == "0.9-1.0" else 2
        sampled.extend(bin_df.sample(min(len(bin_df), n), random_state=42).to_dict(orient="records"))

    # Add 10 random zero-activation proteins 
    zero_df = df[df["activation"] == 0.0]
    if len(zero_df) > 0:
        sampled.extend(zero_df.sample(min(len(zero_df), 10), random_state=42).to_dict(orient="records"))

    # Merge with metadata from annotations_df
    sampled_df = pd.DataFrame(sampled)
    merged = sampled_df.merge(annotations_df, left_on="uniprot_id", right_on="Entry", how="left")

    feature_datasets[fid] = merged

# Example feature dataset
example_fid = selected_features[0]
feature_datasets[example_fid].head()


In [None]:
# --- Setup (once per notebook) ---
# pip install anthropic python-dotenv
from anthropic import Anthropic
from dotenv import load_dotenv
import os, time, json, math, textwrap
import numpy as np
import pandas as pd
from typing import Dict

# Load .env (expects ANTHROPIC_API_KEY=...)
load_dotenv()
client = Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))

# --- Config ---
MODEL_NAME = "claude-3-5-sonnet-20240620"
MAX_TOKENS = 800  # enough for description + summary
TEMPERATURE = 0.0 # deterministic
CHECKPOINT_EVERY = 50
OUTPUT_PATH = "claude_feature_annotations.parquet"

# Columns to show Claude (customize as you like)
# We'll include what exists; missing columns are auto-dropped
PREFERRED_COLS = [
    # keys/ids
    "uniprot_id", "Entry", "Protein names",
    # size/sequence shape
    "Length",
    # functional annotations
    "EC number", "Active site", "Binding site", "Cofactor", "Disulfide bond",
    "Helix", "Turn", "Beta strand", "Coiled coil",
    "Domain [CC]", "Compositional bias", "Domain [FT]", "Motif", "Region", "Zinc finger",
    # your per-feature fields
    "activation", "bin",
    # optional (only used if present)
    "activated_indices", "activated_aas"
]

# Limit rows/cols so the table fits comfortably in context
MAX_ROWS = 80   # you can raise/lower if you hit token limits
TRUNCATE_STR_LEN = 120  # truncate long text fields so tables stay compact


def _coerce_and_trim_cols(df: pd.DataFrame, cols: list[str]) -> pd.DataFrame:
    """Select existing columns, stringify, and truncate long strings so the table is compact."""
    use_cols = [c for c in cols if c in df.columns]
    if not use_cols:
        # Fallback: show whatever is available
        use_cols = list(df.columns)

    out = df[use_cols].copy()

    # Coerce to string and truncate long values
    for c in use_cols:
        out[c] = out[c].astype(str).str.replace(r"\s+", " ", regex=True)
        out[c] = out[c].apply(lambda s: s[:TRUNCATE_STR_LEN] + "…" if len(s) > TRUNCATE_STR_LEN else s)

    # Keep only first MAX_ROWS to control token usage
    return out.head(MAX_ROWS)


PROMPT_TEMPLATE = """Generate description and summary
Analyze this protein dataset to determine what predicts the ’Maximum activation value’ and ‘Amino acids of
highest activated indices in protein’ columns. This description should be as concise as possible but sufficient to
predict these two columns on held-out data given only the description and the rest of the protein metadata
provided. The feature could be specific to a protein family, a structural motif, a sequence motif, a functional
role, etc. These WILL be used to predict how much unseen proteins are activated by the feature so only
highlight relevant factors for this.

Focus on:
• Properties of proteins from the metadata that are associated with high vs medium vs low activation.
• Where in the protein sequence activation occurs (in relation to the protein sequence, length, structure,
  or other properties)
• What functional annotations (binding sites, domains, etc.) and amino acids are present at or near the
  activated positions
• This description that will be used to help predict missing activation values should start with:
  “The activation patterns are characterized by:”

Then, in 1 sentence, summarize what biological feature or pattern this neural network activation is detecting.
This concise summary should start with “The feature activates on”.

Protein record:
{TABLE}
"""

def build_prompt(table_df: pd.DataFrame) -> str:
    table_md = table_df.to_markdown(index=False)
    return PROMPT_TEMPLATE.replace("{TABLE}", table_md)

def call_claude(prompt: str) -> str:
    """Call Claude, return raw text."""
    resp = client.messages.create(
        model=MODEL_NAME,
        max_tokens=MAX_TOKENS,
        temperature=TEMPERATURE,
        messages=[{"role": "user", "content": prompt}],
    )
    return resp.content[0].text

def parse_description_and_summary(text: str) -> dict:
    """
    Best-effort parse: extract the long description (must start with the required phrase)
    and the one-sentence summary (starts with 'The feature activates on').
    Falls back to raw if patterns aren’t found.
    """
    desc = ""
    summ = ""
    lines = [l.strip() for l in text.splitlines() if l.strip()]

    # Find the description block
    start_idx = None
    for i, l in enumerate(lines):
        if l.lower().startswith("the activation patterns are characterized by:"):
            start_idx = i
            break
    if start_idx is not None:
        # collect until we hit the summary or end
        buff = []
        for j in range(start_idx, len(lines)):
            if lines[j].lower().startswith("the feature activates on"):
                break
            buff.append(lines[j])
        desc = "\n".join(buff).strip()

    # Find the one-sentence summary
    for l in lines:
        if l.lower().startswith("the feature activates on"):
            # keep first sentence
            summ = l.split("\n")[0].strip()
            break

    return {
        "description": desc or "",
        "summary": summ or "",
        "raw": text.strip()
    }

# --- Main loop over feature datasets ---
# Expects: feature_datasets: Dict[int, pd.DataFrame]
results_rows = []

processed = 0
for fid, df in feature_datasets[0].items():
    # Build a compact table for the model
    view = _coerce_and_trim_cols(df, PREFERRED_COLS)
    prompt = build_prompt(view)

    try:
        text = call_claude(prompt)
        parsed = parse_description_and_summary(text)
    except Exception as e:
        parsed = {"description": "", "summary": "", "raw": f"[ERROR] {e}"}

    results_rows.append({
        "feature_id": fid,
        "n_rows_shown": len(view),
        "description": parsed["description"],
        "summary": parsed["summary"],
        "raw_response": parsed["raw"],
    })

    processed += 1
    if processed % CHECKPOINT_EVERY == 0:
        pd.DataFrame(results_rows).to_parquet(OUTPUT_PATH, index=False)
        print(f"[checkpoint] saved {processed} → {OUTPUT_PATH}")

# Final save
df_results = pd.DataFrame(results_rows)
df_results.to_parquet(OUTPUT_PATH, index=False)
print(f"[done] {len(df_results)} features → {OUTPUT_PATH}")
df_results.head()


In [None]:
# Pick the first feature id
first_fid = list(feature_datasets.keys())[0]

# Get its protein dataframe
df = feature_datasets[first_fid]

# Build a compact table for Claude
view = _coerce_and_trim_cols(df, PREFERRED_COLS)
prompt = build_prompt(view)

# Send to Claude
text = call_claude(prompt)
parsed = parse_description_and_summary(text)

print("Feature ID:", first_fid)
print("Description:\n", parsed["description"])
print("\nSummary:", parsed["summary"])
