# Neural Message Passing for Quantum Chemistry

Ref: https://arxiv.org/pdf/1704.01212.pdf

Assumptions:
1. Hidden states for atoms are not updated (only for atoms).

## 1. Directed models

In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import torch

from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.autograd import Variable
from utils.utils import rolling_mean, CUDA_wrapper
from utils.data import DatasetSmiles
from mpnn.mpnn_directed import MPNNdirected
from mpnn.directed import Rd, Vd, Ud, Ed
from copy import deepcopy



In [38]:
AT_FEAT = 75
EDG_FEAT = 6
PASSES = 4
HID_SIZE = 32
CUDA = True
DATASET = 'data/data.test'
FLT_DOTS = True
FLT_ATOMS = True
BATCH_SIZE = 32
SHUFFLE = True
N_EPOCHS = 1000

In [39]:
train_dataset = DatasetSmiles(DATASET, cuda=CUDA, filter_atoms=FLT_ATOMS, filter_dots=FLT_DOTS)

File "data/data.test" read. In total 5000 lines.
Data filtered, in total 1162 smiles deleted
Features calculated and datasets prepared. Number of items in dataset: 3838


In [None]:
rd = Rd(inp_size=AT_FEAT, hid_size=HID_SIZE)
vd = Vd(inp_size=AT_FEAT)
ud = Ud(inp_size=AT_FEAT*2+EDG_FEAT, out_size=AT_FEAT)
ed = Ed(inp_size=EDG_FEAT)
mpnn = MPNNdirected(rd, ud, vd, ed, t=PASSES, cuda=CUDA)
losses = []

In [None]:
# TODO: consider replacing collate_fn with smth that prepares folds
# https://github.com/pytorch/pytorch/blob/master/torch/utils/data/dataloader.py
# def collate_fn(lst):
#     batch_x = [x[0] for x in lst]
#     batch_y = [x[1] for x in lst]
#     fold, folded_nodes = mpnn.batch_operations(batch_x, PASSES)
#     return batch_y, fold, folded_nodes

for i in range(N_EPOCHS):
    train_data_loader = DataLoader(deepcopy(train_dataset), batch_size=BATCH_SIZE, collate_fn=lambda x: x, 
                                  shuffle=SHUFFLE)
    batch_num = 0
    for batch in tqdm(train_data_loader):
        batch_x = [x[0] for x in batch]
        batch_y = [x[1] for x in batch]
        fold, folded_nodes = mpnn.batch_operations(batch_x, PASSES)
        #batch_y, fold, folded_nodes = batch
        result = fold.apply(mpnn, folded_nodes)
        loss = mpnn.make_opt_step_batched(result, batch_y)
        losses.append(loss)
        batch_num += 1
    print('epoch: {}, loss: {:.3f}'.format(i, np.array(loss).mean()))

100%|██████████| 120/120 [01:20<00:00,  1.49it/s]


epoch: 0, loss: 0.151


100%|██████████| 120/120 [01:23<00:00,  1.44it/s]


epoch: 1, loss: 0.270


  2%|▎         | 3/120 [00:01<00:56,  2.06it/s]

In [None]:
plt.plot(rolling_mean(losses, 50))

## 2. Undirected models