In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import random
import torch
import numpy as np
import torch.nn as nn
import deepchem as dc
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt

from sklearn.metrics import r2_score
import numpy as np

import random
from collections import OrderedDict
from scipy.stats import pearsonr

from collections import OrderedDict
from torch.autograd import Variable
from rdkit import Chem, DataStructs

from data import prepare_datasets



In [3]:
random.seed(2)
torch.manual_seed(2)
np.random.seed(2)

In [4]:
DATASET = 'data.csv'
T = 3
BATCH_SIZE = 2
MAXITER = 40000
LIMIT = 0
LR = 5e-4

In [5]:
R = nn.Linear(150, 128)   # function R
U = {0: nn.Linear(156, 75), 1: nn.Linear(156, 75), 2: nn.Linear(156, 75)}  # function M
V = {0: nn.Linear(75, 75), 1: nn.Linear(75, 75), 2: nn.Linear(75, 75)}  # function U (but without edge features)
E = nn.Linear(6, 6)  # function U (but without atom features)

In [6]:
def readout(h, h2):
    catted_reads = map(lambda x: torch.cat([h[x[0]], h2[x[1]]], 1), zip(h2.keys(), h.keys()))
    activated_reads = map(lambda x: F.selu(R(x)), catted_reads)
    readout = Variable(torch.zeros(1, 128))
    for read in activated_reads:
        readout = readout + read
    return F.tanh(readout)

In [7]:
def message_pass(g, h, k):
    for v in g.keys():
        neighbors = g[v]
        for neighbor in neighbors:
            e_vw = neighbor[0]  # edge feature variable
            w = neighbor[1]  # number of connected atom
            m_w = V[k](h[w])  # calc hidden variable of atom
            m_e_vw = E(e_vw)  # calc hidden variable of edge 
            reshaped = torch.cat((h[v], m_w, m_e_vw), 1)  # calculating concatenated hid states of atoms and edge
            h[v] = F.selu(U[k](reshaped))  

In [8]:
def get_input_features(smile):
    """
    Get input features for edges (g) and atoms (h).
    """
    g = OrderedDict({})
    h = OrderedDict({})
    molecule = Chem.MolFromSmiles(smile)
    for i in range(0, molecule.GetNumAtoms()):
        atom_i = molecule.GetAtomWithIdx(i)
        h[i] = Variable(torch.FloatTensor(dc.feat.graph_features.atom_features(atom_i).astype(np.float32))).view(1, 75)  # mk: added astype
        for j in range(0, molecule.GetNumAtoms()):
            e_ij = molecule.GetBondBetweenAtoms(i, j)
            if e_ij != None:
                e_ij = list(map(lambda x: 1 if x == True else 0,    # mk: added list
                           dc.feat.graph_features.bond_features(e_ij)))  # ADDED edge feat
                e_ij = Variable(torch.FloatTensor(e_ij).view(1, 6))
                atom_j = molecule.GetAtomWithIdx(j)
                if i not in g:
                    g[i] = []
                    g[i].append((e_ij, j))
    return g, h

In [9]:
train_smiles, train_labels, val_smiles, val_labels = prepare_datasets(DATASET)

About to generate scaffolds
Generating scaffold 0/1638
Generating scaffold 1000/1638
About to sort in scaffold sets




In [10]:
linear = nn.Linear(128, 1)
params = [{'params': R.parameters()},
         {'params': U[0].parameters()},
         {'params': U[1].parameters()},
         {'params': U[2].parameters()},
         {'params': E.parameters()},
         {'params': V[0].parameters()},
         {'params': V[1].parameters()},
         {'params': V[2].parameters()},
         {'params': linear.parameters()}]

In [11]:
num_epoch = 0
optimizer = optim.Adam(params, lr=LR, weight_decay=1e-4)

In [20]:
for i in range(0, MAXITER):
    optimizer.zero_grad()
    train_loss = Variable(torch.zeros(1, 1))
    y_hats_train = []
    for j in range(0, BATCH_SIZE):
        sample_index = random.randint(0, len(train_smiles) - 2)
        smile = train_smiles[sample_index]
        g, h = get_input_features(smile)  # TODO: cache this

        g2, h2 = get_input_features(smile)

        for k in range(0, T):
            message_pass(g, h, k)

        x = readout(h, h2)
        # x = F.selu( fc(x) )
        y_hat = linear(x)
        y = train_labels[sample_index]

        y_hats_train.append(y_hat)

        error = (y_hat - y) * (y_hat - y) / Variable(torch.FloatTensor([BATCH_SIZE])).view(1, 1)
        train_loss = train_loss + error

    train_loss.backward()
    optimizer.step()
    
    print(i)
    
    if i % 10 == 0: #int(len(train_smiles) / BATCH_SIZE) == 0:
        val_loss = Variable(torch.zeros(1, 1), requires_grad=False)
        y_hats_val = []
        for j in range(0, len(val_smiles)):
            g, h = get_input_features(val_smiles[j])
            g2, h2 = get_input_features(val_smiles[j])

            for k in range(0, T):
                message_pass(g, h, k)

            x = readout(h, h2)
            # x = F.selu( fc(x) )
            y_hat = linear(x)
            y = val_labels[j]

            y_hats_val.append(y_hat)

            error = (y_hat - y) * (y_hat - y) / Variable(torch.FloatTensor([len(val_smiles)])).view(1, 1)
            val_loss = val_loss + error

        y_hats_val = np.array(list(map(lambda x: x.data.numpy(), y_hats_val)))
        y_val = np.array(list(map(lambda x: x.data.numpy(), val_labels)))
        y_hats_val = y_hats_val.reshape(-1, 1)
        y_val = y_val.reshape(-1, 1)

        r2_val_old = r2_score(y_val, y_hats_val)
        r2_val_new = pearsonr(y_val, y_hats_val)[0][0] ** 2

        train_loss_ = train_loss.data.numpy()[0][0]
        val_loss_ = val_loss.data.numpy()[0][0]
        print('epoch [{}/{}] train_loss [{}] val_loss [{}] r2_val_old [{}], r2_val_new [{}]'.format(num_epoch, 100, train_loss_, val_loss_, r2_val_old, r2_val_new))
        num_epoch += 1

0
epoch [10/100] train_loss [0.15870541334152222] val_loss [1.4665929079055786] r2_val_old [-0.18785826061897115], r2_val_new [0.02333693757918809]
1
2
3
4
5
6
7
8
9
10
epoch [11/100] train_loss [0.44134455919265747] val_loss [1.286192774772644] r2_val_old [-0.04174443652015025], r2_val_new [0.02541236044262818]
11
12
13
14
15
16
17
18
19
20
epoch [12/100] train_loss [0.46483224630355835] val_loss [1.2226706743240356] r2_val_old [0.009705196392872262], r2_val_new [0.016342409636799227]
21
22
23
24
25
26
27
28
29
30


KeyboardInterrupt: 

In [15]:
print(num_epoch, 100, train_loss_, val_loss_, r2_val_old, r2_val_new)

8 100 [ 2.79654908] [ 1.84487927] -0.494249230909 [ 0.08840133]


In [16]:
pearsonr(y_val, y_hats_val)

(array([ 0.29732361], dtype=float32), array([ 0.00011051], dtype=float32))

In [22]:
y_val == y_hats_val

array([], shape=(164, 0), dtype=bool)

In [25]:
val_loss = Variable(torch.zeros(1, 1), requires_grad=False)
y_hats_val = []
for j in range(0, len(val_smiles)):
    g, h = get_input_features(val_smiles[j])
    g2, h2 = get_input_features(val_smiles[j])

    for k in range(0, T):
        message_pass(g, h, k)

    x = readout(h, h2)
    # x = F.selu( fc(x) )
    y_hat = linear(x)
    y = val_labels[j]

    y_hats_val.append(y_hat)

    error = (y_hat - y) * (y_hat - y) / Variable(torch.FloatTensor([len(val_smiles)])).view(1, 1)
    val_loss = val_loss + error

y_hats_val = np.array(list(map(lambda x: x.data.numpy(), y_hats_val)))
y_val = np.array(list(map(lambda x: x.data.numpy(), val_labels)))
y_hats_val = y_hats_val.reshape(-1, 1)
y_val = y_val.reshape(-1, 1)

r2_val_old = r2_score(y_val, y_hats_val)
r2_val_new = pearsonr(y_val, y_hats_val)[0][0] ** 2

train_loss_ = train_loss.data.numpy()[0][0]
val_loss_ = val_loss.data.numpy()[0][0]
print('epoch [{}/{}] train_loss [{}] val_loss [{}] r2_val_old [{}], r2_val_new [{}]'.format(num_epoch, 100, train_loss_, val_loss_, r2_val_old, r2_val_new))
num_epoch += 1

epoch [13/100] train_loss [1.6652722358703613] val_loss [1.2062128782272339] r2_val_old [0.023035097786379355], r2_val_new [0.024278961348900063]


In [31]:
np.allclose(y_val, y_hats_val)

False

In [32]:
y_val

array([[ 1.71632838],
       [-0.5826391 ],
       [-0.5826391 ],
       [-0.5826391 ],
       [-0.5826391 ],
       [-0.5826391 ],
       [-0.5826391 ],
       [-0.5826391 ],
       [-0.5826391 ],
       [-0.5826391 ],
       [-0.5826391 ],
       [ 1.71632838],
       [-0.5826391 ],
       [ 1.71632838],
       [-0.5826391 ],
       [-0.5826391 ],
       [ 1.71632838],
       [-0.5826391 ],
       [-0.5826391 ],
       [-0.5826391 ],
       [-0.5826391 ],
       [-0.5826391 ],
       [-0.5826391 ],
       [-0.5826391 ],
       [-0.5826391 ],
       [-0.5826391 ],
       [-0.5826391 ],
       [ 1.71632838],
       [-0.5826391 ],
       [-0.5826391 ],
       [ 1.71632838],
       [ 1.71632838],
       [-0.5826391 ],
       [-0.5826391 ],
       [-0.5826391 ],
       [ 1.71632838],
       [ 1.71632838],
       [-0.5826391 ],
       [-0.5826391 ],
       [-0.5826391 ],
       [-0.5826391 ],
       [-0.5826391 ],
       [ 1.71632838],
       [ 1.71632838],
       [-0.5826391 ],
       [-0

In [33]:
y_hats_val

array([[ 0.02747247],
       [ 0.44075501],
       [ 0.36345196],
       [ 0.32603884],
       [ 0.37638909],
       [ 0.08382586],
       [ 0.36810356],
       [ 0.38451442],
       [ 0.40436491],
       [ 0.17843622],
       [ 0.48779029],
       [-0.35172018],
       [ 0.360596  ],
       [ 0.26199052],
       [ 0.35432297],
       [ 0.14550577],
       [ 0.34735271],
       [ 0.09480251],
       [ 0.27808893],
       [ 0.15175366],
       [ 0.18478948],
       [ 0.04963461],
       [ 0.06729124],
       [ 0.15761252],
       [ 0.0389532 ],
       [ 0.0073465 ],
       [ 0.29882187],
       [ 0.40100855],
       [ 0.36948037],
       [ 0.44961333],
       [ 0.31739041],
       [ 0.24621597],
       [ 0.11953473],
       [-0.03070324],
       [ 0.14854182],
       [ 0.04740882],
       [ 0.4221108 ],
       [ 0.28519773],
       [ 0.2099995 ],
       [ 0.12519707],
       [ 0.25319862],
       [ 0.25211006],
       [ 0.34313831],
       [ 0.09453893],
       [ 0.27000189],
       [ 0