In [1]:
from collections import Counter
from pathlib import Path
from typing import  List

import numpy as np
import torch
from rdkit import Chem
from tqdm.auto import tqdm
from utils import misc, reconstruct, transforms
from utils.evaluation import (analyze, eval_atom_type, eval_bond_length,
                              scoring_func)
from utils.evaluation.docking_vina import PrepLig, VinaDock


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
ligands = torch.load("results/ligands/6ed6_15_clean_ligands_xlarge_params_from_paper_context_-16_clip_3e-3.pt", weights_only=False)

In [3]:
ligands

{'pos': [array([[30.52312851, 43.56502151, 55.15967941],
         [24.8726368 , 46.59749222, 51.24454117],
         [28.23488426, 43.60247803, 51.50905609],
         [26.26976204, 49.39509964, 56.22756958],
         [24.35996819, 44.25444794, 51.43669128],
         [28.51873398, 46.76353836, 55.73935318],
         [26.85598564, 50.75447845, 55.98934174],
         [27.04326439, 44.80458069, 53.86606216],
         [26.23090363, 44.01645279, 52.89334106],
         [26.54616165, 48.33630753, 55.24793243],
         [24.79938126, 44.14829254, 52.75789642],
         [23.04829788, 44.31542206, 51.07193375],
         [26.79449272, 46.07875061, 54.21271515],
         [25.36587143, 48.25063705, 54.27602386],
         [25.16547394, 45.34003448, 50.73765564],
         [26.62654877, 45.03405762, 50.69040298],
         [28.11418724, 44.03699493, 54.38352966],
         [24.97289085, 49.43567657, 56.74623871],
         [28.00700569, 42.71325302, 54.23492813],
         [26.82138634, 43.73265076, 51.5523

In [5]:
def evaluate_single_ligand(
    ligand_pdbqt: str,
    protein_pdbqt: str,
    center: List[float],
    box_size: List[float],
    mode: str = "dock",
    exhaustiveness: int = 8,
):
    """Run Vina docking/scoring for a single ligand.

    Returns a dict with affinity, pose, chem_scores and RDKit SMILES.
    """
    # Docking with VinaDock wrapper
    dock = VinaDock(ligand_pdbqt, protein_pdbqt)
    dock.pocket_center = center
    dock.box_size = box_size

    score, pose = dock.dock(
        score_func="vina",
        mode=mode,
        exhaustiveness=exhaustiveness,
        save_pose=True,
    )

    return {
        "affinity": score,
        "pose": pose,
    }

In [6]:
def print_dict(d, logger):
    for k, v in d.items():
        if v is not None:
            logger.info(f'{k}:\t{v:.4f}')
        else:
            logger.info(f'{k}:\tNone')


def print_ring_ratio(all_ring_sizes, logger):
    for ring_size in range(3, 10):
        n_mol = 0
        for counter in all_ring_sizes:
            if ring_size in counter:
                n_mol += 1
        logger.info(f'ring size: {ring_size} ratio: {n_mol / len(all_ring_sizes):.3f}')

In [39]:
# ligands = torch.load("results/ligands/1pxx_10_ligands_large.pt", weights_only=False)
# smiles_list = []
# for (pos, v) in zip(ligands["pos"], ligands["v"]):
#     try:
#         smiles_list.append(compose_smiles({"pos": pos, "v": v}))
#     except Exception as e:
#         print(e)
#         continue
# # for smiles, pos in zip(smiles_list, ligands["pos"]):
#     # print("." in smiles, len(pos), smiles)
# smiles_filtered = [smiles for smiles in smiles_list if "." not in smiles]
# pd.DataFrame({"smiles": smiles_filtered}).to_csv("results/smiles/1pxx_10_smiles_large.csv")

In [12]:
logger = misc.get_logger('evaluate', log_dir="results/evaluation/1pxx")

docking_mode = "dock"
exhaustiveness = 128
box_size = [10, 10, 10]
center = [27.116, 24.090, 14.936]
# protein_root = '/mnt/5tb/tsypin/EyeDrops/BADGER-SBDD/data/crossdocked'
protein_pdbqt = Path("/mnt/5tb/tsypin/EyeDrops/BADGER-SBDD/pockets/1pxx_clean.pdbqt").expanduser().resolve().as_posix()
ligand_path = "/mnt/5tb/tsypin/EyeDrops/BADGER-SBDD/results/ligands_sdf/1pxx_10_large"
ligands = torch.load("results/ligands/1pxx_10_ligands_large.pt", weights_only=False)
verbose = True

num_samples = len(ligands["pos"])
all_mol_stable, all_atom_stable, all_n_atom = 0, 0, 0
n_recon_success, n_eval_success, n_complete = 0, 0, 0
results = []
all_pair_dist, all_bond_dist = [], []
all_atom_types = Counter()
success_pair_dist, success_atom_types = [], Counter()

In [13]:
for i in tqdm(range(len(ligands["pos"])), desc="Evaluating ligands"):
    pred_pos, pred_v = ligands["pos"][i], ligands["v"][i]
    pred_atom_type = transforms.get_atomic_number_from_index(pred_v, mode="add_aromatic")
    all_atom_types += Counter(pred_atom_type)
    r_stable = analyze.check_stability(pred_pos, pred_atom_type)
    all_mol_stable += r_stable[0]
    all_atom_stable += r_stable[1]
    all_n_atom += r_stable[2]

    pair_dist = eval_bond_length.pair_distance_from_pos_v(pred_pos, pred_atom_type)
    all_pair_dist += pair_dist

    # reconstruction
    try:
        pred_aromatic = transforms.is_aromatic_from_index(pred_v, mode="add_aromatic")
        mol = reconstruct.reconstruct_from_generated(pred_pos, pred_atom_type, pred_aromatic)
        smiles = Chem.MolToSmiles(mol)
    except reconstruct.MolReconsError:
        logger.warning(f'Reconstruct failed {i}')
        continue
    n_recon_success += 1

    if '.' in smiles:
        continue
    n_complete += 1

    # convert ligand to pdbqt
    ligand_rdmol = Chem.AddHs(mol, addCoords=True)
    cur_ligand_path = ligand_path + f"/ligand_{i}.sdf"
    sdf_writer = Chem.SDWriter(cur_ligand_path)
    sdf_writer.write(ligand_rdmol)
    sdf_writer.close()

    ligand_pdbqt = cur_ligand_path[:-4] + ".pdbqt"
    lig = PrepLig(cur_ligand_path, "sdf")
    lig.get_pdbqt(ligand_pdbqt)

    try:
        chem_results = scoring_func.get_chem(mol)
        vina_res_score_only = evaluate_single_ligand(
            ligand_pdbqt=ligand_pdbqt,
            protein_pdbqt=protein_pdbqt,
            center=center,
            box_size=box_size,
            mode="score_only",
            exhaustiveness=exhaustiveness,
        )
        vina_res_minimize = evaluate_single_ligand(
            ligand_pdbqt=ligand_pdbqt,
            protein_pdbqt=protein_pdbqt,
            center=center,
            box_size=box_size,
            mode="minimize",
            exhaustiveness=exhaustiveness,
        )
        vina_res_dock = evaluate_single_ligand(
            ligand_pdbqt=ligand_pdbqt,
            protein_pdbqt=protein_pdbqt,
            center=center,
            box_size=box_size,
            mode="dock",
            exhaustiveness=exhaustiveness,
        )
        vina_results = {
            'score_only': vina_res_score_only,
            'minimize': vina_res_minimize,
            'dock': vina_res_dock,
        }
        n_eval_success += 1
    except Exception as e:
        logger.warning(f'Chemistry check failed {i}')
        continue


#     if docking_mode == 'qvina':
#         vina_task = QVinaDockingTask.from_generated_mol(
#             mol, r['data'].ligand_filename, protein_root=protein_root)
#         vina_results = vina_task.run_sync()

    # now we only consider complete molecules as success
    bond_dist = eval_bond_length.bond_distance_from_mol(mol)
    all_bond_dist += bond_dist
    success_pair_dist += pair_dist
    success_atom_types += Counter(pred_atom_type)

    results.append({
        'mol': mol,
        'smiles': smiles,
        'ligand_filename': cur_ligand_path,
        'pred_pos': pred_pos,
        'chem_results': chem_results,
        'pred_v': pred_v,
        'vina': vina_results
    })
logger.info(f'Evaluate done! {num_samples} samples in total.')


fraction_mol_stable = all_mol_stable / num_samples
fraction_atm_stable = all_atom_stable / all_n_atom
fraction_recon = n_recon_success / num_samples
fraction_eval = n_eval_success / num_samples
fraction_complete = n_complete / num_samples
validity_dict = {
    'mol_stable': fraction_mol_stable,
    'atm_stable': fraction_atm_stable,
    'recon_success': fraction_recon,
    'eval_success': fraction_eval,
    'complete': fraction_complete
}
print_dict(validity_dict, logger)

c_bond_length_profile = eval_bond_length.get_bond_length_profile(all_bond_dist)
c_bond_length_dict = eval_bond_length.eval_bond_length_profile(c_bond_length_profile)
logger.info('JS bond distances of complete mols: ')
print_dict(c_bond_length_dict, logger)

success_pair_length_profile = eval_bond_length.get_pair_length_profile(success_pair_dist)
success_js_metrics = eval_bond_length.eval_pair_length_profile(success_pair_length_profile)
print_dict(success_js_metrics, logger)

atom_type_js = eval_atom_type.eval_atom_type_distribution(success_atom_types)
logger.info('Atom type JS: %.4f' % atom_type_js)

# if args.save:
#     eval_bond_length.plot_distance_hist(success_pair_length_profile,
#                                         metrics=success_js_metrics,
#                                         save_path=os.path.join(result_path, f'pair_dist_hist_{args.eval_step}.png'))

logger.info('Number of reconstructed mols: %d, complete mols: %d, evaluated mols: %d' % (
    n_recon_success, n_complete, len(results)))

qed = [r['chem_results']['qed'] for r in results]
sa = [r['chem_results']['sa'] for r in results]
logger.info('QED:   Mean: %.3f Median: %.3f' % (np.mean(qed), np.median(qed)))
logger.info('SA:    Mean: %.3f Median: %.3f' % (np.mean(sa), np.median(sa)))
# if args.docking_mode == 'qvina':
#     vina = [r['vina'][0]['affinity'] for r in results]
#     logger.info('Vina:  Mean: %.3f Median: %.3f' % (np.mean(vina), np.median(vina)))
vina_score_only = [r['vina']['score_only']['affinity'] for r in results]
vina_min = [r['vina']['minimize']['affinity'] for r in results]
logger.info('Vina Score:  Mean: %.3f Median: %.3f' % (np.mean(vina_score_only), np.median(vina_score_only)))
logger.info('Vina Min  :  Mean: %.3f Median: %.3f' % (np.mean(vina_min), np.median(vina_min)))
vina_dock = [r['vina']['dock']['affinity'] for r in results]
logger.info('Vina Dock :  Mean: %.3f Median: %.3f' % (np.mean(vina_dock), np.median(vina_dock)))

# check ring distribution
print_ring_ratio([r['chem_results']['ring_size'] for r in results], logger)


  Failed to kekulize aromatic bonds in OBMol::PerceiveBondOrders

Evaluating ligands:  33%|███▎      | 33/100 [03:22<06:52,  6.15s/it][14:40:18] Explicit valence for atom # 7 F, 2, is greater than permitted
[14:41:09] Explicit valence for atom # 8 Cl, 2, is greater than permitted
[14:41:09] Explicit valence for atom # 1 F, 2, is greater than permitted
Evaluating ligands: 100%|██████████| 100/100 [09:48<00:00,  5.88s/it]
[2025-07-15 14:46:43,894::evaluate::INFO] Evaluate done! 100 samples in total.
[2025-07-15 14:46:43,894::evaluate::INFO] Evaluate done! 100 samples in total.
[2025-07-15 14:46:43,894::evaluate::INFO] Evaluate done! 100 samples in total.
[2025-07-15 14:46:43,896::evaluate::INFO] mol_stable:	0.1900
[2025-07-15 14:46:43,896::evaluate::INFO] mol_stable:	0.1900
[2025-07-15 14:46:43,896::evaluate::INFO] mol_stable:	0.1900
[2025-07-15 14:46:43,897::evaluate::INFO] atm_stable:	0.3333
[2025-07-15 14:46:43,897::evaluate::INFO] atm_stable:	0.3333
[2025-07-15 14:46:43,897::evaluate

In [33]:
results

[{'ligand': 'ligand_3.pdbqt',
  'smiles': '[C]C1([C])[C][C][C]N[C]1',
  'affinity': np.float64(-5.033),
  'vina_mode': 'dock',
  'chem': {'qed': 0.4501691133696031,
   'sa': 0.55,
   'logp': -0.007660000000000139,
   'lipinski': np.int64(5),
   'ring_size': Counter({6: 1})},
  'pose': 'MODEL 1\nREMARK VINA RESULT:    -5.033      0.000      0.000\nREMARK INTER + INTRA:          -5.033\nREMARK INTER:                  -5.033\nREMARK INTRA:                   0.000\nREMARK UNBOUND:                 0.000\nREMARK Flexibility Score: inf\nREMARK Active torsions [ 0 ] -> [ 0 ]\nROOT\nATOM      1  C1  LIG L   1      23.943  23.582  15.257  1.00  0.00     0.012 C \nATOM      2  C2  LIG L   1      27.496  23.814  15.641  1.00  0.00     0.016 C \nATOM      3  N1  LIG L   1      26.359  22.933  13.790  1.00  0.00    -0.315 NA\nATOM      4  C3  LIG L   1      25.811  22.158  14.844  1.00  0.00     0.087 C \nATOM      5  C4  LIG L   1      26.158  24.533  15.882  1.00  0.00     0.009 C \nATOM      6  C

In [20]:
n_recon_success, n_complete

(97, 50)