# Neural Message Passing for Quantum Chemistry

Ref: https://arxiv.org/pdf/1704.01212.pdf

Assumptions:
1. Hidden states for atoms are not updated (only for atoms).

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import torch.optim as optim
import matplotlib.pyplot as plt

from data import prepare_datasets
from mpnn.M import MfuncMLP
from mpnn.R import RfuncMLP
from mpnn.U import UfuncMLP
from mpnn.mpnn import MPNN



In [3]:
m_dim = 32
h_dim = 24
Mfunc_start = MfuncMLP(inp_atom_features=75, inp_edge_features=6, out_size_atom=m_dim)
Mfunc_hid = MfuncMLP(inp_atom_features=h_dim, inp_edge_features=6, out_size_atom=m_dim)
Ufunc_start = UfuncMLP(inp_atom_features=75, inp_atom_m_state=m_dim, out_size_atom=h_dim)
Ufunc_hid = UfuncMLP(inp_atom_features=h_dim, inp_atom_m_state=m_dim, out_size_atom=h_dim)
Rfunc = RfuncMLP(inp_size=h_dim, hid=10)

In [4]:
mpnn = MPNN(Mfunc_start, Mfunc_hid, Ufunc_start, Ufunc_hid, Rfunc, 10)

In [5]:
DATASET = 'data.test'

In [6]:
train_smiles, train_labels, val_smiles, val_labels = prepare_datasets(DATASET, filter_dots=True)

About to generate scaffolds
Generating scaffold 0/1478
Generating scaffold 1000/1478
About to sort in scaffold sets
212 dots filtered




In [7]:
losses = []
for i in range(1000):
    loss = mpnn.make_opt_step((train_smiles[i:i+1], train_labels[i:i+1]), 10)
    print(i, loss)
    losses.append(loss)

1.3787574768066406
0 None
1.40968918800354
1 None
1.4649206399917603
2 None
1.3943042755126953
3 None
1.3482781648635864
4 None
1.3850282430648804
5 None
1.329404354095459
6 None
1.423426628112793
7 None
1.3147790431976318
8 None
1.3662606477737427
9 None
1.3410942554473877
10 None
1.2949105501174927
11 None
1.3160746097564697
12 None
1.3111985921859741
13 None
1.4209195375442505
14 None
1.3350473642349243
15 None
1.316087245941162
16 None
0.9491663575172424
17 None
1.2637760639190674
18 None
1.2933021783828735
19 None
1.2859702110290527
20 None
1.3557792901992798
21 None
1.2897652387619019
22 None
1.2610492706298828
23 None
1.294655203819275
24 None
1.294860601425171
25 None
1.1975672245025635
26 None
1.1975528001785278
27 None
1.2655541896820068
28 None
1.2296606302261353
29 None
1.2459627389907837
30 None
1.2945871353149414
31 None
1.2714999914169312
32 None
1.2439707517623901
33 None
1.2446495294570923
34 None
1.2633967399597168
35 None
1.2923825979232788
36 None
1.2914263010025024

RuntimeError: size mismatch, m1: [1 x 75], m2: [24 x 10] at /opt/conda/conda-bld/pytorch_1503965122592/work/torch/lib/TH/generic/THTensorMath.c:1293

In [None]:
plt.plot(losses)

In [9]:
i = 54
train_smiles[i:i+1]

array(['[Al+3]'], 
      dtype='<U464')

In [None]:
import torch
import random
import numpy as np
import matplotlib.pyplot as plt

from mpnn.M import MfuncMLP
from mpnn.U import UfuncMLP
from mpnn.R import RfuncMLP
from mpnn.mpnn import MPNN

In [None]:
random.seed(2)
torch.manual_seed(2)
np.random.seed(2)

In [None]:
DATASET = 'data.csv'
T = 3
BATCH_SIZE = 1
MAXITER = 40000
LIMIT = 0
LR = 5e-4
NUM_ATOM_FEAT = 75
NUM_EDGE_FEAT = 6
HID_SIZE_ATOM = 25
HID_SIZE_EDGE = 5

In [None]:
def input_transpose(sents, pad_token):
    max_len = max(len(s) for s in sents)
    batch_size = len(sents)

    sents_t = []
    masks = []
    for i in range(max_len):
        sents_t.append([sents[k][i] if len(sents[k]) > i else pad_token for k in range(batch_size)])
        masks.append([1 if len(sents[k]) > i else 0 for k in range(batch_size)])

    return sents_t, masks

In [None]:
sents = [[1,2,3], [1,2,3,4,5,6], [1,2]]
input_transpose(sents, 'pad')

In [None]:
R = nn.Linear(150, 128)   # function R
U = {0: nn.Linear(156, 75), 1: nn.Linear(156, 75), 2: nn.Linear(156, 75)}  # function M
V = {0: nn.Linear(75, 75), 1: nn.Linear(75, 75), 2: nn.Linear(75, 75)}  # function U (but without edge features)
E = nn.Linear(6, 6)  # function U (but without atom features)

In [None]:
def readout(h, h2):
    catted_reads = map(lambda x: torch.cat([h[x[0]], h2[x[1]]], 1), zip(h2.keys(), h.keys()))
    activated_reads = map(lambda x: F.selu(R(x)), catted_reads)
    readout = Variable(torch.zeros(1, 128))
    for read in activated_reads:
        readout = readout + read
    return F.tanh(readout)

In [None]:
def message_pass(g, h, k):
    for v in g.keys():
        neighbors = g[v]
        for neighbor in neighbors:
            e_vw = neighbor[0]  # edge feature variable
            w = neighbor[1]  # number of connected atom
            m_w = V[k](h[w])  # calc hidden variable of atom
            m_e_vw = E(e_vw)  # calc hidden variable of edge 
            reshaped = torch.cat((h[v], m_w, m_e_vw), 1)  # calculating concatenated hid states of atoms and edge
            h[v] = F.selu(U[k](reshaped))  

In [None]:
def get_input_features(smile):
    """
    Get input features for edges (g) and atoms (h).
    """
    g = OrderedDict({})
    h = OrderedDict({})
    molecule = Chem.MolFromSmiles(smile)
    for i in range(0, molecule.GetNumAtoms()):
        atom_i = molecule.GetAtomWithIdx(i)
        h[i] = Variable(torch.FloatTensor(dc.feat.graph_features.atom_features(atom_i).astype(np.float32))).view(1, 75)  # mk: added astype
        for j in range(0, molecule.GetNumAtoms()):
            e_ij = molecule.GetBondBetweenAtoms(i, j)
            if e_ij != None:
                e_ij = list(map(lambda x: 1 if x == True else 0,    # mk: added list
                           dc.feat.graph_features.bond_features(e_ij)))  # ADDED edge feat
                e_ij = Variable(torch.FloatTensor(e_ij).view(1, 6))
                atom_j = molecule.GetAtomWithIdx(j)
                if i not in g:
                    g[i] = []
                    g[i].append((e_ij, j))
    return g, h

In [None]:
train_smiles, train_labels, val_smiles, val_labels = prepare_datasets(DATASET)

In [None]:
linear = nn.Linear(128, 1)
params = [{'params': R.parameters()},
         {'params': U[0].parameters()},
         {'params': U[1].parameters()},
         {'params': U[2].parameters()},
         {'params': E.parameters()},
         {'params': V[0].parameters()},
         {'params': V[1].parameters()},
         {'params': V[2].parameters()},
         {'params': linear.parameters()}]

In [None]:
num_epoch = 0
optimizer = optim.Adam(params, lr=LR, weight_decay=1e-4)

In [None]:
for i in range(0, MAXITER):
    optimizer.zero_grad()
    train_loss = Variable(torch.zeros(1, 1))
    y_hats_train = []
    for j in range(0, BATCH_SIZE):
        sample_index = random.randint(0, len(train_smiles) - 2)
        smile = train_smiles[sample_index]
        g, h = get_input_features(smile)  # TODO: cache this

        g2, h2 = get_input_features(smile)

        for k in range(0, T):
            message_pass(g, h, k)

        x = readout(h, h2)
        # x = F.selu( fc(x) )
        y_hat = linear(x)
        y = train_labels[sample_index]

        y_hats_train.append(y_hat)

        error = (y_hat - y) * (y_hat - y) / Variable(torch.FloatTensor([BATCH_SIZE])).view(1, 1)
        train_loss = train_loss + error

    train_loss.backward()
    optimizer.step()
    
    print(i)
    
    if i % 10 == 0: #int(len(train_smiles) / BATCH_SIZE) == 0:
        val_loss = Variable(torch.zeros(1, 1), requires_grad=False)
        y_hats_val = []
        for j in range(0, len(val_smiles)):
            g, h = get_input_features(val_smiles[j])
            g2, h2 = get_input_features(val_smiles[j])

            for k in range(0, T):
                message_pass(g, h, k)

            x = readout(h, h2)
            # x = F.selu( fc(x) )
            y_hat = linear(x)
            y = val_labels[j]

            y_hats_val.append(y_hat)

            error = (y_hat - y) * (y_hat - y) / Variable(torch.FloatTensor([len(val_smiles)])).view(1, 1)
            val_loss = val_loss + error

        y_hats_val = np.array(list(map(lambda x: x.data.numpy(), y_hats_val)))
        y_val = np.array(list(map(lambda x: x.data.numpy(), val_labels)))
        y_hats_val = y_hats_val.reshape(-1, 1)
        y_val = y_val.reshape(-1, 1)

        r2_val_old = r2_score(y_val, y_hats_val)
        r2_val_new = pearsonr(y_val, y_hats_val)[0][0] ** 2

        train_loss_ = train_loss.data.numpy()[0][0]
        val_loss_ = val_loss.data.numpy()[0][0]
        print('epoch [{}/{}] train_loss [{}] val_loss [{}] r2_val_old [{}], r2_val_new [{}]'.format(num_epoch, 100, train_loss_, val_loss_, r2_val_old, r2_val_new))
        num_epoch += 1

In [None]:
print(num_epoch, 100, train_loss_, val_loss_, r2_val_old, r2_val_new)

In [None]:
pearsonr(y_val, y_hats_val)

In [None]:
y_val == y_hats_val

In [None]:
val_loss = Variable(torch.zeros(1, 1), requires_grad=False)
y_hats_val = []
for j in range(0, len(val_smiles)):
    g, h = get_input_features(val_smiles[j])
    g2, h2 = get_input_features(val_smiles[j])

    for k in range(0, T):
        message_pass(g, h, k)

    x = readout(h, h2)
    # x = F.selu( fc(x) )
    y_hat = linear(x)
    y = val_labels[j]

    y_hats_val.append(y_hat)

    error = (y_hat - y) * (y_hat - y) / Variable(torch.FloatTensor([len(val_smiles)])).view(1, 1)
    val_loss = val_loss + error

y_hats_val = np.array(list(map(lambda x: x.data.numpy(), y_hats_val)))
y_val = np.array(list(map(lambda x: x.data.numpy(), val_labels)))
y_hats_val = y_hats_val.reshape(-1, 1)
y_val = y_val.reshape(-1, 1)

r2_val_old = r2_score(y_val, y_hats_val)
r2_val_new = pearsonr(y_val, y_hats_val)[0][0] ** 2

train_loss_ = train_loss.data.numpy()[0][0]
val_loss_ = val_loss.data.numpy()[0][0]
print('epoch [{}/{}] train_loss [{}] val_loss [{}] r2_val_old [{}], r2_val_new [{}]'.format(num_epoch, 100, train_loss_, val_loss_, r2_val_old, r2_val_new))
num_epoch += 1

In [None]:
np.allclose(y_val, y_hats_val)

In [None]:
y_val

In [None]:
y_hats_val