In [1]:
#for local
import sys
repo_root = "../"
if repo_root not in sys.path:
    sys.path.insert(0, repo_root)

In [32]:
import numpy as np
import torch
from transition import RNNLanguageModel

#ref: https://discuss.pytorch.org/t/loading-tensorflow-grucell-weights-into-pytorch/113216/4

def convert_input_kernel(kernel):
    kernel_r, kernel_z, kernel_h = np.hsplit(kernel, 3)
    return np.concatenate((kernel_z.T, kernel_r.T, kernel_h.T))
    
def convert_recurrent_kernel(kernel):
    kernel_r, kernel_z, kernel_h = np.hsplit(kernel, 3)
    return np.concatenate((kernel_z.T, kernel_r.T, kernel_h.T))

def convert_bias(bias):
    bias = bias.reshape(2, 3, -1) 
    return bias[:, [1, 0, 2], :].reshape((2, -1))

def copy_tf_gru_weights(p_rnn_layer, tf_kernel, tf_recurrent, tf_bias, layer_idx=0):
    suffix = f"_l{layer_idx}"
    with torch.no_grad():
        getattr(p_rnn_layer, f"weight_ih{suffix}").copy_(torch.from_numpy(convert_input_kernel(tf_kernel)))
        getattr(p_rnn_layer, f"weight_hh{suffix}").copy_(torch.from_numpy(convert_recurrent_kernel(tf_recurrent)))
        getattr(p_rnn_layer, f"bias_ih{suffix}").copy_(torch.from_numpy(convert_bias(tf_bias)[0]))
        getattr(p_rnn_layer, f"bias_hh{suffix}").copy_(torch.from_numpy(convert_bias(tf_bias)[1]))

def load_tf_weights_npz_to_gru_model(p_model, npz_path="tf_weights.npz"):
    data = np.load(npz_path)

    weights = [data[f'arr_{i}'] for i in range(len(data.files))]

    with torch.no_grad():
        p_model.embed.weight.copy_(torch.from_numpy(weights[0]))
        copy_tf_gru_weights(p_model.rnn, weights[1], weights[2], weights[3], layer_idx=0)
        copy_tf_gru_weights(p_model.rnn, weights[4], weights[5], weights[6], layer_idx=1)
        p_model.fc.weight.copy_(torch.from_numpy(weights[7].T))
        p_model.fc.bias.copy_(torch.from_numpy(weights[8]))
        
model = RNNLanguageModel(vocab_size=64, embed_size=64, hidden_size=256, num_layers=2, rnn_type="GRU", dropout=0.2, pad_id=0)
load_tf_weights_npz_to_gru_model(model, "tf_weights.npz")

In [None]:
import os
from rdkit import RDLogger
from language import SMILES
from node import MolSentenceNode
from transition import RNNTransition
from filter import ValidityFilter
from generator import RandomGenerator
from utils import add_sep
RDLogger.DisableLog('rdApp.*')

tokens = ['<EOS>', '#', '<BOS>', '(', ')', '-', '/', '1', '2', '3', '4', '5', '6', '7', '8', '=', 'Br', 'C', 'Cl', 'F', 'I', 'N', 'O', 'P', 'S', '[C@@H]', '[C@@]', '[C@H]', '[C@]', '[CH-]', '[CH2-]', '[N+]', '[N-]', '[NH+]', '[NH-]', '[NH2+]', '[NH3+]', '[O+]', '[O-]', '[OH+]', '[P+]', '[P@@H]', '[P@@]', '[P@]', '[PH+]', '[PH2]', '[PH]', '[S+]', '[S-]', '[S@@+]', '[S@@]', '[S@]', '[SH+]', '[n+]', '[n-]', '[nH+]', '[nH]', '[o+]', '[s+]', '\\', 'c', 'n', 'o', 's', "<PAD>", "<UNKNOWN>"]
lang = SMILES.load_tokens_list(tokens)
transition = RNNTransition(lang=lang, model=model, device="cpu", top_p=1)
transition.device = "cpu"
root = MolSentenceNode.node_from_string(lang=lang, string="")
print(root.id_tensor)
filters = [ValidityFilter()]
generator = RandomGenerator(root=root, transition=transition, filters=filters, info_interval=1)

for a, n, p in transition.transitions_with_probs(root):
    print(str(n), p)
    
generator.generate(max_generations=1000)
generator.analyze()

Starting generation...


<best reward updated> order: 1, time: 0.00, reward: 0.2925, node: CC[C@@H](NC(=O)CN1C(=O)CC[C@H]1c1ccncc1)c1ccccc1
order: 2, time: 0.02, reward: 0.2169, node: C(=O)N[C@@H](C(=O)Nc1ccc(Br)cc1F)C(F)(F)F
<best reward updated> order: 3, time: 0.03, reward: 0.3058, node: N#Cc1c(-c2ccc(Cl)cc2Cl)ncnc1N1CCOCC1
order: 4, time: 0.04, reward: 0.2809, node: NC(=O)COc1ccccc1[C@H]1c2ccccc2CCN1S(=O)(=O)c1ccccc1
order: 5, time: 0.05, reward: 0.1323, node: NCc1cccc(-n2ccnc2)c1
<best reward updated> order: 6, time: 0.06, reward: 0.3975, node: CCCCC1CCN(C(=O)CN(C)C(=O)c2ccc(C(F)(F)F)cc2)CC1
order: 7, time: 0.07, reward: 0.0884, node: NC(=O)[C@H]1CC(O)(CNc2ccc3c(c2)OCCO3)CC1
order: 8, time: 0.10, reward: 0.3197, node: CCN1CCO[C@H](SCc2ccc(Cl)s2)C1
order: 9, time: 0.10, reward: -0.0486, node: CC[C@@H](O)[C@H]1CC[NH+]2CCCC[C@H]1[NH2+]C2(C)C
order: 10, time: 0.11, reward: 0.1210, node: C[C@H]1CN(C(=O)c2cnc([C@@H]3CCCO3)sc2=O)CCO1
order: 11, time: 0.13, reward: 0.0666, node: N1[C@@H]2CC[C@@H]1C[C@@H](c1ccccc1

tensor([[2]])
 3.829774414043641e-06
# 1.7877755453810096e-07
<BOS> 1.133628746430304e-09
( 2.8821691699931762e-08
) 7.896676379459677e-07
- 8.279344910988584e-06
/ 1.4682313121738844e-05
1 2.7848954253784086e-09
2 5.656488055727493e-10
3 3.112163304841431e-11
4 1.87349136204773e-10
5 2.6842549016858896e-11
6 2.9448263827447363e-09
7 1.1554129208946051e-08
8 1.320275000438187e-09
= 2.3425814106303733e-06
Br 0.0008404434192925692
C 0.8249887824058533
Cl 0.002484105760231614
F 0.004774454515427351
I 6.113618292147294e-05
N 0.02987510897219181
O 0.12642180919647217
P 1.4781974186917068e-06
S 0.0005980784771963954
[C@@H] 1.1548418115125969e-05
[C@@] 2.6533182335697347e-06
[C@H] 8.031827746890485e-06
[C@] 2.8282349830988096e-06
[CH-] 1.2842268359847253e-09
[CH2-] 3.1723461688670795e-06
[N+] 3.8536068132088985e-06
[N-] 5.7884779380401596e-05
[NH+] 1.2953648365510162e-05
[NH-] 8.671574505569879e-06
[NH2+] 0.00016074645100161433
[NH3+] 0.004418304190039635
[O+] 1.3568174361822116e-09
[O-] 0.00

order: 17, time: 0.19, reward: 0.3125, node: C(C)(C)[C@H]1OCCC[C@H]1C(=O)N1CCC2(CC1)OCCN(Sc1ccccn1)C2
order: 18, time: 0.21, reward: 0.0715, node: NC(=S)N1CCCCC1
order: 19, time: 0.21, reward: 0.3930, node: C(C)(C)[C@H](O)c1nc(-c2ccc(Cl)cc2)cs1
order: 20, time: 0.22, reward: 0.1161, node: C(C)[C@H](O)[C@H](CC)CCO
order: 21, time: 0.24, reward: 0.3385, node: NC(=O)COc1ccccc1-c1ccc(Cl)c(Cl)c1
order: 22, time: 0.25, reward: 0.3346, node: C1CN(c2nc(NCc3ccc4c(c3)OCO4)cc(-c3cccnc3)n2)CC1
order: 23, time: 0.26, reward: -0.1070, node: NC(=O)CNC(=O)C(=O)Nc1ccc2oc(=O)oc2c1
order: 24, time: 0.27, reward: 0.1124, node: NC(=O)[C@@H]1Cc2cc(Cl)ccc2O1
order: 25, time: 0.28, reward: 0.3770, node: CCc1ccc(OCCC(=O)Nc2ccccc2C)cc1
order: 26, time: 0.28, reward: 0.2529, node: CC[C@@H](C)C[NH2+][C@@H](C)C(=O)Nc1cc(C)ccc1Cl
<best reward updated> order: 27, time: 0.32, reward: 0.4032, node: Nc1c(Br)ccc2c1N[C@H](c1ccccc1Cl)C[C@@H]2O
order: 28, time: 0.33, reward: 0.2006, node: C(=O)CCCN1CC[NH+]([C@H]2CCCCN2C(=O

In [35]:
generator.analyze()

number of generated nodes: 1000
valid rate: 0.8243021346469622
unique rate: 0.9967159277504105
node_per_sec: 98.5595695839109


In [31]:
lang.save("ported.lang")
model.save("ported")