# Neural Message Passing for Quantum Chemistry

Ref: https://arxiv.org/pdf/1704.01212.pdf

Assumptions:
1. Hidden states for atoms are not updated (only for atoms).

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import torch
import random
import numpy as np
import matplotlib.pyplot as plt

from mpnn.

In [3]:
random.seed(2)
torch.manual_seed(2)
np.random.seed(2)

In [None]:
DATASET = 'data.csv'
T = 3
BATCH_SIZE = 1
MAXITER = 40000
LIMIT = 0
LR = 5e-4
NUM_ATOM_FEAT = 75
NUM_EDGE_FEAT = 6
HID_SIZE_ATOM = 25
HID_SIZE_EDGE = 5

In [None]:
R = nn.Linear(150, 128)   # function R
U = {0: nn.Linear(156, 75), 1: nn.Linear(156, 75), 2: nn.Linear(156, 75)}  # function M
V = {0: nn.Linear(75, 75), 1: nn.Linear(75, 75), 2: nn.Linear(75, 75)}  # function U (but without edge features)
E = nn.Linear(6, 6)  # function U (but without atom features)

In [None]:
def readout(h, h2):
    catted_reads = map(lambda x: torch.cat([h[x[0]], h2[x[1]]], 1), zip(h2.keys(), h.keys()))
    activated_reads = map(lambda x: F.selu(R(x)), catted_reads)
    readout = Variable(torch.zeros(1, 128))
    for read in activated_reads:
        readout = readout + read
    return F.tanh(readout)

In [None]:
def message_pass(g, h, k):
    for v in g.keys():
        neighbors = g[v]
        for neighbor in neighbors:
            e_vw = neighbor[0]  # edge feature variable
            w = neighbor[1]  # number of connected atom
            m_w = V[k](h[w])  # calc hidden variable of atom
            m_e_vw = E(e_vw)  # calc hidden variable of edge 
            reshaped = torch.cat((h[v], m_w, m_e_vw), 1)  # calculating concatenated hid states of atoms and edge
            h[v] = F.selu(U[k](reshaped))  

In [None]:
def get_input_features(smile):
    """
    Get input features for edges (g) and atoms (h).
    """
    g = OrderedDict({})
    h = OrderedDict({})
    molecule = Chem.MolFromSmiles(smile)
    for i in range(0, molecule.GetNumAtoms()):
        atom_i = molecule.GetAtomWithIdx(i)
        h[i] = Variable(torch.FloatTensor(dc.feat.graph_features.atom_features(atom_i).astype(np.float32))).view(1, 75)  # mk: added astype
        for j in range(0, molecule.GetNumAtoms()):
            e_ij = molecule.GetBondBetweenAtoms(i, j)
            if e_ij != None:
                e_ij = list(map(lambda x: 1 if x == True else 0,    # mk: added list
                           dc.feat.graph_features.bond_features(e_ij)))  # ADDED edge feat
                e_ij = Variable(torch.FloatTensor(e_ij).view(1, 6))
                atom_j = molecule.GetAtomWithIdx(j)
                if i not in g:
                    g[i] = []
                    g[i].append((e_ij, j))
    return g, h

In [None]:
train_smiles, train_labels, val_smiles, val_labels = prepare_datasets(DATASET)

In [None]:
linear = nn.Linear(128, 1)
params = [{'params': R.parameters()},
         {'params': U[0].parameters()},
         {'params': U[1].parameters()},
         {'params': U[2].parameters()},
         {'params': E.parameters()},
         {'params': V[0].parameters()},
         {'params': V[1].parameters()},
         {'params': V[2].parameters()},
         {'params': linear.parameters()}]

In [None]:
num_epoch = 0
optimizer = optim.Adam(params, lr=LR, weight_decay=1e-4)

In [None]:
for i in range(0, MAXITER):
    optimizer.zero_grad()
    train_loss = Variable(torch.zeros(1, 1))
    y_hats_train = []
    for j in range(0, BATCH_SIZE):
        sample_index = random.randint(0, len(train_smiles) - 2)
        smile = train_smiles[sample_index]
        g, h = get_input_features(smile)  # TODO: cache this

        g2, h2 = get_input_features(smile)

        for k in range(0, T):
            message_pass(g, h, k)

        x = readout(h, h2)
        # x = F.selu( fc(x) )
        y_hat = linear(x)
        y = train_labels[sample_index]

        y_hats_train.append(y_hat)

        error = (y_hat - y) * (y_hat - y) / Variable(torch.FloatTensor([BATCH_SIZE])).view(1, 1)
        train_loss = train_loss + error

    train_loss.backward()
    optimizer.step()
    
    print(i)
    
    if i % 10 == 0: #int(len(train_smiles) / BATCH_SIZE) == 0:
        val_loss = Variable(torch.zeros(1, 1), requires_grad=False)
        y_hats_val = []
        for j in range(0, len(val_smiles)):
            g, h = get_input_features(val_smiles[j])
            g2, h2 = get_input_features(val_smiles[j])

            for k in range(0, T):
                message_pass(g, h, k)

            x = readout(h, h2)
            # x = F.selu( fc(x) )
            y_hat = linear(x)
            y = val_labels[j]

            y_hats_val.append(y_hat)

            error = (y_hat - y) * (y_hat - y) / Variable(torch.FloatTensor([len(val_smiles)])).view(1, 1)
            val_loss = val_loss + error

        y_hats_val = np.array(list(map(lambda x: x.data.numpy(), y_hats_val)))
        y_val = np.array(list(map(lambda x: x.data.numpy(), val_labels)))
        y_hats_val = y_hats_val.reshape(-1, 1)
        y_val = y_val.reshape(-1, 1)

        r2_val_old = r2_score(y_val, y_hats_val)
        r2_val_new = pearsonr(y_val, y_hats_val)[0][0] ** 2

        train_loss_ = train_loss.data.numpy()[0][0]
        val_loss_ = val_loss.data.numpy()[0][0]
        print('epoch [{}/{}] train_loss [{}] val_loss [{}] r2_val_old [{}], r2_val_new [{}]'.format(num_epoch, 100, train_loss_, val_loss_, r2_val_old, r2_val_new))
        num_epoch += 1

In [None]:
print(num_epoch, 100, train_loss_, val_loss_, r2_val_old, r2_val_new)

In [None]:
pearsonr(y_val, y_hats_val)

In [None]:
y_val == y_hats_val

In [None]:
val_loss = Variable(torch.zeros(1, 1), requires_grad=False)
y_hats_val = []
for j in range(0, len(val_smiles)):
    g, h = get_input_features(val_smiles[j])
    g2, h2 = get_input_features(val_smiles[j])

    for k in range(0, T):
        message_pass(g, h, k)

    x = readout(h, h2)
    # x = F.selu( fc(x) )
    y_hat = linear(x)
    y = val_labels[j]

    y_hats_val.append(y_hat)

    error = (y_hat - y) * (y_hat - y) / Variable(torch.FloatTensor([len(val_smiles)])).view(1, 1)
    val_loss = val_loss + error

y_hats_val = np.array(list(map(lambda x: x.data.numpy(), y_hats_val)))
y_val = np.array(list(map(lambda x: x.data.numpy(), val_labels)))
y_hats_val = y_hats_val.reshape(-1, 1)
y_val = y_val.reshape(-1, 1)

r2_val_old = r2_score(y_val, y_hats_val)
r2_val_new = pearsonr(y_val, y_hats_val)[0][0] ** 2

train_loss_ = train_loss.data.numpy()[0][0]
val_loss_ = val_loss.data.numpy()[0][0]
print('epoch [{}/{}] train_loss [{}] val_loss [{}] r2_val_old [{}], r2_val_new [{}]'.format(num_epoch, 100, train_loss_, val_loss_, r2_val_old, r2_val_new))
num_epoch += 1

In [None]:
np.allclose(y_val, y_hats_val)

In [None]:
y_val

In [None]:
y_hats_val