In [None]:
# add path (for local)
import sys
repo_root = "../"
if repo_root not in sys.path:
    sys.path.insert(0, repo_root)

In [32]:
import numpy as np
import torch
from transition import RNNLanguageModel

#ref: https://discuss.pytorch.org/t/loading-tensorflow-grucell-weights-into-pytorch/113216/4

def convert_input_kernel(kernel):
    kernel_r, kernel_z, kernel_h = np.hsplit(kernel, 3)
    return np.concatenate((kernel_z.T, kernel_r.T, kernel_h.T))
    
def convert_recurrent_kernel(kernel):
    kernel_r, kernel_z, kernel_h = np.hsplit(kernel, 3)
    return np.concatenate((kernel_z.T, kernel_r.T, kernel_h.T))

def convert_bias(bias):
    bias = bias.reshape(2, 3, -1) 
    return bias[:, [1, 0, 2], :].reshape((2, -1))

def copy_tf_gru_weights(p_rnn_layer, tf_kernel, tf_recurrent, tf_bias, layer_idx=0):
    suffix = f"_l{layer_idx}"
    with torch.no_grad():
        getattr(p_rnn_layer, f"weight_ih{suffix}").copy_(torch.from_numpy(convert_input_kernel(tf_kernel)))
        getattr(p_rnn_layer, f"weight_hh{suffix}").copy_(torch.from_numpy(convert_recurrent_kernel(tf_recurrent)))
        getattr(p_rnn_layer, f"bias_ih{suffix}").copy_(torch.from_numpy(convert_bias(tf_bias)[0]))
        getattr(p_rnn_layer, f"bias_hh{suffix}").copy_(torch.from_numpy(convert_bias(tf_bias)[1]))

def load_tf_weights_npz_to_gru_model(p_model, npz_path="tf_weights.npz"):
    data = np.load(npz_path)

    weights = [data[f'arr_{i}'] for i in range(len(data.files))]

    with torch.no_grad():
        p_model.embed.weight.copy_(torch.from_numpy(weights[0]))
        copy_tf_gru_weights(p_model.rnn, weights[1], weights[2], weights[3], layer_idx=0)
        copy_tf_gru_weights(p_model.rnn, weights[4], weights[5], weights[6], layer_idx=1)
        p_model.fc.weight.copy_(torch.from_numpy(weights[7].T))
        p_model.fc.bias.copy_(torch.from_numpy(weights[8]))
        
model = RNNLanguageModel(vocab_size=64, embed_size=64, hidden_size=256, num_layers=2, rnn_type="GRU", dropout=0.2, pad_id=0)
load_tf_weights_npz_to_gru_model(model, "tf_weights.npz")

In [41]:
import os
from rdkit import RDLogger
from language import SMILES
from node import MolSentenceNode
from transition import RNNTransition
from filter import ValidityFilter
from generator import RandomGenerator
from utils import add_sep
RDLogger.DisableLog('rdApp.*')

tokens = ['<EOS>', '#', '<BOS>', '(', ')', '-', '/', '1', '2', '3', '4', '5', '6', '7', '8', '=', 'Br', 'C', 'Cl', 'F', 'I', 'N', 'O', 'P', 'S', '[C@@H]', '[C@@]', '[C@H]', '[C@]', '[CH-]', '[CH2-]', '[N+]', '[N-]', '[NH+]', '[NH-]', '[NH2+]', '[NH3+]', '[O+]', '[O-]', '[OH+]', '[P+]', '[P@@H]', '[P@@]', '[P@]', '[PH+]', '[PH2]', '[PH]', '[S+]', '[S-]', '[S@@+]', '[S@@]', '[S@]', '[SH+]', '[n+]', '[n-]', '[nH+]', '[nH]', '[o+]', '[s+]', '\\', 'c', 'n', 'o', 's', "<PAD>", "<UNKNOWN>"]
lang = SMILES.load_tokens_list(tokens)
transition = RNNTransition(lang=lang, model=model, device="cpu", top_p=1)
transition.device = "cpu"
root = MolSentenceNode.node_from_string(lang=lang, string="CCC")
filters = [ValidityFilter()]
generator = RandomGenerator(root=root, transition=transition, filters=filters, info_interval=1)

for a, n, p in transition.transitions_with_probs(root):
    print(str(n), p)

CCC 8.32641955383906e-09
CCC# 9.991409024223685e-05
CCC<BOS> 6.931043933233738e-11
CCC( 0.1411621868610382
CCC) 3.924635620933259e-07
CCC- 8.372875299755833e-07
CCC/ 0.0015471169026568532
CCC1 0.04327229782938957
CCC2 2.444878737151157e-05
CCC3 1.9797342076799396e-07
CCC4 8.553368679997675e-09
CCC5 1.4628530049624888e-14
CCC6 5.568572472168694e-14
CCC7 2.6848251052924432e-11
CCC8 3.119293831479325e-13
CCC= 1.6574918845435604e-05
CCCBr 2.2564977371075656e-06
CCCC 0.258292019367218
CCCCl 1.7377755284542218e-05
CCCF 3.6342458997751237e-07
CCCI 1.23610206514968e-07
CCCN 0.11371456831693649
CCCO 0.05759301409125328
CCCP 1.0832563020812813e-05
CCCS 0.01867462694644928
CCC[C@@H] 0.07076003402471542
CCC[C@@] 0.00791863538324833
CCC[C@H] 0.0585256963968277
CCC[C@] 0.006021983455866575
CCC[CH-] 1.8221066322432478e-10
CCC[CH2-] 4.4599465787165116e-10
CCC[N+] 0.00015137640002649277
CCC[N-] 9.144595480847784e-08
CCC[NH+] 0.03324654698371887
CCC[NH-] 9.537796247238717e-11
CCC[NH2+] 0.054703917354345

In [42]:
generator.generate(max_generations=1000)
generator.analyze()

Starting generation...
<best reward updated> order: 1, time: 0.01, reward: 0.0891, node: CCCNC(=O)N1CCC(N2C[C@H](COC)n3cc[nH+]c32)CC1
<best reward updated> order: 2, time: 0.05, reward: 0.1793, node: CCCN1CC[C@@H](NC(=O)c2ccc(N3CCCC3)nc2)CC1=O
order: 3, time: 0.06, reward: 0.1290, node: CCC[n+]1ccc(CN(C)C#N)cc1
order: 4, time: 0.07, reward: 0.1681, node: CCCc1ccc(C(=O)N2NC(=O)NC(C)C2)cc1
<best reward updated> order: 5, time: 0.08, reward: 0.2395, node: CCCCCNC(=O)C(=O)Nc1c(C)ccc(C(C)=O)c1
<best reward updated> order: 6, time: 0.09, reward: 0.2921, node: CCCS(=O)(=O)N[C@@H](C)c1cccc(NC(F)(F)F)c1
<best reward updated> order: 7, time: 0.10, reward: 0.3675, node: CCC(=O)Nc1nc2ccc(N[C@@H](C)CC)cc2s1
order: 8, time: 0.11, reward: 0.2098, node: CCCC1(NC(=O)CNc2cccc(C(F)(F)F)c2)CC[NH2+]CC1
order: 9, time: 0.17, reward: 0.1273, node: CCCO/C([O-])=C1/C(C)=N[C@H]2C(=O)N[C@](C)(C(=O)Nc3ccccc3C)[C@H]12
order: 10, time: 0.18, reward: 0.3395, node: CCCOc1ccc(C(=O)NCn2c(=S)oc3ccccc32)nc1
order: 11, ti

In [31]:
lang.save("ported.lang")
model.save("ported")