In [1]:
import os
if "ntbk" in os.getcwd():
    os.chdir("..")
print(os.getcwd())

import sys
sys.path.append(os.path.join(os.getcwd(), "otgnn"))


%load_ext autoreload
%autoreload 2

from otgnn.models import GCN
from otgnn.graph import MolGraph
from otgnn.utils import save_model, load_model, StatsTracker
from otgnn.graph import SYMBOLS, FORMAL_CHARGES, BOND_TYPES, get_bt_index

from mol_opt.mol_opt import MolOpt
from mol_opt.data_mol_opt import MolOptDataset
from mol_opt.data_mol_opt import get_loader
from mol_opt.decoder_mol_opt import MolOptDecoder
from mol_opt.arguments import get_args
from mol_opt.train_mol_opt import main, get_latest_model
from mol_opt.ot_utils import encode_target, FGW
from mol_opt.train_mol_opt import ft
from mol_opt.ot_utils import Penalty as PenaltyOld

from rdkit.Chem import MolFromSmiles

from molgen.metrics.Penalty import Penalty as PenaltyNew

import torch
from torch import nn
import numpy as np
import time

/home/octav/gitrepos/tum-thesis


In [2]:
sys.argv = ["", "-cuda"]
args = get_args()
model = "pointwise10-onebatch"
args.output_dir = f"mol_opt/output_{model}/"

model_iter = 1000 

model_name = "model_{}_{}".format(model, model_iter)
model_decode_name = "model_{}_decode_{}".format(model, model_iter)

molopt, config = load_model(args.output_dir + model_name, MolOpt, args.device)
print (molopt, config)

molopt_decoder, config_decoder = load_model(args.output_dir + model_decode_name, MolOptDecoder, args.device)
print (molopt_decoder, config_decoder)

loss = FGW(alpha = 0.5)

n_data = 36
data_loader = get_loader("iclr19-graph2graph/data/qed", "train_pairs", n_data, True)
for i in data_loader:
    X = (MolGraph(i[0]))
    Y = (MolGraph(i[1]))
    break

MolOpt(
  (GCN): GCN(
    (W_message_i): Linear(in_features=100, out_features=100, bias=False)
    (W_message_h): Linear(in_features=100, out_features=100, bias=False)
    (W_message_o): Linear(in_features=193, out_features=70, bias=True)
    (W_mol_h): Linear(in_features=70, out_features=100, bias=True)
    (W_mol_o): Linear(in_features=100, out_features=1, bias=True)
    (dropout_gcn): Dropout(p=0.0, inplace=False)
    (dropout_ffn): Dropout(p=0.0, inplace=False)
  )
  (opt0): Linear(in_features=70, out_features=100, bias=True)
  (opt1): Linear(in_features=100, out_features=70, bias=True)
) Namespace(N_transformer=6, agg_func='sum', annealing_rate=0.0005, batch_norm=False, batch_size=6, connectivity=True, connectivity_hard=False, connectivity_lambda=5e-05, cuda=True, device='cuda:0', dim_tangent_space=40, dropout_ffn=0.0, dropout_gcn=0.0, dropout_transformer=0.1, euler_characteristic_penalty=True, euler_lambda=0.0001, ffn_activation='LeakyReLU', init_decoder_model='pointwise10-onebat

In [3]:
pen = PenaltyNew(args, prev_epoch = 0)

In [4]:
yhat_embedding = molopt.forward(X)
yhat_logits = molopt_decoder.forward(yhat_embedding, X, Y)
yhat_labels = molopt_decoder.discretize(*yhat_logits)

pred_pack = (yhat_labels, yhat_logits, Y.scope), Y

In [5]:
for idx in range(50):
    stats_tracker = StatsTracker()
    con_loss, val_loss, eul_loss = pen(*pred_pack, idx)
    stats_tracker.add_stat('conn_penalty', con_loss.item(), n_data)
    stats_tracker.add_stat('val_penalty', val_loss.item(), n_data)
    stats_tracker.add_stat('euler_penalty', eul_loss.item(), n_data)
    
    stats_tracker.print_stats("epoch={}".format(idx))
    pen.log()

epoch=0
 conn_penalty:0.3837645
 val_penalty:0.9759998
 euler_penalty:0.1092538
Penalty params: tau=1.00000 conn_l=0.02500 val_l=0.07000 euler_l=0.30000 epoch=1
epoch=1
 conn_penalty:0.0000004
 val_penalty:0.9690413
 euler_penalty:0.0350247
Penalty params: tau=1.00000 conn_l=0.02500 val_l=0.07000 euler_l=0.30000 epoch=1
epoch=2
 conn_penalty:1.1512926
 val_penalty:1.2964340
 euler_penalty:0.0571473
Penalty params: tau=0.99990 conn_l=0.02500 val_l=0.07001 euler_l=0.30003 epoch=2
epoch=3
 conn_penalty:1.9188207
 val_penalty:1.0524705
 euler_penalty:0.0266666
Penalty params: tau=0.99980 conn_l=0.02501 val_l=0.07001 euler_l=0.30006 epoch=3
epoch=4
 conn_penalty:0.3837645
 val_penalty:0.8931156
 euler_penalty:0.0173650
Penalty params: tau=0.99970 conn_l=0.02501 val_l=0.07002 euler_l=0.30009 epoch=4
epoch=5
 conn_penalty:1.1512926
 val_penalty:1.1441570
 euler_penalty:0.0000000
Penalty params: tau=0.99960 conn_l=0.02501 val_l=0.07003 euler_l=0.30012 epoch=5
epoch=6
 conn_penalty:2.3025852
 v

In [39]:
PenaltyNew(args, prev_epoch = 0).log()

Penalty params: tau=1.00000 conn_l=0.02500 val_l=0.07000 euler_l=0.30000 epoch=1


In [79]:
bonds = pen(*pred_pack, idx)

In [80]:
bonds.sum(axis = 1)

tensor([1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000], device='cuda:0',
       grad_fn=<SumBackward1>)

In [91]:
bonds.shape

torch.Size([2414, 5])

In [11]:
# construct adjacency matrix from this thing
yhat_labels

(tensor([3, 1, 0, 0, 0, 0, 2, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 7, 0,
         0, 0, 0, 7, 0, 0, 1, 1, 0, 2, 0, 0, 0, 1, 0, 0, 2, 0, 0, 0, 1, 0, 0, 2,
         0, 0, 0, 1, 0, 0, 0, 4, 0, 0, 0, 2, 0, 2, 0, 1, 0, 0, 0, 8, 1, 0, 1, 1,
         0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 2,
         1, 0, 3, 2, 2, 1, 0, 0, 0, 0, 0, 2, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        device='cuda:0'),
 tensor([4, 0, 3, 1, 3, 3, 4, 0, 1, 0, 4, 2, 1, 1, 0, 1, 3, 3, 1, 1, 2, 2, 3, 1,
         3, 3, 2, 4, 1, 1, 0, 4, 1, 4, 0, 3, 4, 0, 2, 1, 4, 4, 3, 1, 2, 3, 0, 3,
         3, 0, 1, 3, 1, 3, 4, 4, 4, 1, 4, 2, 4, 0, 0, 2, 0, 0, 1, 3, 0, 1, 4, 2,
         3, 2, 2, 2, 3, 1, 2, 4, 1, 1, 0, 1, 4, 1, 1, 1, 0, 1, 4, 3, 1, 3, 3, 3,
         0, 1, 4, 4, 2, 4, 3, 3, 0, 3, 0, 2, 1, 3, 1, 4, 0, 3, 1, 0, 4, 2, 4, 3],
        device='cuda:0'),
 tensor([4, 4, 4,  ..., 4, 4, 4], device='cuda:0'))

In [41]:
import numpy as np
np.random.randint(100)

75

In [42]:
bonds = Y.get_graph_outputs()[0]["BOND_TYPES"]

In [43]:
adjM = 1 - bonds[:,:,-1]

In [61]:
pen.conn_penalty(adjM)

tensor(1.5259e-05)

In [46]:
N = adjM.shape[0]
device = adjM.device

# get Laplacian
L = torch.diag(torch.matmul(adjM, torch.ones(N, device=device))) - adjM
L_mod = L + torch.ones_like(L, device=device) / N

# calculate log dets
# comment this line to use the rescaling procedure
# return (-torch.logdet(L_mod + self.conn_eps * torch.eye(N, device=device)))

# calculate rescaled eigenvalues
eigvals = torch.symeig(L_mod, eigenvectors = True)[0]
torch.sum(torch.exp(-eigvals))
        

tensor(5.7386)

In [57]:
eps = 1e-9
beta = 1e-3
- torch.sum(torch.log(eigvals.clamp(min = eps, max = beta))) + len(eigvals) * np.log(beta)

tensor(1.5259e-05)