In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('..')
sys.path.append('../data_processing/')
sys.path.append('../evaluation/')

In [3]:
import pickle
import os
import numpy as np
from torch_geometric.loader import DataLoader
import torch
import torch.nn.functional as F
from omegaconf import OmegaConf
from rdkit import Chem
from rdkit.Chem.rdForceFieldHelpers import UFFOptimizeMolecule
from tqdm import tqdm
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import matplotlib
import matplotlib.pylab as plt
%matplotlib inline

In [5]:
from data_processing.paired_data import PharmacophoreDataset, CombinedGraphDataset, CombinedSparseGraphDataset
from data_processing.reconstruction import get_atomic_number_from_index, is_aromatic_from_index, reconstruct_from_generated
from model.pp_bridge import PPBridge
from model.pp_bridge_sampler import PPBridgeSampler
from model.utils.utils_diffusion import center2zero_combined_graph, center2zero_with_mask, center2zero
from script_utils import load_data
from evaluation.utils_eval import build_pdb_dict

In [6]:
raw_data_root = '/home/conghao001/pharmacophore2drug/PP2Drug/data/cleaned_crossdocked_data'
split = 'test'
batch_size = 1000
num_workers = 0

In [7]:
pdb_dict, pdb_rev_dict = build_pdb_dict(raw_data_root)

In [8]:
dataset, dataloader = load_data('CombinedSparseGraphDataset', raw_data_root, split='test', batch_size=batch_size, aromatic=False)

In [9]:
# class OldDataset(CombinedSparseGraphDataset):
#     def __init__(self, root, split='train', transform=None, pre_transform=None, pre_filter=None, aromatic=False):
#         super(OldDataset, self).__init__(root, split, transform, pre_transform, pre_filter, aromatic=aromatic)
#         self.load(self.processed_paths[0])

#     @property
#     def processed_dir(self):
#         # Customize the processed directory name here
#         return os.path.join(self.root, 'processed_CoM_xT_with_noise_std0.01')

In [10]:
# dataset = OldDataset(raw_data_root, split=split, aromatic=True)
# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)

In [11]:
one = next(iter(dataloader))
one.pos.size()

torch.Size([47288, 3])

In [10]:
batch = one.batch
batch.size()

torch.Size([47288])

In [11]:
root = '../../src/lightning_logs/vp_bridge_egnn_CombinedSparseGraphDataset_2024-05-31_14_11_45.077216'
filename = 'generation_res.pkl'

In [12]:
with open(os.path.join(root, filename), 'rb') as f:
    res = pickle.load(f)

In [13]:
res.keys()

dict_keys(['x', 'x_traj', 'h', 'h_traj', 'nfe'])

In [14]:
all_x_traj = res['x_traj']
len(all_x_traj)

15

In [15]:
# 1001 steps in each batch
x_traj = all_x_traj[0]
len(x_traj)

1001

In [16]:
# integrated graph containing 2000 individual graphs at first step
x_traj[0].size()

torch.Size([106946, 3])

In [17]:
idx = 50
batch_idx = batch==idx
batch_idx.sum()

tensor(26)

In [18]:
batch_idx

tensor([False, False, False,  ..., False, False, False])

In [19]:
Gt_mask_i = one.Gt_mask[batch_idx]
Gt_mask_i

tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True, False, False, False, False, False, False, False,
        False, False, False, False, False, False])

In [20]:
x_traj_i = []
for x_traj_t in tqdm(x_traj):
    x_t_i = x_traj_t[batch_idx][Gt_mask_i]
    x_traj_i.append(x_t_i)

  0%|                                                                                                                                                                   | 0/1001 [00:00<?, ?it/s]


IndexError: The shape of the mask [47288] at index 0 does not match the shape of the indexed tensor [106946, 3] at index 0

In [None]:
x_i_arr = x_traj_i[-1].numpy()
x_i_arr

In [None]:
original_pos_i = one.pos[batch_idx][Gt_mask_i]
original_pos_0center_i = center2zero(original_pos_i)
original_pos_0center_i_arr = original_pos_0center_i.numpy()
original_pos_0center_i.size()

In [34]:
pp_pos_i = one.target_pos[batch_idx][~Gt_mask_i]
pp_pos_0center_i = center2zero(pp_pos_i)
pp_pos_0center_i_arr = pp_pos_0center_i.numpy()
pp_pos_0center_i.size()

torch.Size([23, 3])

In [35]:
fig = make_subplots(rows=1, cols=3, 
                   specs = [[{"type": "scatter3d"} for i in range(3)]]
                   )


fig.add_trace(
    go.Scatter3d(x=x_i_arr[:, 0], y=x_i_arr[:, 1], z=x_i_arr[:, 2], mode='markers'),
    row=1, col=1
)

fig.add_trace(
    go.Scatter3d(x=original_pos_0center_i_arr[:, 0], y=original_pos_0center_i_arr[:, 1], z=original_pos_0center_i_arr[:, 2], mode='markers'),
    row=1, col=2
)

fig.add_trace(
    go.Scatter3d(x=pp_pos_0center_i_arr[:, 0], y=pp_pos_0center_i_arr[:, 1], z=pp_pos_0center_i_arr[:, 2], mode='markers'),
    row=1, col=3
)

fig.update_layout(height=500, showlegend=False)
fig.show()

In [36]:
xt_sampled = [x_traj_i[i] for i in range(len(x_traj_i)) if i%70==0]
len(xt_sampled)

15

In [37]:
nc = 3
nr = len(xt_sampled)//nc + 1
fig = make_subplots(rows=nr, cols=nc, 
                   specs = [[{"type": "scatter3d"} for i in range(nc)] for i in range(nr)]
                   )

for i in range(len(xt_sampled)):
    arr = xt_sampled[i].cpu().numpy()
    fig.add_trace(
        go.Scatter3d(x=arr[:, 0], y=arr[:, 1], z=arr[:, 2], mode='markers'),
        row=i//nc + 1, col=i%nc + 1
    )

fig.update_layout(height=nr*350, showlegend=False)
fig.show()

In [None]:
fig = make_subplots(rows=1, cols=2, 
                   specs = [[{"type": "scatter3d"} for i in range(3)]]
                   )


fig.add_trace(
    go.Scatter3d(x=x_i_arr[:, 0], y=x_i_arr[:, 1], z=x_i_arr[:, 2], mode='markers'),
    row=1, col=1
)

fig.add_trace(
    go.Scatter3d(x=original_pos_0center_i_arr[:, 0], y=original_pos_0center_i_arr[:, 1], z=original_pos_0center_i_arr[:, 2], mode='markers'),
    row=1, col=2
)

fig.add_trace(
    go.Scatter3d(x=pp_pos_0center_i_arr[:, 0], y=pp_pos_0center_i_arr[:, 1], z=pp_pos_0center_i_arr[:, 2], mode='markers'),
    row=1, col=3
)

fig.update_layout(height=500, showlegend=False)
fig.show()