In [11]:
import rdkit
from rdkit import Chem
import pandas as pd
import numpy as np
import os
import glob
from tqdm import tqdm
import sys
sys.path.append('..')
from torch_geometric.loader import DataLoader
import torch

In [12]:
from data_processing.paired_data import CombinedSparseGraphDataset

In [13]:
raw_data_path = '../../data/cleaned_crossdocked_data/raw'
docked = os.listdir(raw_data_path)
len(docked)

2402

In [4]:
raw_files = []
for folder in docked:
    raw_files += glob.glob(os.path.join(raw_data_path, folder, '*.sdf'))
    
len(raw_files)

242360

In [5]:
mols = [Chem.MolFromMolFile(f, sanitize=False) for f in tqdm(raw_files)]
len(mols)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 242360/242360 [01:27<00:00, 2759.51it/s]


242360

In [6]:
def check_p_s_cl(mols):
    n_mols_with_p = 0
    n_mols_with_s = 0
    n_mols_with_cl = 0
    n_all = 0

    for mol in tqdm(mols):
        atomic_nums = [atom.GetAtomicNum() for atom in mol.GetAtoms()]
    #     print(atomic_nums)
        if 15 in atomic_nums:
            n_mols_with_p += 1
        if 16 in atomic_nums:
            n_mols_with_s += 1
        if 17 in atomic_nums:
            n_mols_with_cl += 1
            
        if 15 in atomic_nums or 16 in atomic_nums or 17 in atomic_nums:
            n_all += 1

    print(n_mols_with_p, n_mols_with_s, n_mols_with_cl, n_all)
    return n_mols_with_p, n_mols_with_s, n_mols_with_cl, n_all

In [7]:
check_p_s_cl(mols)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 242360/242360 [00:44<00:00, 5500.41it/s]

23945 47718 29103 91547





(23945, 47718, 29103, 91547)

In [14]:
gen_res_path = '../lightning_logs/vp_bridge_2024-05-13_00_09_54.680600/vp'
gen_filenames = os.listdir(gen_res_path)
gen_files = [os.path.join(gen_res_path, gen_filename) for gen_filename in gen_filenames]
len(gen_files)

1466

In [15]:
gen_mols = [Chem.MolFromMolFile(f, sanitize=False) for f in tqdm(gen_files)]
len(gen_mols)

  0%|                                                                                                                                                                  | 0/1466 [00:00<?, ?it/s][15:13:00] atom 17 has specified valence (2) smaller than the drawn valence 3.
[15:13:00] atom 18 has specified valence (3) smaller than the drawn valence 4.
[15:13:00] atom 32 has specified valence (1) smaller than the drawn valence 2.
[15:13:00] atom 21 has specified valence (1) smaller than the drawn valence 2.
[15:13:00] atom 10 has specified valence (3) smaller than the drawn valence 4.
[15:13:00] atom 3 has specified valence (3) smaller than the drawn valence 4.
[15:13:00] atom 5 has specified valence (1) smaller than the drawn valence 2.
[15:13:00] atom 4 has specified valence (1) smaller than the drawn valence 2.
[15:13:00] atom 14 has specified valence (1) smaller than the drawn valence 2.
[15:13:00] atom 16 has specified valence (3) smaller than the drawn valence 4.
[15:13:00] atom 24 ha

 34%|███████████████████████████████████████████████████▎                                                                                                  | 502/1466 [00:00<00:00, 5018.96it/s][15:13:00] atom 3 has specified valence (5) smaller than the drawn valence 6.
[15:13:00] atom 7 has specified valence (1) smaller than the drawn valence 2.
[15:13:00] atom 18 has specified valence (1) smaller than the drawn valence 2.
[15:13:00] atom 22 has specified valence (1) smaller than the drawn valence 2.
[15:13:00] atom 10 has specified valence (3) smaller than the drawn valence 4.
[15:13:00] atom 15 has specified valence (1) smaller than the drawn valence 2.
[15:13:00] atom 0 has specified valence (3) smaller than the drawn valence 4.
[15:13:00] atom 21 has specified valence (1) smaller than the drawn valence 2.
[15:13:00] atom 20 has specified valence (1) smaller than the drawn valence 2.
[15:13:00] atom 8 has specified valence (4) smaller than the drawn valence 5.
[15:13:00] atom 10 has

 70%|████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                            | 1032/1466 [00:00<00:00, 5182.97it/s][15:13:00] atom 4 has specified valence (4) smaller than the drawn valence 5.
[15:13:00] atom 14 has specified valence (2) smaller than the drawn valence 3.
[15:13:00] atom 24 has specified valence (3) smaller than the drawn valence 4.
[15:13:00] atom 27 has specified valence (1) smaller than the drawn valence 2.
[15:13:00] atom 32 has specified valence (1) smaller than the drawn valence 2.
[15:13:00] atom 9 has specified valence (4) smaller than the drawn valence 5.
[15:13:00] atom 15 has specified valence (1) smaller than the drawn valence 2.
[15:13:00] atom 16 has specified valence (1) smaller than the drawn valence 2.
[15:13:00] atom 10 has specified valence (1) smaller than the drawn valence 2.
[15:13:00] atom 11 has specified valence (1) smaller than the drawn valence 2.
[15:13:00] atom 17 

[15:13:00] atom 1 has specified valence (1) smaller than the drawn valence 2.
[15:13:00] atom 11 has specified valence (3) smaller than the drawn valence 4.
[15:13:00] atom 14 has specified valence (3) smaller than the drawn valence 4.
[15:13:00] atom 19 has specified valence (4) smaller than the drawn valence 5.
[15:13:00] atom 12 has specified valence (4) smaller than the drawn valence 5.
[15:13:00] atom 20 has specified valence (1) smaller than the drawn valence 2.
[15:13:00] atom 24 has specified valence (1) smaller than the drawn valence 2.
[15:13:00] atom 10 has specified valence (4) smaller than the drawn valence 5.
[15:13:00] atom 15 has specified valence (1) smaller than the drawn valence 2.
[15:13:00] atom 6 has specified valence (3) smaller than the drawn valence 4.
[15:13:00] atom 7 has specified valence (2) smaller than the drawn valence 3.
[15:13:00] atom 5 has specified valence (4) smaller than the drawn valence 5.
[15:13:00] atom 13 has specified valence (1) smaller tha

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1466/1466 [00:00<00:00, 5199.41it/s]


1466

In [16]:
check_p_s_cl(gen_mols)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1466/1466 [00:00<00:00, 6226.69it/s]

1293 1315 619 1448





(1293, 1315, 619, 1448)

# check the ratio of P, S, Cl in processed dataset

In [34]:
root = '../../data/cleaned_crossdocked_data'
split = 'train'
dataset = CombinedSparseGraphDataset(root, split)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

In [35]:
one = next(iter(dataloader))
one

DataBatch(x=[32, 13], pos=[32, 3], target_x=[32, 13], target_pos=[32, 3], Gt_mask=[32], ligand_name=[1], batch=[32], ptr=[2])

In [38]:
one_class = torch.argmax(one.x[one.Gt_mask], dim=-1)
one_class

tensor([3, 1, 5, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 5, 5, 0])

In [39]:
n_all = 0
n_p = 0
n_s = 0
n_cl = 0
for batch in tqdm(dataloader):
    x = batch.x[batch.Gt_mask]
    assert x.size(0) == batch.x.size(0)//2
    x_class = torch.argmax(x, dim=-1)
    
    
    if any(element in range(8, 13) for element in x_class):
        n_all += 1
        if any(element in [8, 9] for element in x_class):
            n_p += 1
        if any(element in [10, 11] for element in x_class):
            n_s += 1
        if any(element == 12 for element in x_class):
            n_cl += 1
            
print(n_p, n_s, n_cl, n_all)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 184026/184026 [11:59<00:00, 255.76it/s]

14788 36058 20708 65477





In [40]:
len(dataset)

184026