In [1]:
import os
if "otgnn" not in os.getcwd():
    os.chdir("../otgnn")

import numpy as np
import torch
# device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device = 'cpu'

from utils import load_model
from models import GCN, compute_ot
from datasets import PropDataset, get_loader
from graph import MolGraph

from multiprocessing import Pool

In [2]:
# load model
model_path = "output/gcnsum_5layers_balanced/run_9/models/model_best"

gcn_model, namespace = load_model(model_path, model_class = GCN, device = device)

In [3]:
print (namespace)

Namespace(agg_func='sum', batch_norm=False, batch_size=48, combined=False, cost_distance='l2', cuda=False, data='pseudonomas', data_dir='data/pseudonomas', device='cpu', distance_metric='wasserstein', dropout_ffn=0.0, dropout_gcn=0.0, epoch=199, ffn_activation='LeakyReLU', fgw_alpha=0.5, fgw_c1_c2_dist='diff', fgw_path=False, fgw_pc_path=False, grid_dir='', grid_hp_file='', grid_hp_idx='', grid_model=None, grid_splits='', gromov_max_it_inner=None, gromov_max_it_outer=100, gromov_opt_method='basic', gromov_opt_submethod='None', gw_L_tensor_dist='dot', hp_dir='', hp_model='', hp_num_iters=20, init_method='none', init_model=None, init_num=0, init_source='none', latest_train_stat=-1.0, latest_val_stat=-0.8220930232558139, linear_out=False, log_tb=False, lr=0.0003, lr_pc=0.01, max_grad_norm=10, model_dir='output/gcnsum_5layers_balanced/run_9/models', mult_num_atoms=True, n_epochs=200, n_ffn_hidden=100, n_hidden=50, n_labels=1, n_layers=5, n_pc=10, n_splits=10, name='', nce_coef=0.01, nce_ma

In [4]:
# load dataset
# dataset = PropDataset("data/sol/", data_type = "train")
dataset = PropDataset("../iclr19-graph2graph/data/qed", data_type = "wengong")

../iclr19-graph2graph/data/qed   wengong ; split= 0  num total=  88306  num pos= 0


In [5]:
# get molecular graphs for all molecules
with Pool(24) as p:
    mol_graphs = p.map(MolGraph, dataset)

In [6]:
mol_graphs_lens = [(len(dp.mols[0].atoms), len(dp.mols[1].atoms)) for dp in mol_graphs]
mol_graphs_same = [x for i, x in enumerate(mol_graphs) if mol_graphs_lens[i][0] == mol_graphs_lens[i][1]]
len(mol_graphs_same), len(mol_graphs)

(5975, 88306)

In [7]:
dp = mol_graphs_same[0]

lenx = len(dp.mols[0].atoms)
leny = len(dp.mols[1].atoms)

Hx = np.ones(lenx)/lenx
Hy = np.ones(leny)/leny

dp_embedding = gcn_model.forward(dp)
dp_x = dp_embedding[0][0:lenx,:]
dp_y = dp_embedding[0][lenx:lenx+leny,:]

In [8]:
OT_xy = compute_ot(dp_x, dp_y, opt_method = 'emd',sinkhorn_max_it = 100, 
                   H_1 = Hx, H_2 = Hy, sinkhorn_entropy = 0.1)

# this is the rescaled permutation matrix, if x and y have the same cardinality
pmatrix = OT_xy[2] * lenx
dp_y_perm = torch.mm(pmatrix, dp_y)

In [12]:
dp_x.shape[0]

22

In [10]:
dp_y_perm - dp_x

tensor([[ 3.2609e-02,  2.5362e-01,  1.8659e-01,  ..., -2.8552e-01,
         -2.4307e-02, -8.7463e-02],
        [-9.4216e-02,  1.7581e-02,  8.3255e-02,  ..., -1.2304e-01,
          1.2050e-01, -1.3111e-01],
        [-4.1250e-02,  1.6160e-01,  7.7837e-02,  ..., -1.1122e-01,
          9.3742e-02, -1.1690e-02],
        ...,
        [ 0.0000e+00, -1.4901e-08, -7.4506e-09,  ...,  1.8626e-08,
          0.0000e+00,  0.0000e+00],
        [ 2.9802e-08, -7.4506e-09,  7.4506e-09,  ..., -7.4506e-09,
          0.0000e+00,  7.4506e-09],
        [-6.4046e-03,  1.2067e-02,  1.6502e-02,  ..., -1.4632e-02,
         -1.7838e-02, -2.9542e-02]], grad_fn=<SubBackward0>)

In [11]:
dp_x

tensor([[-0.0249, -0.2365, -0.0811,  ...,  0.0842, -0.2385, -0.0374],
        [ 0.1374,  0.2087, -0.2858,  ...,  0.1649, -0.1999, -0.2026],
        [ 0.0915, -0.0769, -0.1070,  ...,  0.1143, -0.1884, -0.0709],
        ...,
        [-0.2960, -0.0722,  0.0546,  ...,  0.0597,  0.2668,  0.0544],
        [-0.2466, -0.0783,  0.0545,  ...,  0.0452,  0.2429, -0.0373],
        [ 0.0917,  0.2418, -0.1777,  ...,  0.0212,  0.0194, -0.1757]],
       grad_fn=<SliceBackward>)

In [12]:
dp_y

tensor([[ 0.0077,  0.0172,  0.1055,  ..., -0.2013, -0.2629, -0.1248],
        [ 0.1376,  0.2057, -0.0800,  ..., -0.0134, -0.4696, -0.4551],
        [ 0.1426,  0.2933, -0.0659,  ..., -0.1418, -0.1216, -0.2257],
        ...,
        [-0.2466, -0.0783,  0.0545,  ...,  0.0452,  0.2429, -0.0373],
        [ 0.0853,  0.2538, -0.1612,  ...,  0.0066,  0.0015, -0.2053],
        [ 0.1162,  0.0466, -0.0533,  ..., -0.0094, -0.0949, -0.0499]],
       grad_fn=<SliceBackward>)

In [14]:
dp_x

torch.Size([22, 50])

In [17]:
opt = torch.nn.Linear(50, 50)

In [19]:
opt(dp_x)

tensor([[ 0.0452,  0.0452, -0.2337,  ...,  0.0500,  0.1039, -0.0787],
        [-0.0911,  0.1315, -0.0566,  ...,  0.0493,  0.1781, -0.0868],
        [ 0.0914,  0.0710, -0.0827,  ...,  0.1487,  0.1193, -0.0727],
        ...,
        [-0.2943,  0.0308,  0.2118,  ...,  0.2272, -0.0664, -0.1165],
        [-0.2589,  0.0491,  0.1853,  ...,  0.2679, -0.0512, -0.0547],
        [-0.1980,  0.2009,  0.0461,  ...,  0.1865,  0.0866, -0.0772]],
       grad_fn=<AddmmBackward>)

In [21]:
opt

Linear(in_features=50, out_features=50, bias=True)

In [28]:
opt(dp_x[2,:])

tensor([ 0.0914,  0.0710, -0.0827,  0.2312, -0.2916,  0.0719, -0.1599,  0.0610,
         0.0023, -0.1936,  0.1506,  0.2091, -0.0643, -0.3261,  0.0815,  0.0889,
        -0.0095, -0.1008, -0.0126,  0.0097, -0.0660,  0.0469,  0.1013,  0.1156,
         0.0449, -0.0513, -0.1620, -0.2501,  0.0678,  0.2904,  0.0553,  0.2455,
         0.2584,  0.0030, -0.0568,  0.1300,  0.0974, -0.0355, -0.0374,  0.1742,
         0.1388,  0.2372, -0.0791,  0.1651,  0.1196, -0.0741, -0.2013,  0.1487,
         0.1193, -0.0727], grad_fn=<AddBackward0>)

In [29]:
# data loading

In [42]:
mol_graphs[0].mols

[<graph.mol_graph.Molecule at 0x7efcd82ec9b0>,
 <graph.mol_graph.Molecule at 0x7efcc1c6d908>]

In [46]:
gcn_model.forward(mol_graphs[0])[0].shape

torch.Size([49, 50])

In [50]:
data_loader = get_loader("../iclr19-graph2graph/data/qed", "wengong", 48)
for x in data_loader:
    y = (MolGraph(x[1]))
    break

../iclr19-graph2graph/data/qed   wengong ; split= 0  num total=  88306  num pos= 0


In [52]:
gcn_model(y)

(tensor([[-0.0870, -0.1251,  0.0091,  ..., -0.0373, -0.2293, -0.1346],
         [-0.0864,  0.2576, -0.1332,  ...,  0.0244, -0.2223, -0.3507],
         [-0.0322,  0.0882, -0.1250,  ...,  0.0517, -0.0901, -0.0536],
         ...,
         [ 0.0359,  0.0439, -0.0475,  ..., -0.0063, -0.0784, -0.0943],
         [ 0.0518,  0.0954, -0.0771,  ...,  0.0004,  0.0100,  0.0281],
         [ 0.0845,  0.1678, -0.3848,  ...,  0.0621,  0.0439,  0.0274]],
        grad_fn=<AddmmBackward>),
 tensor([[-21.8511],
         [-15.1406],
         [ -5.9249],
         [-14.9901],
         [-19.7632],
         [-26.0849],
         [-11.1604],
         [-21.0078],
         [-16.2720],
         [-20.9298],
         [-16.8503],
         [-10.3132],
         [-17.0383],
         [-19.7530],
         [-18.0468],
         [-17.3954],
         [-19.9377],
         [-12.3606],
         [-14.1296],
         [-22.1965],
         [-20.9364],
         [-16.4210],
         [-17.5517],
         [-13.1477],
         [-14.4498],
