In [36]:
import pickle
import torch
from tqdm import tqdm
import os
import random

In [37]:
random.seed(2024)

In [7]:
to_pgmg = {
    0: None,
    1: 'HYBL',
    2: 'AROM',
    3: 'POSC',
    4: None,
    5: 'HDON', 
    6: 'HACC',
    7: None
}

In [3]:
with open('../../data/cleaned_crossdocked_data/metadata_HDBSCAN_non_filtered/test_pp_info.pkl', 'rb') as f:
    metadata = pickle.load(f)

In [22]:
example = metadata['2z9y_A_rec_2z9y_ddr_lig_tt_min_0']
pp_types = torch.argmax(example['pp_types'][0], dim=-1).numpy()
pp_types

array([4, 4, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 5])

In [18]:
pp_positions = example['pp_positions'][0].numpy()
pp_positions

array([[  9.2817 ,  29.60305,  -2.03555],
       [  6.06995,  28.6977 ,  -1.95665],
       [ 10.608  ,  29.9543 ,  -3.682  ],
       [ 10.6176 ,  28.7669 ,  -4.6239 ],
       [ 11.0311 ,  29.2232 ,  -6.0101 ],
       [ 11.3883 ,  28.029  ,  -6.8773 ],
       [ 10.1891 ,  27.575  ,  -7.6878 ],
       [ 10.6624 ,  26.7021 ,  -8.8305 ],
       [  9.4764 ,  26.2818 ,  -9.6707 ],
       [  9.9597 ,  25.5057 , -10.8767 ],
       [  8.9943 ,  24.3769 , -11.1695 ],
       [  5.9186 ,  27.376  ,  -3.6346 ],
       [  7.3107 ,  27.144  ,  -4.1952 ],
       [  7.2478 ,  26.2074 ,  -5.3869 ],
       [  6.646  ,  26.9239 ,  -6.5805 ],
       [  6.4477 ,  25.9412 ,  -7.7138 ],
       [  5.3355 ,  26.4342 ,  -8.6191 ],
       [  5.5933 ,  26.0112 , -10.0574 ],
       [  6.0605 ,  27.1842 , -10.9068 ],
       [  4.9146 ,  28.1441 , -11.1655 ],
       [  7.5198 ,  31.1574 ,   1.2522 ]], dtype=float32)

In [23]:
assert pp_types.shape[0] == pp_positions.shape[0]

In [24]:
pp_l = []
for i in range(pp_types.shape[0]):
    pp_type = to_pgmg[pp_types[i]]
    if pp_type is None:
        continue
    pos = pp_positions[i]
    pp_l.append((pp_type, pos[0], pos[1], pos[2]))

pp_l

[('HYBL', 10.608, 29.9543, -3.682),
 ('HYBL', 10.6176, 28.7669, -4.6239),
 ('HYBL', 11.0311, 29.2232, -6.0101),
 ('HYBL', 11.3883, 28.029, -6.8773),
 ('HYBL', 10.1891, 27.575, -7.6878),
 ('HYBL', 10.6624, 26.7021, -8.8305),
 ('HYBL', 9.4764, 26.2818, -9.6707),
 ('HYBL', 9.9597, 25.5057, -10.8767),
 ('HYBL', 8.9943, 24.3769, -11.1695),
 ('HYBL', 5.9186, 27.376, -3.6346),
 ('HYBL', 7.3107, 27.144, -4.1952),
 ('HYBL', 7.2478, 26.2074, -5.3869),
 ('HYBL', 6.646, 26.9239, -6.5805),
 ('HYBL', 6.4477, 25.9412, -7.7138),
 ('HYBL', 5.3355, 26.4342, -8.6191),
 ('HYBL', 5.5933, 26.0112, -10.0574),
 ('HYBL', 6.0605, 27.1842, -10.9068),
 ('HYBL', 4.9146, 28.1441, -11.1655),
 ('HDON', 7.5198, 31.1574, 1.2522)]

In [43]:
pgmg_dict = {}
for k, v in metadata.items():
    pp_l = []
    pp_types = torch.argmax(v['pp_types'][0], dim=-1).numpy()
    pp_positions = v['pp_positions'][0].numpy()
    assert pp_types.shape[0] == pp_positions.shape[0]
    for i in range(pp_types.shape[0]):
        pp_type = to_pgmg[pp_types[i]]
        if pp_type is None:
            continue
        pos = pp_positions[i]
        pp_l.append((pp_type, pos[0], pos[1], pos[2]))

    if len(pp_l) > 8:
        pp_l = random.sample(pp_l, 8)

    if len(pp_l) <= 1:
        continue
    pgmg_dict[k] = pp_l

In [44]:
pgmg_dict['1a9p_A_rec_1b8n_img_lig_tt_min_0']

[('AROM', 66.92267, -48.447903, 96.63404),
 ('HDON', 66.9997, -45.2465, 101.2257),
 ('HDON', 69.3724, -47.9308, 97.6562),
 ('HDON', 67.855, -44.2498, 99.1071),
 ('HYBL', 67.9754, -46.1456, 97.5954),
 ('HDON', 64.3444, -48.7586, 95.6304),
 ('HACC', 68.5602, -50.527, 96.7065),
 ('HDON', 65.6878, -42.2237, 97.9142)]

In [45]:
save_path = '/home2/conghao001/pharmacophore2drug/PGMG/data/crossdocked'

In [47]:
for k, v in tqdm(pgmg_dict.items()):
    fn = os.path.join(save_path, k + '.posp')
    with open(fn, 'w') as f:
        for node in v:
            f.write(f"{node[0]} {node[1]} {node[2]} {node[3]}\n")

100%|██████████████████████████████████████████████████████████████████████| 15074/15074 [00:01<00:00, 9421.93it/s]
