In [16]:
from anthropic import Anthropic
from dotenv import load_dotenv
import os
import io, csv, json, math, textwrap
import pandas as pd
import numpy as np
from typing import List, Dict, Tuple, Optional

# Load .env file
load_dotenv()

# Get key from environment
api_key = os.getenv("ANTHROPIC_API_KEY")

# Initialize client
client = Anthropic(api_key=api_key)

# Send a simple prompt
response = client.messages.create(
    model="claude-3-7-sonnet-20250219",
    max_tokens=200,
    messages=[
        {"role": "user", "content": "Say hello"}
    ]
)

print(response.content[0].text)


Hello! How can I assist you today?


In [2]:
from pathlib import Path
import pandas as pd, torch, os, gc
from interplm.sae.inference import load_sae_from_hf
import matplotlib.pyplot as plt
import numpy as np

DEVICE="cuda"
DTYPE=torch.float16

DATA_DIR = Path("esm_sae_results"); DATA_DIR.mkdir(exist_ok=True)
SEQUENCES_DIR = Path("/home/ec2-user/SageMaker/InterPLM/data/uniprot/subset_25k.csv")
# ANNOTATIONS_DIR = Path("uniprotkb_swissprot_annotations.tsv.gz")
ANNOTATIONS_DIR = Path("/home/ec2-user/SageMaker/InterPLM/uniprotkb_swissprot_annotations.tsv.gz")


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", do_lower_case=False)
model     = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D",
                                        output_hidden_states=True).to(DEVICE).eval()

# Make sure the SAE you load matches the *plm_model* and *plm_layer* you want to use
plm_model = "esm2-650m"   # matches your checkpoint naming
plm_layer = 24            # <= MUST match esm_layer_sel
sae = load_sae_from_hf(plm_model=plm_model, plm_layer=plm_layer).to(DEVICE).eval()


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
import glob
features_all = pd.read_pickle("features_all.pkl")
features_all.shape


(40000, 6)

In [5]:
features_all.head()

Unnamed: 0,uniprot_id,length,features,max_activation,n_active_features,reconstruction_mse
0,Q9GL23,50,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.002...",1.265625,1876,45.19838
1,Q6GZU6,50,"[0.00023197175, 0.0, 0.0, 0.0, 0.0013056946, 0...",0.843262,2168,13.467114
2,P9WJG6,50,"[0.0, 0.00057144166, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.935059,1740,12.720748
3,P18924,51,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.000...",0.956543,1799,11.394856
4,Q08076,52,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.000...",1.139648,1772,24.694654


In [6]:
annotations_df = pd.read_csv(ANNOTATIONS_DIR, sep="\t", compression="gzip")

In [7]:
annotations_df.head()

Unnamed: 0,Entry,Reviewed,Protein names,Length,Sequence,EC number,Active site,Binding site,Cofactor,Disulfide bond,...,Helix,Turn,Beta strand,Coiled coil,Domain [CC],Compositional bias,Domain [FT],Motif,Region,Zinc finger
0,A0A009IHW8,reviewed,2' cyclic ADP-D-ribose synthase AbTIR (2'cADPR...,269,MSLEQKKGADIISKILQIQNSIGKTTSPSTLKTKLSEISRKEQENA...,3.2.2.-; 3.2.2.6,"ACT_SITE 208; /evidence=""ECO:0000255|PROSITE-P...","BINDING 143; /ligand=""NAD(+)""; /ligand_id=""ChE...",,,...,"HELIX 143..145; /evidence=""ECO:0007829|PDB:7UW...","TURN 146..149; /evidence=""ECO:0007829|PDB:7UWG...","STRAND 135..142; /evidence=""ECO:0007829|PDB:7U...","COILED 31..99; /evidence=""ECO:0000255""",DOMAIN: The TIR domain mediates NAD(+) hydrola...,,"DOMAIN 133..266; /note=""TIR""; /evidence=""ECO:0...",,,
1,A0A023I7E1,reviewed,"Glucan endo-1,3-beta-D-glucosidase 1 (Endo-1,3...",796,MRFQVIVAAATITMITSYIPGVASQSTSDGDDLFVPVSNFDPKSIF...,3.2.1.39,"ACT_SITE 500; /evidence=""ECO:0000255|PROSITE-P...","BINDING 504; /ligand=""(1,3-beta-D-glucosyl)n"";...",,,...,"HELIX 42..44; /evidence=""ECO:0007829|PDB:4K35""...","TURN 287..289; /evidence=""ECO:0007829|PDB:4K35...","STRAND 56..58; /evidence=""ECO:0007829|PDB:4K35...",,,,"DOMAIN 31..759; /note=""GH81""; /evidence=""ECO:0...",,"REGION 31..276; /note=""beta-sandwich subdomain...",
2,A0A024B7W1,reviewed,Genome polyprotein [Cleaved into: Capsid prote...,3423,MKNPKKKSGGFRIVNMLKRGVARVSPFGGLKRLPAGLLLGHGPIRM...,2.1.1.56; 2.1.1.57; 2.7.7.48; 3.4.21.91; 3.6.1...,"ACT_SITE 1553; /note=""Charge relay system; for...","BINDING 1696..1703; /ligand=""ATP""; /ligand_id=...",,"DISULFID 350..406; /evidence=""ECO:0000250|UniP...",...,"HELIX 222..225; /evidence=""ECO:0007829|PDB:6CO...","TURN 237..241; /evidence=""ECO:0007829|PDB:6CO8...","STRAND 234..236; /evidence=""ECO:0007829|PDB:6C...",,DOMAIN: [Small envelope protein M]: The transm...,,"DOMAIN 1503..1680; /note=""Peptidase S7""; /evid...","MOTIF 1787..1790; /note=""DEAH box""; /evidence=...","REGION 1..25; /note=""Disordered""; /evidence=""E...",
3,A0A024RXP8,reviewed,"Exoglucanase 1 (EC 3.2.1.91) (1,4-beta-cellobi...",514,MYRKLAVISAFLATARAQSACTLQSETHPPLTWQKCSSGGTCTQQT...,3.2.1.91,"ACT_SITE 229; /note=""Nucleophile""; /evidence=""...",,,"DISULFID 21..89; /evidence=""ECO:0000250|UniPro...",...,,,,,DOMAIN: The enzyme consists of two functional ...,"COMPBIAS 401..437; /note=""Polar residues""; /ev...","DOMAIN 478..514; /note=""CBM1""; /evidence=""ECO:...",,"REGION 18..453; /note=""Catalytic""; /evidence=""...",
4,A0A024SC78,reviewed,Cutinase (EC 3.1.1.74),248,MRSLAILTTLLAGHAFAYPKPAPQSVNRRDWPSINEFLSELAKVMP...,3.1.1.74,"ACT_SITE 164; /note=""Nucleophile""; /evidence=""...",,,"DISULFID 55..91; /evidence=""ECO:0000269|PubMed...",...,"HELIX 51..69; /evidence=""ECO:0007829|PDB:4PSC""...","TURN 94..100; /evidence=""ECO:0007829|PDB:4PSC""...","STRAND 48..50; /evidence=""ECO:0007829|PDB:4PSC...",,"DOMAIN: In contract to classical cutinases, po...",,,,"REGION 31..70; /note=""Lid covering the active ...",


In [8]:
# import numpy as np
# import pandas as pd
# import random

# # Parameters
# N_FEATURES = 1200
# BINS = np.arange(0, 1.1, 0.1)

# # Randomly select feature ids
# all_feature_ids = list(range(len(features_all.iloc[0].features)))
# print("num features", len(all_feature_ids))
# selected_features = random.sample(all_feature_ids, N_FEATURES)

# print(f"Selected {len(selected_features)} features out of {len(all_feature_ids)}")

# # Build dataset for each feature
# feature_datasets = {}

# # Predefine bin labels
# bin_labels = [f"{BINS[i]:.1f}-{BINS[i+1]:.1f}" for i in range(len(BINS)-1)]

# for fid in selected_features:
#     # Extract activations for this feature
#     activations = [f[fid] for f in features_all["features"]]
#     df = pd.DataFrame({
#         "uniprot_id": features_all["uniprot_id"],
#         "activation": activations
#     })

#     # Assign bins
#     df["bin"] = pd.cut(df["activation"], bins=BINS, labels=bin_labels, include_lowest=True)

#     sampled = []

#     # Sample proteins per bin
#     for b in df["bin"].dropna().unique():
#         bin_df = df[df["bin"] == b]
#         n = 10 if b == "0.9-1.0" else 2
#         sampled.extend(bin_df.sample(min(len(bin_df), n), random_state=42).to_dict(orient="records"))

#     # Add 10 random zero-activation proteins 
#     zero_df = df[df["activation"] == 0.0]
#     if len(zero_df) > 0:
#         sampled.extend(zero_df.sample(min(len(zero_df), 10), random_state=42).to_dict(orient="records"))

#     # Merge with metadata from annotations_df
#     sampled_df = pd.DataFrame(sampled)
#     merged = sampled_df.merge(annotations_df, left_on="uniprot_id", right_on="Entry", how="left")

#     feature_datasets[fid] = merged

# # Example feature dataset
# example_fid = selected_features[0]
# feature_datasets[example_fid].head()


## Normalize features

In [8]:
import numpy as np

#Stakc into [num_proteins, num_features]
X = np.vstack(features_all['features'].values) #Shape (N, F)

#Max activation per feature across all proteins
max_per_feature = X.max(axis=0) # shape: (F,)
eps = 1e-12
max_safe = np.where(max_per_feature > 0, max_per_feature, eps)
#Normalize
X_norm = X / max_safe

#Save back
features_all = features_all.copy()
features_all["features_norm"] = [row for row in X_norm]
print("Original max activation (feature 0):", X[:,0].max())
print("Normalized max activation (feature 0):", X_norm[:,0].max())

Original max activation (feature 0): 0.02611415
Normalized max activation (feature 0): 1.0


In [10]:
import numpy as np
import pandas as pd
import random

# Parameters
N_FEATURES = 10240
BINS = np.arange(0, 1.1, 0.1)

# Randomly select feature ids
# all_feature_ids = list(range(len(features_all.iloc[0].features)))
# print("num features", len(all_feature_ids))
# selected_features = random.sample(all_feature_ids, N_FEATURES)

# print(f"Selected {len(selected_features)} features out of {len(all_feature_ids)}")

# # Build dataset for each feature
# feature_datasets = {}

# # Predefine bin labels
# bin_labels = [f"{BINS[i]:.1f}-{BINS[i+1]:.1f}" for i in range(len(BINS)-1)]

# for fid in selected_features:
#     # Extract activations for this feature
#     activations = [f[fid] for f in features_all["features_norm"]]
#     df = pd.DataFrame({
#         "uniprot_id": features_all["uniprot_id"],
#         "activation": activations
#     })

#     # Assign bins
#     df["bin"] = pd.cut(df["activation"], bins=BINS, labels=bin_labels, include_lowest=True)

#     sampled = []

#     # Sample proteins per bin
#     for b in df["bin"].dropna().unique():
#         bin_df = df[df["bin"] == b]
#         n = 10 if b == "0.9-1.0" else 2
#         sampled.extend(bin_df.sample(min(len(bin_df), n), random_state=42).to_dict(orient="records"))

#     # Add 10 random zero-activation proteins 
#     zero_df = df[df["activation"] == 0.0]
#     if len(zero_df) > 0:
#         sampled.extend(zero_df.sample(min(len(zero_df), 10), random_state=42).to_dict(orient="records"))

#     # Merge with metadata from annotations_df
#     sampled_df = pd.DataFrame(sampled)
#     merged = sampled_df.merge(annotations_df, left_on="uniprot_id", right_on="Entry", how="left")

#     feature_datasets[fid] = merged

# # Example feature dataset
# example_fid = selected_features[0]
# feature_datasets[example_fid].head()


In [11]:
# import pickle
feature_datasets
# # Suppose feature_datasets is your dict of DataFrames
# with open("feature_datasets.pkl", "wb") as f:
#     pickle.dump(feature_datasets, f)

In [18]:
import pickle
with open("feature_datasets.pkl", "rb") as f:
    feature_datasets = pickle.load(f)


In [24]:
import random
import pickle

# Make sure it's reproducible
random.seed(42)

# Get all keys
all_keys = list(feature_datasets.keys())

# Sample 1200 keys (without replacement)
sample_keys = random.sample(all_keys, 1200)

# Build a new dict with only those
subset = {k: feature_datasets[k] for k in sample_keys}

# Save to pickle
with open("feature_datasets_subset.pkl", "wb") as f:
    pickle.dump(subset, f)

## Add Amino Acid Indices and activations for better LLM annotations

In [9]:
from utils import extract_esm_features_batch

@torch.no_grad()
def extract_sae_features(hidden_states: torch.Tensor, sae):
    """
    Pass ESM hidden states through the Sparse Autoencoder (SAE).

    Args
    ----
    hidden_states : torch.Tensor
        Shape [B, L, d] or [L, d].
        - B = batch size (optional if unsqueezed)
        - L = sequence length
        - d = ESM embedding dimension (e.g., 1280 for esm2_t33_650M)

    Returns
    -------
    sae_features : torch.Tensor
        Shape [B, L, F]
        Sparse latent features per residue.
        F = number of SAE dictionary atoms / features.

    recon : torch.Tensor
        Shape [B, L, d]
        Reconstructed embeddings in token space.

    error : torch.Tensor
        Shape [B, L, d]
        Residual = hidden_states - recon
    """
    if hidden_states.dim() == 2:          # [L, d]
        hidden_states = hidden_states.unsqueeze(0)  # → [1, L, d]
    x = hidden_states.to(torch.float32)      # <- ensure fp32 for SAE

    # SAE should have encode() and decode() that operate on last dimension
    sae_features = sae.encode(x)     # [B, L, F]
    recon        = sae.decode(sae_features)      # [B, L, d]
    error        = hidden_states - recon         # [B, L, d]

    return sae_features, recon, error

#Config for extracting activated positions

TOP_K = 8 # How many positions to record per protein
MIN_ACT = 0.0 #Only consider positions with activation > MIN_ACT
BATCH_SIZE = 16 #For ESM/SAE Inference


def _batched(iterable, n):
    """Yield Successive n-sized chunks from iterable
    """
    it = list(iterable)
    for i in range(0, len(it), n):
        yield it[i:i+n]

import numpy as np
import torch
from typing import List, Tuple, Optional, Literal

TOP_K = 8
MIN_ACT = 0.0
BATCH_SIZE = 16

def _batched(iterable, n):
    it = list(iterable)
    for i in range(0, len(it), n):
        yield it[i:i+n]

def _normalize_1d(
    x: np.ndarray,
    mode: Literal["seq_max","feature_global_max","zscore","none"] = "seq_max",
    global_max: Optional[float] = None,
    eps: float = 1e-8
) -> np.ndarray:
    """
    Normalize a 1D activation vector x (valid positions only).
    """
    if mode == "none":
        return x.copy()

    if mode == "feature_global_max":
        if global_max is None or global_max <= eps:
            # fallback to seq_max if global not available/safe
            mode = "seq_max"
        else:
            return x / (global_max + eps)

    if mode == "seq_max":
        m = np.max(x) if x.size else 0.0
        return x / (m + eps)

    if mode == "zscore":
        mu = float(np.mean(x)) if x.size else 0.0
        sd = float(np.std(x)) if x.size else 0.0
        return (x - mu) / (sd + eps)

    # Fallback (shouldn't hit)
    return x.copy()

@torch.no_grad()
def compute_activated_positions_for_feature(
    fid: int,
    seqs: List[str],
    batch_size: int = BATCH_SIZE,
    max_per_feature: Optional[np.ndarray] = None,
    norm_mode: Literal["seq_max","feature_global_max","zscore","none"] = "seq_max",
    top_k: int = TOP_K,
    min_act: float = MIN_ACT,
    device: str = "cuda"
) -> Tuple[List[List[int]], List[List[str]], List[List[float]], List[List[float]]]:
    """
    For a list of sequences and a single SAE feature id (fid),
    return per-sequence top-K activated residue indices, AA identities,
    normalized scores (per selected position), and raw scores.

    Returns
    -------
    all_indices : List[List[int]]
        Per-sequence Top-K indices (0-based).
    all_aas : List[List[str]]
        Per-sequence amino acids at those indices.
    norm_vals : List[List[float]]
        Per-sequence normalized activations aligned with Top-K order.
    raw_vals : List[List[float]]
        Per-sequence raw activations aligned with Top-K order.
    """
    all_indices: List[List[int]] = []
    all_aas: List[List[str]] = []
    norm_vals: List[List[float]] = []
    raw_vals: List[List[float]] = []

    # Optional global max for this feature (for 'feature_global_max' mode)
    global_max = None
    if norm_mode == "feature_global_max" and max_per_feature is not None:
        if 0 <= fid < len(max_per_feature):
            gm = float(max_per_feature[fid])
            global_max = gm if np.isfinite(gm) else None

    for chunk in _batched(seqs, batch_size):
        # ESM -> SAE: token representations and mask
        token_reps, attn_mask = extract_esm_features_batch(
            chunk, layer_sel=plm_layer, device=device, model=model, tokenizer=tokenizer
        )  # token_reps: [B, L, H], attn_mask: [B, L] bool
        sae_feats, _, _ = extract_sae_features(token_reps, sae)  # [B, L, F]

        # Select feature channel -> [B, L] on CPU
        feat_act = sae_feats[..., fid].float().cpu()
        mask = attn_mask.cpu()  # [B, L] bool

        for seq, act_row, m in zip(chunk, feat_act, mask):
            L = int(m.sum().item())  # valid residues
            act_valid = act_row[:L].numpy() if L > 0 else np.array([], dtype=np.float32)

            # positions above threshold
            valid_idx = np.where(act_valid > min_act)[0]
            if valid_idx.size == 0:
                all_indices.append([])
                all_aas.append([])
                norm_vals.append([])
                raw_vals.append([])
                continue

            # sort by activation desc within valid subset and take top_k
            order = np.argsort(-act_valid[valid_idx])
            chosen_local = valid_idx[order[:top_k]].tolist()

            # raw values for the chosen positions
            chosen_raw = act_valid[chosen_local]

            # normalization (done on the full valid slice, then gather chosen)
            norm_full = _normalize_1d(
                act_valid, mode=norm_mode, global_max=global_max
            )
            chosen_norm = norm_full[chosen_local]

            # map to AAs (defensive indexing)
            aas = [seq[i] if i < len(seq) else "X" for i in chosen_local]

            all_indices.append(chosen_local)
            all_aas.append(aas)
            norm_vals.append([float(v) for v in chosen_norm])
            raw_vals.append([float(v) for v in chosen_raw])

        # free tensors early
        del token_reps, sae_feats, feat_act, attn_mask
        torch.cuda.empty_cache()

    return all_indices, all_aas, norm_vals, raw_vals


## Using ray so this doesn't take an hour and a half

## Integrate with feature_datasets dict

In [29]:

# ----------------------------
# Core compute function (uses passed-in model/tokenizer/sae) — same logic you posted
# ----------------------------
@torch.no_grad()
def _compute_activated_positions_for_feature_wrapped(
    fid: int,
    seqs: List[str],
    *,
    batch_size: int,
    max_per_feature: Optional[np.ndarray] = None,
    norm_mode: Literal["seq_max", "feature_global_max", "zscore", "none"] = "seq_max",
    top_k: int = 5,
    min_act: float = 0.0,
    device: str = "cuda:0",
    plm_layer: int = 24,
    model=None,
    tokenizer=None,
    sae=None,
) -> Tuple[List[List[int]], List[List[str]], List[List[float]], List[List[float]]]:
    all_indices: List[List[int]] = []
    all_aas: List[List[str]] = []
    norm_vals: List[List[float]] = []
    raw_vals: List[List[float]] = []

    global_max = None
    if norm_mode == "feature_global_max" and max_per_feature is not None:
        if 0 <= fid < len(max_per_feature):
            gm = float(max_per_feature[fid])
            global_max = gm if np.isfinite(gm) else None

    for chunk in _batched(seqs, batch_size):
        token_reps, attn_mask = extract_esm_features_batch(
            chunk,
            layer_sel=plm_layer,
            device=device,
            model=model,
            tokenizer=tokenizer,
        )

        sae_feats, _, _ = extract_sae_features(token_reps, sae)  # [B, L, F]
        feat_act = sae_feats[..., fid].float().cpu()
        mask = attn_mask.cpu()

        for seq, act_row, m in zip(chunk, feat_act, mask):
            L = int(m.sum().item())
            act_valid = act_row[:L].numpy() if L > 0 else np.array([], dtype=np.float32)

            valid_idx = np.where(act_valid > min_act)[0]
            if valid_idx.size == 0:
                all_indices.append([])
                all_aas.append([])
                norm_vals.append([])
                raw_vals.append([])
                continue

            order = np.argsort(-act_valid[valid_idx])
            chosen_local = valid_idx[order[:top_k]].tolist()

            chosen_raw = act_valid[chosen_local]
            norm_full = _normalize_1d(act_valid, mode=norm_mode, global_max=global_max)
            chosen_norm = norm_full[chosen_local]

            aas = [seq[i] if i < len(seq) else "X" for i in chosen_local]

            all_indices.append(chosen_local)
            all_aas.append(aas)
            norm_vals.append([float(v) for v in chosen_norm])
            raw_vals.append([float(v) for v in chosen_raw])

        # clean up
        del token_reps, sae_feats, feat_act, attn_mask
        torch.cuda.empty_cache()

    return all_indices, all_aas, norm_vals, raw_vals

# ray_feature_activation.py  (only the relevant parts shown)
import os
import ray
import torch
import pandas as pd
import numpy as np
from pathlib import Path
from typing import Dict, List, Optional, Literal, Tuple
from transformers import AutoTokenizer, AutoModel
from interplm.sae.inference import load_sae_from_hf
from tqdm.auto import tqdm
from ray.util.queue import Queue

from utils import extract_esm_features_batch, _batched, _normalize_1d, extract_sae_features

DATA_DIR = Path("esm_sae_results"); DATA_DIR.mkdir(exist_ok=True, parents=True)
os.environ.setdefault("HF_HOME", str(DATA_DIR / "hf_cache"))
os.environ.setdefault("TRANSFORMERS_CACHE", str(DATA_DIR / "hf_cache"))
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True,max_split_size_mb:128")

ray.shutdown()
# Optional: silence worker stdout spam
ray.init(ignore_reinit_error=True, log_to_driver=False)

N_GPUS = int(ray.available_resources().get("GPU", 0))
WORLD_SIZE = max(1, min(8, N_GPUS))

# --- write inputs (as you already do) ---
FEATURE_PICKLE = DATA_DIR / "feature_datasets.pkl"
META_PICKLE    = DATA_DIR / "feature_meta.pkl"
NORM_MODE = "none"
ESM_LAYER_SEL = 24
pd.to_pickle(feature_datasets, FEATURE_PICKLE)
pd.to_pickle(
    dict(
        all_fids=list(feature_datasets.keys()),
        norm_mode=NORM_MODE,
        top_k=TOP_K,
        min_act=MIN_ACT,
        batch_size=BATCH_SIZE,
        plm_layer=(ESM_LAYER_SEL if isinstance(ESM_LAYER_SEL, int) else 24),
        max_per_feature=np.asarray(max_safe) if max_safe is not None else None,
    ),
    META_PICKLE,
)


# ---------- Actor: NO tqdm here; report progress to the driver Queue ----------
@ray.remote(num_gpus=1, num_cpus=0)
class FeatureWorker:
    def __init__(self, progress: Queue, amp_dtype: torch.dtype = torch.float16, plm_layer: int = 24):
        self.progress = progress
        self.device = torch.device("cuda:0")
        self.dtype  = amp_dtype
        self.plm_layer = plm_layer

        self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", do_lower_case=False)
        self.model = AutoModel.from_pretrained(
            "facebook/esm2_t33_650M_UR50D",
            output_hidden_states=True,
            torch_dtype=self.dtype,
            low_cpu_mem_usage=True,
        ).to(self.device).eval()

        self.sae = load_sae_from_hf(plm_model="esm2-650m", plm_layer=self.plm_layer).to(self.device).eval()

    def process_shard(
        self,
        rank: int,
        world_size: int,
        feature_pickle_path: str,
        meta_pickle_path: str,
        out_prefix: str = "feature_acts",
    ) -> List[str]:
        fd: Dict[int, pd.DataFrame] = pd.read_pickle(feature_pickle_path)
        meta = pd.read_pickle(meta_pickle_path)

        all_fids: List[int] = meta["all_fids"]
        norm_mode: str      = meta["norm_mode"]
        top_k: int          = meta["top_k"]
        min_act: float      = meta["min_act"]
        batch_size: int     = meta["batch_size"]
        plm_layer: int      = meta["plm_layer"]
        max_per_feature     = meta["max_per_feature"]

        shard_fids = all_fids[rank::world_size]
        paths = []

        for fid in shard_fids:
            df = fd[fid].copy()
            if "Sequence" not in df.columns:
                raise KeyError("Expected a 'Sequence' column in feature_datasets[fid].")

            seqs = df["Sequence"].astype(str).fillna("").tolist()

            idx_lists, aa_lists, norm_vals, raw_vals = _compute_activated_positions_for_feature_wrapped(
                fid=fid,
                seqs=seqs,
                batch_size=batch_size,
                max_per_feature=max_per_feature,
                norm_mode=norm_mode,
                top_k=top_k,
                min_act=min_act,
                device=str(self.device),
                plm_layer=plm_layer,
                model=self.model,
                tokenizer=self.tokenizer,
                sae=self.sae,
            )

            df["activated_indices"] = idx_lists
            df["activated_aas"] = aa_lists
            df["seq_max_activation_norm"] = norm_vals
            df["activated_indices_str"] = df["activated_indices"].apply(lambda xs: ",".join(map(str, xs)) if xs else "")
            df["activated_aas_str"] = df["activated_aas"].apply(lambda xs: ",".join(xs) if xs else "")
            df["seq_max_activation_norm_str"] = df["seq_max_activation_norm"].apply(
                lambda xs: ",".join(map(str, xs)) if xs else ""
            )

            out_path = DATA_DIR / f"{out_prefix}_fid{fid}_rank{rank}.pkl"
            df.to_pickle(out_path)
            paths.append(str(out_path))

            # tell the driver we finished one FID
            self.progress.put(1)

        shard_bundle = DATA_DIR / f"{out_prefix}_rank{rank}.final.pkl"
        pd.to_pickle({"rank": rank, "paths": paths}, shard_bundle)
        return [str(p) for p in paths] + [str(shard_bundle)]


# ---------- Kick off workers; single tqdm on the driver ----------
progress_q = Queue(maxsize=100000)

actors = [
    FeatureWorker.remote(progress=progress_q, amp_dtype=torch.float16,
                         plm_layer=(ESM_LAYER_SEL if isinstance(ESM_LAYER_SEL, int) else 24))
    for _ in range(WORLD_SIZE)
]

futs = [
    actors[r].process_shard.remote(
        rank=r,
        world_size=WORLD_SIZE,
        feature_pickle_path=str(FEATURE_PICKLE),
        meta_pickle_path=str(META_PICKLE),
        out_prefix="feature_acts",
    )
    for r in range(WORLD_SIZE)
]

# total number of FIDs overall (not per shard)
total_fids = len(pd.read_pickle(META_PICKLE)["all_fids"])
pbar = tqdm(total=total_fids, desc="Features (overall)", leave=True)

pending = set(futs)
while pending:
    # 1) Drain any progress events so the bar moves while tasks run
    drained = 0
    while True:
        try:
            drained += progress_q.get_nowait()
        except Exception:
            break
    if drained:
        pbar.update(drained)

    # 2) Wait briefly for any task completion (non-blocking-ish)
    done, not_done = ray.wait(list(pending), num_returns=1, timeout=0.25)
    if done:
        pending.difference_update(done)

# Final drain (in case a few increments were still queued)
drained = 0
while True:
    try:
        drained += progress_q.get_nowait()
    except Exception:
        break
if drained:
    pbar.update(drained)

pbar.close()

out_paths_per_rank = ray.get(futs)
flat_paths = [p for sub in out_paths_per_rank for p in sub]
print(f"[done] wrote {len(flat_paths)} files")



2025-09-14 02:44:01,309	INFO worker.py:1951 -- Started a local Ray instance.
Features (overall): 100%|██████████| 10240/10240 [09:27<00:00, 18.04it/s]

[done] wrote 10248 files





In [30]:
import re
def merge_outputs_into_feature_datasets(feature_datasets: Dict[int, pd.DataFrame], paths: List[str]) -> None:
    """
    For each per-FID pickle in 'paths', read it and replace feature_datasets[fid]. Edits the dict in place
    """
    fid_paths = [p for p in paths if "_fid" in os.path.basename(p)]
    for p in tqdm(fid_paths, desc="Merging results into feature_datasets"):
        name = os.path.basename(p)
        m = re.search(r"fid(\d+)", name)
        if not m:
            continue
        fid = int(m.group(1))
        feature_datasets[fid] = pd.read_pickle(p)

In [31]:
out_paths_per_rank = ray.get(futs)
flat_paths = [p for sub in out_paths_per_rank for p in sub]
merge_outputs_into_feature_datasets(feature_datasets, flat_paths)
print("Updated feature_datasets in place.")

Merging results into feature_datasets: 100%|██████████| 10240/10240 [00:08<00:00, 1185.69it/s]

Updated feature_datasets in place.





In [21]:
import pandas as pd
from pathlib import Path

FINAL_PATH = Path("per_amino_acid_feature_datasets_with_annotations.pkl")

# # Save
# pd.to_pickle(feature_datasets, FINAL_PATH)  # uses highest protocol by default

# Load later
feature_datasets = pd.read_pickle(FINAL_PATH)


In [22]:
list(feature_datasets)[:5]  # first few feature IDs (keys)

{fid: df.shape for fid, df in list(feature_datasets.items())[:5]}  # shapes

{5590: (21, 34), 7300: (26, 34), 595: (29, 34), 2669: (19, 34), 6281: (23, 34)}

In [23]:
feature_datasets[5590]

Unnamed: 0,uniprot_id,activation,bin,Entry,Reviewed,Protein names,Length,Sequence,EC number,Active site,...,Domain [FT],Motif,Region,Zinc finger,activated_indices,activated_aas,seq_max_activation_norm,activated_indices_str,activated_aas_str,seq_max_activation_norm_str
0,A6LSD0,0.0,0.0-0.1,A6LSD0,reviewed,Acetyl-coenzyme A carboxylase carboxyl transfe...,287,MLKDLFVKRQYATVKSSTLKKSISEEKPNIPSGMWEKCDKCNSMIY...,2.1.3.15,,...,"DOMAIN 34..287; /note=""CoA carboxyltransferase...",,,"ZN_FING 38..60; /note=""C4-type""; /evidence=""EC...",[],[],[],,,
1,P46005,0.0,0.0-0.1,P46005,reviewed,Outer membrane usher protein AggC,842,MKTSSFIIVILLCFRIENVIAHTFSFDASLLNHGSGGIDLTLLEKG...,,,...,,,,,[],[],[],,,
2,Q5WEW8,0.104561,0.1-0.2,Q5WEW8,reviewed,Arginine biosynthesis bifunctional protein Arg...,408,MLTKQTTGQAWKQIKGSITDVKGFTTAGAHCGLKRKRLDIGAIFCD...,2.3.1.1; 2.3.1.35,"ACT_SITE 195; /note=""Nucleophile""; /evidence=""...",...,,,,,"[257, 258, 256, 277, 247, 254, 274, 255]","[W, A, D, K, N, H, K, P]","[0.07567556202411652, 0.048477932810783386, 0....",257258256277247254274255,"W,A,D,K,N,H,K,P","0.07567556202411652,0.048477932810783386,0.044..."
3,Q9Y5C1,0.116409,0.1-0.2,Q9Y5C1,reviewed,Angiopoietin-related protein 3 (Angiopoietin-5...,460,MFTIKLLLFIVPLVISSRIDQDNSSFDSLSPEPKSRFAMLDDVKIL...,,,...,"DOMAIN 237..455; /note=""Fibrinogen C-terminal""...",,"REGION 17..207; /note=""Sufficient to inhibit L...",,"[393, 413, 396, 392, 260, 394, 422, 385]","[C, N, G, N, A, P, K, H]","[0.04759537801146507, 0.04610564932227135, 0.0...",393413396392260394422385,"C,N,G,N,A,P,K,H","0.04759537801146507,0.04610564932227135,0.0401..."
4,Q06303,0.387987,0.3-0.4,Q06303,reviewed,Aerolysin-4 (Hemolysin-4),492,MKKLKITGLSLIISGLLMAQAQAAEPVYPDQLRLFSLGQEVCGDKY...,,,...,,,"REGION 68..84; /note=""Interaction with host N-...",,"[77, 69, 62, 68, 66, 59, 58, 146]","[V, I, G, Q, Q, N, S, G]","[0.08769077807664871, 0.08646269887685776, 0.0...",77696268665958146,"V,I,G,Q,Q,N,S,G","0.08769077807664871,0.08646269887685776,0.0838..."
5,P09166,0.311142,0.3-0.4,P09166,reviewed,Aerolysin,492,MKALKITGLSLIISATLAAQTNAAEPIYPDQLRLFSLGEDVCGTDY...,,,...,,,"REGION 68..84; /note=""Interaction with host N-...",,"[69, 60, 77, 68, 62, 61, 59, 78]","[I, I, V, Q, A, V, N, I]","[0.08975762128829956, 0.07494804263114929, 0.0...",6960776862615978,"I,I,V,Q,A,V,N,I","0.08975762128829956,0.07494804263114929,0.0744..."
6,P15465,0.218435,0.2-0.3,P15465,reviewed,Albumin-1 (WBA-1),175,ADDPVYDAEGNKLVNRGKYTIVSFSDGAGIDVVATGNENPEDPLSI...,,,...,,,,,"[55, 51, 63, 62, 65, 53, 52, 64]","[A, N, K, D, P, M, I, T]","[0.03894788399338722, 0.03494536876678467, 0.0...",5551636265535264,"A,N,K,D,P,M,I,T","0.03894788399338722,0.03494536876678467,0.0346..."
7,Q91F58,0.267097,0.2-0.3,Q91F58,reviewed,Putative MSV199 domain-containing protein 468L,376,MEMATKKCNIFGVDSIGEPEGVVKKALDESLSLLDIFKFIEITNFD...,,,...,,,,,"[106, 109, 113, 110, 108, 167, 104, 112]","[N, E, S, L, I, R, T, P]","[0.0693620815873146, 0.06051858887076378, 0.05...",106109113110108167104112,"N,E,S,L,I,R,T,P","0.0693620815873146,0.06051858887076378,0.05704..."
8,P09167,0.509103,0.5-0.6,P09167,reviewed,Aerolysin,493,MQKIKLTGLSLIISGLLMAQAQAAEPVYPDQLRLFSLGQGVCGDKY...,,,...,,,"REGION 68..84; /note=""Interaction with host N-...",,"[77, 140, 149, 146, 66, 68, 143, 126]","[V, Y, W, G, Q, Q, H, R]","[0.0952877476811409, 0.08448244631290436, 0.08...",771401491466668143126,"V,Y,W,G,Q,Q,H,R","0.0952877476811409,0.08448244631290436,0.08427..."
9,Q08676,1.0,0.9-1.0,Q08676,reviewed,Aerolysin (Hemolysin-3),489,MMNRIITANLANLASSLMLAQVLGWHEPVYPDQVKWAGLGTGVCAS...,,,...,,,"REGION 70..86; /note=""Interaction with host N-...",,"[79, 136, 82, 138, 148, 70, 37, 146]","[V, F, G, K, G, Q, G, Y]","[0.15741534531116486, 0.14548814296722412, 0.1...",79136821381487037146,"V,F,G,K,G,Q,G,Y","0.15741534531116486,0.14548814296722412,0.1404..."


In [27]:

# Show on-disk size
size_bytes = FINAL_PATH.stat().st_size
def sizeof(n):
    for unit in ("B","KB","MB","GB","TB"):
        if n < 1024 or unit == "TB":
            return f"{n:.2f} {unit}"
        n /= 1024
print(f"Saved to: {FINAL_PATH.resolve()}")
print(f"File size: {sizeof(size_bytes)}  ({size_bytes:,} bytes)")

Saved to: /home/ec2-user/SageMaker/InterPLM/per_amino_acid_feature_datasets_with_annotations.pkl
File size: 386.13 MB  (404,884,661 bytes)


In [None]:
# Suppose feature_datasets is your dict of DataFrames
with open("feature_datasets_with_amino_acids.pkl", "wb") as f:
    pickle.dump(feature_datasets, f)

In [32]:
import boto3
# --- Setup (once per notebook) ---
# pip install anthropic python-dotenv
from anthropic import Anthropic
from dotenv import load_dotenv
import os, time, json, math, textwrap
import numpy as np
import pandas as pd
from typing import Dict
from typing import Iterable, Optional, Union, List, Dict

# Load .env (expects ANTHROPIC_API_KEY=...)
load_dotenv()
client = Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))

# --- Config ---
MODEL_NAME = "claude-3-5-sonnet-20240620"
MAX_TOKENS = 800  # enough for description + summary
TEMPERATURE = 0.0 # deterministic
CHECKPOINT_EVERY = 50
OUTPUT_PATH = "claude_feature_annotations.csv"

# Columns to show Claude (customize as you like)
# We'll include what exists; missing columns are auto-dropped
PREFERRED_COLS = [
    # keys/ids
    "uniprot_id", "Entry", "Protein names",
    # size/sequence shape
    "Length",
    # functional annotations
    "EC number", "Active site", "Binding site", "Cofactor", "Disulfide bond",
    "Helix", "Turn", "Beta strand", "Coiled coil",
    "Domain [CC]", "Compositional bias", "Domain [FT]", "Motif", "Region", "Zinc finger",
    # your per-feature fields
    "activation", "bin",
    # optional (only used if present)
    "activated_indices", "activated_aas"
]

# Limit rows/cols so the table fits comfortably in context
MAX_ROWS = 80   # you can raise/lower if you hit token limits
TRUNCATE_STR_LEN = 120  # truncate long text fields so tables tay compact


def _coerce_and_trim_cols(df: pd.DataFrame, cols: list[str]) -> pd.DataFrame:
    """Select existing columns, stringify, and truncate long strings so the table is compact."""
    use_cols = [c for c in cols if c in df.columns]
    if not use_cols:
        # Fallback: show whatever is available
        use_cols = list(df.columns)

    out = df[use_cols].copy()

    # Coerce to string and truncate long values
    for c in use_cols:
        out[c] = out[c].astype(str).str.replace(r"\s+", " ", regex=True)
        out[c] = out[c].apply(lambda s: s[:TRUNCATE_STR_LEN] + "…" if len(s) > TRUNCATE_STR_LEN else s)

    # Keep only first MAX_ROWS to control token usage
    return out.head(MAX_ROWS)


PROMPT_TEMPLATE = """Generate description and summary
Analyze this protein dataset to determine what predicts the ’Maximum activation value’ and ‘Amino acids of
highest activated indices in protein’ columns. This description should be as concise as possible but sufficient to
predict these two columns on held-out data given only the description and the rest of the protein metadata
provided. The feature could be specific to a protein family, a structural motif, a sequence motif, a functional
role, etc. These WILL be used to predict how much unseen proteins are activated by the feature so only
highlight relevant factors for this.

Focus on:
• Properties of proteins from the metadata that are associated with high vs medium vs low activation.
• Where in the protein sequence activation occurs (in relation to the protein sequence, length, structure,
  or other properties)
• What functional annotations (binding sites, domains, etc.) and amino acids are present at or near the
  activated positions
• This description that will be used to help predict missing activation values should start with:
  “The activation patterns are characterized by:”

Then, in 1 sentence, summarize what biological feature or pattern this neural network activation is detecting.
This concise summary should start with “The feature activates on”.

Protein record:
{TABLE}
"""

def build_prompt(table_df: pd.DataFrame) -> str:
    table_md = table_df.to_markdown(index=False)
    return PROMPT_TEMPLATE.replace("{TABLE}", table_md)

# Configure once
BEDROCK_REGION = "us-east-1"
MODEL_ID = "us.anthropic.claude-sonnet-4-20250514-v1:0"  # change if you prefer another Claude on Bedrock


_bedrock = boto3.client("bedrock-runtime", region_name=BEDROCK_REGION)


# def call_claude(prompt: str) -> str:
#     """Call Claude, return raw text."""

#     #Build Bedrock/Anthropic messages pyalod
#     messages = [
#     {
#         "role": "user",
#         "content": [{"type":"text", "text": prompt}]
#     }
#     ]

#     body = {
#         "anthropic_version": "bedrock-2023-05-31",
#         "max_tokens": MAX_TOKENS,
#         "temperature" : TEMPERATURE,
#         "messages": messages
#     }

#     resp = _bedrock.invoke_model(
#         modelId=MODEL_ID,
#         body=json.dumps(body),
#         contentType = "application/json",
#         accept="application/json",
#     )
#     payload = json.loads(resp["body"].read())

#     #Concatenate all text content blocks(CLaude may return multiple)
#     text_parts = []
#     for part in payload.get("content", []):
#         if part.get("type") == "text":
#             text_parts.append(part.get("text", ""))
#     return "".join(text_parts)

def call_claude(
    prompt: str,
    *,
    max_tokens: int = 512,
    temperature: float = 0.2,
    top_p: Optional[float] = None,
    stop_sequences: Optional[Iterable[str]] = None,
    system: Optional[Union[str, List[Dict]]] = None,
    model_id: str = MODEL_ID,
) -> str:
    """
    Call Claude via AWS Bedrock and return the concatenated text response.

    Args:
        prompt: user prompt (string)
        max_tokens: max tokens to generate
        temperature: sampling temperature
        top_p: nucleus sampling
        stop_sequences: iterable of stop strings
        system: optional system instruction. You can pass a plain string,
                or a Bedrock-ready list of content blocks (dicts).
        model_id: Bedrock model id (default set above)

    Returns:
        The assistant's text (may be empty if no text blocks returned).
    """
    # Build Bedrock/Anthropic messages payload
    messages = [
        {
            "role": "user",
            "content": [{"type": "text", "text": prompt}],
        }
    ]

    body = {
        "anthropic_version": "bedrock-2023-05-31",
        "max_tokens": max_tokens,
        "temperature": temperature,
        "messages": messages,
    }
    if top_p is not None:
        body["top_p"] = top_p
    if stop_sequences:
        body["stop_sequences"] = list(stop_sequences)
    if system:
        # Accept either a simple string or already-structured content blocks
        if isinstance(system, str):
            body["system"] = [{"type": "text", "text": system}]
        else:
            body["system"] = system

    resp = _bedrock.invoke_model(
        modelId=model_id,
        body=json.dumps(body),
        contentType="application/json",
        accept="application/json",
    )

    payload = json.loads(resp["body"].read())
    # Concatenate all text content blocks (Claude may return multiple)
    text_parts = []
    for part in payload.get("content", []):
        if part.get("type") == "text":
            text_parts.append(part.get("text", ""))
    return "".join(text_parts)


def parse_description_and_summary(text: str) -> dict:
    """
    Best-effort parse: extract the long description (must start with the required phrase)
    and the one-sentence summary (starts with 'The feature activates on').
    Falls back to raw if patterns aren’t found.
    """
    desc = ""
    summ = ""
    lines = [l.strip() for l in text.splitlines() if l.strip()]

    # Find the description block
    start_idx = None
    for i, l in enumerate(lines):
        if l.lower().startswith("the activation patterns are characterized by:"):
            start_idx = i
            break
    if start_idx is not None:
        # collect until we hit the summary or end
        buff = []
        for j in range(start_idx, len(lines)):
            if lines[j].lower().startswith("the feature activates on"):
                break
            buff.append(lines[j])
        desc = "\n".join(buff).strip()

    # Find the one-sentence summary
    for l in lines:
        if l.lower().startswith("the feature activates on"):
            # keep first sentence
            summ = l.split("\n")[0].strip()
            break

    return {
        "description": desc or "",
        "summary": summ or "",
        "raw": text.strip()
    }


## Build summaries

In [None]:

# # --- Main loop over feature datasets ---
# # Expects: feature_datasets: Dict[int, pd.DataFrame]
# results_rows = []

# processed = 0
# for fid, df in tqdm(list(feature_datasets.items())[:1200], desc="llm to opus"):
#     # Build a compact table for the model
#     view = _coerce_and_trim_cols(df, PREFERRED_COLS)
#     prompt = build_prompt(view)

#     try:
#         text = call_claude(prompt)
#         parsed = parse_description_and_summary(text)
#     except Exception as e:
#         parsed = {"description": "", "summary": "", "raw": f"[ERROR] {e}"}

#     results_rows.append({
#         "feature_id": fid,
#         "n_rows_shown": len(view),
#         "description": parsed["description"],
#         "summary": parsed["summary"],
#         "raw_response": parsed["raw"],
#     })

#     processed += 1
#     if processed % CHECKPOINT_EVERY == 0:
#         pd.DataFrame(results_rows).to_csv(OUTPUT_PATH, index=False)

# # Final save
# df_results = pd.DataFrame(results_rows)
# df_results.to_csv(OUTPUT_PATH, index=False)
# print(f"[done] {len(df_results)} features → {OUTPUT_PATH}")
# df_results.head()


In [16]:
results_rows = []
processed = 0

fid = 2167
df  = feature_datasets[fid]          # <- this is a DataFrame

view   = _coerce_and_trim_cols(df, PREFERRED_COLS)
prompt = build_prompt(view)

try:
    text   = call_claude(prompt)
    parsed = parse_description_and_summary(text)
except Exception as e:
    parsed = {"description": "", "summary": "", "raw": f"[ERROR] {e}"}

results_rows.append({
    "feature_id": fid,
    "n_rows_shown": len(view),
    "description": parsed["description"],
    "summary": parsed["summary"],
    "raw_response": parsed["raw"],
})

# Save
df_results = pd.DataFrame(results_rows)
df_results.to_csv(OUTPUT_PATH, index=False)
print(f"[done] {len(df_results)} features → {OUTPUT_PATH}")
display(df_results.head())


[done] 1 features → claude_feature_annotations.csv


Unnamed: 0,feature_id,n_rows_shown,description,summary,raw_response
0,2167,29,The activation patterns are characterized by: ...,The feature activates on zinc finger domains a...,**Description:**\n\nThe activation patterns ar...


## Validating feature descriptions via activaiton pattern prediction

In [34]:
import io, csv, json, math, textwrap
import pandas as pd
import numpy as np
from typing import List, Dict, Tuple, Optional

# ---------- 1) Choose compact metadata columns ----------
META_COLS = [
    "Entry", "Reviewed", "Protein names", "Length", "EC number",
    "Active site", "Binding site", "Metal binding", "Site",
    "Domain [FT]", "Motif", "Region", "Zinc finger",
]

def _coerce_cols(df: pd.DataFrame, cols: List[str]) -> pd.DataFrame:
    keep = [c for c in cols if c in df.columns]
    out = df[keep].copy()
    # Normalize NaNs to empty strings for cleaner prompts
    for c in keep:
        out[c] = out[c].fillna("").astype(str).str.strip()
    return out

# ---------- 2) Build a compact, token-budget-friendly metadata block ----------
def make_metadata_block(df_query: pd.DataFrame, max_rows: int = 20) -> str:
    """
    Turn metadata rows into a compact, line-based table the LLM can scan quickly.
    We avoid long sequences to keep tokens down.
    """
    dfm = _coerce_cols(df_query, META_COLS).head(max_rows)
    # Render as CSV-like lines (no Markdown) to discourage prose
    lines = [",".join(dfm.columns)]
    for _, row in dfm.iterrows():
        vals = [row[c].replace("\n", " ").replace(",", ";") for c in dfm.columns]
        lines.append(",".join(vals))
    return "\n".join(lines)


# ---------- 3) Prompt template for validation ----------
VALIDATION_INSTRUCTIONS = """\
You are scoring how strongly a single latent feature activates on proteins based on Swiss-Prot metadata.

Task: For each query protein, predict the *maximum feature activation value* in [0.0, 1.0]
according to how well it matches the described activation patterns.

Rules:
- Output ONLY a CSV with the header EXACTLY: "Entry,Maximum activation value"
- One row per Entry in the query table, same Entry IDs.
- Use a decimal in [0.0, 1.0]; you may use up to 3 decimals (e.g., 0.137).
- No extra text, explanations, units, JSON, code fences, or headers beyond that single CSV.
- If the protein is unrelated to the pattern, output 0.0.
- Base your judgment only on the provided Swiss-Prot metadata and the description below.

Scoring rubric (guidance, not to be output):
- Strong, explicit matches to motifs/sites/domains directly named in the description → closer to 0.8–1.0
- Partial matches or related patterns (e.g., same fold/family, weaker motif evidence) → ~0.3–0.7
- Weak/indirect hints only → ~0.05–0.3
- No match → 0.0
"""

def build_validation_prompt(
    feature_id: int,
    description: str,
    metadata_csv_block: str,
    query_entries: List[str],
) -> str:
    # Build an empty query CSV the model should fill
    empty_table = "Entry,Maximum activation value\n" + "\n".join(query_entries)
    return textwrap.dedent(f"""\
        {VALIDATION_INSTRUCTIONS}

        Feature ID: f/{feature_id}

        The activation patterns are characterized by:
        \"\"\"{description.strip()}\"\"\"

        Swiss-Prot metadata for the query proteins (CSV):
        {metadata_csv_block}

        Table to fill out (return ONLY this table with predicted values):
        {empty_table}
    """).strip()

# ---------- 4) Parser for the model's CSV output ----------
def parse_prediction_csv(csv_text: str) -> pd.DataFrame:
    # Be robust to accidental code fences or whitespace
    txt = csv_text.strip()
    if txt.startswith("```"):
        txt = txt.strip("`").strip()
        # Remove possible language hints like ```csv
        first_newline = txt.find("\n")
        if first_newline != -1:
            maybe_header = txt[:first_newline].lower()
            if "entry" not in maybe_header:
                txt = txt[first_newline+1:]

    # Parse CSV strictly on two columns
    reader = csv.reader(io.StringIO(txt))
    rows = list(reader)
    if not rows or len(rows[0]) < 2 or rows[0][0].strip().lower() != "entry":
        raise ValueError("Missing or malformed CSV header. Expected 'Entry,Maximum activation value'.")

    data = []
    for r in rows[1:]:
        if not r:
            continue
        entry = r[0].strip()
        val_raw = r[1].strip() if len(r) > 1 else ""
        if not entry:
            continue
        # Coerce to float in [0,1]
        try:
            v = float(val_raw)
        except Exception:
            # Try to strip stray chars
            val_raw2 = "".join(ch for ch in val_raw if (ch.isdigit() or ch in "."))
            v = float(val_raw2) if val_raw2 else np.nan
        if math.isnan(v):
            v = 0.0
        v = max(0.0, min(1.0, v))
        data.append((entry, v))

    df = pd.DataFrame(data, columns=["Entry", "pred_activation"])
    # Deduplicate keeping the first occurrence
    df = df[~df["Entry"].duplicated(keep="first")].reset_index(drop=True)
    return df



In [35]:
import numpy as np
import pandas as pd
from typing import Tuple, Dict

def predict_activations_via_llm(
    feature_id: int,
    description: str,
    df_all: pd.DataFrame,
    *,
    n_query: int = 50,
    seed: int = 0,
    batch_size: int = 20,
    use_entry_col: bool = True,
    call_fn=call_claude,  # your Bedrock-backed call_claude(prompt)->str
) -> Tuple[pd.DataFrame, Dict[str, float]]:
    """
    Select a held-out set, ask Claude to predict max activation, parse predictions, and evaluate.

    df_all must include 'Entry' (or 'uniprot_id'), metadata cols, and 'activation' (float in [0, 1]).
    """
    rng = np.random.default_rng(seed)
    id_col = "Entry" if (use_entry_col and "Entry" in df_all.columns) else "uniprot_id"

    # Keep rows that have an ID and activation
    df_all = df_all.dropna(subset=[id_col, "activation"]).copy()
    df_all[id_col] = df_all[id_col].astype(str)

    # Sample query set (held-out)
    if n_query > len(df_all):
        n_query = len(df_all)
    query_idx = rng.choice(df_all.index.values, size=n_query, replace=False)
    df_query = df_all.loc[query_idx].copy()

    # Predict in batches to control tokens
    preds = []
    for start in range(0, len(df_query), batch_size):
        chunk = df_query.iloc[start:start+batch_size].copy()

        # Prepare metadata CSV block and query Entry list
        meta_block = make_metadata_block(chunk)
        entries = list(chunk[id_col].values)

        prompt = build_validation_prompt(
            feature_id=feature_id,
            description=description,
            metadata_csv_block=meta_block,
            query_entries=entries,
        )

        # Strong output constraint: ONLY CSV
        resp_text = call_fn(
            prompt,
            system="Return only the CSV requested. No extra text."
        )

        df_pred = parse_prediction_csv(resp_text)
        preds.append(df_pred)

    df_pred_all = pd.concat(preds, ignore_index=True)

    # Join with ground truth (use df_query here)
    df_eval = df_query[[id_col, "activation"]].merge(
        df_pred_all.rename(columns={"Entry": id_col}),
        on=id_col,
        how="left",
    )
    df_eval["pred_activation"] = df_eval["pred_activation"].fillna(0.0)

    # Metrics
    y = df_eval["activation"].astype(float).values
    yhat = df_eval["pred_activation"].astype(float).values
    pearson = float(np.corrcoef(y, yhat)[0, 1]) if len(df_eval) > 1 else np.nan
    mae = float(np.mean(np.abs(y - yhat))) if len(df_eval) else np.nan
    mse = float(np.mean((y - yhat) ** 2)) if len(df_eval) else np.nan
    metrics = {"pearson_r": pearson, "mae": mae, "mse": mse}

    return df_eval.sort_values(id_col).reset_index(drop=True), metrics


## Look at the big summary

In [36]:
import pandas as pn
big_summary = pd.read_csv("final_claude_feature_annotations.csv")
big_summary.head()

Unnamed: 0,feature_id,n_rows_shown,description,summary,raw_response
0,5590,21,,,"Looking at this protein dataset, I can analyze..."
1,7300,26,,,"Looking at this protein dataset, I can identif..."
2,595,29,The activation patterns are characterized by: ...,The feature activates on disordered protein re...,**Description:**\n\nThe activation patterns ar...
3,2669,19,,,**The activation patterns are characterized by...
4,6281,23,The activation patterns are characterized by: ...,The feature activates on C-terminal extensions...,The activation patterns are characterized by: ...


In [37]:

def pick_10_features(big_summary, feature_datasets, min_rows=30, seed=0):
    rng = np.random.default_rng(seed)
    cand = big_summary[
        big_summary["description"].notna() & big_summary["summary"].notna()
    ].copy()

    # Keep only features that have a usable df
    keep = []
    for _, row in cand.iterrows():
        fid = int(row["feature_id"])
        if fid not in feature_datasets:
            continue
        df = feature_datasets[fid]
        has_id = ("Entry" in df.columns) or ("uniprot_id" in df.columns)
        if not has_id or "activation" not in df.columns:
            continue
        if df["activation"].dropna().shape[0] < min_rows:
            continue
        keep.append(fid)

    keep = list(dict.fromkeys(keep))  # dedup
    if not keep:
        raise ValueError("No valid features with descriptions + usable data.")
    if len(keep) > 10:
        keep = list(rng.choice(keep, size=10, replace=False))
    return keep

# ---- Run it
feature_ids = pick_10_features(big_summary, feature_datasets, min_rows=30, seed=0)


In [38]:
feature_ids

[9093, 5511, 9734, 7274, 8939, 1068, 5297, 1243, 1297, 117]

In [40]:
results = []
all_eval = []
from tqdm import tqdm
for fid in tqdm(feature_ids, desc="features"):
    desc = big_summary.loc[big_summary["feature_id"] == fid, "description"].iloc[0]
    df_feat = feature_datasets[fid]

    df_eval, stats = predict_activations_via_llm(
        feature_id=fid,
        description=desc,
        df_all=df_feat,
        n_query=50,      # tweak as you like
        seed=42,
        batch_size=20,
    )
    stats["feature_id"] = fid
    results.append(stats)
    df_eval["feature_id"] = fid
    all_eval.append(df_eval)

metrics_df = pd.DataFrame(results)[["feature_id","pearson_r","mae","mse"]].sort_values("pearson_r", ascending=False)
preview_eval = pd.concat(all_eval, ignore_index=True)

print(metrics_df)
preview_eval.head()


features: 100%|██████████| 10/10 [01:11<00:00,  7.15s/it]

   feature_id  pearson_r       mae       mse
3        7274   0.971164  0.053357  0.009148
0        9093   0.967111  0.068453  0.011185
6        5297   0.935163  0.083581  0.027240
7        1243   0.927260  0.101508  0.024594
8        1297   0.900699  0.111129  0.037984
4        8939   0.884668  0.087675  0.033477
9         117   0.792196  0.146848  0.051391
2        9734   0.717159  0.161985  0.064969
1        5511   0.689147  0.200065  0.098775
5        1068   0.541301  0.279320  0.156865





Unnamed: 0,Entry,activation,pred_activation,feature_id
0,A4YH98,0.944795,0.9,9093
1,A8AAB6,0.870041,0.9,9093
2,A8MBV3,0.915112,0.9,9093
3,B5VPM5,0.300445,0.0,9093
4,B5Z062,0.0,0.0,9093


In [41]:
preview_eval.columns

Index(['Entry', 'activation', 'pred_activation', 'feature_id'], dtype='object')

In [43]:
# How many proteins did we evaluate per feature?
preview_eval.groupby("feature_id").size().sort_values(ascending=False).head()



feature_id
8939    38
1243    38
9093    38
5297    35
5511    34
dtype: int64

In [44]:

# See only one feature’s rows
preview_eval.query("feature_id == 9093").head(20)



Unnamed: 0,Entry,activation,pred_activation,feature_id
0,A4YH98,0.944795,0.9,9093
1,A8AAB6,0.870041,0.9,9093
2,A8MBV3,0.915112,0.9,9093
3,B5VPM5,0.300445,0.0,9093
4,B5Z062,0.0,0.0,9093
5,C3MQ89,0.745696,0.6,9093
6,C3MQN7,0.916713,0.9,9093
7,C3MWN7,0.916713,0.9,9093
8,C3NEW5,0.916713,0.9,9093
9,C3NGS9,0.916713,0.9,9093


In [45]:
# Peek a few rows per feature
preview_eval.sort_values(["feature_id", "Entry"]).groupby("feature_id").head(3)



Unnamed: 0,Entry,activation,pred_activation,feature_id
307,A0A848M4Z0,0.0,0.1,117
308,A2X1A1,0.439998,0.65,117
309,A5N7J5,0.0,0.1,117
173,A0A1L1QK34,0.238514,0.9,1068
174,B1VH34,0.0,0.0,1068
175,B4SW28,0.0,0.0,1068
239,A0A977JPB5,0.482518,0.6,1243
240,A2X2K3,0.221492,0.2,1243
241,A4W8A9,0.948163,1.0,1243
277,A4GYP5,0.481105,0.5,1297


In [46]:
# Verify correlation for one feature (should match metrics_df)
fid = 9093
sub = preview_eval[preview_eval["feature_id"] == fid]
np.corrcoef(sub["activation"], sub["pred_activation"])[0,1]

0.9671108645836247

In [47]:
import os, textwrap
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# preview_eval: long DF with columns ["Entry","activation","pred_activation","feature_id", ...]
# metrics_df:   DF with one row per feature, columns ["feature_id","pearson_r","mae","mse"]
# big_summary:  DF with ["feature_id","description", ...] (or "summary" if you prefer)

def _kde2d_unit_square(x, y, grid=150, bandwidth=0.08):
    """
    Very small 2D Gaussian KDE on [0,1]x[0,1] without SciPy.
    Fine for n ~ 50–200 points.
    """
    xi = np.linspace(0.0, 1.0, grid)
    yi = np.linspace(0.0, 1.0, grid)
    X, Y = np.meshgrid(xi, yi)
    Z = np.zeros_like(X)
    bw2 = 2 * (bandwidth ** 2)

    # Sum Gaussian bumps
    for x0, y0 in zip(x, y):
        Z += np.exp(-((X - x0) ** 2 + (Y - y0) ** 2) / bw2)

    # Normalization not critical for visualization; this keeps values sane
    Z /= (len(x) if len(x) else 1)
    return X, Y, Z

def plot_kde_for_feature(sub: pd.DataFrame, fid: int, pearson: float, desc: str,
                         out_dir: str = "kde_plots", bandwidth: float = 0.08):
    os.makedirs(out_dir, exist_ok=True)

    x = sub["pred_activation"].astype(float).to_numpy()
    y = sub["activation"].astype(float).to_numpy()

    if len(x) < 2:
        print(f"[skip] feature {fid}: not enough points")
        return None

    X, Y, Z = _kde2d_unit_square(x, y, grid=160, bandwidth=bandwidth)

    plt.figure(figsize=(5.2, 4.6))
    # Density
    plt.contourf(X, Y, Z, levels=24)   # no explicit colors (uses defaults)
    # Diagonal
    plt.plot([0, 1], [0, 1])
    # Light scatter overlay to show points
    plt.scatter(x, y, s=15, alpha=0.5)

    plt.xlabel("Predicted Activation")
    plt.ylabel("True Activation")
    plt.title(f"Feature {fid} (pearson r = {pearson:.2f})")

    # Wrap the description into a neat paragraph above the plot
    if isinstance(desc, str) and desc.strip():
        wrapped = textwrap.fill(desc.strip(), width=90)
        plt.gcf().text(0.5, 1.02, wrapped, ha="center", va="bottom")

    out_path = os.path.join(out_dir, f"feature_{fid}.png")
    plt.tight_layout()
    plt.savefig(out_path, dpi=200, bbox_inches="tight")
    plt.close()
    return out_path

def plot_kde_for_features(preview_eval: pd.DataFrame,
                          metrics_df: pd.DataFrame,
                          big_summary: pd.DataFrame,
                          feature_ids=None,
                          top_k: int = None,
                          out_dir: str = "kde_plots",
                          bandwidth: float = 0.08):
    """
    - If feature_ids is given, plots those.
    - Else if top_k is given, takes top_k by pearson_r.
    - Else plots all features present in metrics_df.
    """
    # Build quick lookups
    pearson_map = metrics_df.set_index("feature_id")["pearson_r"].to_dict()
    # prefer 'description'; fall back to 'summary' if missing
    desc_series = (big_summary.set_index("feature_id")["description"]
                   if "description" in big_summary.columns
                   else big_summary.set_index("feature_id")["summary"])
    desc_map = desc_series.fillna("").to_dict()

    if feature_ids is None:
        fids = list(metrics_df["feature_id"])
        if top_k is not None and top_k < len(fids):
            fids = (metrics_df.sort_values("pearson_r", ascending=False)
                              .head(top_k)["feature_id"].tolist())
    else:
        fids = list(feature_ids)

    saved = []
    for fid in fids:
        sub = preview_eval[preview_eval["feature_id"] == fid].copy()
        if sub.empty:
            print(f"[skip] feature {fid}: no rows in preview_eval")
            continue
        pearson = float(pearson_map.get(fid, np.corrcoef(
            sub["activation"].astype(float), sub["pred_activation"].astype(float)
        )[0,1]))
        desc = desc_map.get(fid, "")
        path = plot_kde_for_feature(sub, fid, pearson, desc, out_dir=out_dir, bandwidth=bandwidth)
        if path:
            saved.append(path)
    return saved

# ---- Example usage ----
# Pick top 6 features by Pearson r and plot
saved_paths = plot_kde_for_features(preview_eval, metrics_df, big_summary, top_k=6, out_dir="kde_plots")
saved_paths


['kde_plots/feature_7274.png',
 'kde_plots/feature_9093.png',
 'kde_plots/feature_5297.png',
 'kde_plots/feature_1243.png',
 'kde_plots/feature_1297.png',
 'kde_plots/feature_8939.png']