In [1]:
import math
from matplotlib import pyplot as plt
import os

from biopandas.pdb import PandasPdb
import numpy as np

import torch

import scipy.spatial as spa
from tqdm import tqdm
import json

In [2]:
# Input: expects 3xN matrix of points
# Returns such R, t so that rmsd(R @ A + t, B) is min
# Uses Kabsch algorithm (https://en.wikipedia.org/wiki/Kabsch_algorithm)
# R = 3x3 rotation matrix
# t = 3x1 column vector
# This already takes residue identity into account.
def rigid_transform_Kabsch_3D(A, B):
    assert A.shape[1] == B.shape[1]
    num_rows, num_cols = A.shape
    if num_rows != 3:
        raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}")
    num_rows, num_cols = B.shape
    if num_rows != 3:
        raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}")


    # find mean column wise: 3 x 1
    centroid_A = np.mean(A, axis=1, keepdims=True)
    centroid_B = np.mean(B, axis=1, keepdims=True)

    # subtract mean
    Am = A - centroid_A
    Bm = B - centroid_B

    H = Am @ Bm.T

    # find rotation
    U, S, Vt = np.linalg.svd(H)

    R = Vt.T @ U.T

    # special reflection case
    if np.linalg.det(R) < 0:
        # print("det(R) < R, reflection detected!, correcting for it ...")
        SS = np.diag([1.,1.,-1.])
        R = (Vt.T @ SS) @ U.T
    assert math.fabs(np.linalg.det(R) - 1) < 1e-5

    t = -R @ centroid_A + centroid_B
    return R, t

In [3]:
def compute_rmsd(pred, true):
    return np.sqrt(np.mean(np.sum((pred - true) ** 2, axis=1)))

In [4]:
def get_rmsd_summary(rmsds):
    rmsds_np = np.array(rmsds)
    return {
        'mean': np.mean(rmsds_np),
        'median': np.median(rmsds_np),
        'std': np.std(rmsds_np),
        'lt1': 100 * (rmsds_np < 1.0).sum() / len(rmsds_np),
        'lt2': 100 * (rmsds_np < 2.0).sum() / len(rmsds_np),
        'lt5': 100 * (rmsds_np < 5.0).sum() / len(rmsds_np),
        'lt10': 100 * (rmsds_np < 10.0).sum() / len(rmsds_np)
    }

In [5]:
class RMSDComputer():
    def __init__(self):
        self.complex_rmsd_list = []
        self.ligand_rmsd_list = []
        self.receptor_rmsd_list = []
    
    def update_all_rmsd(self, ligand_coors_pred, receptor_coors_pred, ligand_coors_true, receptor_coors_true):
        complex_rmsd = self.update_complex_rmsd(ligand_coors_pred, receptor_coors_pred, ligand_coors_true, receptor_coors_true)
        ligand_rmsd = self.update_ligand_rmsd(ligand_coors_pred, receptor_coors_pred, ligand_coors_true, receptor_coors_true)
        return complex_rmsd, ligand_rmsd

    def update_complex_rmsd(self, ligand_coors_pred, receptor_coors_pred, ligand_coors_true, receptor_coors_true):
        complex_coors_pred = np.concatenate((ligand_coors_pred, receptor_coors_pred), axis=0)
        complex_coors_true = np.concatenate((ligand_coors_true, receptor_coors_true), axis=0)

        R,t = rigid_transform_Kabsch_3D(complex_coors_pred.T, complex_coors_true.T)
        complex_coors_pred_aligned = (R @ complex_coors_pred.T + t).T

        complex_rmsd = compute_rmsd(complex_coors_pred_aligned, complex_coors_true)
        self.complex_rmsd_list.append(complex_rmsd)

        return complex_rmsd

    def update_ligand_rmsd(self, ligand_coors_pred, receptor_coors_pred, ligand_coors_true, receptor_coors_true):
        if np.allclose(receptor_coors_pred, receptor_coors_true, rtol=1e-6):
            ligand_coors_pred_aligned = ligand_coors_pred
        else:
            R, t = rigid_transform_Kabsch_3D(receptor_coors_pred.T, receptor_coors_true.T)
            ligand_coors_pred_aligned = (R @ ligand_coors_pred.T + t).T
        
        ligand_rmsd = compute_rmsd(ligand_coors_pred_aligned, ligand_coors_true)
        self.ligand_rmsd_list.append(ligand_rmsd)
        
        return ligand_rmsd
    
    def summarize(self):
        ligand_rmsd_summarized = get_rmsd_summary(self.ligand_rmsd_list) if self.ligand_rmsd_list else None
        complex_rmsd_summarized = get_rmsd_summary(self.complex_rmsd_list)
        return ligand_rmsd_summarized, complex_rmsd_summarized

In [6]:
def get_coords(pdb_file, atoms_to_keep):
    ppdb_model = PandasPdb().read_pdb(pdb_file)
    df = ppdb_model.df['ATOM']
    df = df[df["atom_name"].apply(lambda atom_name: atom_name in atoms_to_keep)]
    return df[['x_coord', 'y_coord', 'z_coord']].to_numpy().squeeze().astype(np.float32)

In [7]:
atoms_to_keep = ("CA",)

In [8]:
os.chdir("../../../")

In [21]:
def compute_rmsds(dataset, model, bound=True):
    input_dir = f"./equidock_public/test_sets_pdb/{dataset}_{model}_results/"
    ground_truth_dir = f"./equidock_public/test_sets_pdb/{dataset}_test_random_transformed/complexes/"

    pdb_files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith('.pdb')]
    pdb_files = sorted(pdb_files)

    meter = RMSDComputer()

    Irmsd_meter = RMSDComputer()

    bound_letter = 'b' if bound else 'u'

    all_crmsd = []
    all_irmsd = []
    all_lrmsd = []

    num_test_files = 0

    for file in tqdm(pdb_files):
        suffix = model.upper()
        
        if not file.endswith(f'_l_{bound_letter}_{suffix}.pdb'):
            continue
            
        ligand_model_file = f"{input_dir}/{file}"
        ligand_gt_file = f"{ground_truth_dir}/{file.replace(f'_l_{bound_letter}_{suffix}.pdb', f'_l_{bound_letter}_COMPLEX.pdb')}"
        receptor_gt_file = f"{ground_truth_dir}/{file.replace(f'_l_{bound_letter}_{suffix}.pdb', f'_r_{bound_letter}_COMPLEX.pdb')}"
        
        receptor_model_file = f"{input_dir}/{file.replace(f'_l_{bound_letter}_{suffix}.pdb', f'_r_{bound_letter}_{suffix}.pdb')}"
        if not os.path.exists(receptor_model_file):
            receptor_model_file = receptor_gt_file

        num_test_files += 1

        ligand_model_coords = get_coords(ligand_model_file, atoms_to_keep)
        receptor_model_coords = get_coords(receptor_model_file, atoms_to_keep)

        ligand_gt_coords = get_coords(ligand_gt_file, atoms_to_keep)
        receptor_gt_coords = get_coords(receptor_gt_file, atoms_to_keep)

        assert ligand_model_coords.shape[0] == ligand_gt_coords.shape[0]
        assert receptor_model_coords.shape[0] == receptor_gt_coords.shape[0]

        ligand_receptor_distance = spa.distance.cdist(ligand_gt_coords, receptor_gt_coords)
        positive_tuple = np.where(ligand_receptor_distance < 8.)

        active_ligand = positive_tuple[0]
        active_receptor = positive_tuple[1]

        ligand_model_pocket_coors = ligand_model_coords[active_ligand, :]
        receptor_model_pocket_coors = receptor_model_coords[active_receptor, :]

        ligand_gt_pocket_coors = ligand_gt_coords[active_ligand, :]
        receptor_gt_pocket_coors = receptor_gt_coords[active_receptor, :]


        crmsd, lrmsd = meter.update_all_rmsd(np.array(ligand_model_coords), np.array(receptor_model_coords),
                          np.array(ligand_gt_coords), np.array(receptor_gt_coords))

        irmsd = Irmsd_meter.update_complex_rmsd(np.array(ligand_model_pocket_coors), np.array(receptor_model_pocket_coors),
                                np.array(ligand_gt_pocket_coors), np.array(receptor_gt_pocket_coors))

        all_crmsd.append(crmsd)
        all_lrmsd.append(lrmsd)
        all_irmsd.append(irmsd)
        
    return all_crmsd, all_lrmsd, all_irmsd

In [23]:
def fill_results(results, models, datasets, bound=True):
    bound_letter = 'b' if bound else 'u'
    
    for model in models:
        if model not in results:
            results[model] = {}
            
        for dataset in datasets:
            all_crmsd, all_lrmsd, all_irmsd = compute_rmsds(dataset=dataset, model=model, bound=bound)
        
            crmsd_summary = get_rmsd_summary(all_crmsd)
            lmsd_summary = get_rmsd_summary(all_lrmsd)
            irmsd_summary = get_rmsd_summary(all_irmsd)
            
            input_dir = f"./equidock_public/test_sets_pdb/{dataset}_{model}_results/"
            pdb_files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith('.pdb') and f.endswith(f'_l_{bound_letter}_{model.upper()}.pdb')]
            pdb_files = sorted(pdb_files)
        
            results[model][dataset] = {
                "summary": {
                    "num_test_files": len(all_crmsd),
                    "crmsd": crmsd_summary, 
                    "lrmsd": lmsd_summary, 
                    "irmsd": irmsd_summary, 
                },
                "all": {
                    "crmsd": all_crmsd, 
                    "lrmsd": all_lrmsd, 
                    "irmsd": all_irmsd,
                    "pdb_files": pdb_files
                }
            }

In [24]:
results = {}

models = ["attract", "cluspro", "equidock_no_clashes", "equidock", "hdock", "patchdock"]
datasets = ["db5", "dips"]

fill_results(results, models, datasets)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:02<00:00, 18.17it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:11<00:00, 17.63it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:02<00:00,  9.40it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:11<00:00,  9.08it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:02<00:00,  9.95it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████

In [26]:
fill_results(results, ["equidock_on_dips_no_clashes", "equidock_on_dips"], ["db5"])

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:02<00:00,  8.86it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:02<00:00,  9.43it/s]


In [27]:
fill_results(results, ["equidock", "equidock_no_clashes",  "equidock_on_dips", "equidock_on_dips_no_clashes"], ["db5_unbound"], bound=False)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:05<00:00,  4.85it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:03<00:00,  6.32it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:03<00:00,  7.04it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:03<00:00,  7.68it/s]


In [28]:
results = json.loads(str(results).replace("'", '"'))

In [29]:
!ls

diffdock-protein	     equidock_public  rmwu
diffdock-protein-amine-copy  hub	      visualization


In [31]:
with open("diffdock-protein/baselines/baselines.json", "w") as file:
    json.dump(results, file)

In [32]:
with open("diffdock-protein/baselines/baselines.json", "r") as file:
    results = json.load(file)

In [41]:
import copy
summaries = copy.deepcopy(results)
for key, value in summaries.items():
    del value["db5"]["all"]
    if "dips" in value:
        del value["dips"]["all"]
    if "db5_unbound" in value:
        del value["db5_unbound"]["all"]
    
with open("baselines_summary.json", "w") as file:
    json.dump(summaries, file)

In [42]:
with open("baselines_summary.json", "r") as file:
    summaries = json.load(file)

In [46]:
def draw_plots(model, dataset):
        rmsds = results[model][dataset]["all"]
    
        fig, ax = plt.subplots(1, 3, figsize = (30, 6))
        _ = ax[0].hist(rmsds["lrmsd"], bins=100)
        _ = ax[0].set(xlabel="Ligand RMSD")

        _ = ax[1].hist(rmsds["crmsd"], bins=100)
        _ = ax[1].set(xlabel="Complex RMSD")

        _ = ax[2].hist(rmsds["irmsd"], bins=100)
        _ = ax[2].set(xlabel="Interface RMSD")
        
        fig.suptitle(f'Model: {model} Dataset: {dataset}')

In [55]:
models = sorted(list(results.keys()))
for model in models:
    datasets = sorted(list(results[model].keys()))
    for dataset in datasets:
        draw_plots(model, dataset)
        name = f"rmsd_plots/{dataset}_{model}_rmsd.png"
        plt.savefig(name)
        plt.close()

In [61]:
models = sorted(list(results.keys()))
for model in models:
    datasets = sorted(list(results[model].keys()))
    for dataset in datasets:
        print(f"Model: {model}\tDataset: {dataset}")
        summary = results[model][dataset]["summary"]
        
        lrmsd = summary["lrmsd"]
        crmsd = summary["crmsd"]
        irmsd = summary["irmsd"]
        
        print(f'Number of samples:\t\t{summary["num_test_files"]}')
        print()
        print(f"Ligand RMSD median/mean:\t{lrmsd['median']:.3}/{lrmsd['mean']:.3} ± {lrmsd['std']:.3}")
        print(f"Complex RMSD median/mean:\t{crmsd['median']:.3}/{crmsd['mean']:.3} ± {crmsd['std']:.3}")
        print(f"Interface RMSD median/mean:\t{irmsd['median']:.3}/{irmsd['mean']:.3} ± {irmsd['std']:.3}")
        print()
        print(f"Ligand lt1/lt2/lt5/lt10,:\t{lrmsd['lt1']:.3}%/{lrmsd['lt2']:.3}%/{lrmsd['lt5']:.3}%/{lrmsd['lt10']:.3}%")
        print(f"Complex lt1/lt2/lt5/lt10,:\t{crmsd['lt1']:.3}%/{crmsd['lt2']:.3}%/{crmsd['lt5']:.3}%/{crmsd['lt10']:.3}%")
        print(f"Interface lt1/lt2/lt5/lt10,:\t{irmsd['lt1']:.3}%/{irmsd['lt2']:.3}%/{irmsd['lt5']:.3}%/{irmsd['lt10']:.3}%")
        
        print('\n-------------------------------------------------------\n')

Model: attract	Dataset: db5
Number of samples:		25

Ligand RMSD median/mean:	23.1/24.9 ± 24.9
Complex RMSD median/mean:	9.55/10.1 ± 9.88
Interface RMSD median/mean:	7.48/10.7 ± 10.9

Ligand lt1/lt2/lt5/lt10,:	12.0%/40.0%/44.0%/44.0%
Complex lt1/lt2/lt5/lt10,:	44.0%/44.0%/44.0%/52.0%
Interface lt1/lt2/lt5/lt10,:	44.0%/44.0%/44.0%/52.0%

-------------------------------------------------------

Model: attract	Dataset: dips
Number of samples:		100

Ligand RMSD median/mean:	37.3/38.2 ± 29.5
Complex RMSD median/mean:	17.2/14.9 ± 10.4
Interface RMSD median/mean:	12.4/14.0 ± 11.8

Ligand lt1/lt2/lt5/lt10,:	9.0%/17.0%/20.0%/22.0%
Complex lt1/lt2/lt5/lt10,:	20.0%/20.0%/23.0%/33.0%
Interface lt1/lt2/lt5/lt10,:	20.0%/20.0%/22.0%/38.0%

-------------------------------------------------------

Model: cluspro	Dataset: db5
Number of samples:		25

Ligand RMSD median/mean:	7.69/20.5 ± 19.9
Complex RMSD median/mean:	3.38/8.26 ± 7.92
Interface RMSD median/mean:	2.32/8.72 ± 9.9

Ligand lt1/lt2/lt5/lt10,:	0