In [2]:
from rdkit import RDLogger
import pandas as pd
import numpy as np
from dgllife.utils import smiles_to_bigraph
from dgllife.utils import CanonicalAtomFeaturizer, CanonicalBondFeaturizer
from dgllife.utils import ConcatFeaturizer
from dgllife.utils import atom_degree, atom_is_aromatic, atomic_number, atom_is_in_ring, BaseAtomFeaturizer
from dgllife.utils import bond_type_one_hot, bond_is_conjugated, bond_is_in_ring, BaseBondFeaturizer
import torch
from tqdm.notebook import tqdm
from dgl.data.utils import load_graphs
from dgl.data.utils import save_graphs
from dgl.dataloading.pytorch import GraphDataLoader
from torch.utils.data import DataLoader
from rdkit import Chem
from torch.utils.data import Dataset
import torch.nn as nn
from dgl.nn.pytorch import Set2Set
#from dgl.nn.pytorch import MPNNGNN
#from dgllife.model.model_zoo.mpnn_predictor import MPNNPredictor
from dgllife.model.gnn.mpnn import MPNNGNN
from dgllife.model.readout.mlp_readout import MLPNodeReadout
import wandb

Using backend: pytorch


In [7]:
elements = pd.read_csv('indiv_energies.csv')
elements = dict(zip(elements['symbol'], elements['E']))

def get_E_pred(smiles):
    m = Chem.MolFromSmiles(smiles)
    m = Chem.AddHs(m)
    pred = 0
    for atom in m.GetAtoms():
        pred += elements[atom.GetSymbol()]
    return pred

In [8]:
RDLogger.DisableLog('rdApp.*') # disable annoying warning messages during featurization -- we'll just remove incompletely featurized molecules later

dat = []

for dataset in ['B', 'C', 'D', 'G', 'H', 'I']:
    dat.append(pd.read_csv('../data/dataset-%s-E0.csv' % dataset))
    
dat = pd.concat(dat, axis=0)
dat = dat.dropna()

smiles = dat['SMILES']
Y = dat['E_0']
E_diff = Y - smiles.apply(get_E_pred)

del dat

In [3]:
class Dataset(Dataset):
    def __init__(self):
        self.x = load_graphs('canonical_feat_graphs')[0]
        self.y = torch.tensor(list(E_diff), dtype=torch.float64)
    
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

In [9]:
data = Dataset()

In [10]:
gdl = GraphDataLoader(data, batch_size=1536, collate_fn=None)

In [11]:
class OurMPNNPredictor(nn.Module):
    """MPNN for regression and classification on graphs. Modified by HB, MG, JW, JW.

    MPNN is introduced in `Neural Message Passing for Quantum Chemistry
    <https://arxiv.org/abs/1704.01212>`__.

    Parameters
    ----------
    node_in_feats : int
        Size for the input node features.
    edge_in_feats : int
        Size for the input edge features.
    node_out_feats : int
        Size for the output node representations. Default to 8. -> 16
    edge_hidden_feats : int
        Size for the hidden edge representations. Default to 8. -> 16
    n_tasks : int
        Number of tasks, which is also the output size. Default to 1.
    num_step_message_passing : int
        Number of message passing steps. Default to 2.
    """
    def __init__(self,
                 node_in_feats,
                 edge_in_feats,
                 node_out_feats=16,
                 edge_hidden_feats=16,
                 n_tasks=1,
                 num_step_message_passing=2):
        super(OurMPNNPredictor, self).__init__()

        self.gnn = MPNNGNN(node_in_feats=node_in_feats,
                           node_out_feats=node_out_feats,
                           edge_in_feats=edge_in_feats,
                           edge_hidden_feats=edge_hidden_feats,
                           num_step_message_passing=num_step_message_passing)
        self.readout = MLPNodeReadout(node_feats=node_out_feats,
                                                         hidden_feats=2 * node_out_feats,
                                                         graph_feats=2 * node_out_feats)
        self.predict = nn.Sequential(
            nn.Linear(2 * node_out_feats, node_out_feats),
            nn.ReLU(),
            nn.Linear(node_out_feats, n_tasks)
        )

    def forward(self, g, node_feats, edge_feats):
        """Graph-level regression/soft classification.

        Parameters
        ----------
        g : DGLGraph
            DGLGraph for a batch of graphs.
        node_feats : float32 tensor of shape (V, node_in_feats)
            Input node features.
        edge_feats : float32 tensor of shape (E, edge_in_feats)
            Input edge features.

        Returns
        -------
        float32 tensor of shape (G, n_tasks)
            Prediction for the graphs in the batch. G for the number of graphs.
        """
        node_feats = self.gnn(g, node_feats, edge_feats)
        graph_feats = self.readout(g, node_feats)
        return self.predict(graph_feats)


In [12]:
model = OurMPNNPredictor(node_in_feats=74, edge_in_feats=12)
model.load_state_dict(torch.load('down-best_diff2.params'))
model.train()

OurMPNNPredictor(
  (gnn): MPNNGNN(
    (project_node_feats): Sequential(
      (0): Linear(in_features=74, out_features=16, bias=True)
      (1): ReLU()
    )
    (gnn_layer): NNConv(
      (edge_func): Sequential(
        (0): Linear(in_features=12, out_features=16, bias=True)
        (1): ReLU()
        (2): Linear(in_features=16, out_features=256, bias=True)
      )
    )
    (gru): GRU(16, 16)
  )
  (readout): MLPNodeReadout(
    (in_project): Linear(in_features=16, out_features=32, bias=True)
    (out_project): Linear(in_features=32, out_features=32, bias=True)
  )
  (predict): Sequential(
    (0): Linear(in_features=32, out_features=16, bias=True)
    (1): ReLU()
    (2): Linear(in_features=16, out_features=1, bias=True)
  )
)

In [13]:
# best for diff prediction so far: 500 AdamW with 0.004, then 200 AdamW with 0.001, then 200 AdamW with 0.0001

def train(epochs):
    run = wandb.init(project='mlchem_hw8', entity='weiss')
    
    optimizer = torch.optim.AdamW(model.parameters(),lr=0.0001) 
    
    best_score = None
    
    for epoch in tqdm(range(epochs)):
        
        model.train()
        
        running_loss = 0.
                
        for batch_x, batch_y in gdl:
            optimizer.zero_grad()
            atoms = batch_x.ndata['h']
            edges = batch_x.edata['e']
            y_pred = model.forward(batch_x, atoms, edges)
            sse = ((torch.squeeze(y_pred) - batch_y)**2).sum()
            running_loss = sse.item()
            sse.backward()
            optimizer.step()
            
        running_loss /= len(data)
        
        model.eval()
        
        if (best_score is None or running_loss < best_score):
            best_score = running_loss
            torch.save(model.state_dict(), 'down-best_inprogress.params')
            
        wandb.log({'Train loss': running_loss,
                   'Best loss': best_score})
    
    run.finish()

In [None]:
train(200) # this is going to take a while

In [None]:
best_model = OurMPNNPredictor(node_in_feats=74, edge_in_feats=12)
best_model.load_state_dict(torch.load('down-best_inprogress.params'))
best_model.eval()

def model_from_smiles(smiles):
    global best_model
    g = smiles_to_bigraph(smiles, node_featurizer=featurize_atoms, edge_featurizer=featurize_bonds, explicit_hydrogens=True)
    return best_model(g, g.ndata['h'], g.edata['e']).detach().numpy() + np.array(get_E_pred(smiles))

In [None]:
test(model_from_smiles,'Down') # diff2 params

In [None]:
test(model_from_smiles,'Down') # inprogress params