In [2]:
import sys
sys.path.append('../')
sys.path.append('../data_processing/')
sys.path.append('../model/')

In [22]:
import torch
from torch_geometric.loader import DataLoader, DenseDataLoader
from torch_geometric.utils import to_dense_batch
from data_processing.utils import make_edge_mask

In [20]:
from data_processing.paired_data import CombinedGraphDataset
from model.utils.utils_diffusion import append_dims, vp_logs, vp_logsnr, mean_flat, center2zero, center2zero_with_mask, center2zero_combined_graph, sample_zero_center_gaussian, sample_zero_center_gaussian_with_mask

In [5]:
dataset = CombinedGraphDataset(root='../../data/small_dataset', split='all')

Processing...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:00<00:00, 83.15it/s]
Done!


In [6]:
data_loader = DataLoader(dataset, batch_size=5, shuffle=False, num_workers=24)

In [8]:
one = next(iter(data_loader))
batch = one
batch

DataBatch(x=[5, 172, 13], pos=[5, 172, 3], original_x=[114, 13], original_pos=[114, 3], target_x=[5, 172, 13], target_pos=[5, 172, 3], CoM=[5, 3], node_mask=[5, 172], Gt_mask=[860], edge_mask=[5, 29584], node_pp_index=[1, 860], ligand_name=[5], num_nodes=114, batch=[114], ptr=[6])

In [10]:
node_mask = batch.node_mask
edge_mask = batch.edge_mask
CoM = batch.CoM
num_nodes = batch.num_nodes
batch_info = batch.batch

node_mask = node_mask.unsqueeze(-1)
h0 = batch.x
x0 = batch.pos

hT = batch.target_x
xT = batch.target_pos

Gt_mask = batch.Gt_mask.view(node_mask.size(0), node_mask.size(1), -1)

In [12]:
for pos in x0:
    print(pos)

tensor([[ -2.0808,  21.8273, -28.2589],
        [ -1.9501,  22.5794, -29.2635],
        [ -2.7814,  23.8029, -29.4013],
        [ -3.6636,  24.2354, -28.3832],
        [ -3.7971,  23.5298, -27.1929],
        [ -4.4254,  25.3948, -28.5608],
        [ -4.3235,  26.1345, -29.7399],
        [ -5.0836,  27.2837, -29.9065],
        [ -3.4573,  25.7181, -30.7504],
        [ -2.6932,  24.5616, -30.5812],
        [ -3.2968,  22.7502, -27.0518],
        [ -5.0150,  27.7801, -30.6982],
        [ -0.9639,  22.2320, -30.3006],
        [ -0.5110,  20.9901, -30.4588],
        [  0.4726,  20.6300, -31.4891],
        [  0.7762,  19.2791, -31.6876],
        [  1.7074,  18.8985, -32.6590],
        [  1.9978,  17.5528, -32.8438],
        [  1.1108,  21.6036, -32.2738],
        [  2.0418,  21.2254, -33.2452],
        [  2.3415,  19.8740, -33.4395],
        [  3.2681,  19.5067, -34.4071],
        [  2.6210,  17.2929, -33.4934],
        [  3.6890,  20.1636, -34.9260],
        [ -2.0808,  21.8273, -28.2589],


In [25]:
def preprocess(x0, xT, h0, hT, node_mask, Gt_mask=None, num_node=None, batch_info=None, use_mass=False):
    # print(num_node)
    # xT is already included in x0
    x0 = center2zero_combined_graph(x0, node_mask, Gt_mask)
    xT = xT

    # convert the dense graphs into sparse ones

    # x0_ = x0.view(x0.size(0) * x0.size(1), -1)
    # xT_ = xT.view(xT.size(0) * xT.size(1), -1)
    # node_mask_ = node_mask.view(node_mask.size(0) * node_mask.size(1), -1)
    node_mask = node_mask.squeeze(-1)
    # print(node_mask.size(), x0.size(), x0[0].size())
    # x0 = x0_[node_mask_].view(x0.size(0), x0.size(1), -1)
    # xT = xT_[node_mask_].view(xT.size(0), xT.size(1), -1)
    bs = x0.size(0)
    x0_, xT_, h0_, hT_ = [], [], [], []
    sparse_Gt_mask = []
    batch_all = []
    for batch_idx in range(bs):
        x0_.append(x0[batch_idx][node_mask[batch_idx]])
        xT_.append(xT[batch_idx][node_mask[batch_idx]])
        h0_.append(h0[batch_idx][node_mask[batch_idx]])
        hT_.append(hT[batch_idx][node_mask[batch_idx]])
        N = node_mask[batch_idx].sum().item()    # 2 x number of nodes
        # print('number of nodes in the current graph', N)
        Gt_mask_batch = torch.zeros(N)
        # print(Gt_mask_batch.size(), Gt_mask_batch)
        Gt_mask_batch[:(N//2)] = 1
        Gt_mask_batch = Gt_mask_batch.bool()
        sparse_Gt_mask.append(Gt_mask_batch)
        batch_all.append(torch.ones(N, dtype=torch.long) * batch_idx)

    x0 = torch.cat(x0_, dim=0)
    xT = torch.cat(xT_, dim=0)
    h0 = torch.cat(h0_, dim=0)
    hT = torch.cat(hT_, dim=0)
    Gt_mask = torch.cat(sparse_Gt_mask, dim=0)
    batch_info = torch.cat(batch_all, dim=0)
    # x0, xT, h0, hT = x0_, xT_, h0_, hT_
    # print(x0.size(), h0.size(), xT.size(), hT.size(), x0.device, Gt_mask.size())
    return x0, xT, h0, hT, Gt_mask, batch_info

In [26]:
processed_x0, processed_xT, processed_h0, processed_hT, processed_Gt_mask, processed_batch_info = preprocess(x0, xT, h0, hT, node_mask, Gt_mask=Gt_mask, num_node=num_nodes, batch_info=batch_info)

In [27]:
processed_x0

tensor([[ -0.3321, -45.4470,  67.0281],
        [ -0.0707, -43.9428,  65.0189],
        [ -1.7333, -41.4958,  64.7433],
        [ -3.4977, -40.6308,  66.7795],
        [ -3.7647, -42.0420,  69.1601],
        [ -5.0213, -38.3120,  66.4243],
        [ -4.8175, -36.8326,  64.0661],
        [ -6.3377, -34.5342,  63.7329],
        [ -3.0851, -37.6654,  62.0451],
        [ -1.5569, -39.9784,  62.3835],
        [ -2.7641, -43.6012,  69.4423],
        [ -6.2005, -33.5414,  62.1495],
        [  1.9017, -44.6376,  62.9447],
        [  2.8075, -47.1214,  62.6283],
        [  4.7747, -47.8416,  60.5677],
        [  5.3819, -50.5434,  60.1707],
        [  7.2443, -51.3046,  58.2279],
        [  7.8251, -53.9960,  57.8583],
        [  6.0511, -45.8944,  58.9983],
        [  7.9131, -46.6508,  57.0555],
        [  8.5125, -49.3536,  56.6669],
        [ 10.3657, -50.0882,  54.7317],
        [  9.0715, -54.5158,  56.5591],
        [ 11.2075, -48.7744,  53.6939],
        [ -0.0000,   0.0000,  -0.0000],


In [28]:
processed_batch_info

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4])

In [29]:
processed_Gt_mask

tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True, False, False,
        False, False, False, False, False, False, False, False, 

# Bugs
1. VP bridge model keeps generating Nan: check if ***sigma max*** is < 1. This happens because of using the sigma values for VE bridge