# Neural Message Passing for Quantum Chemistry

Ref: https://arxiv.org/pdf/1704.01212.pdf

Assumptions:
1. Hidden states for atoms are not updated (only for atoms).

## 1. Directed models

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [13]:
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import torch

from torch.autograd import Variable
from utils import rolling_mean
from data import DataPreprocessor
from mpnn.mpnn_directed import MPNNdirected
from mpnn.directed import Rd, Vd, Ud, Ed

In [17]:
AT_FEAT = 75
EDG_FEAT = 6
PASSES = 4
HID_SIZE = 32
CUDA = False
DATASET = 'data.test'
FLT_DOTS = True
FLT_ATOMS = True

In [15]:
rd = Rd(inp_size=AT_FEAT*2, hid_size=HID_SIZE)
vd = Vd(inp_size=AT_FEAT)
ud = Ud(inp_size=AT_FEAT*2+EDG_FEAT, out_size=AT_FEAT)
ed = Ed(inp_size=EDG_FEAT)

In [16]:
mpnn = MPNNdirected(rd, ud, vd, ed, t=PASSES, cuda=CUDA)

In [18]:
data = DataPreprocessor(DATASET, filter_dots=FLT_DOTS, filter_atoms=FLT_ATOMS)

In [19]:
data.load_dataset()

File data.test read. In total 1698 lines.


In [20]:
data.filter_data()

Data filtered, in total 220 smiles deleted


In [21]:
train_smiles, train_labels, valid_smiles, valid_labels, test_smiles, test_labels = data.get_data()

File data.test read. In total 1698 lines.
Data filtered, in total 220 smiles deleted
About to generate scaffolds
Generating scaffold 0/1478
Generating scaffold 1000/1478
About to sort in scaffold sets


In [10]:
losses = []

In [11]:
for j in range(100):
    for i in range(0, 1):
        #print(i)
        %time loss = mpnn.make_opt_step_batched(train_smiles[i:i+100], train_labels[i:i+100], 4)
        losses.append(loss)
    print(j)

CPU times: user 5.4 s, sys: 124 ms, total: 5.53 s
Wall time: 1.84 s
0
CPU times: user 5.4 s, sys: 112 ms, total: 5.51 s
Wall time: 1.77 s
1
CPU times: user 5.48 s, sys: 96 ms, total: 5.57 s
Wall time: 1.76 s
2
CPU times: user 6.1 s, sys: 76 ms, total: 6.17 s
Wall time: 1.85 s
3
CPU times: user 5.8 s, sys: 104 ms, total: 5.9 s
Wall time: 1.82 s
4
CPU times: user 6.21 s, sys: 88 ms, total: 6.3 s
Wall time: 1.89 s
5


KeyboardInterrupt: 

6
CPU times: user 5.78 s, sys: 96 ms, total: 5.88 s
Wall time: 1.92 s
7
CPU times: user 5.33 s, sys: 68 ms, total: 5.4 s
Wall time: 1.66 s
8
CPU times: user 6.23 s, sys: 80 ms, total: 6.31 s
Wall time: 1.92 s
9
CPU times: user 5.09 s, sys: 64 ms, total: 5.16 s
Wall time: 1.72 s
10
CPU times: user 5.85 s, sys: 88 ms, total: 5.94 s
Wall time: 1.86 s
11
CPU times: user 6.5 s, sys: 76 ms, total: 6.57 s
Wall time: 2.04 s
12
CPU times: user 4.76 s, sys: 80 ms, total: 4.84 s
Wall time: 1.66 s
13
CPU times: user 5.48 s, sys: 84 ms, total: 5.56 s
Wall time: 1.76 s
14
CPU times: user 5.49 s, sys: 96 ms, total: 5.59 s
Wall time: 1.78 s
15
CPU times: user 5.49 s, sys: 96 ms, total: 5.58 s
Wall time: 1.77 s
16
CPU times: user 6.23 s, sys: 76 ms, total: 6.31 s
Wall time: 1.89 s
17
CPU times: user 4.76 s, sys: 56 ms, total: 4.82 s
Wall time: 1.65 s
18
CPU times: user 5.46 s, sys: 96 ms, total: 5.56 s
Wall time: 1.77 s
19
CPU times: user 6.33 s, sys: 92 ms, total: 6.42 s
Wall time: 1.92 s
20
CPU times

In [None]:
plt.plot(rolling_mean(losses, 100))

In [None]:
res = []
for i in range(100):
    res.append(forward_pass(mpnn, train_smiles[i], passes).data[0][0])

In [None]:
plt.hist(res)

In [None]:
def forward_pass(self, x, t):
    g, h = self.get_features_from_smiles(x)
    g2, h2 = self.get_features_from_smiles(x)
#     for k, v in h.items():
#         print(v.data.numpy())
    for k in range(0, t):
        self.single_message_pass(g, h, k)
#         print('*'*33)
#         print(h[0])
    y_pred = self.R(h, h2)
    return y_pred

In [None]:
g, h = mpnn.get_features_from_smiles(train_smiles[10])

In [None]:
np.array(losses[-50:]).mean()

In [None]:
np.array(train_labels[:100]).mean()

In [None]:
for i in range(50):
    print(i, int(train_labels[i]), int(res[i]))

In [None]:
r = (np.array(res) > 0.5).astype(int)

## 2. Undirected models

In [None]:
for p in mpnn.params:
    print(p.requires_grad)

In [None]:
mpnn.params[-2]

In [None]:
for p in mpnn.params:
    print(p.data.size())