# Message Passing

In this notebook, we'll introduce the concept of Graph Convolution Neural Networks, specifically Message Passing Neural Networks.

In [1]:
smiles = []
with open('../data/01_raw/moses/dataset_v1.csv', 'r') as f:
    _ = f.readline()
    smiles = [f.readline().split(',')[0] for _ in range(120)]
    f.close()

from rdkit import Chem
from rdkit import RDLogger
from rdkit.Chem import Draw
from rdkit.Chem.Draw.MolDrawing import MolDrawing, DrawingOptions
DrawingOptions.bondLineWidth=1.8

import os,sys,inspect
sys.path.insert(0,'/home/icarus/app/src') 

lg = RDLogger.logger()

lg.setLevel(RDLogger.CRITICAL)

import torch as torch
#torch.multiprocessing.set_start_method('spawn')



In [1]:
%load_ext line_profiler

## Feature Extraction

In [2]:
from structures.moltree import MolTree
from structures.mol_features import N_ATOM_FEATS, N_BOND_FEATS
mt = MolTree(smiles[0])

import networkx as nx
import matplotlib.pyplot as plt
plt.subplot(121)
nx.draw(mt.to_networkx(), with_labels=True)
plt.show()

<Figure size 640x480 with 1 Axes>

In [3]:
graph, atom_features, bond_features = mt.encode(recssemble=True)
print(atom_features.shape)
print(bond_features.shape)

del atom_features
del bond_features

torch.Size([19, 97])
torch.Size([40, 13])


In [4]:
import dgl
import dgl.function as fn
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph, mean_nodes

## Loopy Belief Propagation



In [5]:
mpn_loopy_bp_msg = fn.copy_src(src='msg', out='msg')
mpn_loopy_bp_reduce = fn.sum(msg='msg', out='accum_msg')

class LoopyBeliefProp_Update(nn.Module):
    def __init__(self, hidden_size: int, args = None):
        super(LoopyBeliefProp_Update, self).__init__()
        self.device = args.device
        self.use_cuda = args.use_cuda
        self.hidden_size = hidden_size
        
        self.W_h = nn.Linear(  # y = xA^T + b
            in_features=hidden_size,
            out_features=hidden_size,
            bias=False
        )
        if self.use_cuda: self.W_h = self.W_h.cuda()
    
    def forward(self, nodes):
        msg_input = nodes.data['msg_input']
        accum_msg = nodes.data['accum_msg'].cuda() if self.use_cuda else nodes.data['accum_msg']
        msg_delta = self.W_h(nodes.data['accum_msg'])
        msg = F.relu(msg_input + msg_delta)
        return {'msg': msg}

## Reduce

In [6]:
mpn_gather_msg = fn.copy_edge(edge='msg', out='msg')
mpn_gather_reduce = fn.sum(msg='msg', out='m')

class MPN_Gather_Update(nn.Module):
    def __init__(self, hidden_size: int, device=torch.device('cpu')):
        super(MPN_Gather_Update, self).__init__()
        self.device = args.device
        self.hidden_size = hidden_size
        self.use_cuda = args.use_cuda
        
        self.W_o = nn.Linear(N_ATOM_FEATS + hidden_size, hidden_size)
        if self.use_cuda: self.W_o = self.W_o.cuda()
        
    def forward(self, nodes):
        m, x = nodes.data['m'], nodes.data['x']
        if self.use_cuda: 
            m, x = m.cuda(), x.cuda()
        h = F.relu(self.W_o(torch.cat([x, m], 1)))
        if self.use_cuda: h = h.cuda()
            
        return {
            'h': h
        }

## MPN

In [7]:
from models.modules import GraphConvNet

ModuleNotFoundError: No module named 'jtnn_vae'

In [None]:
from torch.utils.data import Dataset
from structures import Vocab
from typing import List, Tuple

from utils.data import JTNNCollator

In [None]:
import structures
import structures.mol_features as mf
import torch as torch



In [None]:
from utils.data import JTNNDataset

In [None]:
from torch.utils.data import DataLoader

class ArgsTemp():
    def __init__(self, hidden_size, depth, device):
        self.hidden_size = hidden_size
        self.depth = depth
        self.device = device
        self.use_cuda = torch.cuda.is_available()
        
args = ArgsTemp(200,3, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
print(args.depth)

dataset = JTNNDataset(data='valid', vocab='vocab', training=True, intermediates=True)
vocab = dataset.vocab

dataloader = DataLoader(
        dataset,
        batch_size=10,
        shuffle=True,
        num_workers=1,
        collate_fn=JTNNCollator(vocab, True, intermediates=True),
        drop_last=True,
        worker_init_fn=None)


In [None]:
data_iter = iter(dataloader)
batch = next(data_iter)
#for key in batch.keys():
    #print("{}: {}".format(key, batch[key]))
mol_tree = batch['mol_trees'][0]
graph_batch = batch['mol_graph_batch']
line_graph = graph_batch.line_graph(backtracking=False, shared=True)

#mols = [Chem.MolFromSmiles(mg.smiles) for mg in batch]

import networkx as nx
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
import numpy as np
from utils.cuda import cuda, move_dgl_to_cuda
#plt.subplot(121)
#nx.draw(mol_tree.to_networkx(), with_labels=True)
#plt.show()
#plt.subplot(121)
#nx.draw(line_graph.to_networkx(), with_labels=True)
#plt.show()

from rdkit.Chem import Draw
#smile_img = Draw.MolToImage(Chem.MolFromSmiles(batch['smiles'][0]))
#imshow(smile_img)
#Draw.MolsToGridImage(mols[0:6],molsPerRow=3,subImgSize=(300,200),legends=smiles)

def move_to_cuda(mol_batch):
        for t in mol_batch['mol_trees']:
            move_dgl_to_cuda(t)

        move_dgl_to_cuda(mol_batch['mol_graph_batch'])
        if 'cand_graph_batch' in mol_batch:
            move_dgl_to_cuda(mol_batch['cand_graph_batch'])
        if mol_batch.get('stereo_cand_graph_batch') is not None:
            move_dgl_to_cuda(mol_batch['stereo_cand_graph_batch'])

In [None]:
"""
%matplotlib inline
for i,img in enumerate(batch['smiles_img']):
    fig, axs = plt.subplots(len(batch['smiles_img']))
    fig.suptitle('Candidates')
    axs[0].imshow(img)
    plt.imsave('mol'+str(i)+'.png', np.array(img))
    axs[1].imshow(batch['img_grid'][i])
    plt.imsave('mol'+str(i)+'cands.png', np.array(batch['img_grid'][i]))
"""

import datetime
from tqdm import tqdm
from models.modules.jtgcn import DGLJTMPN
from models.modules.jtnn_enc import DGLJTNNEncoder

embedding = nn.Embedding(vocab.size(), args.hidden_size)
mpn = GraphConvNet(args)
jtmpn = DGLJTMPN(args.hidden_size, args.depth)
jtenc = DGLJTNNEncoder(vocab, args.hidden_size, embedding)

programmers = ['Alex','Nicole','Sara','Etienne','Chelsea','Jody','Marianne']

base = datetime.datetime.today()
dates = base - np.arange(180) * datetime.timedelta(days=1)
z = np.random.poisson(size=(len(programmers), len(dates)))

#_ = [next(data_iter) for _ in range(200)]
#print(mpn(batch['mol_graph_batch'])[:1,0])
#fig, ax = plt.subplots()
#plt.ion()
#fig.show()
#fig.tight_layout()
#ax.clear()

historical_t = None
historical_j = None
batch_size=5
for i in tqdm(range(0,5)):
    mol_batch = next(data_iter)
    t = mpn(mol_batch['mol_graph_batch'])[:1].cpu().detach()
    j_batch, j = jtenc(mol_batch['mol_trees'])
    j = j[:1].cpu().detach()
    if i == 0:
        historical_t = t
        historical_j = j
    else:
        historical_t = torch.cat((historical_t,t),axis=0)
        historical_j = torch.cat((historical_j,j),axis=0)
        
plt.imshow(historical_t, cmap='viridis')
plt.colorbar()
plt.show()
plt.imshow(historical_j, cmap='viridis')
plt.colorbar()
plt.show()

In [None]:
hidden_size = 200
latent_size = 72
T_mean = nn.Linear(hidden_size, latent_size // 2)
T_var = nn.Linear(hidden_size, latent_size // 2)
G_mean = nn.Linear(hidden_size, latent_size // 2)
G_var = nn.Linear(hidden_size, latent_size // 2)

def sample(tree_vec, mol_vec, e1=None, e2=None):
        tree_mean = T_mean(tree_vec)
        tree_log_var = -torch.abs(T_var(tree_vec))
        mol_mean = G_mean(mol_vec)
        mol_log_var = -torch.abs(G_var(mol_vec))

        epsilon = cuda(torch.randn(*tree_mean.shape)) if e1 is None else e1
        tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon
        epsilon = cuda(torch.randn(*mol_mean.shape)) if e2 is None else e2
        mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon

        z_mean = torch.cat([tree_mean, mol_mean], 0)
        z_log_var = torch.cat([tree_log_var, mol_log_var], 0)

        return tree_vec, mol_vec, z_mean, z_log_var

In [None]:

for i in range(batch_size):
    tree_vec, mol_vec, z_mean, z_log_var = sample(historical_j[i], historical_t[i],None,None)
    print(-0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size)