# ** 2.Loading external stability and expression datasets ** 

**NESG Solubility** 
(https://loschmidt.chemi.muni.cz/soluprot/?page=download)
* 10k proteins
* Labels: exp, sol, uniprot id or local ID 
* Units: integer 

**Soluprot Solubility**
(https://loschmidt.chemi.muni.cz/soluprot/?page=download)
* 11k training, 3k test
* Label: solubility, number IDs with no conversion map (has seq)
* Unit: 0/1

**Price Solubility**
(https://pmc.ncbi.nlm.nih.gov/articles/PMC3372292/)
* 7k proteins 
* Label: usability. uniprot id
* Unit: 0/1

**PSI Solubility** 
(https://academic.oup.com/bioinformatics/article/36/18/4691/5860015?login=false)
* 11k proteins
* Label: solubility, Aa0000 ID scheme (has seq)
* Unit: 0/1
* Note: ecoli with custom IDs, dropped for now 

**Meltome Stability** 
(https://meltomeatlas.proteomics.wzw.tum.de/master_meltomeatlasapp/)
* 1M variants 
* Label: temperature, meltpoint, fold_change, uniprot id 
* Note: ecoli with custom IDs, dropped for now  

**FireprotDB Stability** 
(https://loschmidt.chemi.muni.cz/fireprotdb/)
* 53k variants
* Label: ddG, dTm, pH, Tm, mutation_effect, uniprot id 

**ThermomutDB Stability**
(https://biosig.lab.uq.edu.au/thermomutdb/downloads)
* 12k variants
* Label: pH, ddG, temperature, dTm, uniprot/pdb id 
* Note: these genes were not retrieved from the database due to removal from uniprotkb:  A0A410ZNC6 D0WVP7 G7LSK3 GQ884175 M5A5Y8 Q9REI6

**CAFA** 
(https://www.kaggle.com/competitions/cafa-5-protein-function-prediction/code)
* 142k variants

**Novozyme**
(https://www.kaggle.com/code/jinyuansun/eda-and-finetune-esm)
* 31k variants

**Protsol Solubility**
(https://huggingface.co/datasets/AI4Protein/ProtSolM)
* 71k proteins
* Label: solubility, no ID but has sequence
* Unit: 0/1 

**MaveDB** 
- https://mavedb.org/search?target-organism-name=Escherichia+coli+K-12 

**ProteinGym**
-  

**Align2023**
- there is 4 enzymes with train.csv datasets in each folder, they all have mutation codes, some are multi- so will need to only get the single mutation code ones. they all have the sequence we will only look at beta-glucosidaseB and alphaamylase. 

**Prothermdb** 
- sent the email for access 

In [None]:
# benchmark.ipynb — Build unified WT–mutant pair table with explicit per-dataset input paths
# Edit DATASET_INPUTS to point to your local files (csv/tsv/parquet/xlsx/jsonl).
# Outputs:
#   data/benchmark/derived/benchmark_wt_mutant_pairs.tsv
#   data/benchmark/derived/benchmark_wt_parents.tsv

from __future__ import annotations
from pathlib import Path
import re
import pandas as pd
import numpy as np

OUT_DIR = Path("data/benchmark/derived"); OUT_DIR.mkdir(parents=True, exist_ok=True)

# ----------------------------
# 1) EDIT THESE INPUT PATHS
# ----------------------------
# Each entry can be:
#   - a file path string
#   - a Path
#   - a glob string (e.g. "align2023/**/train.csv")
#   - a list of any of the above
#
# kind:
#   - "variant": dataset contains mutations and/or mutant sequences (WT↔mutant pairs)
#   - "protein": dataset is protein-level only (no mutants); becomes WT-only rows
DATASET_INPUTS = [
    # Variant datasets
    {"name": "FireProtDB",  "kind": "variant", "paths": ["data/benchmark/fireprotdb/*.tsv", "data/benchmark/fireprotdb/*.csv"]},
    {"name": "ThermoMutDB", "kind": "variant", "paths": ["data/benchmark/thermomutdb/*.tsv", "data/benchmark/thermomutdb/*.csv"]},
    {"name": "Align2023",   "kind": "variant", "paths": ["data/benchmark/align2023/**/train.csv"]},
    {"name": "ProteinGym",  "kind": "variant", "paths": ["data/benchmark/proteingym/**/*.csv", "data/benchmark/proteingym/**/*.tsv", "data/benchmark/proteingym/**/*.parquet"]},
    {"name": "MaveDB",      "kind": "variant", "paths": ["data/benchmark/mavedb/**/*.csv", "data/benchmark/mavedb/**/*.tsv", "data/benchmark/mavedb/**/*.parquet"]},
    {"name": "CAFA5",       "kind": "variant", "paths": ["data/benchmark/cafa5/**/*.csv", "data/benchmark/cafa5/**/*.tsv", "data/benchmark/cafa5/**/*.parquet"]},
    {"name": "Novozyme",    "kind": "variant", "paths": ["data/benchmark/novozymes/**/*.csv", "data/benchmark/novozymes/**/*.tsv", "data/benchmark/novozymes/**/*.parquet"]},

    # Protein-level datasets (WT-only rows)
    {"name": "NESG_Sol",    "kind": "protein", "paths": ["data/benchmark/nesg/*.tsv", "data/benchmark/nesg/*.csv"]},
    {"name": "SoluProt",    "kind": "protein", "paths": ["data/benchmark/soluprot/*.tsv", "data/benchmark/soluprot/*.csv"]},
    {"name": "ProtSolM",    "kind": "protein", "paths": ["data/benchmark/protsol/*.parquet", "data/benchmark/protsol/*.csv", "data/benchmark/protsol/*.tsv"]},
    {"name": "Price_Sol",   "kind": "protein", "paths": ["data/benchmark/price_sol/*.tsv", "data/benchmark/price_sol/*.csv"]},
    {"name": "PSI_Sol",     "kind": "protein", "paths": ["data/benchmark/psi_sol/*.tsv", "data/benchmark/psi_sol/*.csv"]},
]

# Optional curated WT mapping file to override inference for variant datasets
# Columns: protein_id\twt_seq
WT_MAP_PATH = Path("data/benchmark/wt_map.tsv")

# ----------------------------
# 2) HELPERS
# ----------------------------
def read_any(path: Path) -> pd.DataFrame:
    suf = path.suffix.lower()
    if suf in [".tsv", ".tab"]:
        return pd.read_csv(path, sep="\t")
    if suf == ".csv":
        return pd.read_csv(path)
    if suf == ".parquet":
        return pd.read_parquet(path)
    if suf == ".jsonl":
        return pd.read_json(path, lines=True)
    if suf in [".xlsx", ".xls"]:
        return pd.read_excel(path)
    raise ValueError(f"Unsupported file: {path}")

def canon_cols(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    df.columns = (
        df.columns.astype(str)
        .str.strip().str.lower()
        .str.replace(r"\s+", "_", regex=True)
        .str.replace("(", "", regex=False).str.replace(")", "", regex=False)
        .str.replace("?", "", regex=False)
    )
    return df

def pick_col(cols, candidates):
    cols_l = {c.lower(): c for c in cols}
    for cand in candidates:
        if cand.lower() in cols_l:
            return cols_l[cand.lower()]
    for cand in candidates:
        for c in cols:
            if cand.lower() in c.lower():
                return c
    return None

def normalize_mut_str(s) -> str:
    if s is None or (isinstance(s, float) and np.isnan(s)):
        return ""
    s = str(s).strip()
    if s in ["", "WT", "wt", "wildtype", "Wildtype", "wild-type", "WILD_TYPE", "0", "nan", "None"]:
        return ""
    s = s.replace(",", ";").replace("|", ";").replace("/", ";").replace(":", ";")
    s = re.sub(r"\s+", "", s)
    s = re.sub(r";{2,}", ";", s).strip(";")
    return s

def split_muts(muts: str) -> list[str]:
    muts = normalize_mut_str(muts)
    return [m for m in muts.split(";") if m] if muts else []

AA1 = set("ACDEFGHIKLMNPQRSTVWY")
_mut_re = re.compile(r"^([ACDEFGHIKLMNPQRSTVWY])(\d+)([ACDEFGHIKLMNPQRSTVWY])$")

def diff_to_mutations(wt: str, mut: str) -> str:
    if not isinstance(wt, str) or not isinstance(mut, str) or not wt or not mut or len(wt) != len(mut):
        return ""
    out = []
    for i, (a, b) in enumerate(zip(wt, mut), start=1):
        if a != b and (a in AA1) and (b in AA1):
            out.append(f"{a}{i}{b}")
    return ";".join(out)

def apply_mutations(wt_seq: str, muts: list[str]):
    if not isinstance(wt_seq, str) or not wt_seq:
        return "", False, 0, len(muts)
    seq = list(wt_seq)
    mismatch = 0
    invalid = 0
    for m in muts:
        mm = _mut_re.match(m.strip())
        if not mm:
            invalid += 1
            continue
        a0, pos_s, a1 = mm.group(1), mm.group(2), mm.group(3)
        pos = int(pos_s)
        if pos < 1 or pos > len(seq):
            invalid += 1
            continue
        if seq[pos - 1] != a0:
            mismatch += 1
            continue
        seq[pos - 1] = a1
    ok = (invalid == 0) and (mismatch == 0)
    return "".join(seq), ok, mismatch, invalid

def infer_wt_seq(df: pd.DataFrame, protein_id_col: str, observed_seq_col: str, muts_col: str | None) -> dict:
    wt_map = {}
    if not protein_id_col or not observed_seq_col:
        return wt_map
    gcols = [protein_id_col, observed_seq_col] + ([muts_col] if muts_col else [])
    d = df[gcols].copy()
    if muts_col:
        d["_muts_norm"] = d[muts_col].map(normalize_mut_str)
    for pid, g in d.groupby(protein_id_col, dropna=True):
        pid = str(pid)
        if muts_col:
            wt_rows = g[g["_muts_norm"].eq("")]
            if len(wt_rows) > 0:
                seqs = wt_rows[observed_seq_col].dropna().astype(str)
                if len(seqs) > 0:
                    wt_map[pid] = seqs.value_counts().idxmax()
                    continue
        seqs = g[observed_seq_col].dropna().astype(str)
        if len(seqs) > 0:
            wt_map[pid] = seqs.value_counts().idxmax()
    return wt_map

def infer_label_col(df: pd.DataFrame):
    cols = list(df.columns)
    return pick_col(cols, [
        "label","y","fitness","score","assay_value","assay",
        "ddg","dtm","tm","solubility","sol","exp","expression","temperature"
    ])

def expand_paths(paths):
    if isinstance(paths, (str, Path)):
        paths = [paths]
    out = []
    for p in paths:
        p = Path(p)
        if any(ch in str(p) for ch in ["*", "?", "["]):
            out.extend(sorted(Path().glob(str(p))))
        else:
            if p.exists():
                out.append(p)
    return out

# WT override map
WT_MAP = {}
if WT_MAP_PATH.exists():
    wm = pd.read_csv(WT_MAP_PATH, sep="\t")
    WT_MAP = dict(zip(wm["protein_id"].astype(str), wm["wt_seq"].astype(str)))

# ----------------------------
# 3) NORMALIZE ONE FILE
# ----------------------------
def build_pairs_from_df(df: pd.DataFrame, source_dataset: str, kind: str) -> pd.DataFrame:
    df = canon_cols(df)
    cols = list(df.columns)

    protein_id_col = pick_col(cols, ["protein_id","uniprot_id","uniprot","accession","acc","pdb_id","pdb","target_id","gene","name","id"])
    variant_id_col = pick_col(cols, ["variant_id","mutant_id","mutation_id","sample_id","entry_id","seq_id","name","id"])

    wt_seq_col  = pick_col(cols, ["wt_seq","wildtype_sequence","wt_sequence","sequence_wt"])
    mut_seq_col = pick_col(cols, ["mut_seq","mutant_sequence","variant_sequence","sequence_mut"])
    seq_col     = pick_col(cols, ["sequence","protein_sequence","aa_seq","fasta","seq"])

    muts_col = pick_col(cols, ["mutations","mutation_code","mutation","mut","variant","aa_substitutions"])

    label_col = infer_label_col(df)

    if protein_id_col is None:
        df["protein_id"] = source_dataset
        protein_id_col = "protein_id"
    if variant_id_col is None:
        df["variant_id"] = [f"{source_dataset}:{i}" for i in range(len(df))]
        variant_id_col = "variant_id"

    if kind == "protein":
        use_seq = wt_seq_col or seq_col or mut_seq_col
        seqs = df[use_seq].astype(str) if use_seq else pd.Series([""] * len(df))
        out = pd.DataFrame({
            "source_dataset": source_dataset,
            "protein_id": df[protein_id_col].astype(str),
            "variant_id": df[variant_id_col].astype(str),
            "wt_seq": seqs,
            "mut_seq": seqs,
            "mutations": "",
            "n_mut": 0,
            "label": df[label_col] if label_col else np.nan,
            "label_col": label_col or "",
            "mut_apply_ok": True,
            "mut_apply_mismatch_n": 0,
            "mut_apply_invalid_n": 0,
        })
        return out

    # kind == "variant"
    observed_seq_col = mut_seq_col or wt_seq_col or seq_col
    if observed_seq_col is None and muts_col is None:
        raise ValueError(f"{source_dataset}: need at least a sequence column or mutations column")

    # Base sequences
    wt_seq = df[wt_seq_col].astype(str) if wt_seq_col else pd.Series([""] * len(df))
    mut_seq = df[mut_seq_col].astype(str) if mut_seq_col else (df[seq_col].astype(str) if seq_col else pd.Series([""] * len(df)))
    muts_raw = df[muts_col].map(normalize_mut_str) if muts_col else pd.Series([""] * len(df))

    # WT inference per protein_id if wt_seq missing
    local_wt_map = {}
    if (wt_seq_col is None) and (observed_seq_col is not None):
        local_wt_map = infer_wt_seq(df, protein_id_col, observed_seq_col, muts_col)

    def get_wt(pid: str) -> str:
        pid = str(pid)
        if pid in WT_MAP:
            return WT_MAP[pid]
        if pid in local_wt_map:
            return local_wt_map[pid]
        return ""

    wt_out, mut_out, ok_out, mm_out, inv_out, nmut_out, muts_out = [], [], [], [], [], [], []
    for pid, w, m, mstr in zip(df[protein_id_col].astype(str), wt_seq, mut_seq, muts_raw):
        wt = w if (isinstance(w, str) and w) else get_wt(pid)
        mut = m if (isinstance(m, str) and m) else ""

        mlist = split_muts(mstr)
        # If we have both sequences with equal length, derive mutation codes (more reliable than some dataset strings)
        if wt and mut and len(wt) == len(mut):
            derived = diff_to_mutations(wt, mut)
            if derived:
                mlist = split_muts(derived)
                mstr = derived

        # If we have WT + mutation codes but mut seq missing or equals WT, synthesize mut seq
        if wt and mlist and (not mut or mut == wt):
            newseq, ok, mmc, invc = apply_mutations(wt, mlist)
            mut = newseq if newseq else mut
        else:
            ok, mmc, invc = (True if (not mlist) else False), 0, 0

        # WT row normalization
        if not mlist:
            mut = wt if wt else mut

        wt_out.append(wt)
        mut_out.append(mut)
        ok_out.append(ok)
        mm_out.append(mmc)
        inv_out.append(invc)
        muts_out.append(mstr)
        nmut_out.append(len(mlist))

    out = pd.DataFrame({
        "source_dataset": source_dataset,
        "protein_id": df[protein_id_col].astype(str),
        "variant_id": df[variant_id_col].astype(str),
        "wt_seq": wt_out,
        "mut_seq": mut_out,
        "mutations": muts_out,
        "n_mut": nmut_out,
        "label": df[label_col] if label_col else np.nan,
        "label_col": label_col or "",
        "mut_apply_ok": ok_out,
        "mut_apply_mismatch_n": mm_out,
        "mut_apply_invalid_n": inv_out,
    })

    # Keep only single mutants for Align2023 if desired
    if source_dataset.lower().startswith("align2023"):
        out = out[out["n_mut"].isin([0, 1])].reset_index(drop=True)

    return out

# ----------------------------
# 4) RUN INGEST + EXPORT
# ----------------------------
all_rows = []
manifest = []

for spec in DATASET_INPUTS:
    ds_name = spec["name"]
    ds_kind = spec["kind"]
    files = expand_paths(spec["paths"])

    if not files:
        manifest.append({"dataset": ds_name, "kind": ds_kind, "files_found": 0, "note": "no files matched"})
        continue

    for fp in files:
        try:
            df = read_any(fp)
            norm = build_pairs_from_df(df, source_dataset=ds_name, kind=ds_kind)
            norm["source_file"] = str(fp)
            all_rows.append(norm)
            manifest.append({"dataset": ds_name, "kind": ds_kind, "file": str(fp), "rows_in": len(df), "rows_out": len(norm)})
        except Exception as e:
            manifest.append({"dataset": ds_name, "kind": ds_kind, "file": str(fp), "error": repr(e)})

pairs = pd.concat(all_rows, ignore_index=True) if all_rows else pd.DataFrame(columns=[
    "source_dataset","protein_id","variant_id","wt_seq","mut_seq","mutations","n_mut","label","label_col",
    "mut_apply_ok","mut_apply_mismatch_n","mut_apply_invalid_n","source_file"
])

# QC flags
pairs["has_wt_seq"] = pairs["wt_seq"].astype(str).str.len().gt(0)
pairs["has_mut_seq"] = pairs["mut_seq"].astype(str).str.len().gt(0)
pairs["len_match"]  = pairs["has_wt_seq"] & pairs["has_mut_seq"] & (pairs["wt_seq"].astype(str).str.len() == pairs["mut_seq"].astype(str).str.len())
pairs["is_wt_row"]  = pairs["mutations"].astype(str).eq("") | pairs["n_mut"].astype(int).eq(0)

pairs_path = OUT_DIR / "benchmark_wt_mutant_pairs.tsv"
pairs.to_csv(pairs_path, sep="\t", index=False)

wt_parents = (
    pairs[pairs["has_wt_seq"]]
    .groupby("protein_id", as_index=False)
    .agg(
        wt_seq=("wt_seq", lambda x: x.value_counts().idxmax()),
        n_variants=("variant_id", "nunique"),
        sources=("source_dataset", lambda x: ";".join(sorted(set(map(str, x)))))
    )
)
wt_path = OUT_DIR / "benchmark_wt_parents.tsv"
wt_parents.to_csv(wt_path, sep="\t", index=False)

manifest_df = pd.DataFrame(manifest)
manifest_path = OUT_DIR / "ingest_manifest.tsv"
manifest_df.to_csv(manifest_path, sep="\t", index=False)

print("Wrote:", pairs_path)
print("Wrote:", wt_path)
print("Wrote:", manifest_path)
print("pairs rows:", len(pairs), "unique proteins:", pairs["protein_id"].nunique(), "WT parents:", len(wt_parents))
print("missing WT seq rows:", int((~pairs["has_wt_seq"]).sum()))
print("length mismatch rows:", int((pairs["has_wt_seq"] & pairs["has_mut_seq"] & ~pairs["len_match"]).sum()))
print("mutation apply mismatches:", int((pairs["mut_apply_mismatch_n"] > 0).sum()), "invalid muts:", int((pairs["mut_apply_invalid_n"] > 0).sum()))

RuntimeError: No input dataframe found. Either set INTEGRATED_PATH or ensure your dataset DFs are loaded.

In [None]:
#we will now load "masterdb.csv" found under data
import pandas as pd 
import os 
path = "data/masterdb.tsv"
df = pd.read_csv(path,sep="\t")
df.columns = (
    df.columns
    .str.strip()
    .str.lower()
    .str.replace(" ", "_")
    .str.replace("(", "")
    .str.replace(")", "")
    .str.replace("?", "")
)
df2 = pd.DataFrame()
df2["id"]=df["name"]
df2["sequence"] = df["protein_sequence"].astype(str)
print(df2)

def dftofasta(df,outfile):
    with open(outfile,"w") as f:
        for index,row in df.iterrows():
            f.write(f">{row['id']}\n")
            f.write(f"{row['sequence']}\n")
    return outfile 

#dftofasta(df2,"data/masterdb.fasta")

In [None]:
############################################
#setting the paths, loading them into dataframes, and making the merged fasta file to cdhit 
############################################


#PSI_PATH      = "data/benchmark/sol_benchmark/PSI_Biology_solubility_trainset.csv"
#psi_detail_path = "data/benchmark/sol_benchmark/PSI_all_data_esol.tab"
NESG_PATH     = "data/benchmark/sol_benchmark/nesg/nesg.csv"
nesg_fasta_path = "data/benchmark/sol_benchmark/nesg/nesg.fasta"
PRICE_PATH    = "data/benchmark/sol_benchmark/Price_usability_trainset.csv"
soluprot_train_path = "data/benchmark/sol_benchmark/soluprot_data/training_set.csv"
soluprot_test_path = "data/benchmark/sol_benchmark/soluprot_data/test_set.csv" 
soluprot_train_fasta = "data/benchmark/sol_benchmark/soluprot_data/training_set.fasta"
soluprot_test_fasta = "data/benchmark/sol_benchmark/soluprot_data/test_set.fasta" 
#meltome_path = "data/benchmark/stab_benchmark/meltome_cross-species.csv"
#meltome_fasta_path = "data/benchmark/stab_benchmark/meltome_fasta.fasta"
fireprot_path = "data/benchmark/stab_benchmark/fireprotdb_results_stability.csv"
thermomutdb_path = "data/benchmark/stab_benchmark/thermomutdb.json"
thermomutdb_fasta = "data/benchmark/stab_benchmark/thermomutdb.fasta"
protsol_train_path = "data/benchmark/protsolm_data/protsolm_train.csv"
protsol_test_path = "data/benchmark/protsolm_data/protsolm_test.csv"
novozyme_test_path = "data/benchmark/novozymes-enzyme-stability-prediction/test.csv"
novozyme_train_path = "data/benchmark/novozymes-enzyme-stability-prediction/train.csv"
novozyme_test_labels_path = "data/benchmark/novozymes-enzyme-stability-prediction/test_labels.csv"


#LOAD DATASETS 
#psi = load_psi("data/benchmark/sol_benchmark/PSI_Biology_solubility_trainset.csv","data/benchmark/sol_benchmark/PSI_all_data_esol.tab")
nesg = load_nesg(NESG_PATH, nesg_fasta_path) #no seq col 
price = load_price(PRICE_PATH) #fasta 
soluprot = load_soluprot(
    soluprot_train_path,
    soluprot_test_path,
    soluprot_train_fasta,
    soluprot_test_fasta    
)
fireprot = load_fireprot(fireprot_path) #sequence
thermomut = load_thermomut(thermomutdb_path,thermomutdb_fasta)
#meltome = load_meltome(meltome_path,meltome_fasta_path) 
protsolm = load_protsolm(protsol_train_path, protsol_test_path) #aa_seq
novozymes = load_novozymes(
    novozyme_train_path, novozyme_test_path, novozyme_test_labels_path
)

#ADD FASTA TO DATAFRAMES IF THE CSV DID NOT HAVE IT
nesg_fasta = read_fasta_dict(nesg_fasta_path)
nesg["sequence"] = nesg["id"].map(nesg_fasta)
price["sequence"] = price["fasta"]
protsolm["sequence"] = protsolm["aa_seq"]


#MERGE ALL FASTA FILES FOR CD-HIT
def fasta_merger_from_dfs(datasets, outpath):
    with open(outpath, "w") as out:
        for i, df in enumerate(datasets):
            for idx, row in df.iterrows():
                seq = row["sequence"]
                if seq is None or pd.isna(seq):
                    continue
                header = f"{df.__class__.__name__}_{i}_{idx}"
                out.write(f">{header}\n{seq}\n")

datasets = [nesg, price, soluprot, fireprot, protsolm,thermomut, novozymes]
fasta_merger_from_dfs(datasets, "allbenchmarks.fasta")

#SETTING  CANONICAL IDS
for i, df in enumerate(datasets):
    df["canonical_id"] = [
        f"{df.__class__.__name__}_{i}_{idx}" for idx in df.index
    ]

In [None]:
############################################
#Setting up functions 
############################################
import pandas as pd
import json
import numpy as np
from pathlib import Path
import requests
from Bio.SeqIO.FastaIO import SimpleFastaParser

def read_fasta_dict(path: str):
    seqs = {}
    with open(path) as fh:
        for header, seq in SimpleFastaParser(fh):
            sid = header.split()[0].strip()
            seqs[sid] = seq.strip()
    return seqs

def load_nesg(csv_path: str, fasta_path: str) -> pd.DataFrame:
    df = pd.read_csv(csv_path)  # uses CSV header row directly: id, exp, sol
    seqs = read_fasta_dict(fasta_path)
    #"sid" "usability" "fasta" 
    df["sequence"] = df["id"].map(seqs)
    return df 

def load_psi(csv_path: str, psi_detail_path: str) -> pd.DataFrame:
    df = pd.read_csv(csv_path)                        # has sid + fasta + labels
    psi_all = pd.read_csv(psi_detail_path, sep="\t")  # extra metadata

    # merge on sid
    df = df.merge(psi_all, on="sid", how="left")

    # sequence is already in the "fasta" column
    df["sequence"] = df["fasta"]

    return df
def load_price(csv_path: str) -> pd.DataFrame:
    df = pd.read_csv(csv_path)
    df["sequence"] = df["fasta"]
    # "sid" "usability" "fasta" 
    return df

def load_soluprot(train_csv: str, test_csv: str,
                  train_fasta_path: str, test_fasta_path: str) -> pd.DataFrame:

    # load FASTA → dict, keys exactly as in FASTA headers
    train_fasta = read_fasta_dict(train_fasta_path)
    test_fasta  = read_fasta_dict(test_fasta_path)

    # merge FASTA dicts
    fasta = {**train_fasta, **test_fasta}

    # load CSVs
    df1 = pd.read_csv(train_csv)
    df2 = pd.read_csv(test_csv)

    df = pd.concat([df1, df2], ignore_index=True)

    # map sequences using the exact ids
    df["sequence"] = df["sid"].map(fasta)


    return df


def load_meltome(csv_path: str, fasta_path: str) -> pd.DataFrame:
    df = pd.read_csv(csv_path)

    # load fasta into dict: {uniprot_id: sequence}
    fasta = read_fasta_dict(fasta_path)

    # extract uniprot prefix from Protein_ID (before "_")
    df["uniprot_id"] = df["Protein_ID"].astype(str).apply(lambda x: x.split("_")[0])

    # map sequences
    df["sequence"] = df["uniprot_id"].map(fasta)

    return df

def load_fireprot(csv_path: str) -> pd.DataFrame:
    df = pd.read_csv(csv_path)
    # "uniprot_id" "pdb_id" "muutation" "ddG" "dTm" "pH" "tm" "mutation_effect" "sequence"
    return df 

def load_thermomut(json_path: str, fasta_path: str) -> pd.DataFrame:
    # load JSON metadata
    with open(json_path) as fh:
        data = json.load(fh)
    df = pd.DataFrame(data)

    # load FASTA sequences
    fasta_dict = read_fasta_dict(fasta_path)

    # map UniProt → sequence
    # JSON column is "uniprot"
    df["sequence"] = df["uniprot"].map(fasta_dict)

    # ensure required labels exist even if missing in JSON
    for col in ["ph","ddg","temperature","dtm","PDB_wild",
                "pdb_mutant","mutation_code","mutated chain","effect"]:
        if col not in df.columns:
            df[col] = None

    return df

def load_protsolm(train_csv: str, test_csv: str) -> pd.DataFrame:
    df1 = pd.read_csv(train_csv)
    df2 = pd.read_csv(test_csv)
    df = pd.concat([df1, df2], ignore_index=True)
    # "aa_seq" "detail"
    return df

def fasta_merger(fasta_paths: list, outpath: str):
    with open(outpath, "w") as out:
        for path in fasta_paths:
            with open(path) as fh:
                for line in fh:
                    out.write(line)

def parse_cd_hit_clusters(clstr_path):
    clusters = {}
    current = None

    with open(clstr_path) as f:
        for line in f:
            line = line.strip()
            if line.startswith(">Cluster"):
                current = line.split()[1]
                clusters[current] = []
            else:
                # Example: "0       50aa, >SEQ123... *"
                sid = line.split(">")[1].split("...")[0]
                clusters[current].append(sid)

    return clusters

def fasta_merger_from_dfs(datasets, outpath):
    with open(outpath, "w") as out:
        for i, df in enumerate(datasets):
            for idx, row in df.iterrows():
                seq = row["sequence"]
                if seq is None or pd.isna(seq):
                    continue
                header = f"{df.__class__.__name__}_{i}_{idx}"
                out.write(f">{header}\n{seq}\n")

def fetch_uniprot_fasta(ids, out_fasta, delay=0.15):
    """Fetch FASTA for many UniProt IDs with error handling."""
    with open(out_fasta, "w") as out:
        for uid in ids:
            url = f"https://rest.uniprot.org/uniprotkb/{uid}.fasta"
            r = requests.get(url, timeout=10)

            if r.status_code == 200 and r.text.startswith(">"):
                out.write(r.text.strip() + "\n")
            else:
                # write placeholder for failed fetch
                out.write(f">{uid}\nFAILED_FETCH\n")

            time.sleep(delay)  # rate-limit to avoid 500 errors



def load_novozymes(train_path: str, test_path: str, test_labels_path: str) -> pd.DataFrame:
    train_df = pd.read_csv(train_path)         # seq_id, protein_sequence, pH, Tm
    test_df = pd.read_csv(test_path)           # seq_id, protein_sequence, pH
    test_labels = pd.read_csv(test_labels_path)  # seq_id, Tm

    # merge test with its labels
    test_df = test_df.merge(test_labels, on="seq_id", how="left")

    # unify column names to match your other datasets
    train_df["sequence"] = train_df["protein_sequence"]
    test_df["sequence"]  = test_df["protein_sequence"]

    # combine train + test into one dataframe
    df = pd.concat([train_df, test_df], ignore_index=True)

    return df