In [1]:
import tokenizer
import torch

graph_tokenizer = tokenizer.GraphTokenizer(torch.load("dictionary.pt"))

In [2]:
import h5py
import tqdm
import numpy as np
import torch
import data

all_data = []
with h5py.File('Data/train.h5', 'r') as f:
    for label in tqdm.tqdm(f.keys()):
      group = f[label]

      graph_group = group['graph']
      graph_data = {k: torch.tensor(np.array(v)) for k, v in graph_group.items()}
      graph_data = {k: v.float() if not k.startswith("edge_index") else v.long() for k, v in graph_data.items()}
      all_data.append(data.PairData(**graph_data))

all_data[0]

100%|█████████████████████████████████████| 500/500 [00:00<00:00, 1289.40it/s]


PairData(y=0.5769230127334595, edge_attr_s=[476, 3], edge_attr_t=[452, 3], edge_index_s=[2, 476], edge_index_t=[2, 452], x_s=[254, 9], x_t=[239, 9])

In [3]:
graph_tokenizer.tokenize(data.PairData.split(all_data[0])[0])

Data(x=[254], edge_index=[2, 476], edge_attr=[476])

In [4]:
import mpnn

config = mpnn.Config(node_out_feats=16,
                 edge_hidden_feats=16, num_step_message_passing=3)
model = mpnn.from_config(config,node_in_feats=9, edge_in_feats=3,dropout=.1, do_edge_update=True)
exmpl = data.PairData.split(all_data[0])[0]
model(exmpl,exmpl.x,exmpl.edge_attr)

(tensor([[ 0.0505,  0.5706, -0.1154,  ...,  0.0723,  0.0000, -0.0000],
         [ 0.0000,  0.3452, -0.1039,  ...,  0.0000,  0.8823, -0.3158],
         [ 0.0000,  0.3657, -0.0000,  ...,  0.1043,  0.6915, -0.2505],
         ...,
         [ 0.3025,  0.4391, -0.1137,  ..., -0.1171,  0.5915,  0.2551],
         [ 0.3834,  0.4089,  0.1342,  ...,  0.2130,  0.7842, -0.1940],
         [ 0.0841,  0.5722, -0.0849,  ...,  0.0548,  0.5465, -0.2707]],
        grad_fn=<MulBackward0>),
 tensor([[-0.1903, -0.0086, -0.1255,  ..., -0.0937, -0.0921, -0.0336],
         [-0.2319, -0.0844, -0.0714,  ..., -0.0000, -0.0527, -0.0490],
         [-0.2093, -0.0560, -0.1125,  ..., -0.0245, -0.0471, -0.0514],
         ...,
         [-0.2409, -0.0669, -0.1180,  ..., -0.0098, -0.0623, -0.0614],
         [-0.1927, -0.0199, -0.1403,  ...,  0.0205, -0.0436, -0.0061],
         [-0.2569, -0.1423, -0.0227,  ..., -0.1516, -0.0206, -0.0996]],
        grad_fn=<MulBackward0>))

In [5]:
import utils

utils.readout_counts(model)

{'total': '12,224',
 'base': '6,480',
 'project_edge_feats': '64',
 'edge_update_network': '1,056',
 'gnn_layer': '4,624'}

In [6]:
import encoder
import torch

mpnn_configs = [mpnn.Config(node_out_feats=16,
                 edge_hidden_feats=8, num_step_message_passing=5), mpnn.Config(node_out_feats=64,
                 edge_hidden_feats=32, num_step_message_passing=3), mpnn.Config(node_out_feats=128,
                 edge_hidden_feats=64, num_step_message_passing=1)]
config = {"mpnn_configs":mpnn_configs, "do_edge_update":True, "embedding_dim_x":32, "embedding_dim_edge_attr": 64, "do_edge_update":True, "num_sabs":8,"dropout":0.1, "heads":8, "warmup":.05, "lr": 1e-3, "weight_decay":.01, "betas":(.99,.999)}
ex_model = encoder.Encoder(graph_tokenizer=None,**config)
ex_model(exmpl)

tensor([[ 0.6101, -0.4512,  1.5945,  0.8752,  0.7099,  1.6768,  0.2967,  1.6356,
         -1.2052, -0.3205,  0.8781,  1.8801,  0.7371,  0.3144,  0.4874,  1.1919,
          0.5892, -0.0384,  0.9751,  0.7648,  0.3897,  0.3725,  0.1954,  0.5741,
          1.2732,  0.4859,  0.9891,  1.1286,  0.9149,  0.3023,  1.1488,  3.7451,
          1.0321,  0.8561,  0.0465,  2.9972,  0.9058,  0.2281,  2.1526, -0.7293,
          1.7258,  0.6981,  0.3712,  0.6742,  0.4356,  0.9451,  0.9083,  1.5207,
          0.6516,  0.6826,  1.1785,  1.2185, -0.2631,  0.4959,  0.9786,  0.6892,
          0.9494,  1.8249,  1.5897,  0.9381,  1.0205,  0.5508,  1.4961,  1.7059,
          0.4662,  0.5898,  0.3639,  0.8537,  0.9042, -0.4059, -0.7011,  1.0326,
          0.4363,  2.0677,  0.0569,  0.2997, -0.1143,  0.6037,  2.2834, -0.8241,
          0.4871,  1.6056,  0.7004,  0.3769,  2.0509,  0.9937,  1.9255,  0.0805,
          0.9496,  0.4200,  2.0577,  0.8032,  1.6570, -0.0720,  1.7626,  1.4303,
          0.4679,  0.9930,  

In [7]:
mpnn_configs = [mpnn.Config(node_out_feats=64,
                 edge_hidden_feats=32, num_step_message_passing=3)]
config = {"mpnn_configs":mpnn_configs, "embedding_dim_x":32, "embedding_dim_edge_attr": 64, "do_edge_update":False, "num_sabs":8,"dropout":0.1, "heads":8, "warmup":.05, "lr": 1e-3, "weight_decay":.01, "betas":(.99,.999)}
ex_model = encoder.Encoder(graph_tokenizer=graph_tokenizer,**config)
exmpl_tokenized_graph = graph_tokenizer.tokenize(data.PairData.split(all_data[0])[0])
ex_model(exmpl_tokenized_graph)

tensor([[ 0.5638,  1.5537,  0.2933,  2.7888,  3.3211,  1.7558,  2.5972,  3.8185,
          2.5779,  3.8431,  2.1207,  1.7362,  3.2465,  0.7290,  1.7333,  3.3889,
          6.7661,  2.0897, -0.1893, -1.1457,  4.9700, -0.5566,  3.8269,  0.6197,
          3.3572,  0.5504,  3.6545, -1.1053,  1.0019,  0.3341,  0.9272,  6.9728,
         -1.0765,  4.7533,  5.4569,  4.3734,  2.7634,  1.4379,  1.9102,  2.3873,
          4.4836,  3.6509,  1.3289,  0.1755,  1.1336,  1.7797,  2.9211,  5.2131,
          0.4265,  2.1700,  3.0682,  0.3907,  0.3684,  4.7401,  3.6901,  2.0178,
          1.4420,  0.3862,  5.8757,  2.6989, -0.2608,  0.3131, -0.1164,  5.6350]],
       grad_fn=<ViewBackward0>)