In [1]:
import tokenizer
import torch

graph_tokenizer = tokenizer.GraphTokenizer(torch.load("dictionary.pt"))

In [2]:
import h5py
import tqdm
import numpy as np
import torch
import data
import torch_geometric as tg

all_data = []
with h5py.File('Data/train.h5', 'r') as f:
    for label in tqdm.tqdm(f.keys()):
        group = f[label]
        graph1 = data.read_graph(group['graph1'])
        graph2 = data.read_graph(group['graph2'])
        # Index using () for scalar dataset
        y = group["y"][()]
        all_data.append({"graph1":graph1,"graph2":graph2,"y":torch.tensor(y)})

all_data[0]

100%|█████████████████████████████████████| 500/500 [00:00<00:00, 1108.28it/s]


{'graph1': BlendData(x=[254, 9], edge_index=[2, 476], edge_attr=[476, 3], blend_batch=[28], mol_batch=[254]),
 'graph2': BlendData(x=[239, 9], edge_index=[2, 452], edge_attr=[452, 3], blend_batch=[28], mol_batch=[239]),
 'y': tensor(0.5769)}

In [3]:
graph_tokenizer.tokenize(all_data[0]["graph1"])

BlendData(x=[254], edge_index=[2, 476], edge_attr=[476], blend_batch=[28], mol_batch=[254])

In [4]:
import aggregate

agg = aggregate.BlendAggregator(True,9,1,1,0)
from torch_geometric.loader import DataLoader
batch = next(iter(DataLoader([all_data[0]["graph1"],all_data[0]["graph2"]],batch_size=2)))
print(agg(batch.x,batch).shape)
print(agg(all_data[0]["graph1"].x,all_data[0]["graph1"]).shape)

torch.Size([2, 9])
torch.Size([1, 9])


In [5]:
import aggregate

agg = aggregate.BlendAggregator(False,9,1,1,0)
from torch_geometric.loader import DataLoader
batch = next(iter(DataLoader([all_data[0]["graph1"],all_data[0]["graph2"]],batch_size=2)))
print(agg(batch.x,batch).shape)
print(agg(all_data[0]["graph1"].x,all_data[0]["graph1"]).shape)

torch.Size([2, 9])
torch.Size([1, 9])


In [6]:
import mpnn

config = mpnn.Config(node_out_feats=16,
                 edge_hidden_feats=16, num_step_message_passing=3)
model = mpnn.from_config(config,node_in_feats=9, edge_in_feats=3,dropout=.1, do_edge_update=True)
exmpl = all_data[0]["graph1"]
model(exmpl,exmpl.x,exmpl.edge_attr)

(tensor([[-0.0061,  0.0380, -0.2471,  ..., -0.3648,  0.1038,  0.2279],
         [-0.2348, -0.0865, -0.1775,  ..., -0.3918,  0.0007, -0.0456],
         [-0.1419, -0.0472, -0.2339,  ..., -0.4211,  0.0457,  0.1520],
         ...,
         [-0.0000, -0.0282, -0.0769,  ..., -0.3449,  0.2760, -0.0175],
         [-0.1764, -0.1300, -0.0935,  ..., -0.4022,  0.1159,  0.0821],
         [ 0.0129,  0.0000, -0.2097,  ..., -0.3474,  0.1487,  0.2154]],
        grad_fn=<MulBackward0>),
 tensor([[-0.0000, -0.2004,  0.2614,  ...,  0.1953, -0.0359,  0.0978],
         [-0.0358, -0.1888,  0.0000,  ...,  0.1675, -0.0655,  0.0000],
         [-0.0336, -0.1829,  0.2747,  ...,  0.1635, -0.0626,  0.0949],
         ...,
         [-0.0365, -0.2037,  0.2410,  ...,  0.1521, -0.0518,  0.0906],
         [-0.0170, -0.2200,  0.2573,  ...,  0.1755, -0.0749,  0.0433],
         [-0.0000, -0.2078,  0.2550,  ...,  0.1729, -0.0000,  0.0830]],
        grad_fn=<MulBackward0>))

In [7]:
import utils

utils.readout_counts(model)

{'total': '12,224',
 'base': '6,480',
 'project_edge_feats': '64',
 'edge_update_network': '1,056',
 'gnn_layer': '4,624'}

In [8]:
import encoder
import torch

mpnn_configs = [mpnn.Config(node_out_feats=16,
                 edge_hidden_feats=8, num_step_message_passing=5), mpnn.Config(node_out_feats=64,
                 edge_hidden_feats=32, num_step_message_passing=3), mpnn.Config(node_out_feats=128,
                 edge_hidden_feats=64, num_step_message_passing=1)]
config = {"mpnn_configs":mpnn_configs, "do_two_stage":True, "do_edge_update":True, "embedding_dim_x":32, "embedding_dim_edge_attr": 64, "do_edge_update":True, "num_sabs":8,"dropout":0.1, "heads":8, "warmup":.05, "lr": 1e-3, "weight_decay":.01, "betas":(.99,.999)}
ex_model = encoder.Encoder(graph_tokenizer=None,**config)
ex_model(exmpl)

tensor([[  1.7581,  -0.0954,  22.0972,   1.3918,   9.3610,   8.2376,   3.7030,
          -7.3447,   1.8093,  24.4488,  22.8326,  25.0683,  10.9185,  15.5346,
          25.4530,  -0.4636,  22.9850,  11.8873,   6.6919,  11.7493,  12.8346,
          25.8173,  14.5051,  -3.9931,   4.9809,  -0.6271,  24.2211,  -1.9439,
           5.2912,  11.0629,  -9.5516,  17.9782,   3.2383,   7.0570,  28.8752,
           9.3216,  -0.3137,  28.7258,  12.2864,   1.6070,  20.6726,  10.6713,
          12.6482,   3.1562,  26.2624,  22.5475,  16.7776,  24.5206,  11.5999,
           4.4933,  -3.4439,  17.9020,  11.0727,  25.0334,  26.2841,  -5.3051,
          11.9627,  10.0281,  13.1333,  13.1051,  15.6307,   8.1171,   9.0696,
          -2.6088,  -2.1163,  31.6892,   8.3284,  -0.5623,   5.5673,   9.2366,
           5.8159,   3.9972,  25.4645,  29.3814,  23.7107,   3.1713,  14.8862,
          13.1169,  11.0825,   9.8003,   9.1494,   7.2493,   5.1378,  12.6467,
          33.3584,   2.6590,   3.1000,   3.9050,  21

In [9]:
mpnn_configs = [mpnn.Config(node_out_feats=64,
                 edge_hidden_feats=32, num_step_message_passing=3)]
config = {"mpnn_configs":mpnn_configs,  "do_two_stage":False, "embedding_dim_x":32, "embedding_dim_edge_attr": 64, "do_edge_update":False, "num_sabs":8,"dropout":0.1, "heads":8, "warmup":.05, "lr": 1e-3, "weight_decay":.01, "betas":(.99,.999)}
ex_model = encoder.Encoder(graph_tokenizer=graph_tokenizer,**config)
exmpl_tokenized_graph = graph_tokenizer.tokenize(all_data[0]["graph1"])
ex_model(exmpl_tokenized_graph)

tensor([[ 0.7817,  2.3664,  2.5107,  1.4718,  0.1293,  0.3854,  1.5666,  2.9225,
          0.3086,  4.7423,  2.5347,  3.4020, -0.0377,  0.0152,  0.8125,  5.3271,
          0.4485, -0.2605,  1.5422,  2.2646,  3.4002,  0.6419,  0.4884,  2.3693,
          0.1999,  1.1988, -0.4371,  3.1834,  1.1290,  0.2870,  2.7848, -0.9172,
          1.6293,  2.7480,  0.7802,  0.6460,  1.4231,  0.9999,  1.5083,  0.1930,
          3.3473, -1.7230,  2.3767,  1.5664,  1.5493,  1.2255, -0.0850,  1.4479,
          1.3188,  1.2619,  2.5797,  0.9492,  2.9994,  0.3161, -0.2414,  1.8305,
          1.5627,  0.2711,  3.0488,  1.7639, -0.5238, -0.9707,  1.1103,  1.0575]],
       grad_fn=<ViewBackward0>)

In [10]:
class CrossEncoder(torch.nn.Module):
    def __init__(self,encoder, do_cosine_similarity, **kwargs):
        super(CrossEncoder, self).__init__()
        self.encoder = encoder
        self.do_cosine_similarity = do_cosine_similarity
        if not self.do_cosine_similarity:
          self.readout = torch.nn.Linear(self.encoder.readout.in_channels*2,1)

    def forward(self,graph1, graph2):
      embed1 = self.encoder(graph1)
      embed2 = self.encoder(graph2)

      if self.do_cosine_similarity:
        return torch.nn.functional.cosine_similarity(embed1,embed2)

      return torch.nn.functional.sigmoid(self.readout(torch.cat([embed1,embed2],dim=-1))).squeeze(dim=-1)

m = CrossEncoder(ex_model,False)