In [1]:
import tokenizer
import torch

graph_tokenizer = tokenizer.GraphTokenizer(torch.load("dictionary.pt"))

In [2]:
import h5py
import tqdm
import numpy as np
import torch
import data
import torch_geometric as tg

all_data = []
with h5py.File('Data/train.h5', 'r') as f:
    for label in tqdm.tqdm(f.keys()):
        group = f[label]
        graph1 = data.read_graph(group['graph1'])
        graph2 = data.read_graph(group['graph2'])
        # Index using () for scalar dataset
        y = group["y"][()]
        all_data.append({"graph1":graph1,"graph2":graph2,"y":torch.tensor(y)})

all_data[0]

100%|██████████████████████████████████████| 500/500 [00:00<00:00, 920.98it/s]


{'graph1': BlendData(x=[254, 9], edge_index=[2, 476], edge_attr=[476, 3], blend_batch=[28], mol_batch=[254]),
 'graph2': BlendData(x=[239, 9], edge_index=[2, 452], edge_attr=[452, 3], blend_batch=[28], mol_batch=[239]),
 'y': tensor(0.5769)}

In [3]:
graph_tokenizer.tokenize(all_data[0]["graph1"])

BlendData(x=[254], edge_index=[2, 476], edge_attr=[476], blend_batch=[28], mol_batch=[254])

In [4]:
import aggregate

agg = aggregate.BlendAggregator(True,9,1,1,0)
from torch_geometric.loader import DataLoader
batch = next(iter(DataLoader([all_data[0]["graph1"],all_data[0]["graph2"]],batch_size=2)))
print(agg(batch.x,batch).shape)
print(agg(all_data[0]["graph1"].x,all_data[0]["graph1"]).shape)

torch.Size([2, 9])
torch.Size([1, 9])


In [5]:
import aggregate

agg = aggregate.BlendAggregator(False,9,1,1,0)
from torch_geometric.loader import DataLoader
batch = next(iter(DataLoader([all_data[0]["graph1"],all_data[0]["graph2"]],batch_size=2)))
print(agg(batch.x,batch).shape)
print(agg(all_data[0]["graph1"].x,all_data[0]["graph1"]).shape)

torch.Size([2, 9])
torch.Size([1, 9])


In [6]:
import mpnn

config = mpnn.Config(node_out_feats=16,
                 edge_hidden_feats=16, num_step_message_passing=3)
model = mpnn.from_config(config,node_in_feats=9, edge_in_feats=3,dropout=.1, do_edge_update=True)
exmpl = all_data[0]["graph1"]
model(exmpl,exmpl.x,exmpl.edge_attr)

(tensor([[-0.1041,  0.3365,  0.0174,  ...,  0.2844, -0.0000,  0.1039],
         [-0.1871,  0.3660, -0.0453,  ...,  0.3649, -0.1400,  0.2395],
         [-0.2525,  0.3784, -0.0000,  ...,  0.0000, -0.0000,  0.2323],
         ...,
         [-0.4904,  0.2055, -0.3225,  ...,  0.7239, -0.0000,  0.7296],
         [-0.3014,  0.3132, -0.1970,  ...,  0.3428, -0.1677,  0.0000],
         [-0.1738,  0.2647, -0.0656,  ...,  0.3084, -0.0000,  0.1603]],
        grad_fn=<MulBackward0>),
 tensor([[ 0.2465, -0.2977,  0.0166,  ...,  0.0000,  0.0756,  0.1721],
         [ 0.2352, -0.2759, -0.0417,  ...,  0.1757,  0.0000,  0.1873],
         [ 0.2047, -0.2888, -0.0505,  ...,  0.1421,  0.0282,  0.1918],
         ...,
         [ 0.2390, -0.2703, -0.0639,  ...,  0.1028,  0.0587,  0.1232],
         [ 0.2472, -0.2958, -0.0105,  ...,  0.1546,  0.0464,  0.1766],
         [ 0.2118, -0.2945,  0.0156,  ...,  0.0752,  0.0841,  0.1985]],
        grad_fn=<MulBackward0>))

In [7]:
import utils

utils.readout_counts(model)

{'total': '12,224',
 'base': '6,480',
 'project_edge_feats': '64',
 'edge_update_network': '1,056',
 'gnn_layer': '4,624'}

In [8]:
import encoder
import torch

mpnn_configs = [mpnn.Config(node_out_feats=16,
                 edge_hidden_feats=8, num_step_message_passing=5), mpnn.Config(node_out_feats=64,
                 edge_hidden_feats=32, num_step_message_passing=3), mpnn.Config(node_out_feats=128,
                 edge_hidden_feats=64, num_step_message_passing=1)]
config = {"mpnn_configs":mpnn_configs, "do_two_stage":True, "do_edge_update":True, "embedding_dim_x":32, "embedding_dim_edge_attr": 64, "do_edge_update":True, "num_sabs":8,"dropout":0.1, "heads":8, "warmup":.05, "lr": 1e-3, "weight_decay":.01, "betas":(.99,.999)}
ex_model = encoder.Encoder(graph_tokenizer=None,**config)
ex_model(exmpl)

tensor([[ 14.5679,  15.0214,  15.6756,   6.8135,  32.3703,   8.4644,   2.6436,
          -3.7035,  27.7643,  39.0933,  -3.1549,  -7.7342,  44.6607,  37.7973,
          22.1152,  22.6417,  26.5870,  10.7395,  21.6800,  -8.5880,  16.7585,
          27.6666,  12.2636,  41.1406, -17.0898,  12.3411,  29.1615,  26.2636,
           7.9025,  23.8421, -11.4530,  19.8079,  19.0724,   6.8836,  15.0650,
          14.7421,   7.1672,   5.5850,   7.1261,  22.0255,  12.7941,   4.8541,
          34.4244,  26.3723,   9.7602,  18.0545,  -5.5065,  22.9658,  -7.0983,
           5.3882,   3.7933,   4.6332,  27.9746,  -5.7652,  26.5953,   3.8623,
         -12.1401,  -7.3063,  25.7088,   7.1268,  13.8138,  10.4985,  22.2109,
          10.4615,  11.3154,  24.7482,   3.4151,  -0.9278,  25.4153,  31.0019,
           2.1628,  21.3221,  15.6051,  11.0922,  -6.8726,   8.3026,  16.8213,
           2.9926,   8.1325,   9.4760,  -9.2109,  13.8579,  19.5858,  41.1380,
          16.4197,  21.3555,  23.7364,  24.5721,  26

In [9]:
mpnn_configs = [mpnn.Config(node_out_feats=64,
                 edge_hidden_feats=32, num_step_message_passing=3)]
config = {"mpnn_configs":mpnn_configs,  "do_two_stage":False, "embedding_dim_x":32, "embedding_dim_edge_attr": 64, "do_edge_update":False, "num_sabs":8,"dropout":0.1, "heads":8, "warmup":.05, "lr": 1e-3, "weight_decay":.01, "betas":(.99,.999)}
ex_model = encoder.Encoder(graph_tokenizer=graph_tokenizer,**config)
exmpl_tokenized_graph = graph_tokenizer.tokenize(all_data[0]["graph1"])
ex_model(exmpl_tokenized_graph)

tensor([[ 3.6229e+00,  7.0717e+00,  1.8130e+00,  2.9501e+00,  3.0101e+00,
         -3.6405e-01,  2.5116e+00,  3.0838e+00,  3.6784e-01,  5.2213e+00,
          4.3386e+00,  1.0588e+00,  3.3996e+00,  5.7579e-01,  2.8739e+00,
          2.6644e+00,  1.5622e+00,  4.2907e+00,  4.6320e+00,  6.4612e+00,
          2.6568e+00,  5.7772e+00,  9.8764e-01,  2.2693e+00,  3.0668e+00,
          1.6809e+00,  3.1287e+00,  1.6481e+00,  1.4621e+00,  3.5973e+00,
          3.0528e+00,  1.5554e+00,  6.3562e+00,  4.0825e+00,  3.2474e+00,
          1.7612e-01, -1.3366e+00,  2.4514e+00,  4.7522e+00,  2.6687e+00,
          2.1462e+00,  4.3759e+00,  1.4140e+00, -9.0805e-01,  2.9532e+00,
          3.9732e+00,  5.7494e-01,  2.6475e-01,  3.6274e-01,  6.8068e-01,
          3.2150e+00,  3.8738e+00,  4.3919e-03,  1.5392e+00,  2.5508e+00,
          1.4269e+00,  3.5355e+00, -7.4297e-02,  2.2697e-01,  2.7027e+00,
         -2.5722e-01,  2.2003e+00,  2.0599e+00,  2.4785e-01]],
       grad_fn=<ViewBackward0>)