In [1]:
import tokenizer
import torch

graph_tokenizer = tokenizer.GraphTokenizer(torch.load("dictionary.pt"))

In [2]:
import h5py
import tqdm
import numpy as np
import torch
import data

all_data = []
with h5py.File('Data/train.h5', 'r') as f:
    for label in tqdm.tqdm(f.keys()):
      group = f[label]

      graph_group = group['graph']
      graph_data = {k: torch.tensor(np.array(v)) for k, v in graph_group.items()}
      graph_data = {k: v.float() if not k.startswith("edge_index") else v.long() for k, v in graph_data.items()}
      all_data.append(data.PairData(**graph_data))

all_data[0]

100%|█████████████████████████████████████| 500/500 [00:00<00:00, 1292.27it/s]


PairData(y=0.5769230127334595, edge_attr_s=[476, 3], edge_attr_t=[452, 3], edge_index_s=[2, 476], edge_index_t=[2, 452], x_s=[254, 9], x_t=[239, 9])

In [3]:
graph_tokenizer.tokenize(data.PairData.split(all_data[0])[0])

Data(x=[254], edge_index=[2, 476], edge_attr=[476])

In [4]:
import mpnn

config = mpnn.Config(node_out_feats=16,
                 edge_hidden_feats=16, num_step_message_passing=3)
model = mpnn.from_config(config,node_in_feats=9, edge_in_feats=3,dropout=.1, do_edge_update=True)
exmpl = data.PairData.split(all_data[0])[0]
model(exmpl,exmpl.x,exmpl.edge_attr)

(tensor([[ 0.4714, -0.0111,  0.0000,  ...,  0.4893, -0.1524, -0.7537],
         [ 0.5443, -0.2311,  3.2477,  ...,  0.3374,  0.1596, -0.7088],
         [ 0.6122, -0.0874,  0.0000,  ...,  0.4163, -0.0220, -0.6850],
         ...,
         [ 0.4853, -0.1381,  2.9273,  ...,  0.3530,  0.0585, -0.0000],
         [ 0.0000, -0.0978,  3.8597,  ...,  0.2326,  0.2355, -0.7976],
         [ 0.2437,  0.0425,  3.5459,  ...,  0.4893, -0.0366, -0.7458]],
        grad_fn=<MulBackward0>),
 tensor([[ 0.0529,  0.3552, -0.0714,  ...,  0.4663, -0.4724,  0.0916],
         [ 0.2328,  0.6689, -0.5839,  ...,  0.1626, -0.6247, -0.0924],
         [ 0.2554,  0.5018, -0.0000,  ...,  0.2723, -0.5082, -0.0468],
         ...,
         [ 0.3301,  0.5117, -0.3790,  ...,  0.3733, -0.3686, -0.0179],
         [ 0.3062,  0.6033, -0.3145,  ...,  0.3638, -0.4779,  0.1043],
         [ 0.0000,  0.6174, -0.4303,  ...,  0.3565, -0.5251,  0.0165]],
        grad_fn=<MulBackward0>))

In [5]:
import utils

utils.readout_counts(model)

{'total': '12,224',
 'base': '6,480',
 'project_edge_feats': '64',
 'edge_update_network': '1,056',
 'gnn_layer': '4,624'}

In [6]:
import encoder
import torch

mpnn_configs = [mpnn.Config(node_out_feats=16,
                 edge_hidden_feats=8, num_step_message_passing=5), mpnn.Config(node_out_feats=64,
                 edge_hidden_feats=32, num_step_message_passing=3), mpnn.Config(node_out_feats=128,
                 edge_hidden_feats=64, num_step_message_passing=1)]
config = {"mpnn_configs":mpnn_configs, "do_edge_update":True, "embedding_dim_x":32, "embedding_dim_edge_attr": 64, "do_edge_update":True, "num_sabs":8,"dropout":0.1, "heads":8, "warmup":.05, "lr": 1e-3, "weight_decay":.01, "betas":(.99,.999)}
ex_model = encoder.Encoder(graph_tokenizer=None,**config)
ex_model(exmpl)

tensor([[ 0.3629,  1.9652,  0.9449,  0.7582,  1.1357,  1.0366,  0.5773,  0.1636,
          0.1302,  1.8411,  1.2199,  1.5705, -0.6150,  0.3423,  0.2715,  1.5571,
         -0.4466,  2.0710, -0.2432,  1.7037,  0.5280,  2.2168,  0.7261,  2.0933,
          1.0107,  1.0249,  1.9450,  1.1366, -0.2466,  0.7996,  1.5958,  0.2197,
          0.4320,  0.5953,  1.8889,  0.5723, -0.5251,  2.0812, -0.4864,  2.0457,
          1.4566,  1.2685, -0.1298,  1.6062,  0.8319,  0.2482, -0.0858,  2.4688,
          1.6206,  1.1771,  1.7183, -1.8192,  1.8075,  1.5534,  1.1015,  3.1736,
          0.0186,  0.8398,  0.7455,  2.1844,  1.5703, -0.2572,  2.0397,  1.8072,
          1.6626,  0.8331,  0.6384,  0.1090,  0.7643, -0.2922,  1.1658, -0.2924,
          1.7136,  1.3961, -0.3477,  1.3909,  0.9249,  2.7707,  0.8840,  0.9375,
         -0.4439,  1.2272,  0.3573,  1.4213,  1.5085, -0.8358,  0.9378,  1.2565,
          0.1258,  0.6108,  1.9812,  2.3055,  0.1109,  0.9012,  0.6665,  0.1384,
          1.0472,  2.1546,  

In [7]:
mpnn_configs = [mpnn.Config(node_out_feats=64,
                 edge_hidden_feats=32, num_step_message_passing=3)]
config = {"mpnn_configs":mpnn_configs, "embedding_dim_x":32, "embedding_dim_edge_attr": 64, "do_edge_update":False, "num_sabs":8,"dropout":0.1, "heads":8, "warmup":.05, "lr": 1e-3, "weight_decay":.01, "betas":(.99,.999)}
ex_model = encoder.Encoder(graph_tokenizer=graph_tokenizer,**config)
exmpl_tokenized_graph = graph_tokenizer.tokenize(data.PairData.split(all_data[0])[0])
ex_model(exmpl_tokenized_graph)

tensor([[ 2.6978,  1.3747,  2.1392, -0.1023,  2.7157,  2.2231,  0.5449,  0.9221,
          4.3321,  1.5619, -0.6350, -0.1828, -0.2767,  2.8312,  1.1922,  2.1519,
         -1.0704, -0.5354, -0.4266,  0.0557,  2.3580,  2.4410,  1.2161,  1.7591,
         -0.7637,  4.0394,  0.3276,  0.4908,  3.8681,  2.6026,  0.6065,  5.5445,
          1.3901,  1.0080,  1.4422, -0.4697,  0.0873,  0.9109,  0.9345,  1.5691,
          1.6160,  1.7157,  1.8750,  5.2876,  5.9620,  2.8573,  1.3516,  1.9110,
          0.4763,  2.9129,  2.1059,  0.9725,  0.4200,  0.5850, -0.2953, -1.2477,
          1.7547,  0.6012,  2.6529,  1.4475,  0.5265,  3.1546,  3.1475, -0.1677]],
       grad_fn=<ViewBackward0>)