In [None]:
import os, re, json, argparse, warnings
from typing import List, Dict, Optional
import time
from sklearn.model_selection import StratifiedKFold, GridSearchCV

In [20]:
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm

In [None]:
from transformers import AutoTokenizer, AutoModel, AutoConfig, T5EncoderModel
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import MinMaxScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
    accuracy_score, balanced_accuracy_score, roc_auc_score,
    matthews_corrcoef, confusion_matrix, average_precision_score, make_scorer
)
from sklearn.model_selection import train_test_split

In [22]:
def safe_model_tag(model_id: str) -> str:
    return re.sub(r"[^a-zA-Z0-9_.-]+", "_", model_id)

In [23]:
def clean_aa(seq: str) -> str:
    # keep only standard AA-ish; map rare to X
    return seq.strip().upper().replace("U","X").replace("Z","X").replace("O","X").replace("B","X")

In [24]:
def needs_space_separated(model_id: str) -> bool:
    mid = model_id.lower()
    return any(k in mid for k in ["rostlab", "prot_t5", "prot_bert", "distilprotbert"])

In [25]:
def mean_pool(last_hidden: torch.Tensor,
              attn_mask: torch.Tensor,
              special_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
    mask = attn_mask.bool()

    # special_mask 有些 tokenizer 可能行为怪（甚至全 1），要做“保底”
    if special_mask is not None:
        keep = mask & (~special_mask.bool())
        # 如果某些样本被 special_mask 全剔空，就忽略 special_mask（否则 pooled 会全 0）
        bad = (keep.sum(dim=1) == 0)
        if bad.any():
            keep[bad] = mask[bad]
        mask = keep

    mask_f = mask.unsqueeze(-1).float()            # [B, L, 1]
    summed = (last_hidden * mask_f).sum(dim=1)     # [B, H]
    denom = mask_f.sum(dim=1).clamp(min=1.0)       # [B, 1]
    return summed / denom

In [26]:
def compute_metrics(y_true: np.ndarray, y_prob: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]:
    acc = accuracy_score(y_true, y_pred)
    bacc = balanced_accuracy_score(y_true, y_pred)
    mcc = matthews_corrcoef(y_true, y_pred)

    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
    sn = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    sp = tn / (tn + fp) if (tn + fp) > 0 else 0.0

    try:
        auc = roc_auc_score(y_true, y_prob)
    except ValueError:
        auc = float("nan")

    ap = average_precision_score(y_true, y_prob)
    return {"ACC": acc, "BACC": bacc, "Sn": sn, "Sp": sp, "MCC": mcc, "AUC": auc, "AP": ap}

In [27]:
def load_model_tokenizer_transformers(model_id: str, device: str):
    tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    cfg = AutoConfig.from_pretrained(model_id, trust_remote_code=True)

    is_ankh = model_id.lower().startswith("synthyra/ankh")

    if is_ankh:
        torch_dtype = torch.float32   # ✅ ANKH 强制 FP32
        model = T5EncoderModel.from_pretrained(model_id, torch_dtype=torch_dtype)
    else:
        torch_dtype = torch.float16 if str(device).startswith("cuda") else torch.float32
        if getattr(cfg, "model_type", "") == "t5":
            model = T5EncoderModel.from_pretrained(model_id, torch_dtype=torch_dtype)
        else:
            model = AutoModel.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch_dtype)

    model.to(device).eval()
    return tok, model

In [28]:
def embed_sequences_esmc(model_id: str, seqs: List[str], device: str) -> np.ndarray:
    if "300m" in model_id.lower():
        esm_name = "esmc_300m"
    elif "600m" in model_id.lower():
        esm_name = "esmc_600m"
    else:
        raise ValueError(f"Unknown ESMC size in model_id: {model_id}")

    from esm.models.esmc import ESMC
    from esm.sdk.api import ESMProtein, LogitsConfig

    client = ESMC.from_pretrained(esm_name).to(device)
    client.eval()

    vecs = []
    with torch.no_grad():
        for seq in tqdm(seqs, desc=f"Embedding (ESMC:{esm_name})"):
            s = clean_aa(seq)
            protein = ESMProtein(sequence=s)
            protein_tensor = client.encode(protein)
            out = client.logits(protein_tensor, LogitsConfig(sequence=True, return_embeddings=True))

            emb = out.embeddings
            seq_emb = getattr(emb, "sequence", None)
            if seq_emb is None and isinstance(emb, dict):
                seq_emb = emb.get("sequence", None)
            if seq_emb is None:
                seq_emb = emb

            if not torch.is_tensor(seq_emb):
                seq_emb = torch.tensor(seq_emb)

            if seq_emb.dim() == 3:
                seq_emb = seq_emb[0]  # [L, D]

            # 常见：len(seq)+2 含 BOS/EOS
            if seq_emb.size(0) == len(s) + 2:
                seq_emb = seq_emb[1:-1]

            vec = seq_emb.mean(dim=0)
            vecs.append(vec.float().cpu().numpy())

    X = np.stack(vecs, axis=0)
    return np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)

In [29]:
def embed_sequences_esm3(seqs: List[str], device: str) -> np.ndarray:
    from esm.models.esm3 import ESM3
    from esm.sdk.api import ESMProtein, SamplingConfig

    client = ESM3.from_pretrained("esm3_sm_open_v1").to(device)
    client.eval()

    vecs = []
    with torch.no_grad():
        for seq in tqdm(seqs, desc="Embedding (ESM3:esm3_sm_open_v1)"):
            s = clean_aa(seq)
            protein = ESMProtein(sequence=s)
            protein_tensor = client.encode(protein)

            out = client.forward_and_sample(
                protein_tensor,
                SamplingConfig(return_per_residue_embeddings=True)
            )

            emb = getattr(out, "per_residue_embedding", None)
            if emb is None:
                emb = getattr(out, "per_residue_embeddings", None)
            if emb is None:
                raise RuntimeError("ESM3 output has no per_residue_embedding(s).")

            if emb.dim() == 3:
                emb = emb[0]  # [L, D]

            if emb.size(0) == len(s) + 2:
                emb = emb[1:-1]

            vec = emb.mean(dim=0)
            vecs.append(vec.float().cpu().numpy())

    X = np.stack(vecs, axis=0)
    return np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)

In [30]:
@torch.no_grad()
def embed_sequences_transformers(model_id: str,
                                seqs: List[str],
                                device: str,
                                batch_size: int,
                                max_length: int) -> np.ndarray:
    tok, model = load_model_tokenizer_transformers(model_id, device)

    all_vecs = []
    for i in tqdm(range(0, len(seqs), batch_size), desc=f"Embedding {model_id}"):
        batch_seqs = [clean_aa(s) for s in seqs[i:i+batch_size]]
        batch_text = [" ".join(list(s)) for s in batch_seqs] if needs_space_separated(model_id) else batch_seqs

        enc = tok(
            batch_text,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length,
            return_special_tokens_mask=True,
        )
        # tokenizer-side quick checks (on CPU ok)
        input_ids_cpu = enc["input_ids"]
        unk_id = getattr(tok, "unk_token_id", None)
        if unk_id is not None:
            unk_rate = (input_ids_cpu == unk_id).float().mean().item()
            if unk_rate > 0.2:
                print(f"[WARN] high UNK rate={unk_rate:.3f} for {model_id}")

        special_cpu = enc.get("special_tokens_mask", None)
        if special_cpu is not None:
            all_special = (special_cpu.sum(dim=1) == special_cpu.size(1)).any().item()
            if all_special:
                print(f"[WARN] special_tokens_mask marks ALL tokens as special for {model_id} (will auto-fallback).")

        enc = {k: v.to(device) for k, v in enc.items()}
        attn_mask = enc.get("attention_mask", None)
        special_mask = enc.get("special_tokens_mask", None)

        out = model(**{k: enc[k] for k in ["input_ids", "attention_mask"] if k in enc})
        last_hidden = out.last_hidden_state

        if attn_mask is None:
            attn_mask = torch.ones(last_hidden.shape[:2], device=device, dtype=torch.long)

        is_t5_like = ("ankh" in model_id.lower()) or (getattr(model, "config", None) and getattr(model.config, "model_type", "") == "t5")

        if is_t5_like:
            vec = mean_pool(last_hidden, attn_mask, special_mask=None)   # ✅ T5/ANKH：别传 special_mask
        else:
            vec = mean_pool(last_hidden, attn_mask, special_mask)
        """
        if (vec.abs().sum(dim=1) == 0).any().item():
            print(f"[WARN] zero pooled embedding exists for {model_id} (check tokenizer/pooling)")
        
        if i <= 3:
            print("Example input:", batch_text[0][:80])
            print("input_ids[0][:30]:", enc["input_ids"][0][:30].tolist())
            print("attention_mask sum[0]:", int(enc["attention_mask"][0].sum()))
            if "input_ids" in enc and enc["input_ids"].size(0) > 1:
                print("input_ids[1][:30]:", enc["input_ids"][1][:30].tolist())
        """
        all_vecs.append(vec.detach().cpu().float().numpy())

    X = np.concatenate(all_vecs, axis=0)
    return np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)

In [31]:
def embed_sequences_dplm(model_id: str, seqs: List[str], device: str) -> np.ndarray:
    try:
        from byprot.models.lm.dplm import DiffusionProteinLanguageModel  # type: ignore
        dplm = DiffusionProteinLanguageModel.from_pretrained(model_id)
        dplm.to(device).eval()

        # NOTE: byprot API varies across versions. We'll try a robust path.
        vecs = []
        with torch.no_grad():
            for seq in tqdm(seqs, desc=f"Embedding (DPLM:{model_id})"):
                s = clean_aa(seq)
                # common: dplm.encode / dplm.get_representation not guaranteed
                if hasattr(dplm, "encode"):
                    rep = dplm.encode([s])
                elif hasattr(dplm, "get_representation"):
                    rep = dplm.get_representation([s])
                else:
                    # last resort: forward with tokenizer inside model (may fail)
                    rep = dplm([s])

                rep = torch.as_tensor(rep)
                rep = rep.squeeze(0)  # [L,D] or [D]
                if rep.dim() == 2:
                    rep = rep.mean(dim=0)
                vecs.append(rep.float().detach().cpu().numpy())

        X = np.stack(vecs, axis=0)
        return np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)

    except Exception as e:
        warnings.warn(f"[DPLM] byprot path failed ({type(e).__name__}: {e}). Fallback to transformers.")
        return embed_sequences_transformers(model_id, seqs, device, batch_size=2, max_length=256)

In [32]:
def embed_sequences(model_id: str,
                    seqs: List[str],
                    device: str,
                    batch_size: int,
                    cache_dir: str,
                    max_length: int) -> np.ndarray:
    os.makedirs(cache_dir, exist_ok=True)
    tag = safe_model_tag(model_id)
    cache_path = os.path.join(cache_dir, f"{tag}.npy")

    if os.path.exists(cache_path):
        return np.load(cache_path)

    # ---- model-specific routing ----
    if model_id.startswith("EvolutionaryScale/esmc-"):
        X = embed_sequences_esmc(model_id, seqs, device)

    elif model_id == "EvolutionaryScale/esm3-sm-open-v1":
        X = embed_sequences_esm3(seqs, device)

    elif model_id == "airkingbd/dplm_650m":
        X = embed_sequences_dplm(model_id, seqs, device)

    elif model_id.startswith("westlake-repl/SaProt_"):
        # SaProt 通常需要 seq+structure tokens；只有 AA seq 时结果可能接近基线。
        print("[NOTE] SaProt usually expects structure-aware tokens (often containing '#'). "
              "If you only have AA sequences, performance may collapse to majority baseline.")
        X = embed_sequences_transformers(model_id, seqs, device, batch_size, max_length)

    else:
        # ESM2 / ANKH / Mistral-Prot
        X = embed_sequences_transformers(model_id, seqs, device, batch_size, max_length)

    np.save(cache_path, X)
    return X

In [None]:
def solver_cv_select_strict_mcc(X_tr, y_tr, seed=42, out_csv="solver_cv_table.csv", n_jobs=8):
    """STRICT model selection:
    - Use ONLY training data (X_tr/y_tr) to compare solvers and pick hyperparams.
    - Selection metric: MCC (via CV).
    - Returns: (best_estimator_fitted_on_Xtr, cv_table_df, best_info_dict)
    """
    base_pipe = Pipeline(steps=[
        ("imp", SimpleImputer(strategy="constant", fill_value=0.0)),
        ("scaler", MinMaxScaler()),
        ("clf", LogisticRegression(
            max_iter=5000,
            class_weight="balanced",
            random_state=seed
        ))
    ])

    scoring = {
        "mcc": make_scorer(matthews_corrcoef),
        "auc": "roc_auc",
        "bacc": "balanced_accuracy",
    }

    cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=seed)

    solver_grids = {
        "lbfgs": [
            {"clf__solver": ["lbfgs"], "clf__penalty": ["l2"], "clf__C": [0.1, 1, 10]}
        ],
        "liblinear": [
            {"clf__solver": ["liblinear"], "clf__penalty": ["l1", "l2"], "clf__C": [0.1, 1, 10]}
        ],
        "newton-cg": [
            {"clf__solver": ["newton-cg"], "clf__penalty": ["l2"], "clf__C": [0.1, 1, 10]}
        ],
        "saga": [
            {"clf__solver": ["saga"], "clf__penalty": ["l1", "l2"], "clf__C": [0.1, 1, 10]},
            {"clf__solver": ["saga"], "clf__penalty": ["elasticnet"], "clf__C": [0.1, 1, 10], "clf__l1_ratio": [0.2, 0.5, 0.8]},
        ],
    }

    rows = []
    best_overall = None
    best_info = {"best_solver_group": None, "best_params": None, "cv_best_mcc": -1e18}

    for solver_name, grid in solver_grids.items():
        t0 = time.time()

        gs = GridSearchCV(
            estimator=base_pipe,
            param_grid=grid,
            scoring=scoring,
            refit="mcc",      # ★ STRICT: select by MCC
            cv=cv,
            n_jobs=n_jobs,
            verbose=0,
            return_train_score=False
        )
        gs.fit(X_tr, y_tr)

        best_mcc = float(gs.best_score_)
        # For reference (same hyperparams picked by MCC)
        best_idx = gs.best_index_
        best_auc = float(gs.cv_results_["mean_test_auc"][best_idx]) if "mean_test_auc" in gs.cv_results_ else float("nan")
        best_bacc = float(gs.cv_results_["mean_test_bacc"][best_idx]) if "mean_test_bacc" in gs.cv_results_ else float("nan")

        dt = time.time() - t0
        row = {
            "solver_group": solver_name,
            "cv_best_mcc": best_mcc,
            "cv_auc_at_best_mcc": best_auc,
            "cv_bacc_at_best_mcc": best_bacc,
            **{f"best_{k}": v for k, v in gs.best_params_.items()},
            "seconds": dt,
        }
        rows.append(row)
        print(f"[CV:{solver_name}] best_mcc={best_mcc:.4f} best_params={gs.best_params_} time={dt:.1f}s")

        if best_mcc > best_info["cv_best_mcc"]:
            best_info["cv_best_mcc"] = best_mcc
            best_info["best_solver_group"] = solver_name
            best_info["best_params"] = gs.best_params_
            best_overall = gs.best_estimator_

    df_out = pd.DataFrame(rows).sort_values("cv_best_mcc", ascending=False)
    df_out.to_csv(out_csv, index=False)
    print("Saved CV table:", out_csv)
    print("Selected (by MCC):", best_info["best_solver_group"], "cv_best_mcc=", best_info["cv_best_mcc"])
    return best_overall, df_out, best_info

In [33]:
import argparse

def build_parser():
    ap = argparse.ArgumentParser(prog="plm_lr_benchmark")  # prog 可选；别传 argv list
    ap.add_argument("--data_csv", required=True)
    ap.add_argument("--seq_col", default="sequence")
    ap.add_argument("--label_col", default="label")
    ap.add_argument("--model_id", required=True)
    ap.add_argument("--device", default="cuda")
    ap.add_argument("--batch_size", type=int, default=8)
    ap.add_argument("--max_length", type=int, default=256)
    ap.add_argument("--test_size", type=float, default=0.2)
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument("--cache_dir", default="cache")
    ap.add_argument("--n_jobs", type=int, default=8)
    ap.add_argument("--out_dir", default="results")
    return ap

In [None]:
def main(argv=None):
    ap = build_parser()
    args = ap.parse_args(argv)
    print(args)

    os.makedirs(args.out_dir, exist_ok=True)

    df = pd.read_csv(args.data_csv)
    if args.seq_col not in df.columns:
        raise KeyError(f"seq_col '{args.seq_col}' not found. Columns={list(df.columns)}")
    if args.label_col not in df.columns:
        raise KeyError(f"label_col '{args.label_col}' not found. Columns={list(df.columns)}")

    seqs = df[args.seq_col].astype(str).tolist()
    y = df[args.label_col].astype(int).to_numpy()

    X = embed_sequences(args.model_id, seqs, args.device, args.batch_size, args.cache_dir, args.max_length)

    print("X shape:", X.shape)
    print("X mean std over dims:", X.std(axis=0).mean())
    print("Unique rows (rounded):", np.unique(X.round(6), axis=0).shape[0], "/", X.shape[0])

    # 1) One fixed split (test is held-out and used ONCE)
    idx = np.arange(len(y))
    tr_idx, te_idx = train_test_split(
        idx,
        test_size=args.test_size,
        random_state=args.seed,
        shuffle=True,
        stratify=y
    )

    split_path = os.path.join(args.out_dir, f"split_seed{args.seed}_test{args.test_size}.npz")
    np.savez(split_path, train_idx=tr_idx, test_idx=te_idx)
    print("Saved split:", split_path)

    X_tr, X_te = X[tr_idx], X[te_idx]
    y_tr, y_te = y[tr_idx], y[te_idx]

    # 2) STRICT model selection (ONLY on training data) using MCC
    tag = safe_model_tag(args.model_id)
    cv_csv = os.path.join(args.out_dir, f"solver_cv_{tag}_seed{args.seed}_test{args.test_size}.csv")

    best_pipe, df_cv, best_info = solver_cv_select_strict_mcc(
        X_tr, y_tr,
        seed=args.seed,
        out_csv=cv_csv,
        n_jobs=getattr(args, "n_jobs", 8)
    )

    # 3) Final evaluation on held-out test (used ONCE)
    y_prob = best_pipe.predict_proba(X_te)[:, 1]
    y_pred = best_pipe.predict(X_te)  # consistent with MCC scorer (uses predict)
    print("y positive rate:", y.mean())
    print("pred positive rate:", y_pred.mean(), "prob std:", y_prob.std())

    m = compute_metrics(y_te, y_prob, y_pred)

    out = {
        "model_id": args.model_id,
        "n": int(len(df)),
        "dim": int(X.shape[1]),
        "selection": {
            "metric": "MCC",
            "cv_table_csv": os.path.basename(cv_csv),
            "best_solver_group": best_info["best_solver_group"],
            "best_params": best_info["best_params"],
            "cv_best_mcc": float(best_info["cv_best_mcc"]),
        },
        "test_metrics": m,
        "split": {"method": "train_test_split", "test_size": args.test_size, "seed": args.seed},
    }

    out_path = os.path.join(args.out_dir, f"{tag}.json")
    with open(out_path, "w", encoding="utf-8") as f:
        json.dump(out, f, ensure_ascii=False, indent=2)

    print("Saved:", out_path, flush=True)
    print(json.dumps(out["selection"], ensure_ascii=False, indent=2), flush=True)
    print(json.dumps(m, ensure_ascii=False, indent=2), flush=True)

In [35]:
main([
    "--data_csv","data.csv",
    "--seq_col","sequence",
    "--label_col","label",
    "--model_id","EvolutionaryScale/esmc-300m-2024-12",
    "--batch_size","8",
    "--max_length","256",
])

Namespace(data_csv='data.csv', seq_col='sequence', label_col='label', model_id='EvolutionaryScale/esmc-300m-2024-12', device='cuda', batch_size=8, max_length=256, test_size=0.2, seed=42, cache_dir='cache', out_dir='results')
X shape: (3444, 960)
X mean std over dims: 0.014620587
Unique rows (rounded): 3444 / 3444
Saved split: results\split_seed42_test0.2.npz
Fitting 5 folds for each of 24 candidates, totalling 120 fits
y positive rate: 0.6173054587688734
pred positive rate: 0.5703918722786647 prob std: 0.28252706
Saved: results\EvolutionaryScale_esmc-300m-2024-12.json
{
  "ACC": 0.7474600870827286,
  "BACC": 0.7443538324420678,
  "Sn": 0.7576470588235295,
  "Sp": 0.7310606060606061,
  "MCC": 0.4799582357794938,
  "AUC": 0.8236096256684493,
  "AP": 0.8776901289796043
}
