In [22]:
import pandas as pd
import requests
from tqdm import tqdm
from Bio.Align import PairwiseAligner

df = pd.read_csv("../data/subset.csv")

df["protein_sequence"] = ""
df["domain_start"] = None
df["domain_end"] = None
df["note"] = ""

# PDB ID replacements
pdb_id_replacements = {
    "1vw4": "3j6b",
    "4gns": "4yg8",
    "1vs9": "4v4i",
    "3p9d": "4v81",
    "4d8q": "4v94",
    "4a17": "4v8p",
}

def corrected_pdb_id(original_id):
    return pdb_id_replacements.get(original_id.lower(), original_id.lower())

def get_fasta_sequence(pdb_id, chain_id):
    url = f"https://www.rcsb.org/fasta/entry/{pdb_id}"
    try:
        response = requests.get(url)
        if response.status_code != 200:
            print(f"Failed to fetch FASTA for {pdb_id}")
            return ""

        chain_id = chain_id.upper()
        fasta_blocks = response.text.strip().split(">")
        for block in fasta_blocks:
            lines = block.strip().splitlines()
            if not lines:
                continue
            header = lines[0]
            sequence = "".join(lines[1:])

            if "|Chain " in header or "|Chains " in header:
                chain_field = header.split("|")[1]
                parts = chain_field.replace("Chains ", "").replace("Chain ", "").split(",")
                for part in parts:
                    part = part.strip()
                    if "[auth " in part:
                        model_id = part.split("[auth")[0].strip().upper()
                        auth_id = part.split("[auth")[1].replace("]", "").strip().upper()
                        if chain_id == model_id or chain_id == auth_id:
                            return sequence
                    else:
                        if chain_id == part.upper():
                            return sequence

        print(f"Chain {chain_id} not found in FASTA for {pdb_id.upper()}")
        return ""
    except Exception as e:
        print(f"Exception {e}")
        return ""

def collect_sw_mismatches(domain_seq, matched_seq, domain_start):
    mismatches = []
    for i, (d, p) in enumerate(zip(domain_seq, matched_seq)):
        if d != p:
            mismatches.append(f"{domain_start + i}({d};{p})")
    return ", ".join(mismatches)

# Initialize Smith–Waterman aligner
aligner = PairwiseAligner()
aligner.mode = "local"
aligner.match_score = 2
aligner.mismatch_score = -1
aligner.open_gap_score = -0.5
aligner.extend_gap_score = -0.1

# Main loop
for idx, row in tqdm(df.iterrows(), total=len(df)):
    domain_id = row["domain_id"]
    if not isinstance(domain_id, str) or len(domain_id) < 5:
        continue

    pdb_id_raw = domain_id[:4]
    pdb_id = corrected_pdb_id(pdb_id_raw)
    chain_id = domain_id[4]
    domain_seq = str(row.get("sequence", "")).strip()

    if not domain_seq:
        print(f"Missing domain sequence for {domain_id}")
        continue

    protein_seq = get_fasta_sequence(pdb_id, chain_id)
    if not protein_seq:
        continue

    df.at[idx, "protein_sequence"] = protein_seq

    # Exact match first
    pos = protein_seq.find(domain_seq)
    if pos != -1:
        df.at[idx, "domain_start"] = pos + 1
        df.at[idx, "domain_end"] = pos + len(domain_seq)
    else:
        # Smith-Waterman alignment
        alignments = aligner.align(protein_seq, domain_seq)
        if alignments:
            best = alignments[0]
            prot_blocks = best.aligned[0]
            dom_blocks = best.aligned[1]

            start_in_prot = prot_blocks[0][0]
            end_in_prot = prot_blocks[-1][1]
            start_in_dom = dom_blocks[0][0]
            end_in_dom = dom_blocks[-1][1]

            matched_prot = protein_seq[start_in_prot:end_in_prot]
            matched_dom = domain_seq[start_in_dom:end_in_dom]

            df.at[idx, "domain_start"] = start_in_prot + 1
            df.at[idx, "domain_end"] = end_in_prot
            df.at[idx, "note"] = collect_sw_mismatches(matched_dom, matched_prot, start_in_prot + 1)
        else:
            print(f"Domain not found in chain {chain_id} of {pdb_id}")
            print(f"Domain seq: {domain_seq}")
            print(f"PDB seq (start): {protein_seq}")

df.to_csv("../data/subset_with_protein_mapping.csv", index=False)

100%|██████████| 1000/1000 [08:04<00:00,  2.06it/s]
