In [1]:
from Bio.PDB import PDBParser, MMCIFParser
from Bio.SVDSuperimposer import SVDSuperimposer
import numpy as np

def load_structure(filename, struct_id="structure"):
    if filename.endswith(".pdb"):
        parser = PDBParser(QUIET=True)
    elif filename.endswith(".cif") or filename.endswith(".mmcif"):
        parser = MMCIFParser(QUIET=True)
    else:
        raise ValueError("Unsupported file format: must be .pdb or .cif/.mmcif")
    structure = parser.get_structure(struct_id, filename)
    return structure

def get_ligand_coords(structure):
    lig_coords = []
    for model in structure:
        for chain in model:
            for residue in chain:
                if residue.id[0] != " ":
                    for atom in residue:
                        lig_coords.append(atom.get_coord())
    return lig_coords

def find_pocket_residues(structure, lig_coords, cutoff=4.0):
    pocket_res = set()
    for model in structure:
        for chain in model:
            for residue in chain:
                if residue.id[0] == " ":
                    for atom in residue:
                        for lig_coord in lig_coords:
                            d = np.linalg.norm(atom.get_coord() - lig_coord)
                            if d <= cutoff:
                                pocket_res.add( (chain.id, residue.id[1]) )
                                break
    return pocket_res

def get_ca_coords(structure):
    ca_coords = []
    for model in structure:
        for chain in model:
            for residue in chain:
                if 'CA' in residue:
                    ca_coords.append(residue['CA'].get_coord())
    return ca_coords

def get_sidechain_coords(structure, res_id_set):
    sc_coords = {}
    for model in structure:
        for chain in model:
            for residue in chain:
                key = (chain.id, residue.id[1])
                if key in res_id_set and residue.id[0] == " ":
                    for atom in residue:
                        if atom.get_name() in ['N','CA','C','O','OXT'] or atom.get_name().startswith('H'):
                            continue
                        tag = (key, atom.get_name())
                        sc_coords[tag] = atom.get_coord()
    return sc_coords


def kabsch_rmsd(coords1, coords2):
    sup = SVDSuperimposer()
    sup.set(np.array(coords1), np.array(coords2))
    sup.run()
    return sup.get_rms()

def calculate_CA_rmsd(ref_fn, cmp_fn):
    ref_struct = load_structure(ref_fn)
    cmp_struct = load_structure(cmp_fn)
    ref_ca = get_ca_coords(ref_struct)
    cmp_ca = get_ca_coords(cmp_struct)
    return kabsch_rmsd(ref_ca, cmp_ca)
    
def calculate_sc_rmsd(ref_fn, cmp_fn, cutoff=4.0):
    ref_struct = load_structure(ref_fn)
    cmp_struct = load_structure(cmp_fn)
    lig_coords = get_ligand_coords(ref_struct)
    pocket_res = find_pocket_residues(ref_struct, lig_coords, cutoff=cutoff)
    ref_sc = get_sidechain_coords(ref_struct, pocket_res)
    cmp_sc = get_sidechain_coords(cmp_struct, pocket_res)
    tags = sorted(list(ref_sc.keys()))
    coords1 = [ref_sc[tag] for tag in tags if tag in cmp_sc]
    coords2 = [cmp_sc[tag] for tag in tags if tag in cmp_sc]
    if len(coords1) == 0 or len(coords2) == 0 or len(coords1) != len(coords2):
        raise ValueError("Sidechain coordinate extraction failed (no pocket residues or atoms)")
    return kabsch_rmsd(coords1, coords2)

In [2]:
import pandas as pd
from multiprocessing import Pool, cpu_count
from tqdm import tqdm
from functools import partial

def rmsd_worker(row, model, col='lmpnn'):
    lmpnn_path = row[(col, 'path')]
    model_path = row[(model, 'path')]
    try:
        ca = calculate_CA_rmsd(lmpnn_path, model_path)
        sc = calculate_sc_rmsd(lmpnn_path, model_path)
    except Exception as e:
        print(f"Error processing {lmpnn_path} and {model_path}: {e}")
        ca = np.nan
        sc = np.nan
    return (ca, sc)

def process(df, model, col='lmpnn', pool_size=cpu_count()):
    rows = [row for _, row in df.iterrows()]
    worker = partial(rmsd_worker, model=model, col=col)
    with Pool(pool_size) as pool:
        results = list(tqdm(pool.imap(worker, rows), total=len(rows)))
    ca_rmsd, sc_rmsd = zip(*results)
    df[(model, 'ca_rmsd')] = ca_rmsd
    df[(model, 'sc_rmsd')] = sc_rmsd
    return df

path = 'lmpnn_filt_data_added.parquet'
df = pd.read_parquet(path)

df = process(df, model='af3', col='lmpnn', pool_size=8)
df = process(df, model='boltz', col='lmpnn', pool_size=8)

100%|██████████| 36/36 [00:00<00:00, 53.18it/s]
100%|██████████| 36/36 [00:00<00:00, 56.09it/s]


In [3]:
df.to_parquet('rmsd_1st.parquet')
df

Unnamed: 0_level_0,diffusion,diffusion,lmpnn,lmpnn,lmpnn,lmpnn,lmpnn,lmpnn,lmpnn,lmpnn,...,boltz,boltz,boltz,boltz,boltz,boltz,af3,af3,boltz,boltz
Unnamed: 0_level_1,id,batch,tag,ddg,fa_rep,res_totalscore,totalscore,seq,path,relaxed_path,...,path,iptm,ptm,score,affinity_ic50,affinity_bind,ca_rmsd,sc_rmsd,ca_rmsd,sc_rmsd
0,result_7_packed_3_1,pht_demo,result_7_packed_3_1,-27.749462,95.805969,-1.583578,-199.530838,SLEEIIAKIRASDPATVDWGAHFREFCKAAGVAEVTPEERALAEKA...,../3_lmpnn/output/packed/result_7_packed_3_1.pdb,../3_lmpnn/output/packed/result_7_packed_3_1_b...,...,../6_boltz/output/boltz_results_0/predictions/...,0.769996,0.739775,0.712917,1.609957,0.187223,1.091246,1.289126,12.82891,6.293877
1,result_61_packed_8_1,pht_demo,result_61_packed_8_1,-26.640675,104.494186,-1.441779,-181.664185,SEELLAAIKAAFRKIAGDLLTDRVDLDELAQFILDTLTLSEEERAR...,../3_lmpnn/output/packed/result_61_packed_8_1.pdb,../3_lmpnn/output/packed/result_61_packed_8_1_...,...,../6_boltz/output/boltz_results_0/predictions/...,0.581382,0.760046,0.718728,1.729165,0.146521,14.882061,7.540875,4.103964,3.004177
2,result_7_packed_8_1,pht_demo,result_7_packed_8_1,-25.789537,98.276131,-1.57081,-197.922104,SLAEILAEIRAADPATVDWEAHFRRFCEAAGVEAVTPEERELAARA...,../3_lmpnn/output/packed/result_7_packed_8_1.pdb,../3_lmpnn/output/packed/result_7_packed_8_1_b...,...,../6_boltz/output/boltz_results_0/predictions/...,0.86967,0.904298,0.692104,1.849106,0.117982,2.475292,3.455196,11.608226,5.500517
3,result_7_packed_2_1,pht_demo,result_7_packed_2_1,-25.547806,102.521637,-1.61018,-202.882706,SLAELIQEIRDADPKTIDWEAFFRRFAEAAGVAAVTPEQRALAARM...,../3_lmpnn/output/packed/result_7_packed_2_1.pdb,../3_lmpnn/output/packed/result_7_packed_2_1_b...,...,../6_boltz/output/boltz_results_0/predictions/...,0.790683,0.717135,0.668814,1.888733,0.13917,10.356954,7.56935,5.136358,2.721634
4,result_16_packed_4_1,pht_demo,result_16_packed_4_1,-25.318375,91.409203,-1.448796,-160.816406,ALSDEVKAMLRRMAPAAERLGTEGLLRRMQELGVVPEVTPDLLKAF...,../3_lmpnn/output/packed/result_16_packed_4_1.pdb,../3_lmpnn/output/packed/result_16_packed_4_1_...,...,../6_boltz/output/boltz_results_0/predictions/...,0.873105,0.93513,0.949214,1.517867,0.142706,0.899024,1.500364,1.442684,2.219334
5,result_16_packed_7_1,pht_demo,result_16_packed_7_1,-24.98737,88.052048,-1.478883,-164.156021,MLSETVKNMLKRLAPAAERLGTEGLLRRMIEAGVIPEVTPELLKAL...,../3_lmpnn/output/packed/result_16_packed_7_1.pdb,../3_lmpnn/output/packed/result_16_packed_7_1_...,...,../6_boltz/output/boltz_results_0/predictions/...,0.828962,0.883267,0.866744,1.793561,0.134394,1.203721,1.934954,1.422562,1.649351
6,result_7_packed_4_1,pht_demo,result_7_packed_4_1,-23.823254,103.83667,-1.429071,-180.062973,SLAEILAEIRASDPATADWLALARRFAEAAGVDEVTPEERELAAKA...,../3_lmpnn/output/packed/result_7_packed_4_1.pdb,../3_lmpnn/output/packed/result_7_packed_4_1_b...,...,../6_boltz/output/boltz_results_0/predictions/...,0.607069,0.774949,0.785562,1.833631,0.143905,8.709874,4.62572,2.397142,2.042227
7,result_29_packed_7_1,pht_demo,result_29_packed_7_1,-22.68211,75.388725,-1.675448,-177.597519,SAAFRAILRAMCEAFAELAPGLTLSDEELELVLNPDDEELRKRLNV...,../3_lmpnn/output/packed/result_29_packed_7_1.pdb,../3_lmpnn/output/packed/result_29_packed_7_1_...,...,../6_boltz/output/boltz_results_0/predictions/...,0.709298,0.809468,0.882201,1.662283,0.137814,5.64005,3.841254,7.460195,2.651895
8,result_59_packed_1_1,pht_demo,result_59_packed_1_1,-22.405994,70.00029,-1.624849,-180.358276,LATEAFLRTFIQSAEALELMRARGTAAAAEIAALVLAALKAKGVSS...,../3_lmpnn/output/packed/result_59_packed_1_1.pdb,../3_lmpnn/output/packed/result_59_packed_1_1_...,...,../6_boltz/output/boltz_results_0/predictions/...,0.390886,0.405878,0.538988,1.754866,0.172469,10.367922,5.525122,6.260994,3.597751
9,result_59_packed_6_1,pht_demo,result_59_packed_6_1,-21.15624,66.255737,-1.786232,-198.271744,SATEAFLRLVIASPEALELMRTRGTAAADEIAALMLAALEAKGISA...,../3_lmpnn/output/packed/result_59_packed_6_1.pdb,../3_lmpnn/output/packed/result_59_packed_6_1_...,...,../6_boltz/output/boltz_results_0/predictions/...,0.592455,0.621564,0.68668,1.901321,0.184374,19.843605,19.514959,4.15033,2.965234
