In [None]:
%run Featurize_Input.ipynb

Using backend: pytorch


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import NNConv
from rdkit import Chem
from tqdm.notebook import tqdm

In [None]:
from dgllife.model.gnn.mpnn import MPNNGNN
from dgllife.model.readout.mlp_readout import MLPNodeReadout
from dgllife.model.readout.attentivefp_readout import AttentiveFPReadout
from dgllife.model.readout.weighted_sum_and_max import WeightedSumAndMax

In [None]:
from dgllife.model.model_zoo.mpnn_predictor import MPNNPredictor

In [None]:
class Model(nn.Module):
    def __init__(self, 
                 node_in_feats,
                 edge_in_feats,
                 node_out_feats=64,
                 edge_hidden_feats=128,
                 n_tasks=1,
                 num_step_message_passing=6):
        
        super(Model, self).__init__()
        self.gnn = MPNNGNN(node_in_feats=node_in_feats,
                           node_out_feats=node_out_feats,
                           edge_in_feats=edge_in_feats,
                           edge_hidden_feats=edge_hidden_feats,
                           num_step_message_passing=num_step_message_passing)
        
        self.readout = MLPNodeReadout(node_feats=node_in_feats, hidden_feats=edge_hidden_feats, graph_feats=node_out_feats)
        #self.readout = AttentiveFPReadout(feat_size=node_in_feats)
        #self.readout = WeightedSumAndMax(in_feats=node_in_feats)
        
        self.predict = nn.Sequential(
            nn.Linear(node_out_feats, node_out_feats), #nn.Linear(2, node_out_feats), 
            nn.Dropout(p=0.5),
            nn.ReLU(),
            nn.Linear(node_out_feats, n_tasks)
        )
        
    def forward(self, g, nodes, edges):
        node_feats = self.gnn(g, nodes, edges)
        graph_feats = self.readout(g, nodes)
        return self.predict(graph_feats)

In [None]:
#model = MPNNPredictor(1,1)
model = Model(1,1)

In [None]:
def train(epochs):
    optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
    for epoch in tqdm(range(epochs)):
        model.train()
        running_loss = 0.
        for batch_x, batch_y in dataloader:
            optimizer.zero_grad()
            atoms = batch_x.ndata['atomic']
            edges = batch_x.edata['type']
            y_pred = model(batch_x, atoms, edges)
            mse = ((y_pred.reshape(-1) - batch_y)**2).sum()
            running_loss += mse.item()
            mse.backward()
            optimizer.step()
            
        running_loss /= len(dataloader)
        print("Train loss: ", running_loss)

In [None]:
train(10)

In [None]:
#train(2) # mlp + 1 ELU

In [None]:
#train(2) # mlp + 2 ReLU

In [None]:
#train(2) # mlp + 2 ELU

In [None]:
#train(2) #readout mlp + 1 ReLU (original)

In [None]:
#train(2) #readout =  attentive

In [None]:
#train(2) #readout = weighted

In [None]:
#train(2) #mpnn predictor

In [None]:
print('done')