# Neural Message Passing for Quantum Chemistry

Ref: https://arxiv.org/pdf/1704.01212.pdf

Assumptions:
1. Hidden states for atoms are not updated (only for atoms).

## 1. Directed models

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np

import torch
from torch.autograd import Variable
from utils import rolling_mean
from data import DataPreprocessor
from mpnn.mpnn_directed import MPNNdirected
from mpnn.directed import Rd, Vd, Ud, Ed

In [None]:
at_feat = 75
edg_feat = 6
passes = 4
rd = Rd(inp_size=at_feat*2, hid_size=100)
vd = Vd(inp_size=at_feat)
ud = Ud(inp_size=at_feat*2+edg_feat, out_size=at_feat)
ed = Ed(inp_size=edg_feat)

In [None]:
mpnn = MPNNdirected(rd, ud, vd, ed, passes)

In [None]:
DATASET = 'data.test'

In [None]:
data = DataPreprocessor(DATASET, filter_dots=True, filter_atoms=True)

In [None]:
data.load_dataset()

In [None]:
data.filter_data()

In [None]:
train_smiles, train_labels, valid_smiles, valid_labels, test_smiles, test_labels = data.get_data()

In [None]:
losses = []

In [None]:
for _ in range(10):
    for i in range(50):
        loss = mpnn.make_opt_step(train_smiles[i:i+1], train_labels[i:i+1], passes)
        print(i, loss)
        losses.append(loss)

In [None]:
plt.plot(rolling_mean(losses, 50))

In [None]:
res = []
for i in range(50):
    res.append(forward_pass(mpnn, train_smiles[i], passes).data[0][0])

In [None]:
plt.hist(res)

In [None]:
def forward_pass(self, x, t):
    g, h = self.get_features_from_smiles(x)
    g2, h2 = self.get_features_from_smiles(x)
#     for k, v in h.items():
#         print(v.data.numpy())
    for k in range(0, t):
        self.single_message_pass(g, h, k)
        print('*'*33)
        print(h[0])
    y_pred = self.R(h, h2)
    return y_pred

In [None]:
g, h = mpnn.get_features_from_smiles(train_smiles[10])

In [None]:
np.array(losses[-50:]).mean()

In [None]:
np.array(train_labels[:50]).mean()

In [None]:
for i in range(50):
    print(i, int(train_labels[i]), int(res[i]))

In [None]:
res

## 2. Undirected models

In [None]:
for p in mpnn.params:
    print(p.requires_grad)

In [None]:
mpnn.params[-2]

In [None]:
for p in mpnn.params:
    print(p.data.size())