In [None]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
import os
import numpy as np
import pandas as pd
from rdkit import Chem
from utils.docking import cal_docking
from utils.metric import *
from rdkit.Chem.rdForceFieldHelpers import UFFOptimizeMolecule
from rdkit.Chem.Descriptors import MolLogP, qed
from utils.sascorer import compute_sa_score
from utils.misc import *
import argparse
from multiprocessing import Pool, set_start_method
import matplotlib.pyplot as plt
from tabulate import tabulate

In [None]:
# Load configs
config = load_config('./configs/sample.yml')
seed_all(config.sample.seed)

In [29]:
def cal_metrics(mol_list, pdb_path, save_path):
    qed_list, sa_list, Lip_list, qvina_list = [], [], [], []
    rdmol_list = []
    for mol in mol_list:
        try:
            Chem.MolFromSmiles(Chem.MolToSmiles(mol))
            rdmol_list.append(mol)
        except:
            pass

    for mol in rdmol_list:
        qed_list.append(qed(mol))
        sa_list.append(compute_sa_score(mol))
        Lip_list.append(lipinski(mol))

        try:
            ligand_rdmol = Chem.AddHs(mol, addCoords=True)
            UFFOptimizeMolecule(ligand_rdmol)
            qvina_list.append(cal_docking(ligand_rdmol, pdb_path))
        except:
            qvina_list.append('nan')
    
    df = pd.DataFrame({'SMILES': [Chem.MolToSmiles(i) for i in rdmol_list], 'Qvina': qvina_list, 'QED': qed_list,
                       'SA': sa_list, 'Lip': Lip_list})
    df.to_csv(save_path + '.csv', index=False, mode='a')

def gen_results_file(path, save_result_path):  
    if not os.path.exists(save_result_path):
        os.makedirs(save_result_path)
    for result_dir in os.listdir(path):

        mol_list = []
        for file_name in os.listdir(path + result_dir):
            if file_name == 'pocket_info.txt':
                pdb_file_path = path + result_dir + '/' + file_name
                with open(pdb_file_path, 'r') as file:
                    pdb_name = file.readline()
                pdb_path = os.path.join(config.dataset.path, pdb_name)
            if file_name[-3:] == 'sdf':
                sdf_dir = path + result_dir + '/' + file_name
                mol_list.append(Chem.MolFromMolFile(sdf_dir))
    
        cal_metrics(mol_list, pdb_path, save_result_path + result_dir)

def gen_results_file_resgen(path, save_result_path):  
    for result_dir in os.listdir(path):
        mol_list = []
        for file_name in os.listdir(path + result_dir):
            if file_name[-3:] == 'pdb':
                pdb_path = path + result_dir + '/' + file_name
            if file_name == 'SDF':
                if len(os.listdir(path + result_dir + '/SDF/')) > 1:
                    for sdf_name in os.listdir(path + result_dir + '/SDF/'):
                        sdf_path = path + result_dir + '/SDF/' + sdf_name
                        mol_list.append(Chem.MolFromMolFile(sdf_path))

        cal_metrics(mol_list, pdb_path, save_result_path + result_dir)



In [30]:
def cal_high_affinity(results_path, top_num, test_path='./baselines_results/testset/'):
    high_affinity_result = []
    mpbg_list = []
    for test_name in os.listdir(test_path):
        for results_name in os.listdir(results_path):
            if test_name.split('-')[0] == results_name.split('-')[0]:
                
                test_value = pd.read_csv(test_path + test_name)['Qvina'].mean()
                results_value = pd.read_csv(results_path + results_name).head(100)
                results_value.sort_values(by='Qvina', inplace=True)

                # cal MPBG
                min_test_value = pd.read_csv(test_path + test_name)['Qvina'].mean()
                sum_mpbg = 0
                lenth = 0
                for qvina in results_value.head(top_num)['Qvina'].values:
                    if not np.isnan(qvina):
                        sum_mpbg += ((min_test_value - qvina) / min_test_value)
                        lenth += 1

                mpbg_list.append(sum_mpbg / lenth)
           
                results_value = results_value.head(top_num)
                high_affinity = (results_value['Qvina'] < test_value).sum()
                high_affinity_result.append((high_affinity / top_num))
   
    return np.mean(high_affinity_result), np.std(high_affinity_result), np.mean(mpbg_list), np.std(mpbg_list)

def read_and_process_file(file_path):
    """
    Reads a CSV file, processes it, and returns a DataFrame.
    """
    df = pd.read_csv(file_path, nrows=100)
    df.iloc[:, ~df.columns.isin(['SMILES'])] = df.iloc[:, ~df.columns.isin(['SMILES'])].apply(pd.to_numeric, errors='coerce')
    df.sort_values(by='Qvina', inplace=True)
    return df

def get_top_n_dfs(df, top_n_list):
    """
    Returns a dictionary of DataFrames for each specified top N.
    """
    return {n: df.head(n) for n in top_n_list}

def calculate_statistics(path, top_dfs, top_nums):
    """
    Calculates statistics for given top N DataFrames, adds a title, and prints them in a table format with top N as rows.
    """
    all_stats = []

    for top_num in top_nums:
        top_df = top_dfs[top_num]
        top_mean = top_df.mean().round(3)
        top_std = top_df.std().round(2)
     
        high_affinity_mean, high_affinity_std, mpbg_mean, mpbg_std = cal_high_affinity(path, top_num)

        stats = {
            'Top': f'Top {top_num}',
            'Qvina': f"{top_mean['Qvina']} ± {top_std['Qvina']}",
            'HA': f"{high_affinity_mean.round(3)} ± {high_affinity_std.round(2)}",
            'MPBG': f"{mpbg_mean.round(3)} ± N/A",
            'QED': f"{top_mean['QED']} ± {top_std['QED']}",
            'SA': f"{top_mean['SA']} ± {top_std['SA']}",
            'Lipinski': f"{top_mean['Lip']} ± {top_std['Lip']}",  
        }

        all_stats.append(stats)

    stats_df = pd.DataFrame(all_stats).set_index('Top')
    print(tabulate(stats_df, headers='keys', tablefmt='pretty', showindex=True))
    
def cal_result(path):
    top_n_list = [1, 3, 5, 10, 100]
    top_dfs = {n: [] for n in top_n_list}
    
    result_files = [file_name for file_name in os.listdir(path) if file_name.endswith('.csv')]
    
    for file_name in result_files:
        df = read_and_process_file(os.path.join(path, file_name))
        for n, top_df in get_top_n_dfs(df, top_n_list).items():
            top_dfs[n].append(top_df)
            
    # Concatenate all top N DataFrames
    top_dfs = {n: pd.concat(dfs) for n, dfs in top_dfs.items()}
    calculate_statistics(path, top_dfs, top_n_list)



## Ours

In [31]:
results_path = 'results/new/RL_np_reward_recovvery_ligand&protein/ckpt_200_results/'
cal_result(results_path) 

+---------+----------------+--------------+--------------+--------------+--------------+--------------+
|   Top   |     Qvina      |      HA      |     MPBG     |     QED      |      SA      |   Lipinski   |
+---------+----------------+--------------+--------------+--------------+--------------+--------------+
|  Top 1  | -10.201 ± 1.62 | 0.98 ± 0.14  | -1.519 ± N/A | 0.524 ± 0.15 | 0.514 ± 0.1  | 4.82 ± 0.39  |
|  Top 3  | -9.765 ± 1.51  | 0.973 ± 0.15 | -1.425 ± N/A | 0.538 ± 0.15 | 0.514 ± 0.11 | 4.893 ± 0.31 |
|  Top 5  | -9.508 ± 1.47  | 0.958 ± 0.17 | -1.366 ± N/A | 0.553 ± 0.15 | 0.512 ± 0.1  | 4.91 ± 0.29  |
| Top 10  | -9.102 ± 1.42  | 0.928 ± 0.22 | -1.273 ± N/A | 0.56 ± 0.14  | 0.515 ± 0.11 | 4.932 ± 0.25 |
| Top 100 |  -6.364 ± 2.3  | 0.492 ± 0.3  | -0.625 ± N/A | 0.543 ± 0.14 | 0.574 ± 0.14 | 4.976 ± 0.16 |
+---------+----------------+--------------+--------------+--------------+--------------+--------------+


## FLAG

In [12]:
results_path = 'baselines_results/FLAG/'
cal_result(results_path) 

+---------+---------------+--------------+--------------+--------------+--------------+--------------+
|   Top   |     Qvina     |      HA      |     MPBG     |     QED      |      SA      |   Lipinski   |
+---------+---------------+--------------+--------------+--------------+--------------+--------------+
|  Top 1  | -9.837 ± 1.6  |  0.96 ± 0.2  | -1.445 ± N/A | 0.436 ± 0.17 | 0.397 ± 0.13 | 4.28 ± 0.94  |
|  Top 3  | -9.414 ± 1.48 | 0.93 ± 0.23  | -1.345 ± N/A | 0.458 ± 0.17 | 0.406 ± 0.12 | 4.377 ± 0.83 |
|  Top 5  | -9.159 ± 1.42 | 0.908 ± 0.26 | -1.288 ± N/A | 0.468 ± 0.16 | 0.406 ± 0.12 | 4.432 ± 0.81 |
| Top 10  | -8.756 ± 1.37 | 0.873 ± 0.29 | -1.186 ± N/A | 0.476 ± 0.16 | 0.415 ± 0.12 | 4.493 ± 0.76 |
| Top 100 | -5.865 ± 3.59 | 0.46 ± 0.31  | -0.504 ± N/A | 0.495 ± 0.14 | 0.485 ± 0.14 | 4.794 ± 0.54 |
+---------+---------------+--------------+--------------+--------------+--------------+--------------+


## ResGen

In [11]:
results_path = 'baselines_results/ResGen/'
cal_result(results_path) 

+---------+---------------+--------------+--------------+--------------+--------------+--------------+
|   Top   |     Qvina     |      HA      |     MPBG     |     QED      |      SA      |   Lipinski   |
+---------+---------------+--------------+--------------+--------------+--------------+--------------+
|  Top 1  | -8.693 ± 2.3  | 0.91 ± 0.29  | -0.956 ± N/A | 0.567 ± 0.15 | 0.796 ± 0.12 | 4.94 ± 0.28  |
|  Top 3  | -8.429 ± 2.25 | 0.883 ± 0.31 | -0.886 ± N/A | 0.566 ± 0.16 | 0.804 ± 0.12 | 4.913 ± 0.33 |
|  Top 5  | -8.292 ± 2.21 | 0.856 ± 0.32 | -0.85 ± N/A  | 0.571 ± 0.16 | 0.802 ± 0.11 | 4.912 ± 0.34 |
| Top 10  | -8.083 ± 2.16 | 0.824 ± 0.35 |  -0.8 ± N/A  | 0.579 ± 0.16 |  0.8 ± 0.11  | 4.908 ± 0.34 |
| Top 100 | -6.591 ± 2.3  | 0.486 ± 0.36 | -0.492 ± N/A | 0.586 ± 0.16 | 0.794 ± 0.11 | 4.874 ± 0.42 |
+---------+---------------+--------------+--------------+--------------+--------------+--------------+


## DecompDiff

In [14]:
results_path = 'baselines_results/DecompDiff/'
cal_result(results_path) 

+---------+---------------+--------------+--------------+--------------+--------------+--------------+
|   Top   |     Qvina     |      HA      |     MPBG     |     QED      |      SA      |   Lipinski   |
+---------+---------------+--------------+--------------+--------------+--------------+--------------+
|  Top 1  | -8.485 ± 1.99 | 0.95 ± 0.22  | -0.994 ± N/A | 0.552 ± 0.19 | 0.656 ± 0.15 | 4.51 ± 0.82  |
|  Top 3  | -8.116 ± 2.03 | 0.92 ± 0.25  | -0.778 ± N/A | 0.544 ± 0.2  | 0.66 ± 0.15  | 4.517 ± 0.85 |
|  Top 5  | -7.946 ± 2.01 | 0.906 ± 0.27 | -0.718 ± N/A | 0.541 ± 0.2  | 0.662 ± 0.15 | 4.492 ± 0.89 |
| Top 10  | -7.715 ± 1.97 | 0.864 ± 0.3  | -0.654 ± N/A | 0.539 ± 0.21 | 0.666 ± 0.15 | 4.435 ± 0.96 |
| Top 100 | -6.495 ± 1.81 | 0.458 ± 0.34 | -0.424 ± N/A | 0.495 ± 0.22 | 0.653 ± 0.14 | 4.219 ± 1.1  |
+---------+---------------+--------------+--------------+--------------+--------------+--------------+


## TargetDiff

In [16]:
results_path = 'baselines_results/TargetDiff/'
cal_result(results_path) 

+---------+---------------+--------------+--------------+--------------+--------------+--------------+
|   Top   |     Qvina     |      HA      |     MPBG     |     QED      |      SA      |   Lipinski   |
+---------+---------------+--------------+--------------+--------------+--------------+--------------+
|  Top 1  | -9.278 ± 1.75 | 0.94 ± 0.24  | -1.221 ± N/A | 0.426 ± 0.2  | 0.476 ± 0.12 | 4.42 ± 0.82  |
|  Top 3  | -8.892 ± 1.68 | 0.923 ± 0.25 | -1.121 ± N/A | 0.457 ± 0.21 | 0.502 ± 0.12 | 4.453 ± 0.88 |
|  Top 5  | -8.686 ± 1.65 | 0.908 ± 0.27 | -1.072 ± N/A | 0.453 ± 0.21 | 0.511 ± 0.12 | 4.446 ± 0.87 |
| Top 10  | -8.37 ± 1.61  | 0.885 ± 0.29 | -0.998 ± N/A | 0.466 ± 0.2  | 0.523 ± 0.12 | 4.519 ± 0.79 |
| Top 100 | -6.461 ± 1.93 | 0.477 ± 0.34 | -0.521 ± N/A | 0.455 ± 0.19 | 0.581 ± 0.13 | 4.488 ± 0.79 |
+---------+---------------+--------------+--------------+--------------+--------------+--------------+


## Pocket2Mol

In [18]:
results_path = 'baselines_results/Pocket2Mol/'
cal_result(results_path) 

+---------+---------------+--------------+--------------+--------------+--------------+--------------+
|   Top   |     Qvina     |      HA      |     MPBG     |     QED      |      SA      |   Lipinski   |
+---------+---------------+--------------+--------------+--------------+--------------+--------------+
|  Top 1  | -8.785 ± 2.88 | 0.88 ± 0.32  | -0.95 ± N/A  | 0.517 ± 0.14 | 0.778 ± 0.13 | 4.95 ± 0.22  |
|  Top 3  | -8.525 ± 2.9  | 0.827 ± 0.35 | -0.858 ± N/A | 0.525 ± 0.15 | 0.777 ± 0.12 | 4.947 ± 0.24 |
|  Top 5  | -8.395 ± 2.88 | 0.804 ± 0.37 | -0.813 ± N/A | 0.528 ± 0.15 | 0.781 ± 0.12 | 4.946 ± 0.24 |
| Top 10  | -8.19 ± 2.83  | 0.771 ± 0.39 | -0.76 ± N/A  | 0.53 ± 0.15  | 0.78 ± 0.12  | 4.941 ± 0.28 |
| Top 100 | -6.662 ± 2.48 | 0.465 ± 0.36 | -0.452 ± N/A | 0.55 ± 0.14  | 0.786 ± 0.12 | 4.958 ± 0.25 |
+---------+---------------+--------------+--------------+--------------+--------------+--------------+
