In [35]:
import json
import h5py
import base64
import tqdm
import torch
import torch_geometric as tg
import numpy as np

def decode_smiles(encoded):
    return base64.urlsafe_b64decode(encoded.encode()).decode()

dataset = dict()
with h5py.File('Data/dataset.h5', 'r') as f:
    for i, encoded_pair in enumerate(tqdm.tqdm(f.keys())):
        pair_str = decode_smiles(encoded_pair)
        pair = json.loads(pair_str)
        
        group = f[encoded_pair]
        
        positives = group['positives'][:].astype(str).tolist()
        negatives = group['negatives'][:].astype(str).tolist()
        
        graph_group = group['graph']
        graph_data = {k: torch.tensor(np.array(v)) for k, v in graph_group.items()}
        graph_data = {k: v.float() if k != "edge_index" else v.long() for k, v in graph_data.items()}
        
        dataset[tuple(pair)] = {"positives":positives, "negatives":negatives, "graph": tg.data.Data(**graph_data)}
        
        if i > 100:
            break

  0%|                                    | 101/166733 [00:05<2:30:44, 18.42it/s]


In [47]:
import dgl

example = next(iter(dataset.values()))
pyg_data = example["graph"]
# Convert edge_index from PyG format to DGL format
src, dst = pyg_data.edge_index

# Create a DGL graph
dgl_graph = dgl.graph((src, dst))

# Add node features
dgl_graph.ndata['feat'] = pyg_data.x

# Add edge features
dgl_graph.edata['feat'] = pyg_data.edge_attr

In [54]:
from dgllife.model.gnn import MPNNGNN

class Encoder(torch.nn.Module):
    def __init__(self,node_out_feats=64,edge_hidden_feats=128,num_step_message_passing=6):
        super(Encoder, self).__init__()
        self.conv = MPNNGNN(node_in_feats=9,edge_in_feats=3,node_out_feats=node_out_feats,edge_hidden_feats=edge_hidden_feats,num_step_message_passing=num_step_message_passing)
        # https://github.com/davidbuterez/multi-fidelity-gnns-for-drug-discovery-and-quantum-mechanics/blob/3f39d12b66447f62960bf9e4b45070b266328555/schnet_multiple_fidelities/schnet_high_fidelity.py#L159
        self.readout = tg.nn.aggr.set_transformer.SetTransformerAggregation(node_out_feats,heads=8,num_encoder_blocks=2,num_decoder_blocks=2)

    def forward(self,x):
        node_feats: torch.Tensor = x.ndata["feat"]
        edge_feats: torch.Tensor = x.edata["feat"]
        return self.readout(self.conv(x,node_feats,edge_feats))
    
# model = Encoder()
# model(example["graph"])

model = Encoder()
# Forward pass through the NNConv layer

model(dgl_graph).shape

torch.Size([1, 64])

Data(x=[24, 9], edge_index=[2, 44], edge_attr=[44, 3])