In [5]:
import torch 
import torch_geometric
from torch_geometric.nn import MessagePassing
from molecular_mpns.data import MarkovMolGraph
from molecular_mpns.config import data_dir
import numpy as np

In [6]:
class EdgeFeature(torch.nn.Module):
    
    def __init__(self,n_rbf,rbf_range,gamma,h_dim):
        
        super(EdgeFeature,self).__init__()
        
        self.register_buffer('centers',torch.linspace(rbf_range[0],rbf_range[1],n_rbf))
        self.register_buffer('gamma',torch.tensor(gamma))
        self.dense1 = torch.nn.Linear(n_rbf,h_dim)
        self.dense2 = torch.nn.Linear(h_dim,h_dim)
        
    def forward(self,edge_attr):
        
        center_dists = self.centers-edge_attr
        a = torch.exp(-self.gamma*(center_dists**2))
        a = self.dense1(a)
        a = torch.nn.functional.silu(a)
        a = self.dense2(a)
        a = torch.nn.functional.silu(a)
        
        return a
    
class CFConv(MessagePassing):
    
    def __init__(self,n_rbf,rbf_range,gamma,h_dim):
        
        super(CFConv,self).__init__(aggr = 'add')
        
        self.edge_features = EdgeFeature(n_rbf,rbf_range,gamma,h_dim)
        
    def forward(self,edge_index,edge_attr,x):
        a = self.edge_features(edge_attr)
        return self.propagate(edge_index=edge_index,a=a,x=x)
    
    def message(self,x_i,x_j,a,flow = 'source_to_target'):
        #breakpoint()
        tmp = a*x_j
        return a*x_j
    
    def update(self,aggr_out):
        #breakpoint()
        return aggr_out
    
class Interaction(torch.nn.Module):
    
    def __init__(self,n_rbf,rbf_range,gamma,dim):
        super(Interaction,self).__init__()
        
        self.atomwise1 = torch.nn.Linear(dim,dim)
        self.cfconv = CFConv(n_rbf,rbf_range,gamma,dim)
        self.atomwise2 = torch.nn.Linear(dim,dim)
        self.atomwise3 = torch.nn.Linear(dim,dim)
        
    def forward(self,edge_index,edge_attr,x):
        
        h = self.atomwise1(x)
        h = self.cfconv(edge_index,edge_attr,h)
        h = self.atomwise2(h)
        h = torch.nn.functional.silu(h)
        h = self.atomwise3(h)
        return x + h
    
class ProtoNet(torch.nn.Module):
    
    def __init__(self,emb_dim,intermediate_dim,n_rbf,rbf_range,gamma):
        super(ProtoNet,self).__init__()
        
        self.emb_dim = emb_dim
        self.embedding = torch.nn.Embedding(5,emb_dim)
        
        self.interaction1 = Interaction(n_rbf,rbf_range,gamma,emb_dim)
        self.interaction2 = Interaction(n_rbf,rbf_range,gamma,emb_dim)
        self.interaction3 = Interaction(n_rbf,rbf_range,gamma,emb_dim)
        
        self.atomwise1 = torch.nn.Linear(emb_dim,intermediate_dim)
        self.atomwise2 = torch.nn.Linear(intermediate_dim,1)
        
    def forward(self,mol_graph):
        
        edge_index,x,r_current = mol_graph.edge_index,mol_graph.Z,mol_graph.r_current
        edge_attr = torch.cdist(r_current,r_current)
        edge_attr = edge_attr.view(r_current.shape[0]*r_current.shape[0],1)
        
        h = self.embedding(x)
        h = h.view(x.shape[0],self.emb_dim)
        h = self.interaction1(edge_index,edge_attr,h)
        h = self.interaction2(edge_index,edge_attr,h)
        h = self.interaction3(edge_index,edge_attr,h)
        h = self.atomwise1(h)
        h = self.atomwise2(h)
        
        V =torch.sum(h)
        
        return V
        

In [7]:
# example data
data = np.load(str(data_dir)+'/proto_mol_traj.npy')
r_current,r_next,Z = data[0,:],data[1,:],np.array([[0],[1],[2],[3],[4]])
ex = MarkovMolGraph(r_current,r_next,Z)

# build and test network
n_rbf,rbf_range,gamma,h_dim = 8,[0,4],2,16

proto = ProtoNet(4,4,n_rbf,rbf_range,gamma).double()
tst2 = proto(ex)
grd = torch.autograd.grad(tst2,ex.r_current,create_graph = True)