In [29]:
import os
if "otgnn" not in os.getcwd():
    os.chdir("../otgnn/")

import numpy as np
import torch
# device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device = 'cpu'

from utils import load_model
from models import GCN, compute_ot
from datasets import PropDataset
from graph import MolGraph

from multiprocessing import Pool

In [3]:
# load model
model_path = "output/gcnsum_5layers_balanced/run_9/models/model_best"

gcn_model, namespace = load_model(model_path, model_class = GCN, device = device)

In [4]:
print (namespace)

Namespace(agg_func='sum', batch_norm=False, batch_size=48, combined=False, cost_distance='l2', cuda=False, data='pseudonomas', data_dir='data/pseudonomas', device='cpu', distance_metric='wasserstein', dropout_ffn=0.0, dropout_gcn=0.0, epoch=199, ffn_activation='LeakyReLU', fgw_alpha=0.5, fgw_c1_c2_dist='diff', fgw_path=False, fgw_pc_path=False, grid_dir='', grid_hp_file='', grid_hp_idx='', grid_model=None, grid_splits='', gromov_max_it_inner=None, gromov_max_it_outer=100, gromov_opt_method='basic', gromov_opt_submethod='None', gw_L_tensor_dist='dot', hp_dir='', hp_model='', hp_num_iters=20, init_method='none', init_model=None, init_num=0, init_source='none', latest_train_stat=-1.0, latest_val_stat=-0.8220930232558139, linear_out=False, log_tb=False, lr=0.0003, lr_pc=0.01, max_grad_norm=10, model_dir='output/gcnsum_5layers_balanced/run_9/models', mult_num_atoms=True, n_epochs=200, n_ffn_hidden=100, n_hidden=50, n_labels=1, n_layers=5, n_pc=10, n_splits=10, name='', nce_coef=0.01, nce_ma

In [5]:
# load dataset
# dataset = PropDataset("data/sol/", data_type = "train")
dataset = PropDataset("../iclr19-graph2graph/data/qed", data_type = "wengong")

../iclr19-graph2graph/data/qed   wengong ; split= 0  num total=  88306  num pos= 0


In [6]:
# get molecular graphs for all molecules
with Pool(24) as p:
    mol_graphs = p.map(MolGraph, dataset)

In [7]:
mol_graphs_lens = [(len(dp.mols[0].atoms), len(dp.mols[1].atoms)) for dp in mol_graphs]
mol_graphs_same = [x for i, x in enumerate(mol_graphs) if mol_graphs_lens[i][0] == mol_graphs_lens[i][1]]
len(mol_graphs_same), len(mol_graphs)

(5975, 88306)

In [144]:
dp = mol_graphs_same[0]

lenx = len(dp.mols[0].atoms)
leny = len(dp.mols[1].atoms)

Hx = np.ones(lenx)/lenx
Hy = np.ones(leny)/leny

dp_embedding = gcn_model.forward(dp)
dp_x = dp_embedding[0][0:lenx,:]
dp_y = dp_embedding[0][lenx:lenx+leny,:]

In [145]:
OT_xy = compute_ot(dp_x, dp_y, opt_method = 'emd',sinkhorn_max_it = 100, 
                   H_1 = Hx, H_2 = Hy, sinkhorn_entropy = 0.1)

# this is the rescaled permutation matrix, if x and y have the same cardinality
pmatrix = OT_xy[2] * lenx
dp_y_perm = torch.mm(pmatrix, dp_y)

In [147]:
pmatrix

tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 

In [148]:
dp_y_perm - dp_x

tensor([[ 3.2609e-02,  2.5362e-01,  1.8659e-01,  ..., -2.8552e-01,
         -2.4307e-02, -8.7463e-02],
        [-9.4216e-02,  1.7581e-02,  8.3255e-02,  ..., -1.2304e-01,
          1.2050e-01, -1.3111e-01],
        [-4.1250e-02,  1.6160e-01,  7.7837e-02,  ..., -1.1122e-01,
          9.3742e-02, -1.1690e-02],
        ...,
        [ 0.0000e+00, -1.4901e-08, -7.4506e-09,  ...,  1.8626e-08,
          0.0000e+00,  0.0000e+00],
        [ 2.9802e-08, -7.4506e-09,  7.4506e-09,  ..., -7.4506e-09,
          0.0000e+00,  7.4506e-09],
        [-6.4046e-03,  1.2067e-02,  1.6502e-02,  ..., -1.4632e-02,
         -1.7838e-02, -2.9542e-02]], grad_fn=<SubBackward0>)

In [118]:
dp_x

tensor([[-0.0404, -0.2395, -0.1349,  ...,  0.0630, -0.1939, -0.0365],
        [ 0.0579,  0.0018, -0.1433,  ...,  0.0499, -0.2140, -0.0689],
        [-0.0036,  0.2639, -0.2290,  ..., -0.0037, -0.2387, -0.1792],
        ...,
        [ 0.0034, -0.0015, -0.1184,  ..., -0.0100, -0.2122, -0.2220],
        [-0.0851, -0.1026, -0.0239,  ...,  0.0192, -0.2096, -0.0698],
        [ 0.0600,  0.0789,  0.0065,  ..., -0.0646, -0.1643, -0.1547]],
       grad_fn=<SliceBackward>)

In [111]:
dp_y

tensor([[-0.0851, -0.1026, -0.0239,  ...,  0.0192, -0.2096, -0.0698],
        [ 0.0034, -0.0015, -0.1184,  ..., -0.0100, -0.2122, -0.2220],
        [-0.0535,  0.3274, -0.0577,  ...,  0.0237, -0.1292, -0.2605],
        ...,
        [-0.0158,  0.3111, -0.0802,  ..., -0.0422, -0.2071, -0.3211],
        [-0.0781, -0.1207,  0.0157,  ..., -0.0265, -0.2068, -0.1380],
        [-0.0241,  0.2868, -0.1817,  ..., -0.0227, -0.1362, -0.1939]],
       grad_fn=<SliceBackward>)

(tensor([[-4.7938e-02, -1.9188e-01, -6.3034e-02, -3.2365e-02,  2.9387e-02,
           7.6850e-02,  1.1982e-01,  8.4841e-02,  6.0202e-02,  1.3848e-01,
           5.5980e-02, -1.5844e-01, -3.5167e-01, -9.9443e-02, -7.1597e-02,
           1.2279e-01, -1.5567e-01,  1.6104e-01, -8.8660e-02, -2.6265e-01,
           9.6059e-02, -2.6767e-01, -1.1301e-01, -1.2001e-02, -7.8244e-02,
          -3.9489e-02, -1.3516e-01, -8.6357e-02, -6.8217e-02,  4.2147e-01,
          -3.6685e-02, -1.6600e-02, -1.4952e-01, -2.3193e-01, -5.8785e-01,
          -3.7551e-02,  2.1026e-01,  4.1773e-01,  3.8300e-02, -6.1371e-02,
           1.5349e-01, -2.9925e-01, -2.8491e-02,  1.1181e-01, -7.6826e-03,
          -1.2164e-01,  1.0976e-01,  3.0522e-02, -1.7944e-01, -4.3544e-02],
         [-2.7892e-02,  2.0410e-01, -1.0660e-01, -1.0270e-01,  5.2191e-02,
           7.6625e-02,  1.7560e-01, -5.0244e-02,  1.2810e-01,  2.0907e-01,
           1.9086e-01, -3.5129e-03, -1.9149e-01,  6.5191e-02,  8.1955e-03,
           7.7935e-02, -