In [1]:
import json
import h5py
import base64
import tqdm
import torch
import torch_geometric as tg
import numpy as np

def decode_smiles(encoded):
    return base64.urlsafe_b64decode(encoded.encode()).decode()

dataset = dict()
with h5py.File('Data/dataset.h5', 'r') as f:
    for i, encoded_pair in enumerate(tqdm.tqdm(f.keys())):
        pair_str = decode_smiles(encoded_pair)
        pair = json.loads(pair_str)
        
        group = f[encoded_pair]
        
        positives = group['positives'][:].astype(str).tolist()
        negatives = group['negatives'][:].astype(str).tolist()
        
        graph_group = group['graph']
        graph_data = {k: torch.tensor(np.array(v)) for k, v in graph_group.items()}
        graph_data = {k: v.float() if k != "edge_index" else v.long() for k, v in graph_data.items()}
        
        dataset[tuple(pair)] = {"positives":positives, "negatives":negatives, "graph": tg.data.Data(**graph_data)}
        
        if i > 10000:
            break

  6%|██                                 | 10001/166733 [00:15<04:08, 631.97it/s]


In [2]:
# The ordering of the graphs in this is arbitrary (alphabetically based on SMILES)
# but the logistic predictor does rely on this ordering.
class PairData(tg.data.Data):
    def __inc__(self, key, value, *args, **kwargs):
        if key == 'edge_index_anchor':
            return self.x_anchor.size(0)
        if key == 'edge_index_positive':
            return self.x_positive.size(0)
        return super().__inc__(key, value, *args, **kwargs)

    def __cat_dim__(self, key, value, *args, **kwargs):
        if key == 'y':
            return None
        return super().__cat_dim__(key, value, *args, **kwargs)
    
def make_pair_data(anchor_graph,positive_graph):
    return PairData(x_anchor=anchor_graph["x"].float(),
                  edge_attr_anchor=anchor_graph["edge_attr"].float(),
                  edge_index_anchor=anchor_graph["edge_index"],
                  x_positive=positive_graph["x"].float(),
                  edge_attr_positive=positive_graph["edge_attr"].float(),
                  edge_index_positive=positive_graph["edge_index"])

In [5]:
import random

def get_in_dataset(pair,key):
    positive = ""
    i = 0
    # TODO: Remove this once we load in all data
    while not positive in dataset:
        positive = tuple(random.choice(dataset[pair][key]))
        i += 1
        if i > 10000:
            print("FAILED")
            break
    positive_graph = dataset[positive]["graph"]
    return positive_graph

In [6]:
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.data import Batch
from torch_geometric.data import Dataset

def triplet_graph_generator(to_yield,num_negatives):
    for i, (pair, data) in enumerate(dataset.items()):
        if i >= to_yield:
            break

        anchor = data["graph"]
        
        positive = get_in_dataset(pair,"positives")
        
        negatives = [get_in_dataset(pair,"negatives") for _ in range(num_negatives)]
        
        yield anchor, positive, negatives
        

class TripletGraphDataset(Dataset):
    def __init__(self, generator_func, to_yield, num_negatives):
        super(TripletGraphDataset, self).__init__()
        self.generator_func = generator_func
        self.data_list = list(generator_func(to_yield, num_negatives))  # Converting generator to list for len() support

    def len(self):
        return len(self.data_list)

    def get(self, idx):
        anchor, positive, negatives = self.data_list[idx]
        return anchor, positive, negatives


def triplet_collate_fn(batch):
    anchors, positives, negatives = [], [], []
    for anchor, positive, neg_list in batch:
        anchors.append(anchor)
        positives.append(positive)
        negatives.extend(neg_list)
    return Batch.from_data_list(anchors), Batch.from_data_list(positives), Batch.from_data_list(negatives)

# Need to switch the number of negatives over time (this will mean lowering bsz)
# Create an instance of the custom dataset
triplet_dataset = TripletGraphDataset(triplet_graph_generator,16,0)

# Create a DataLoader with a custom collate function
dataloader = DataLoader(triplet_dataset, batch_size=4, shuffle=True, collate_fn=triplet_collate_fn)

In [12]:
def count_parameters(module):
    return "{:,}".format(sum(p.numel() for p in module.parameters()))

def readout_counts(module):
    results = {"total":count_parameters(module)}
    for n, c in module.named_children():
        results[n] = count_parameters(c)
    return results
readout_counts(model.conv)

{'total': '2,230,912',
 'project_node_feats': '1,280',
 'gnn_layer': '2,130,560',
 'gru': '99,072',
 'final_dropout': '0'}

In [13]:
import torch.nn as nn
import torch.nn.functional as F

# Shamelessly stolen from (converted to PytorchGeometric)
# https://lifesci.dgl.ai/_modules/dgllife/model/gnn/mpnn.html
class MPNNGNN(nn.Module):
    """MPNN.

    MPNN is introduced in `Neural Message Passing for Quantum Chemistry
    <https://arxiv.org/abs/1704.01212>`__.

    This class performs message passing in MPNN and returns the updated node representations.

    Parameters
    ----------
    node_in_feats : int
        Size for the input node features.
    node_out_feats : int
        Size for the output node representations. Default to 64.
    edge_in_feats : int
        Size for the input edge features. Default to 128.
    edge_hidden_feats : int
        Size for the hidden edge representations.
    num_step_message_passing : int
        Number of message passing steps. Default to 6.
    """
    def __init__(self, node_in_feats, edge_in_feats, node_out_feats=64,
                 edge_hidden_feats=128, num_step_message_passing=6, dropout=0.1):
        super(MPNNGNN, self).__init__()

        # This should be changed to node wise dropout. But maybe not?
        # See https://arxiv.org/pdf/1411.4280        
        self.project_node_feats = nn.Sequential(
            nn.Linear(node_in_feats, node_out_feats),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        self.num_step_message_passing = num_step_message_passing
        edge_network = nn.Sequential(
            nn.Linear(edge_in_feats, edge_hidden_feats),
            nn.ReLU(), # Could add dropout after this.
            nn.Linear(edge_hidden_feats, node_out_feats * node_out_feats),
            nn.Dropout(dropout) # This one is after the largest by far.
        )

        self.gnn_layer = tg.nn.conv.NNConv(
            in_channels=node_out_feats,
            out_channels=node_out_feats,
            nn=edge_network,
            aggregator_type='sum'
        )

        # If we add a second layer, we could add dropout.
        self.gru = nn.GRU(node_out_feats, node_out_feats,bidirectional=False)
        self.final_dropout = nn.Dropout(dropout)


    def forward(self, graph):
        """Performs message passing and updates node representations.

        Parameters
        ----------
        g : DGLGraph
            DGLGraph for a batch of graphs.
        node_feats : float32 tensor of shape (V, node_in_feats)
            Input node features. V for the number of nodes in the batch of graphs.
        edge_feats : float32 tensor of shape (E, edge_in_feats)
            Input edge features. E for the number of edges in the batch of graphs.

        Returns
        -------
        node_feats : float32 tensor of shape (V, node_out_feats)
            Output node representations.
        """
        node_feats = graph.x
        edge_feats = graph.edge_attr
        node_feats = self.project_node_feats(node_feats) # (V, node_out_feats)
        hidden_feats = node_feats.unsqueeze(0)           # (1, V, node_out_feats)

        for _ in range(self.num_step_message_passing):
            node_feats = F.relu(self.gnn_layer(node_feats, graph.edge_index, edge_feats))
            node_feats, hidden_feats = self.gru(node_feats.unsqueeze(0), hidden_feats)
            node_feats = self.final_dropout(node_feats.squeeze(0))

        return node_feats


class Encoder(torch.nn.Module):
    def __init__(self,node_out_feats=128,edge_hidden_feats=128,num_step_message_passing=5,dropout=0.1):
        super(Encoder, self).__init__()
        self.conv = MPNNGNN(node_in_feats=9,edge_in_feats=3,node_out_feats=node_out_feats,edge_hidden_feats=edge_hidden_feats,num_step_message_passing=num_step_message_passing,dropout=dropout)
        # https://github.com/davidbuterez/multi-fidelity-gnns-for-drug-discovery-and-quantum-mechanics/blob/3f39d12b66447f62960bf9e4b45070b266328555/schnet_multiple_fidelities/schnet_high_fidelity.py#L159
        self.readout = tg.nn.aggr.set_transformer.SetTransformerAggregation(node_out_feats,heads=8,num_encoder_blocks=2,num_decoder_blocks=2,dropout=dropout)

    def forward(self,graph):
        x = self.conv(graph)
        if "batch" in graph:
            return self.readout(x,graph.batch)
        return self.readout(x)
    
    def count_parameters(self):
        return {"total":count_parameters(self), "conv":count_parameters(self.conv), "readout":count_parameters(self.readout)}
    
# model = Encoder()
# model(example["graph"])

model = Encoder()
# Forward pass through the NNConv layer



example = next(iter(dataset.values()))
print(model(example["graph"]).shape)
print(readout_counts(model))
print(readout_counts(model.conv))


torch.Size([1, 128])
{'total': '2,660,352', 'conv': '2,230,912', 'readout': '429,440'}
{'total': '2,230,912', 'project_node_feats': '1,280', 'gnn_layer': '2,130,560', 'gru': '99,072', 'final_dropout': '0'}


In [14]:
import info_nce
loss_fn = info_nce.InfoNCE()
# Training loop
for anchors, positives, negatives in dataloader:
    anchor_embeds = model(anchors)
    positives_embeds = model(positives)
    loss = loss_fn(anchor_embeds,positives_embeds)
