# Neural Message Passing for Quantum Chemistry

Ref: https://arxiv.org/pdf/1704.01212.pdf

Assumptions:
1. Hidden states for atoms are not updated (only for atoms).

## 1. Directed models

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [131]:
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import torch

from torch.autograd import Variable
from utils import rolling_mean, CUDA_wrapper
from data import DataPreprocessor
from mpnn.mpnn_directed import MPNNdirected
from mpnn.directed import Rd, Vd, Ud, Ed

In [59]:
AT_FEAT = 75
EDG_FEAT = 6
PASSES = 4
HID_SIZE = 32
CUDA = True
DATASET = 'data.test'
FLT_DOTS = True
FLT_ATOMS = True

In [4]:
rd = Rd(inp_size=AT_FEAT*2, hid_size=HID_SIZE)
vd = Vd(inp_size=AT_FEAT)
ud = Ud(inp_size=AT_FEAT*2+EDG_FEAT, out_size=AT_FEAT)
ed = Ed(inp_size=EDG_FEAT)

In [132]:
mpnn = MPNNdirected(rd, ud, vd, ed, t=PASSES, cuda=CUDA)

In [6]:
data = DataPreprocessor(DATASET, filter_dots=FLT_DOTS, filter_atoms=FLT_ATOMS)

In [7]:
data.load_dataset()

File data.test read. In total 1698 lines.


In [8]:
data.filter_data()

Data filtered, in total 220 smiles deleted


In [9]:
train_smiles, train_labels, valid_smiles, valid_labels, test_smiles, test_labels = data.get_data()

File data.test read. In total 1698 lines.
Data filtered, in total 220 smiles deleted
About to generate scaffolds
Generating scaffold 0/1478
Generating scaffold 1000/1478
About to sort in scaffold sets


In [133]:
train_x = [mpnn.get_features_from_smiles(x, cuda=CUDA) for x in train_smiles]
train_x2 = [mpnn.get_features_from_smiles(x, cuda=CUDA) for x in train_smiles]
train_y = train_labels
smiles = train_smiles
#train_y = CUDA_wrapper(torch.FloatTensor(train_labels).view(-1, 1), CUDA)

In [134]:
i = 0
n = 100
%time res = mpnn.make_opt_step_batched(smiles[i:i+n], train_x[i:i+n], train_x2[i:i+n], train_y[i:i+n], 4)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
CCC
3
CC
2
CC(=O)OC(CC(=O)[O-])C[N+](C)(C)C
14
C(C(C(C(C(CO)O)O)O)O)O
12
C(C(=O)O)C(CC(=O)O)(C(=O)O)O
13
CCCCCC(C=CC=CC=CC=CC(C(CCCC(=O)O)O)O)O
25
CC(=CCCC(=CCCC(=CCCC(=O)OCC=C(C)CCC=C(C)C)C)C)C
29
CCCCCON=O
8
C(CCl)N(C(=O)NC(C=O)C(C(C(CO)O)O)O)N=O
20
C(CCl)NP(=O)(NCCCl)O
11
CC(C(=O)NC(CCC(=O)NC(CCCC(C(=O)O)N)C(=O)NCC(=O)O)C(=O)O)NC(=O)C(C)O
36
CS(=O)(=O)CS(=O)(=O)OCCCl
12
CCCCCCCCCCCCCCCCOP(=O)([O-])OCC[N+](C)(C)C
27
CN(CC(=O)N)C(=O)N(CCCl)N=O
14
C(C=[Se])C(C(=O)O)N
8
C(CC(=O)O)C(=O)CN
9
CCCCCCCCCCCCCCCCC(C[N+](C)(C)C)OP(=O)(O)[O-]
27
CC(COCCO)OCCO
11
C(C(CO[N+](=O)O)O[N+](=O)O)O[N+](=O)O
15
CCC(C)C(C(=NC(C(C)O)C(=NC(C(C)O)C(=NC(CCCCN)C(=NC(CC(=O)O)C(=NC(CC(C)C)C(=NC(CCCCN)C(=NC(CCC(=O)O)C(=NC(CCCCN)

In [118]:
c = 0
for i, r in enumerate(res[1]):
    c += len(r)
    print(i, c, len(r))

0 3 3
1 5 2
2 19 14
3 31 12
4 44 13
5 69 25
6 98 29
7 106 8
8 126 20
9 137 11
10 173 36
11 185 12
12 212 27
13 226 14
14 234 8
15 243 9
16 270 27
17 281 11
18 296 15
19 513 217
20 525 12
21 540 15
22 550 10
23 561 11
24 570 9
25 582 12
26 584 2
27 602 18
28 629 27
29 640 11
30 649 9
31 653 4
32 659 6
33 669 10
34 675 6
35 679 4
36 693 14
37 706 13
38 718 12
39 740 22
40 750 10
41 757 7
42 767 10
43 791 24
44 798 7
45 805 7
46 824 19
47 848 24
48 858 10
49 884 26
50 897 13
51 921 24
52 932 11
53 969 37
54 986 17
55 992 6
56 1023 31
57 1036 13
58 1045 9
59 1057 12
60 1072 15
61 1087 15
62 1096 9
63 1126 30
64 1141 15
65 1161 20
66 1172 11
67 1180 8
68 1186 6
69 1200 14
70 1235 35
71 1270 35
72 1275 5
73 1282 7
74 1286 4
75 1326 40


In [127]:
res[0][69].size()

torch.Size([14, 75])

In [136]:
smiles[i:i+n]

['CCC',
 'CC',
 'CC(=O)OC(CC(=O)[O-])C[N+](C)(C)C',
 'C(C(C(C(C(CO)O)O)O)O)O',
 'C(C(=O)O)C(CC(=O)O)(C(=O)O)O',
 'CCCCCC(C=CC=CC=CC=CC(C(CCCC(=O)O)O)O)O',
 'CC(=CCCC(=CCCC(=CCCC(=O)OCC=C(C)CCC=C(C)C)C)C)C',
 'CCCCCON=O',
 'C(CCl)N(C(=O)NC(C=O)C(C(C(CO)O)O)O)N=O',
 'C(CCl)NP(=O)(NCCCl)O',
 'CC(C(=O)NC(CCC(=O)NC(CCCC(C(=O)O)N)C(=O)NCC(=O)O)C(=O)O)NC(=O)C(C)O',
 'CS(=O)(=O)CS(=O)(=O)OCCCl',
 'CCCCCCCCCCCCCCCCOP(=O)([O-])OCC[N+](C)(C)C',
 'CN(CC(=O)N)C(=O)N(CCCl)N=O',
 'C(CCl)NP(=O)(NCCCl)O',
 'C(C=[Se])C(C(=O)O)N',
 'C(CC(=O)O)C(=O)CN',
 'CCCCCCCCCCCCCCCCC(C[N+](C)(C)C)OP(=O)(O)[O-]',
 'CC(COCCO)OCCO',
 'C(C(CO[N+](=O)O)O[N+](=O)O)O[N+](=O)O',
 'CCC(C)C(C(=NC(C(C)O)C(=NC(C(C)O)C(=NC(CCCCN)C(=NC(CC(=O)O)C(=NC(CC(C)C)C(=NC(CCCCN)C(=NC(CCC(=O)O)C(=NC(CCCCN)C(=NC(CCCCN)C(=NC(CCC(=O)O)C(=NC(C(C)C)C(=NC(C(C)C)C(=NC(CCC(=O)O)C(=NC(CCC(=O)O)C(=NC(C)C(=NC(CCC(=O)O)C(=NC(CC(=N)O)C(=O)O)O)O)O)O)O)O)O)O)O)O)O)O)O)O)O)O)O)N=C(C(CCC(=O)O)N=C(C(CO)N=C(C(CO)N=C(C(C(C)O)N=C(C(CC(=O)O)N=C(C(C(C)C)N=C(C(C)N

In [None]:
losses = []

In [None]:
for j in range(100):
    for i in range(0, 1):
        #print(i)
        %time loss = mpnn.make_opt_step_batched(train_smiles[i:i+100], train_labels[i:i+100], 4)
        losses.append(loss)
    print(j)

In [None]:
plt.plot(rolling_mean(losses, 100))

In [None]:
res = []
for i in range(100):
    res.append(forward_pass(mpnn, train_smiles[i], passes).data[0][0])

In [None]:
plt.hist(res)

In [None]:
def forward_pass(self, x, t):
    g, h = self.get_features_from_smiles(x)
    g2, h2 = self.get_features_from_smiles(x)
#     for k, v in h.items():
#         print(v.data.numpy())
    for k in range(0, t):
        self.single_message_pass(g, h, k)
#         print('*'*33)
#         print(h[0])
    y_pred = self.R(h, h2)
    return y_pred

In [None]:
g, h = mpnn.get_features_from_smiles(train_smiles[10])

In [None]:
np.array(losses[-50:]).mean()

In [None]:
np.array(train_labels[:100]).mean()

In [None]:
for i in range(50):
    print(i, int(train_labels[i]), int(res[i]))

In [None]:
r = (np.array(res) > 0.5).astype(int)

## 2. Undirected models

In [None]:
for p in mpnn.params:
    print(p.requires_grad)

In [None]:
mpnn.params[-2]

In [None]:
for p in mpnn.params:
    print(p.data.size())