In [8]:
import time
import csv
from Bio import Entrez

# --- CONFIGURATION ---
INPUT_FILE = "tournament_wt_id.txt"  # Your file name
OUTPUT_FILE = "taxonomy_map.csv"
EMAIL = "justin.seyedmoomenkashi@mail.mcgill.ca"     # REQUIRED: Change this
# ---------------------

Entrez.email = EMAIL

def get_taxonomy_robust(input_path, output_path):
    # 1. Read IDs
    with open(input_path, "r") as f:
        id_list = [line.strip() for line in f if line.strip()]

    print(f"Loaded {len(id_list)} IDs. Fetching taxonomy safely...")

    with open(output_path, "w", newline="", encoding="utf-8") as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(["Original_ID", "Accession", "Organism", "TaxId"])

        # 2. Process in small batches using ESEARCH (safer for mixed IDs)
        batch_size = 20  # Smaller batch size for Search URL limits
        
        for i in range(0, len(id_list), batch_size):
            batch = id_list[i:i+batch_size]
            print(f"Processing batch {i} to {i+len(batch)}...")

            try:
                # Step A: Search for the IDs to get their numeric GI numbers
                # We join them with " OR " to search multiple at once
                search_term = " OR ".join(batch)
                search_handle = Entrez.esearch(db="protein", term=search_term, retmax=batch_size)
                search_results = Entrez.read(search_handle)
                
                # Get the list of numeric UIDs (GI numbers) found
                uids = search_results["IdList"]

                if not uids:
                    print(f"  Warning: No results found for batch starting at {batch[0]}")
                    continue

                # Step B: Fetch details using the numeric UIDs (Guaranteed to work)
                summary_handle = Entrez.esummary(db="protein", id=",".join(uids))
                records = Entrez.read(summary_handle)

                for record in records:
                    # AccessionVersion is the clean ID (e.g. WP_123.1)
                    acc = record.get("AccessionVersion", "Unknown")
                    organism = record.get("Organism", "Unknown")
                    taxid = record.get("TaxId", "Unknown")
                    
                    # We try to match it back to your original input list if possible
                    # (Note: The order might shuffle slightly due to 'OR' search)
                    writer.writerow([acc, acc, organism, taxid])

            except Exception as e:
                print(f"Error on batch {i}: {e}")
                # Fallback: Try one-by-one if a batch fails completely
                print("  -> Retrying this batch one-by-one...")
                for single_id in batch:
                    try:
                        sh = Entrez.esearch(db="protein", term=single_id)
                        sr = Entrez.read(sh)
                        if sr["IdList"]:
                            sumh = Entrez.esummary(db="protein", id=sr["IdList"][0])
                            rec = Entrez.read(sumh)[0]
                            writer.writerow([single_id, rec.get("AccessionVersion"), rec.get("Organism"), rec.get("TaxId")])
                        else:
                            writer.writerow([single_id, "Not Found", "-", "-"])
                    except:
                        writer.writerow([single_id, "Error", "-", "-"])
            
            time.sleep(0.5)

    print(f"Done! Results saved to {output_path}")

if __name__ == "__main__":
    get_taxonomy_robust(INPUT_FILE, OUTPUT_FILE)

Loaded 313 IDs. Fetching taxonomy safely...
Processing batch 0 to 20...
Processing batch 20 to 40...
Processing batch 40 to 60...
Processing batch 60 to 80...
Processing batch 80 to 100...
Processing batch 100 to 120...
Processing batch 120 to 140...
Processing batch 140 to 160...
Processing batch 160 to 180...
Processing batch 180 to 200...
Processing batch 200 to 220...
Processing batch 220 to 240...
Processing batch 240 to 260...
Processing batch 260 to 280...
Processing batch 280 to 300...
Processing batch 300 to 313...
Done! Results saved to taxonomy_map.csv


# **1. Tournament Dataset Analysis**

In [15]:
os.chdir("..")

In [None]:
# ============================================================
# PETase Tournament FASTA — QC & Stats Notebook Template
# ============================================================
# Assumptions:
# - Input: wild-type PETase FASTA (600k+ seqs)
# - Goal: sanity checks, redundancy, motifs, readiness metrics
# - Modular: each cell runnable independently
# ============================================================

# -------------------------
# 0) CONFIG
# -------------------------
FASTA = "tournament_wt.fasta"
OUTDIR = "qc_outputs"
REF_FASTA = "reference_petase.fasta"  # optional (e.g., IsPETase)
MIN_LEN, MAX_LEN = 280, 320

import os, sys, subprocess, json, math
from collections import Counter, defaultdict
import numpy as np
import pandas as pd
from Bio import SeqIO
from Bio.Seq import Seq

os.makedirs(OUTDIR, exist_ok=True)

# -------------------------
# 1) LOAD FASTA + BASIC COUNTS
# -------------------------
records = list(SeqIO.parse(FASTA, "fasta"))
N_total = len(records)

seqs = [str(r.seq) for r in records]
lengths = np.array([len(s) for s in seqs])

basic_stats = {
    "total_sequences": N_total,
    "unique_sequences": len(set(seqs)),
    "min_len": int(lengths.min()),
    "max_len": int(lengths.max()),
    "mean_len": float(lengths.mean()),
    "median_len": float(np.median(lengths)),
    "IQR_len": float(np.percentile(lengths, 75) - np.percentile(lengths, 25)),
    "pct_in_expected_window": float(((lengths>=MIN_LEN)&(lengths<=MAX_LEN)).mean()*100),
}
basic_stats

# -------------------------
# 2) RESIDUE QUALITY
# -------------------------
ambiguous = set("XBZJUO")
def has_amb(s): return any(c in ambiguous for c in s)

qc_residue = {
    "pct_with_ambiguous": 100 * np.mean([has_amb(s) for s in seqs]),
    "pct_with_stop": 100 * np.mean(["*" in s for s in seqs]),
}
qc_residue

# -------------------------
# 3) REDUNDANCY (CD-HIT)
# -------------------------
# Requires: cd-hit in PATH
def run(cmd):
    subprocess.run(cmd, shell=True, check=True)

run(f"cd-hit -i {FASTA} -o {OUTDIR}/cdhit_100.fasta -c 1.00 -n 5 -d 0")
run(f"cd-hit -i {FASTA} -o {OUTDIR}/cdhit_95.fasta  -c 0.95 -n 5 -d 0")
run(f"cd-hit -i {FASTA} -o {OUTDIR}/cdhit_90.fasta  -c 0.90 -n 5 -d 0")

def fasta_count(fp): return sum(1 for _ in SeqIO.parse(fp, "fasta"))

redundancy = {
    "clusters_100": fasta_count(f"{OUTDIR}/cdhit_100.fasta"),
    "clusters_95":  fasta_count(f"{OUTDIR}/cdhit_95.fasta"),
    "clusters_90":  fasta_count(f"{OUTDIR}/cdhit_90.fasta"),
}
redundancy

# -------------------------
# 4) MOTIF / SITE SANITY (PETase)
# -------------------------
# crude motif windows; adjust numbering if aligned
def motif_present(seq, motif): return motif in seq

motifs = {
    "has_GXSXG_like": lambda s: any(s[i]=="G" and s[i+2]=="S" for i in range(len(s)-2)),
    "has_W185_like":  lambda s: "W" in s,  # proxy; replace with alignment-based check
    "has_Cys_pair":  lambda s: s.count("C") >= 2,
}

motif_stats = {k: 100*np.mean([fn(s) for s in seqs]) for k,fn in motifs.items()}
motif_stats

# -------------------------
# 5) SIGNAL PEPTIDE / TM (external tools)
# -------------------------
# Optional; placeholders for batch runs
# signalp, tmhmm typically run outside notebook
# Save FASTA subsets if needed
with open(f"{OUTDIR}/all.fasta","w") as fh:
    SeqIO.write(records, fh, "fasta")

# -------------------------
# 6) DOMAIN CHECK (InterProScan)
# -------------------------
# Requires interproscan.sh
# Example (run externally for scale):
# interproscan.sh -i all.fasta -f tsv -dp -o interpro.tsv

# -------------------------
# 7) STRUCTURAL PROXIES (cheap)
# -------------------------
# disorder proxy: fraction of low-complexity residues
low_complex = set("PASTEG")
def lc_frac(s): return sum(c in low_complex for c in s)/len(s)

struct_proxy = {
    "mean_low_complex_frac": float(np.mean([lc_frac(s) for s in seqs])),
}
struct_proxy

# -------------------------
# 8) REFERENCE IDENTITY (optional)
# -------------------------
# blastp vs reference; DIAMOND recommended
# diamond makedb --in reference_petase.fasta -d ref_db
# diamond blastp -q tournament_wt.fasta -d ref_db -o ref_hits.tsv -f 6 qseqid pident length

# -------------------------
# 9) DATASET READINESS SUMMARY
# -------------------------
summary = {}
summary.update(basic_stats)
summary.update(qc_residue)
summary.update(redundancy)
summary.update(motif_stats)
summary.update(struct_proxy)

pd.Series(summary).to_csv(f"{OUTDIR}/summary_metrics.csv")
summary

# -------------------------
# 10) FINAL NOTES
# -------------------------
# - Enforce filters (length, ambiguity, motifs) to get trainable set
# - Recompute stats post-filter
# - Freeze deduplicated FASTA for benchmarking

# **2. Loading external stability and expression datasets**

**NESG Solubility** 
(https://loschmidt.chemi.muni.cz/soluprot/?page=download)
* 10k proteins
* Labels: exp, sol, uniprot id or local ID 
* Units: integer 

**Soluprot Solubility**
(https://loschmidt.chemi.muni.cz/soluprot/?page=download)
* 11k training, 3k test
* Label: solubility, number IDs with no conversion map (has seq)
* Unit: 0/1

**Price Solubility**
(https://pmc.ncbi.nlm.nih.gov/articles/PMC3372292/)
* 7k proteins 
* Label: usability. uniprot id
* Unit: 0/1

**PSI Solubility** 
(https://academic.oup.com/bioinformatics/article/36/18/4691/5860015?login=false)
* 11k proteins
* Label: solubility, Aa0000 ID scheme (has seq)
* Unit: 0/1
* Note: ecoli with custom IDs, dropped for now 

**Meltome Stability** 
(https://meltomeatlas.proteomics.wzw.tum.de/master_meltomeatlasapp/)
* 1M variants 
* Label: temperature, meltpoint, fold_change, uniprot id 
* Note: ecoli with custom IDs, dropped for now  

**FireprotDB Stability** 
(https://loschmidt.chemi.muni.cz/fireprotdb/)
* 53k variants
* Label: ddG, dTm, pH, Tm, mutation_effect, uniprot id 

**ThermomutDB Stability**
(https://biosig.lab.uq.edu.au/thermomutdb/downloads)
* 12k variants
* Label: pH, ddG, temperature, dTm, uniprot/pdb id 
* Note: these genes were not retrieved from the database due to removal from uniprotkb:  A0A410ZNC6 D0WVP7 G7LSK3 GQ884175 M5A5Y8 Q9REI6

**CAFA** 
(https://www.kaggle.com/competitions/cafa-5-protein-function-prediction/code)
* 142k variants

**Novozyme**
(https://www.kaggle.com/code/jinyuansun/eda-and-finetune-esm)
* 31k variants

**Protsol Solubility**
(https://huggingface.co/datasets/AI4Protein/ProtSolM)
* 71k proteins
* Label: solubility, no ID but has sequence
* Unit: 0/1 

In [None]:
#we will now load "masterdb.csv" found under data
import pandas as pd 
import os 
path = "data/masterdb.tsv"
df = pd.read_csv(path,sep="\t")
df.columns = (
    df.columns
    .str.strip()
    .str.lower()
    .str.replace(" ", "_")
    .str.replace("(", "")
    .str.replace(")", "")
    .str.replace("?", "")
)
df2 = pd.DataFrame()
df2["id"]=df["name"]
df2["sequence"] = df["protein_sequence"].astype(str)
print(df2)

def dftofasta(df,outfile):
    with open(outfile,"w") as f:
        for index,row in df.iterrows():
            f.write(f">{row['id']}\n")
            f.write(f"{row['sequence']}\n")
    return outfile 

#dftofasta(df2,"data/masterdb.fasta")

In [112]:
############################################
#Setting up functions 
############################################
import pandas as pd
import json
import numpy as np
from pathlib import Path
import requests
from Bio.SeqIO.FastaIO import SimpleFastaParser

def read_fasta_dict(path: str):
    seqs = {}
    with open(path) as fh:
        for header, seq in SimpleFastaParser(fh):
            sid = header.split()[0].strip()
            seqs[sid] = seq.strip()
    return seqs

def load_nesg(csv_path: str, fasta_path: str) -> pd.DataFrame:
    df = pd.read_csv(csv_path)  # uses CSV header row directly: id, exp, sol
    seqs = read_fasta_dict(fasta_path)
    #"sid" "usability" "fasta" 
    df["sequence"] = df["id"].map(seqs)
    return df 

def load_psi(csv_path: str, psi_detail_path: str) -> pd.DataFrame:
    df = pd.read_csv(csv_path)                        # has sid + fasta + labels
    psi_all = pd.read_csv(psi_detail_path, sep="\t")  # extra metadata

    # merge on sid
    df = df.merge(psi_all, on="sid", how="left")

    # sequence is already in the "fasta" column
    df["sequence"] = df["fasta"]

    return df
def load_price(csv_path: str) -> pd.DataFrame:
    df = pd.read_csv(csv_path)
    df["sequence"] = df["fasta"]
    # "sid" "usability" "fasta" 
    return df

def load_soluprot(train_csv: str, test_csv: str,
                  train_fasta_path: str, test_fasta_path: str) -> pd.DataFrame:

    # load FASTA → dict, keys exactly as in FASTA headers
    train_fasta = read_fasta_dict(train_fasta_path)
    test_fasta  = read_fasta_dict(test_fasta_path)

    # merge FASTA dicts
    fasta = {**train_fasta, **test_fasta}

    # load CSVs
    df1 = pd.read_csv(train_csv)
    df2 = pd.read_csv(test_csv)

    df = pd.concat([df1, df2], ignore_index=True)

    # map sequences using the exact ids
    df["sequence"] = df["sid"].map(fasta)


    return df


def load_meltome(csv_path: str, fasta_path: str) -> pd.DataFrame:
    df = pd.read_csv(csv_path)

    # load fasta into dict: {uniprot_id: sequence}
    fasta = read_fasta_dict(fasta_path)

    # extract uniprot prefix from Protein_ID (before "_")
    df["uniprot_id"] = df["Protein_ID"].astype(str).apply(lambda x: x.split("_")[0])

    # map sequences
    df["sequence"] = df["uniprot_id"].map(fasta)

    return df

def load_fireprot(csv_path: str) -> pd.DataFrame:
    df = pd.read_csv(csv_path)
    # "uniprot_id" "pdb_id" "muutation" "ddG" "dTm" "pH" "tm" "mutation_effect" "sequence"
    return df 

def load_thermomut(json_path: str, fasta_path: str) -> pd.DataFrame:
    # load JSON metadata
    with open(json_path) as fh:
        data = json.load(fh)
    df = pd.DataFrame(data)

    # load FASTA sequences
    fasta_dict = read_fasta_dict(fasta_path)

    # map UniProt → sequence
    # JSON column is "uniprot"
    df["sequence"] = df["uniprot"].map(fasta_dict)

    # ensure required labels exist even if missing in JSON
    for col in ["ph","ddg","temperature","dtm","PDB_wild",
                "pdb_mutant","mutation_code","mutated chain","effect"]:
        if col not in df.columns:
            df[col] = None

    return df

def load_protsolm(train_csv: str, test_csv: str) -> pd.DataFrame:
    df1 = pd.read_csv(train_csv)
    df2 = pd.read_csv(test_csv)
    df = pd.concat([df1, df2], ignore_index=True)
    # "aa_seq" "detail"
    return df

def fasta_merger(fasta_paths: list, outpath: str):
    with open(outpath, "w") as out:
        for path in fasta_paths:
            with open(path) as fh:
                for line in fh:
                    out.write(line)

def parse_cd_hit_clusters(clstr_path):
    clusters = {}
    current = None

    with open(clstr_path) as f:
        for line in f:
            line = line.strip()
            if line.startswith(">Cluster"):
                current = line.split()[1]
                clusters[current] = []
            else:
                # Example: "0       50aa, >SEQ123... *"
                sid = line.split(">")[1].split("...")[0]
                clusters[current].append(sid)

    return clusters

def fasta_merger_from_dfs(datasets, outpath):
    with open(outpath, "w") as out:
        for i, df in enumerate(datasets):
            for idx, row in df.iterrows():
                seq = row["sequence"]
                if seq is None or pd.isna(seq):
                    continue
                header = f"{df.__class__.__name__}_{i}_{idx}"
                out.write(f">{header}\n{seq}\n")

def fetch_uniprot_fasta(ids, out_fasta, delay=0.15):
    """Fetch FASTA for many UniProt IDs with error handling."""
    with open(out_fasta, "w") as out:
        for uid in ids:
            url = f"https://rest.uniprot.org/uniprotkb/{uid}.fasta"
            r = requests.get(url, timeout=10)

            if r.status_code == 200 and r.text.startswith(">"):
                out.write(r.text.strip() + "\n")
            else:
                # write placeholder for failed fetch
                out.write(f">{uid}\nFAILED_FETCH\n")

            time.sleep(delay)  # rate-limit to avoid 500 errors



def load_novozymes(train_path: str, test_path: str, test_labels_path: str) -> pd.DataFrame:
    train_df = pd.read_csv(train_path)         # seq_id, protein_sequence, pH, Tm
    test_df = pd.read_csv(test_path)           # seq_id, protein_sequence, pH
    test_labels = pd.read_csv(test_labels_path)  # seq_id, Tm

    # merge test with its labels
    test_df = test_df.merge(test_labels, on="seq_id", how="left")

    # unify column names to match your other datasets
    train_df["sequence"] = train_df["protein_sequence"]
    test_df["sequence"]  = test_df["protein_sequence"]

    # combine train + test into one dataframe
    df = pd.concat([train_df, test_df], ignore_index=True)

    return df

In [None]:
############################################
#setting the paths, loading them into dataframes, and making the merged fasta file to cdhit 
############################################


#PSI_PATH      = "data/benchmark/sol_benchmark/PSI_Biology_solubility_trainset.csv"
#psi_detail_path = "data/benchmark/sol_benchmark/PSI_all_data_esol.tab"
NESG_PATH     = "data/benchmark/sol_benchmark/nesg/nesg.csv"
nesg_fasta_path = "data/benchmark/sol_benchmark/nesg/nesg.fasta"
PRICE_PATH    = "data/benchmark/sol_benchmark/Price_usability_trainset.csv"
soluprot_train_path = "data/benchmark/sol_benchmark/soluprot_data/training_set.csv"
soluprot_test_path = "data/benchmark/sol_benchmark/soluprot_data/test_set.csv" 
soluprot_train_fasta = "data/benchmark/sol_benchmark/soluprot_data/training_set.fasta"
soluprot_test_fasta = "data/benchmark/sol_benchmark/soluprot_data/test_set.fasta" 
#meltome_path = "data/benchmark/stab_benchmark/meltome_cross-species.csv"
#meltome_fasta_path = "data/benchmark/stab_benchmark/meltome_fasta.fasta"
fireprot_path = "data/benchmark/stab_benchmark/fireprotdb_results_stability.csv"
thermomutdb_path = "data/benchmark/stab_benchmark/thermomutdb.json"
thermomutdb_fasta = "data/benchmark/stab_benchmark/thermomutdb.fasta"
protsol_train_path = "data/benchmark/protsolm_data/protsolm_train.csv"
protsol_test_path = "data/benchmark/protsolm_data/protsolm_test.csv"
novozyme_test_path = "data/benchmark/novozymes-enzyme-stability-prediction/test.csv"
novozyme_train_path = "data/benchmark/novozymes-enzyme-stability-prediction/train.csv"
novozyme_test_labels_path = "data/benchmark/novozymes-enzyme-stability-prediction/test_labels.csv"


#LOAD DATASETS 
#psi = load_psi("data/benchmark/sol_benchmark/PSI_Biology_solubility_trainset.csv","data/benchmark/sol_benchmark/PSI_all_data_esol.tab")
nesg = load_nesg(NESG_PATH, nesg_fasta_path) #no seq col 
price = load_price(PRICE_PATH) #fasta 
soluprot = load_soluprot(
    soluprot_train_path,
    soluprot_test_path,
    soluprot_train_fasta,
    soluprot_test_fasta    
)
fireprot = load_fireprot(fireprot_path) #sequence
thermomut = load_thermomut(thermomutdb_path,thermomutdb_fasta)
#meltome = load_meltome(meltome_path,meltome_fasta_path) 
protsolm = load_protsolm(protsol_train_path, protsol_test_path) #aa_seq
novozymes = load_novozymes(
    novozyme_train_path, novozyme_test_path, novozyme_test_labels_path
)

#ADD FASTA TO DATAFRAMES IF THE CSV DID NOT HAVE IT
nesg_fasta = read_fasta_dict(nesg_fasta_path)
nesg["sequence"] = nesg["id"].map(nesg_fasta)
price["sequence"] = price["fasta"]
protsolm["sequence"] = protsolm["aa_seq"]


#MERGE ALL FASTA FILES FOR CD-HIT
def fasta_merger_from_dfs(datasets, outpath):
    with open(outpath, "w") as out:
        for i, df in enumerate(datasets):
            for idx, row in df.iterrows():
                seq = row["sequence"]
                if seq is None or pd.isna(seq):
                    continue
                header = f"{df.__class__.__name__}_{i}_{idx}"
                out.write(f">{header}\n{seq}\n")

datasets = [nesg, price, soluprot, fireprot, protsolm,thermomut, novozymes]
fasta_merger_from_dfs(datasets, "allbenchmarks.fasta")

#SETTING  CANONICAL IDS
for i, df in enumerate(datasets):
    df["canonical_id"] = [
        f"{df.__class__.__name__}_{i}_{idx}" for idx in df.index
    ]

  df = pd.read_csv(csv_path)


In [74]:
############################################
#Running CD-HIT, parsing output and merging the final merged_df
############################################

#cd-hit -i all_sequences.fasta -o all_sequences_100.fasta -c 1.0 -n 5 -d 0
#PARSING THE CLUSTERS FROM CD-HIT INTO ONE MERGED DATAFRAME 
datasets = {
    "nesg": nesg,
    "price": price,
    "soluprot": soluprot,
    "fireprot": fireprot,
    "protsolm": protsolm,
    "thermomutdb":thermomut,
    "novozymes": novozymes
}

# Ensure every df has canonical_id as previously assigned
for name, df in datasets.items():
    if "canonical_id" not in df.columns:
        raise ValueError(f"{name} missing canonical_id")


clusters = parse_cd_hit_clusters("data/benchmark/benchmark_cdhit100_cluster.txt")
merged_rows = []
for clust_id, members in clusters.items():
    rep = members[0]
    collected = []
    for name, df in datasets.items():
        sub = df[df["canonical_id"].isin(members)]
        if len(sub) > 0:
            sub = sub.copy()
            sub["source_dataset"] = name
            collected.append(sub)

    if collected:
        merged_block = pd.concat(collected, ignore_index=True)
    else:
        merged_block = pd.DataFrame()

    merged_rows.append({
        "cluster_id": clust_id,
        "canonical_ids": members,
        "representative": rep,
        "merged_block": merged_block
    })

# this is your final merged output
merged_df = pd.DataFrame(merged_rows)

print(merged_df.shape)
merged_df.head()

(96118, 4)


Unnamed: 0,cluster_id,canonical_ids,representative,merged_block
0,0,"[DataFrame_3_25912, DataFrame_3_27807, DataFra...",DataFrame_3_25912,experiment_id protein_name uniprot_id pdb_i...
1,1,[DataFrame_6_28079],DataFrame_6_28079,seq_id pr...
2,2,[DataFrame_6_28080],DataFrame_6_28080,seq_id pr...
3,3,[DataFrame_6_28081],DataFrame_6_28081,seq_id pr...
4,4,[DataFrame_6_28082],DataFrame_6_28082,seq_id pr...


In [113]:
print(merged_df)

      cluster_id                                      canonical_ids  \
0              0  [DataFrame_3_25912, DataFrame_3_27807, DataFra...   
1              1                                [DataFrame_6_28079]   
2              2                                [DataFrame_6_28080]   
3              3                                [DataFrame_6_28081]   
4              4                                [DataFrame_6_28082]   
...          ...                                                ...   
96113      96113                                 [DataFrame_6_2392]   
96114      96114                                [DataFrame_6_31072]   
96115      96115                                 [DataFrame_6_1497]   
96116      96116                                [DataFrame_6_30387]   
96117      96117                                 [DataFrame_6_2806]   

          representative                                       merged_block  
0      DataFrame_3_25912     experiment_id protein_name uniprot_id pd

In [114]:
KEEP_COLS = {
    "fireprot": [
        "uniprot_id","pdb_id","chain","wild_type","position","mutation",
        "pH","tm","dTm","ddG","interpro_families","is_essential","sequence"
    ],
    "protsolm": ["name","label","detail","sequence"],
    "novozymes": ["seq_id","pH","tm","protein_sequence","sequence"],
    "nesg": ["sid","exp","sol","solubility","fasta","sequence"],
    "price": ["sid","Usability|0=NotUsable|1=Usable","fasta","sequence"]
}

def clean_block(df):
    if df.empty:
        return df
    src = df["source_dataset"].iloc[0]
    if src not in KEEP_COLS:
        return df
    keep = KEEP_COLS[src]
    # always preserve canonical_id + source_dataset
    keep = [c for c in keep if c in df.columns] + ["canonical_id","source_dataset"]
    return df[keep]

merged_df["merged_block"] = merged_df["merged_block"].apply(clean_block)

In [None]:
############################################################
# Attempt to get structure of this benchmark database to make multimodal db but will need to filter the proteins
############################################################

missing_pdb_uniprot = set()
has_pdb_ids = set()
no_uniprot_no_pdb_seqs = set()

for block in merged_df["merged_block"]:
    if not isinstance(block, pd.DataFrame) or block.empty:
        continue

    cols = block.columns

    # 1. uniprot_id present but pdb_id missing
    if "uniprot_id" in cols:
        b = block[["uniprot_id", "pdb_id"]].copy()
        b = b[b["uniprot_id"].notna()]  # entries with uniprot
        b = b[b["pdb_id"].isna()]       # missing pdb
        missing_pdb_uniprot.update(b["uniprot_id"].dropna().astype(str).tolist())

    # 2. unique pdb ids
    if "pdb_id" in cols:
        ids = block["pdb_id"].dropna().astype(str).tolist()
        has_pdb_ids.update(ids)

    # 3. no uniprot, no pdb → need sequences
    if "sequence" in cols:
        mask = pd.Series(True, index=block.index)
        if "uniprot_id" in cols:
            mask &= block["uniprot_id"].isna()
        if "pdb_id" in cols:
            mask &= block["pdb_id"].isna()
        seqs = block.loc[mask, "sequence"].dropna().tolist()
        for s in seqs:
            no_uniprot_no_pdb_seqs.add(s)


############################################################
# WRITE OUTPUT FILES
############################################################

# 1. uniprot present but no pdb
with open("missing_pdb_uniprot_ids.txt", "w") as f:
    for uid in sorted(missing_pdb_uniprot):
        f.write(uid + "\n")

# 2. pdb ids
with open("has_pdb_ids.txt", "w") as f:
    for pid in sorted(has_pdb_ids):
        f.write(pid + "\n")

# 3. fasta for no uniprot AND no pdb
with open("no_uniprot_no_pdb.fasta", "w") as f:
    for i, seq in enumerate(sorted(no_uniprot_no_pdb_seqs)):
        f.write(f">seq_{i}\n{seq}\n")

print("DONE:")
print(" missing_pdb_uniprot_ids.txt")
print(" has_pdb_ids.txt")
print(" no_uniprot_no_pdb.fasta")


#Run mmseqs2 to seq2struct to find top hits, if not, use AF2

DONE:
 missing_pdb_uniprot_ids.txt
 has_pdb_ids.txt
 no_uniprot_no_pdb.fasta


In [120]:
all_pdb_ids = set()

for block in merged_df["merged_block"]:
    if isinstance(block, pd.DataFrame) and "pdb_id" in block.columns:
        all_pdb_ids.update(block["pdb_id"].dropna().unique())
print(all_pdb_ids)

{'1FEP', '1UHG', '1AZP|1AZP|1AZP|1BF4', '1IMQ', '3VUB|3VUB', '2ADA', '1BNI|1A2P', '1AXB|1BTL|1XPB|1ZG4', '1YPI|1YPI', '1IDS|1IDS|1IDS|1IDS', '2TRX', '2IMM', '1ZNJ|1ZNJ|1ZNJ|1ZNJ|1ZNJ|1ZNJ|1ZNJ|1ZNJ|1ZNJ|1ZNJ|1ZNJ|1ZNJ', '1CYC', '1CAH', '1DPM', '1B5M', '1UZC', '3MBP|1SVX|1SVX', '3PG0', '1RX4|1DDR|1DDR|1DYJ|5DFR', '1DKT|1DKT', '1HME', '1C52', '1P2P|1P2P', '1AAR|1AAR', '1PDO', '1JU3', '1LBI|1LBI|1LBI|1LBI', '1RIS|1RIS|1RIS|1RIS', '1C9O|1C9O', '3BLS|1KE4', '1YYX', '1BNL', '3HHR|1HGU', '1KCQ', '1QGD|1QGD', '1E21', '1LVE', '1BCX', '1TIN', '1ROP|1ROP', '1ZYM', '1ARR|1ARR', '1HNG|1HNG|1CDC|1CDC', '1I4N', '1RTP', '1TPK', '1IR3|1IR3|1IR3|1IR3', '1TTG|1FNA|1FNF', '1APS', '1BP2|1G4I', '3D2A|1ISP|3D2C', '1CF3', '1HK0', '1AON|1AON|1AON|1AON|1AON|1AON|1AON|1AON|1AON|1AON|1AON|1AON|1AON|1AON|1AON|1AON|1AON|1AON|1AON|1AON|1AON', '1WIT', '1M21|1M21', '1CYO', '8TIM', '451C', '1KFW', '2Q98', '4ZLU|4ZLU', '2ABD', '3PGK', '3WP4', '1LUC', '1DIL|3SIL', '2CBR', '1OLR', '1BAH', '1HZ6', '1AM7|1AM7|1AM7', '1K9Q|1

# 4. Fine-tuning esm2/3 model