In [75]:
%run Featurize_Input.ipynb

done


In [76]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import NNConv
from rdkit import Chem
from tqdm.notebook import tqdm

In [95]:
from dgllife.model.gnn.mpnn import MPNNGNN
from dgllife.model.readout.mlp_readout import MLPNodeReadout
from dgllife.model.readout.attentivefp_readout import AttentiveFPReadout
from dgllife.model.readout.weighted_sum_and_max import WeightedSumAndMax

In [96]:
from dgllife.model.model_zoo.mpnn_predictor import MPNNPredictor

In [124]:
class Model(nn.Module):
    def __init__(self, 
                 node_in_feats,
                 edge_in_feats,
                 node_out_feats=64,
                 edge_hidden_feats=128,
                 n_tasks=1,
                 num_step_message_passing=6):
        
        super(Model, self).__init__()
        self.gnn = MPNNGNN(node_in_feats=node_in_feats,
                           node_out_feats=node_out_feats,
                           edge_in_feats=edge_in_feats,
                           edge_hidden_feats=edge_hidden_feats,
                           num_step_message_passing=num_step_message_passing)
        
        self.readout = MLPNodeReadout(node_feats=node_in_feats, hidden_feats=edge_hidden_feats, graph_feats=node_out_feats)
        #self.readout = AttentiveFPReadout(feat_size=node_in_feats)
        #self.readout = WeightedSumAndMax(in_feats=node_in_feats)
        
        self.predict = nn.Sequential(
            nn.Linear(node_out_feats, node_out_feats), #nn.Linear(2, node_out_feats), 
            nn.Dropout(p=0.5),
            nn.ReLU(),
            nn.Linear(node_out_feats, n_tasks)
        )
        
    def forward(self, g, nodes, edges):
        node_feats = self.gnn(g, nodes, edges)
        graph_feats = self.readout(g, nodes)
        return self.predict(graph_feats)

In [121]:
#model = MPNNPredictor(1,1)
model = Model(1,1)

In [122]:
def train(epochs):
    optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
    for epoch in tqdm(range(epochs)):
        model.train()
        running_loss = 0.
        for batch_x, batch_y in dataloader:
            optimizer.zero_grad()
            atoms = batch_x.ndata['atomic']
            edges = batch_x.edata['type']
            y_pred = model(batch_x, atoms, edges)
            mse = ((y_pred.reshape(-1) - batch_y)**2).sum()
            running_loss += mse.item()
            mse.backward()
            optimizer.step()
            
        running_loss /= len(dataloader)
        print("Train loss: ", running_loss)

In [None]:
train(20)

In [9]:
# test
# for batch_x, batch_y in dataloader:
#     model.eval()
#     atoms = batch_x.ndata['atomic']
#     edges = batch_x.edata['type']
#     y_pred = model(batch_x, atoms, edges)
#     mse = ((y_pred.reshape(-1) - batch_y)**2).sum()
#     break
# print(mse)

tensor(1.3691e+08, dtype=torch.float64, grad_fn=<SumBackward0>)


In [123]:
#train(2) # mlp + 1 ELU

  0%|          | 0/2 [00:00<?, ?it/s]

Train loss:  137018193.2235191
Train loss:  70259238.37893505


In [115]:
#train(2) # mlp + 2 ReLU

  0%|          | 0/2 [00:00<?, ?it/s]

Train loss:  145059593.34239236
Train loss:  68584692.3672683


In [119]:
#train(2) # mlp + 2 ELU

  0%|          | 0/2 [00:00<?, ?it/s]

Train loss:  146496884.24099994
Train loss:  81367387.03909391


In [108]:
#train(2) #readout mlp + 1 ReLU (original)

  0%|          | 0/2 [00:00<?, ?it/s]

Train loss:  136093413.00844288
Train loss:  61622441.00163776


In [94]:
#train(2) #readout =  attentive

  0%|          | 0/2 [00:00<?, ?it/s]

Train loss:  152534715.5615262
Train loss:  147530688.93964967


In [104]:
#train(2) #readout = weighted

  0%|          | 0/2 [00:00<?, ?it/s]

Train loss:  154704668.90200973
Train loss:  152357714.10068044


In [111]:
#train(2) #mpnn predictor

  0%|          | 0/2 [00:00<?, ?it/s]

Train loss:  154432868.16419235
Train loss:  152550736.06438985


In [None]:
print('done')