# Neural Message Passing for Quantum Chemistry

Ref: https://arxiv.org/pdf/1704.01212.pdf

Assumptions:
1. Hidden states for atoms are not updated (only for atoms).

## 1. Directed models

In [132]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [162]:
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np

import torch
from torch.autograd import Variable
from utils import rolling_mean
from data import DataPreprocessor
from mpnn.mpnn_directed import MPNNdirected
from mpnn.directed import Rd, Vd, Ud, Ed

In [134]:
at_feat = 75
edg_feat = 6
passes = 4
rd = Rd(inp_size=at_feat*2, hid_size=32)
vd = Vd(inp_size=at_feat)
ud = Ud(inp_size=at_feat*2+edg_feat, out_size=at_feat)
ed = Ed(inp_size=edg_feat)

In [163]:
mpnn = MPNNdirected(rd, ud, vd, ed, passes)

In [136]:
DATASET = 'data.test'

In [137]:
data = DataPreprocessor(DATASET, filter_dots=True, filter_atoms=True)

In [138]:
data.load_dataset()

File data.test read. In total 1698 lines.


In [139]:
data.filter_data()

Data filtered, in total 220 smiles deleted


In [140]:
train_smiles, train_labels, valid_smiles, valid_labels, test_smiles, test_labels = data.get_data()

File data.test read. In total 1698 lines.
Data filtered, in total 220 smiles deleted
About to generate scaffolds
Generating scaffold 0/1478
Generating scaffold 1000/1478
About to sort in scaffold sets


In [141]:
losses = []
# когда cat пытается сконкатенировать результат катов, то выдает ошибку. поправить

In [164]:
for j in range(100):
    for i in range(0, 100):
        print(train_smiles[i:i+1])
        loss = mpnn.make_opt_step_batched(train_smiles[i:i+1], train_labels[i:i+1], 4)
        losses.append(loss)
    print(j)

['CCC']
*
0
1
2
3
11
***
fold_cat
(Variable containing:

Columns 0 to 12 
    1     0     0     0     0     0     0     0     0     0     0     0     0

Columns 13 to 25 
    0     0     0     0     0     0     0     0     0     0     0     0     0

Columns 26 to 38 
    0     0     0     0     0     0     0     0     0     0     0     0     0

Columns 39 to 51 
    0     0     0     0     0     0     1     0     0     0     0     0     0

Columns 52 to 64 
    0     0     0     0     0     0     1     0     0     0     0     1     0

Columns 65 to 74 
    0     0     0     0     0     0     1     0     0     0
[torch.FloatTensor of size 1x75]
, [0:0]V_0, [0:0]E)
step: 1
***
fold_cat
(Variable containing:

Columns 0 to 12 
    1     0     0     0     0     0     0     0     0     0     0     0     0

Columns 13 to 25 
    0     0     0     0     0     0     0     0     0     0     0     0     0

Columns 26 to 38 
    0     0     0     0     0     0     0     0     0     0     0     0  

AttributeError: 'Node' object has no attribute 'size'

In [None]:
plt.plot(rolling_mean(losses, 100))

In [None]:
res = []
for i in range(100):
    res.append(forward_pass(mpnn, train_smiles[i], passes).data[0][0])

In [None]:
plt.hist(res)

In [None]:
def forward_pass(self, x, t):
    g, h = self.get_features_from_smiles(x)
    g2, h2 = self.get_features_from_smiles(x)
#     for k, v in h.items():
#         print(v.data.numpy())
    for k in range(0, t):
        self.single_message_pass(g, h, k)
#         print('*'*33)
#         print(h[0])
    y_pred = self.R(h, h2)
    return y_pred

In [None]:
g, h = mpnn.get_features_from_smiles(train_smiles[10])

In [None]:
np.array(losses[-50:]).mean()

In [None]:
np.array(train_labels[:100]).mean()

In [None]:
for i in range(50):
    print(i, int(train_labels[i]), int(res[i]))

In [None]:
r = (np.array(res) > 0.5).astype(int)

In [96]:
mpnn.h.values()

odict_values([[19:4]fold_non_lin, [23:3]fold_non_lin, [27:7]fold_non_lin, [27:8]fold_non_lin, [31:2]fold_non_lin, [35:5]fold_non_lin, [39:0]fold_non_lin, [43:0]fold_non_lin, [43:1]fold_non_lin, [35:6]fold_non_lin, [39:1]fold_non_lin, [43:2]fold_non_lin, [43:3]fold_non_lin, [43:4]fold_non_lin])

## 2. Undirected models

In [None]:
for p in mpnn.params:
    print(p.requires_grad)

In [None]:
mpnn.params[-2]

In [None]:
for p in mpnn.params:
    print(p.data.size())