Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
91 lines (81 sloc) 3.27 KB
#!/usr/bin/env python
# coding: utf-8
# pylint: disable=C0103, C0111, E1101, W0612
"""Implementation of MPNN model."""
import torch.nn as nn
import torch.nn.functional as F
from ...contrib.deprecation import deprecated
from ...nn.pytorch import Set2Set, NNConv
class MPNNModel(nn.Module):
"""
MPNN from
`Neural Message Passing for Quantum Chemistry <https://arxiv.org/abs/1704.01212>`__
Parameters
----------
node_input_dim : int
Dimension of input node feature, default to be 15.
edge_input_dim : int
Dimension of input edge feature, default to be 15.
output_dim : int
Dimension of prediction, default to be 12.
node_hidden_dim : int
Dimension of node feature in hidden layers, default to be 64.
edge_hidden_dim : int
Dimension of edge feature in hidden layers, default to be 128.
num_step_message_passing : int
Number of message passing steps, default to be 6.
num_step_set2set : int
Number of set2set steps
num_layer_set2set : int
Number of set2set layers
"""
@deprecated('Import MPNNPredictor from dgllife.model instead.', 'class')
def __init__(self,
node_input_dim=15,
edge_input_dim=5,
output_dim=12,
node_hidden_dim=64,
edge_hidden_dim=128,
num_step_message_passing=6,
num_step_set2set=6,
num_layer_set2set=3):
super(MPNNModel, self).__init__()
self.num_step_message_passing = num_step_message_passing
self.lin0 = nn.Linear(node_input_dim, node_hidden_dim)
edge_network = nn.Sequential(
nn.Linear(edge_input_dim, edge_hidden_dim), nn.ReLU(),
nn.Linear(edge_hidden_dim, node_hidden_dim * node_hidden_dim))
self.conv = NNConv(in_feats=node_hidden_dim,
out_feats=node_hidden_dim,
edge_func=edge_network,
aggregator_type='sum')
self.gru = nn.GRU(node_hidden_dim, node_hidden_dim)
self.set2set = Set2Set(node_hidden_dim, num_step_set2set, num_layer_set2set)
self.lin1 = nn.Linear(2 * node_hidden_dim, node_hidden_dim)
self.lin2 = nn.Linear(node_hidden_dim, output_dim)
def forward(self, g, n_feat, e_feat):
"""Predict molecule labels
Parameters
----------
g : DGLGraph
Input DGLGraph for molecule(s)
n_feat : tensor of dtype float32 and shape (B1, D1)
Node features. B1 for number of nodes and D1 for
the node feature size.
e_feat : tensor of dtype float32 and shape (B2, D2)
Edge features. B2 for number of edges and D2 for
the edge feature size.
Returns
-------
res : Predicted labels
"""
out = F.relu(self.lin0(n_feat)) # (B1, H1)
h = out.unsqueeze(0) # (1, B1, H1)
for i in range(self.num_step_message_passing):
m = F.relu(self.conv(g, out, e_feat)) # (B1, H1)
out, h = self.gru(m.unsqueeze(0), h)
out = out.squeeze(0)
out = self.set2set(g, out)
out = F.relu(self.lin1(out))
out = self.lin2(out)
return out
You can’t perform that action at this time.