In [1]:
import tokenizer
import torch

graph_tokenizer = tokenizer.GraphTokenizer(torch.load("dictionary.pt"))

In [2]:
import h5py
import tqdm
import numpy as np
import torch
import data
import torch_geometric as tg

all_data = []
with h5py.File('Data/train.h5', 'r') as f:
    for label in tqdm.tqdm(f.keys()):
        group = f[label]
        graph1 = data.read_graph(group['graph1'])
        graph2 = data.read_graph(group['graph2'])
        # Index using () for scalar dataset
        y = group["y"][()]
        all_data.append({"graph1":graph1,"graph2":graph2,"y":torch.tensor(y)})

all_data[0]

100%|██████████████████████████████████████| 500/500 [00:00<00:00, 957.19it/s]


{'graph1': BlendData(x=[254, 9], edge_index=[2, 476], edge_attr=[476, 3], blend_batch=[28], mol_batch=[254]),
 'graph2': BlendData(x=[239, 9], edge_index=[2, 452], edge_attr=[452, 3], blend_batch=[28], mol_batch=[239]),
 'y': tensor(0.5769)}

In [3]:
graph_tokenizer.tokenize(all_data[0]["graph1"])

BlendData(x=[254], edge_index=[2, 476], edge_attr=[476], blend_batch=[28], mol_batch=[254])

In [4]:
import aggregate

agg = aggregate.BlendAggregator(True,9,1,1,0)
from torch_geometric.loader import DataLoader
batch = next(iter(DataLoader([all_data[0]["graph1"],all_data[0]["graph2"]],batch_size=2)))
print(agg(batch.x,batch).shape)
print(agg(all_data[0]["graph1"].x,all_data[0]["graph1"]).shape)

torch.Size([2, 9])
torch.Size([1, 9])


In [5]:
import aggregate

agg = aggregate.BlendAggregator(False,9,1,1,0)
from torch_geometric.loader import DataLoader
batch = next(iter(DataLoader([all_data[0]["graph1"],all_data[0]["graph2"]],batch_size=2)))
print(agg(batch.x,batch).shape)
print(agg(all_data[0]["graph1"].x,all_data[0]["graph1"]).shape)

torch.Size([2, 9])
torch.Size([1, 9])


In [6]:
import mpnn

config = mpnn.Config(node_out_feats=16,
                 edge_hidden_feats=16, num_step_message_passing=3)
model = mpnn.from_config(config,node_in_feats=9, edge_in_feats=3,dropout=.1, do_edge_update=False, act_mode="relu", aggr_mode="mean")
exmpl = all_data[0]["graph1"]
model(exmpl,exmpl.x,exmpl.edge_attr)

(tensor([[-1.1086e-01,  5.9793e-01, -2.2948e-01,  ..., -0.0000e+00,
          -9.4011e-02,  8.7698e-02],
         [-4.5254e-01,  4.4808e-01,  4.9272e-04,  ..., -0.0000e+00,
          -4.1328e-02,  9.4875e-01],
         [-2.9034e-01,  4.4947e-01,  7.5658e-02,  ..., -3.9582e-01,
           1.7581e-02,  0.0000e+00],
         ...,
         [-3.3554e-01,  5.5025e-01, -4.5601e-02,  ..., -5.9421e-01,
           1.7330e-01,  9.7878e-01],
         [-4.4461e-01,  5.9406e-01, -4.6533e-02,  ..., -5.5005e-01,
           0.0000e+00,  9.8289e-01],
         [-3.6731e-01,  5.6484e-01, -6.6888e-04,  ..., -5.0887e-01,
           4.2269e-02,  1.1149e+00]], grad_fn=<MulBackward0>),
 tensor([[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]))

In [7]:
import mpnn

config = mpnn.Config(node_out_feats=16,
                 edge_hidden_feats=16, num_step_message_passing=3)
model = mpnn.from_config(config,node_in_feats=9, edge_in_feats=3,dropout=.1, do_edge_update=True, act_mode="silu", aggr_mode="mean")
exmpl = all_data[0]["graph1"]
model(exmpl,exmpl.x,exmpl.edge_attr)

(tensor([[ 0.0651,  0.6191,  0.0000,  ..., -0.6613,  0.4398,  0.1102],
         [ 0.3710,  0.1486,  0.4202,  ..., -0.7344,  0.5642, -0.0541],
         [ 0.4385, -0.4399,  1.1991,  ..., -0.7480,  0.5107, -0.2213],
         ...,
         [ 0.2609,  0.4220,  0.0000,  ..., -0.5616,  0.4564,  0.1654],
         [ 0.2520,  0.6360,  1.4052,  ..., -0.7652,  0.5367,  0.0000],
         [ 0.3235,  0.8571,  1.3611,  ..., -0.7862,  0.5557,  0.0000]],
        grad_fn=<MulBackward0>),
 tensor([[-0.0917, -0.1025,  0.0140,  ..., -0.0713,  0.0266,  0.0646],
         [-0.1151, -0.0911, -0.0175,  ..., -0.0892, -0.0077,  0.0854],
         [-0.1225, -0.1051, -0.0367,  ..., -0.0712, -0.0018,  0.0951],
         ...,
         [-0.0941, -0.1129,  0.0032,  ..., -0.0701,  0.0014,  0.0922],
         [-0.0922, -0.1213,  0.0118,  ..., -0.0000,  0.0109,  0.0993],
         [-0.0000, -0.1283,  0.0011,  ..., -0.0947,  0.0329,  0.0954]],
        grad_fn=<MulBackward0>))

In [8]:
import utils

utils.readout_counts(model)

{'total': '7,536',
 'project_node_feats': '160',
 'project_edge_feats': '64',
 'edge_update_network': '1,056',
 'gnn_layer': '4,624',
 'gru': '1,632',
 'final_dropout': '0',
 'act_fn': '0'}

In [9]:
import encoder
import torch

mpnn_configs = [mpnn.Config(node_out_feats=16,
                 edge_hidden_feats=8, num_step_message_passing=5), mpnn.Config(node_out_feats=64,
                 edge_hidden_feats=32, num_step_message_passing=3), mpnn.Config(node_out_feats=128,
                 edge_hidden_feats=64, num_step_message_passing=1)]
config = {"mpnn_configs":mpnn_configs, "do_two_stage":True, "do_edge_update":True, "embedding_dim_x":32, "embedding_dim_edge_attr": 64, "do_edge_update":True, "num_sabs":8,"dropout":0.1, "heads":8, "warmup":.05, "lr": 1e-3, "weight_decay":.01, "betas":(.99,.999), "act_mode":"silu", "aggr_mode":"mean"}
ex_model = encoder.Encoder(graph_tokenizer=None,**config)
ex_model(exmpl)

tensor([[ 18.2376,  12.8071,  -5.5295,  16.5835,  18.4512,   4.5106,  18.9346,
           5.3940,   5.7368,  -4.6714,  -8.5467,  29.2297,   0.4495,  18.6136,
           7.4170,   3.7315,  31.7972,   3.0480,  11.4155,   4.5478,   7.8203,
           4.8087,  19.6080,  -9.3006,   5.8667,  -8.6544,   4.4850,   9.0722,
          22.7579,   1.4889,  -0.8683,  26.1983,  28.2484,   6.8676,   1.8244,
          13.4864,  17.5423,  16.9630,   0.1558,   9.0989,  32.4871,  30.0749,
           3.7551,   3.4496,  14.7507,  35.4850,  45.2909,  18.2363,  -4.7399,
          20.7594,  -0.4542,   2.3403,  29.2756,  17.5551,   9.3151,   8.3835,
          21.2074,   5.4204,   3.1834,   1.5830,  10.1714,  20.9592,   5.7934,
           2.1979,   3.4059,   1.6745,  10.9144,  22.0496,  43.2983,   9.6749,
           2.0627,  27.1440,  -4.7370,  26.5000,  13.2379,  10.7729,   0.1793,
          -2.9773,  -7.3232,  21.0099,  29.1560,  32.1109,   4.8592,   7.1054,
          14.8848,   6.1785,  16.9974,  19.2201,  10

In [10]:
mpnn_configs = [mpnn.Config(node_out_feats=64,
                 edge_hidden_feats=32, num_step_message_passing=3)]
config = {"mpnn_configs":mpnn_configs,  "do_two_stage":False, "embedding_dim_x":32, "embedding_dim_edge_attr": 64, "do_edge_update":False, "num_sabs":8,"dropout":0.1, "heads":8, "warmup":.05, "lr": 1e-3, "weight_decay":.01, "betas":(.99,.999), "act_mode":"gelu","aggr_mode":"max"}
ex_model = encoder.Encoder(graph_tokenizer=graph_tokenizer,**config)
exmpl_tokenized_graph = graph_tokenizer.tokenize(all_data[0]["graph1"])
ex_model(exmpl_tokenized_graph)

tensor([[ 1.3085, -0.6367,  3.6747,  2.3550,  0.2019,  0.1382,  1.9521,  2.8015,
          3.1114,  2.4208,  1.3376,  0.5894,  2.9270, -0.7420,  2.3614,  0.5792,
          0.2500,  1.0394,  0.5326,  0.2194,  2.6679, -0.6281,  1.6159, -0.4010,
          2.0637,  2.2164,  3.0328,  0.8726,  0.1433,  2.5497,  0.8665,  3.3314,
          2.1336,  1.2047,  0.8032, -0.9606,  2.7784, -0.0494, -0.2224,  0.1641,
          0.7583,  1.1312,  2.1103,  2.1527,  2.4876,  1.7718,  1.5835,  1.3554,
          1.6235,  1.0931,  1.5464, -0.2125,  1.2835,  4.1847,  1.9146,  1.0801,
          1.0353,  4.0354,  1.0553,  1.0515,  4.6782,  1.6404,  1.4699,  0.9885]],
       grad_fn=<ViewBackward0>)

In [11]:
class CrossEncoder(torch.nn.Module):
    def __init__(self,encoder, do_cosine_similarity, **kwargs):
        super(CrossEncoder, self).__init__()
        self.encoder = encoder
        self.do_cosine_similarity = do_cosine_similarity
        if not self.do_cosine_similarity:
          self.readout = torch.nn.Linear(self.encoder.readout.in_channels*2,1)

    def forward(self,graph1, graph2):
      embed1 = self.encoder(graph1)
      embed2 = self.encoder(graph2)

      if self.do_cosine_similarity:
        return torch.nn.functional.cosine_similarity(embed1,embed2)

      return torch.nn.functional.sigmoid(self.readout(torch.cat([embed1,embed2],dim=-1))).squeeze(dim=-1)

m = CrossEncoder(ex_model,False)