In [None]:
from pathlib import Path
import pickle
import numpy as np
import pandas as pd
import scipy

from rdkit.Chem.rdchem import Mol
from rdkit.Chem import RemoveAllHs

from typing import Optional

from torch_geometric.data import HeteroData

from utils.molecules_utils import get_symmetry_rmsd

In [None]:
def filter_positions_to_only_none_Hs(
        positions: np.ndarray,
        feature_vec_x: np.ndarray,
    ) -> np.array:
    not_Hs = np.not_equal(feature_vec_x[:, 0], 0)
    if len(positions.shape) == 2:
        assert positions.shape[0] == len(not_Hs)
        return positions[not_Hs, :]
    elif len(positions.shape) == 3:
        assert positions.shape[1] == len(not_Hs)
        return positions[:, not_Hs, :]
    else:
        raise ValueError


def calculate_rmsds(
        orig_positions: np.ndarray,
        predicted_ligand_positions: np.ndarray,
        mol: Mol,
):

    if orig_positions.shape[0] != 1:
        raise RuntimeError("Have never seen this before, investigate that it works")

    rmsds = []
    for i in range(orig_positions.shape[0]):
        try:
            rmsd = get_symmetry_rmsd(mol, orig_positions[i], [l for l in predicted_ligand_positions])
        except Exception as e:
            print("Using non corrected RMSD because of the error:", e)
            rmsd = np.sqrt(((predicted_ligand_positions - orig_positions[i]) ** 2).sum(axis=2).mean(axis=1))
        rmsds.append(rmsd)
    rmsds = np.asarray(rmsds)
    rmsd = np.min(rmsds, axis=0)
    return rmsd


def calculate_rmsds_on_pred_output(
        orig_complex_graph: HeteroData,
        predicted_ligand_positions: np.ndarray,
        remove_all_Hs: bool = True,
):
    orig_positions = np.array(orig_complex_graph["ligand"]["orig_pos"])
    orig_center = orig_complex_graph["original_center"].numpy()
    feature_vec_x = orig_complex_graph["ligand"]["x"].numpy()
    predicted_ligand_positions
    mol = orig_complex_graph["mol"][0]

    # Translate original molecule into new frame
    orig_positions = orig_positions - orig_center

    if remove_all_Hs:
        # Remove all Hs from molecules when calculating rmsd
        orig_positions = filter_positions_to_only_none_Hs(
            orig_positions,
            feature_vec_x,
        )
        predicted_ligand_positions = filter_positions_to_only_none_Hs(
            predicted_ligand_positions,
            feature_vec_x,
        )
        mol = RemoveAllHs(mol)

    return calculate_rmsds(
        orig_positions,
        predicted_ligand_positions,
        mol
    )

In [None]:
def get_rmsd_data_from_path(
    complex_pkl_path, 
):
    with open(complex_pkl_path, "rb") as f:
        complex_data = pickle.load(f)

    rmsd = calculate_rmsds_on_pred_output(
        complex_data["orig_complex_graph"],
        complex_data["predicted_ligand_pos"],
        remove_all_Hs=True
    )
    return pd.DataFrame(
        {
            "sim_idx": np.arange(len(rmsd)),
            "rmsd": rmsd,
        }
    )

def get_rmsd_df_for_complexes(
    complexes_dir,
    all_complexes,
):
    res = []
    for complex_name in all_complexes:
        complex_pkl_path = complexes_dir.joinpath(f"{complex_name}.pkl")
        complex_res_df = get_rmsd_data_from_path(complex_pkl_path)
        complex_res_df["complex_name"] = complex_name
        res.append(complex_res_df)

    res = pd.concat(res, ignore_index=True).set_index(["complex_name", "sim_idx"])
    return res

def get_confidence_pred_from_path(
    complex_pkl_path,
):
    with open(complex_pkl_path, "rb") as f:
        complex_data = pickle.load(f)
    first_logit_conf = complex_data["confidence"][:, 0]
    first_logit_pred = first_logit_conf > 0
    first_logit_rank = np.argsort(-first_logit_conf)

    softmax_conf = scipy.special.softmax(complex_data["confidence"], axis=1)[:, 0]
    softmax_pred = softmax_conf > 0.5
    softmax_rank = np.argsort(-softmax_conf)

    return pd.DataFrame(
        {
            "sim_idx": np.arange(len(first_logit_conf)),
            "first_logit_conf": first_logit_conf,
            "first_logit_pred": first_logit_pred,
            "first_logit_rank": first_logit_rank,
            "softmax_conf": softmax_conf,
            "softmax_pred": softmax_pred,
            "softmax_rank": softmax_rank,
        }
    )

def get_confidence_df_for_complexes(
    complexes_dir,
    all_complexes,
):
    res = []
    for complex_name in all_complexes:
        complex_pkl_path = complexes_dir.joinpath(f"{complex_name}.pkl")
        complex_res_df = get_confidence_pred_from_path(complex_pkl_path)
        complex_res_df["complex_name"] = complex_name
        res.append(complex_res_df)

    res = pd.concat(res, ignore_index=True).set_index(["complex_name", "sim_idx"])
    return res

In [None]:
inference_dir = Path("workdir/out/diffdock_inference/2024-03-22T175642")
complexes_dir = inference_dir.joinpath("complexes_out")

all_complexes = ["6qqw", "6d08"]

rmsd_df = get_rmsd_df_for_complexes(
    complexes_dir=complexes_dir,
    all_complexes=all_complexes
)
rmsd_df["rmsd<2"] = rmsd_df["rmsd"] < 2
confidence_df = get_confidence_df_for_complexes(
    complexes_dir=complexes_dir,
    all_complexes=all_complexes  
)

In [None]:
(confidence_df["first_logit_pred"] == rmsd_df["rmsd<2"]).mean()

In [None]:
(confidence_df["softmax_pred"] == rmsd_df["rmsd<2"]).mean()

In [None]:
rmsd_df.join(confidence_df)