In [35]:
import json
import h5py
import base64
import tqdm
import torch
import torch_geometric as tg
import numpy as np

def decode_smiles(encoded):
    return base64.urlsafe_b64decode(encoded.encode()).decode()

dataset = dict()
with h5py.File('Data/dataset.h5', 'r') as f:
    for i, encoded_pair in enumerate(tqdm.tqdm(f.keys())):
        pair_str = decode_smiles(encoded_pair)
        pair = json.loads(pair_str)
        
        group = f[encoded_pair]
        
        positives = group['positives'][:].astype(str).tolist()
        negatives = group['negatives'][:].astype(str).tolist()
        
        graph_group = group['graph']
        graph_data = {k: torch.tensor(np.array(v)) for k, v in graph_group.items()}
        graph_data = {k: v.float() if k != "edge_index" else v.long() for k, v in graph_data.items()}
        
        dataset[tuple(pair)] = {"positives":positives, "negatives":negatives, "graph": tg.data.Data(**graph_data)}
        
        if i > 100:
            break

  0%|                                    | 101/166733 [00:05<2:30:44, 18.42it/s]


In [74]:
example = next(iter(dataset.values()))
example

{'positives': [['CC(C)C1=CC=C(C=C1)CC=O', 'CC(C)CC1=NC=CS1'],
  ['CC(=CCC/C(=C/COCC=O)/C)C', 'CCCCCCOC(=O)C(=CC)C'],
  ['CC(=CC/C=C(/C)\\C=C)C', 'CC1=CCOC(C1)C=C(C)C'],
  ['CCCCCCC/C=C/C=C/CO', 'CCCCCCOC(=O)CCC'],
  ['CC(CC(C)(C)C1=CC=CC=C1)O', 'CCCCCCCCOCC=O'],
  ['CCCCCCC(OC)OC', 'CCCCCCC/C=C/C=O'],
  ['CC/C=C\\CCOCC(=C)C', 'CC1CC(=C)CC(O1)C2=CC=CC=C2'],
  ['CC1=C(SC(=N1)C)C', 'CCCCCC=CC=CC=C'],
  ['CC1(CC[C@@H]([C@](O1)(C)C=C)O)C', 'CCCCCC(/C(=C/CC)/C)O'],
  ['CC/C=C\\CCOC(=O)C(C)CC', 'CCOC1CC(CC(C1)(C)C)C'],
  ['C=CCOC(=O)COC1CCCCC1', 'CCCCCC(=O)OCC(C)C'],
  ['CC(=CCC/C(=C\\CCC(=C)C=C)/C)C', 'CCCCCCOC(C)OCC'],
  ['CC(C)OCCC1=CC=CC=C1', 'CC(CCCC(=C)C)CCOC=O'],
  ['CC/C=C\\CC/C=C/C=O', 'CC\\C=C/CCOC(C)OCC\\C=C/CC'],
  ['CC(C)C(=O)OCCOC1=CC=CC=C1', 'CC=CC(=O)C1=C(CCCC1(C)C)C'],
  ['CCCCCCC(OC)OC', 'CCOC(C)OC(=O)C'],
  ['C1=CC=C(C=C1)C(C2=CC=CC=C2)O', 'CCOC(=O)\\C=C(/C)\\CCC=C(C)C'],
  ['CC1=CCC(CC1)C(C)(C)OC', 'CC1=CCOC(C1)C=C(C)C'],
  ['CCCCCCC(OC)OC', 'CCCCCCCCCC(=O)O'],
  ['C1CC(OC

In [86]:
import torch.nn as nn
import torch.nn.functional as F

# Shamelessly stolen from (converted to PytorchGeometric)
# https://lifesci.dgl.ai/_modules/dgllife/model/gnn/mpnn.html
class MPNNGNN(nn.Module):
    """MPNN.

    MPNN is introduced in `Neural Message Passing for Quantum Chemistry
    <https://arxiv.org/abs/1704.01212>`__.

    This class performs message passing in MPNN and returns the updated node representations.

    Parameters
    ----------
    node_in_feats : int
        Size for the input node features.
    node_out_feats : int
        Size for the output node representations. Default to 64.
    edge_in_feats : int
        Size for the input edge features. Default to 128.
    edge_hidden_feats : int
        Size for the hidden edge representations.
    num_step_message_passing : int
        Number of message passing steps. Default to 6.
    """
    def __init__(self, node_in_feats, edge_in_feats, node_out_feats=64,
                 edge_hidden_feats=128, num_step_message_passing=6):
        super(MPNNGNN, self).__init__()

        self.project_node_feats = nn.Sequential(
            nn.Linear(node_in_feats, node_out_feats),
            nn.ReLU()
        )
        self.num_step_message_passing = num_step_message_passing
        edge_network = nn.Sequential(
            nn.Linear(edge_in_feats, edge_hidden_feats),
            nn.ReLU(),
            nn.Linear(edge_hidden_feats, node_out_feats * node_out_feats)
        )
        self.gnn_layer = tg.nn.conv.NNConv(
            in_channels=node_out_feats,
            out_channels=node_out_feats,
            nn=edge_network,
            aggregator_type='sum'
        )
        self.gru = nn.GRU(node_out_feats, node_out_feats)


    def forward(self, graph):
        """Performs message passing and updates node representations.

        Parameters
        ----------
        g : DGLGraph
            DGLGraph for a batch of graphs.
        node_feats : float32 tensor of shape (V, node_in_feats)
            Input node features. V for the number of nodes in the batch of graphs.
        edge_feats : float32 tensor of shape (E, edge_in_feats)
            Input edge features. E for the number of edges in the batch of graphs.

        Returns
        -------
        node_feats : float32 tensor of shape (V, node_out_feats)
            Output node representations.
        """
        node_feats = graph.x
        edge_feats = graph.edge_attr
        node_feats = self.project_node_feats(node_feats) # (V, node_out_feats)
        hidden_feats = node_feats.unsqueeze(0)           # (1, V, node_out_feats)

        for _ in range(self.num_step_message_passing):
            node_feats = F.relu(self.gnn_layer(node_feats, graph.edge_index, edge_feats))
            node_feats, hidden_feats = self.gru(node_feats.unsqueeze(0), hidden_feats)
            node_feats = node_feats.squeeze(0)

        return node_feats


class Encoder(torch.nn.Module):
    def __init__(self,node_out_feats=64,edge_hidden_feats=128,num_step_message_passing=5):
        super(Encoder, self).__init__()
        self.conv = MPNNGNN(node_in_feats=9,edge_in_feats=3,node_out_feats=node_out_feats,edge_hidden_feats=edge_hidden_feats,num_step_message_passing=num_step_message_passing)
        # https://github.com/davidbuterez/multi-fidelity-gnns-for-drug-discovery-and-quantum-mechanics/blob/3f39d12b66447f62960bf9e4b45070b266328555/schnet_multiple_fidelities/schnet_high_fidelity.py#L159
        self.readout = tg.nn.aggr.set_transformer.SetTransformerAggregation(node_out_feats,heads=8,num_encoder_blocks=2,num_decoder_blocks=2)

    def forward(self,graph):
        return self.readout(self.conv(graph))
    
# model = Encoder()
# model(example["graph"])

model = Encoder()
# Forward pass through the NNConv layer

def count_parameters(module):
    return sum(p.numel() for p in module.parameters())


model(example["graph"]).shape, "{:,}".format(count_parameters(model))

(torch.Size([1, 64]), '666,880')