In [None]:
# add path (for local)
import sys
repo_root = "../"
if repo_root not in sys.path:
    sys.path.insert(0, repo_root)

In [None]:
import numpy as np
import torch
from transition import RNNLanguageModel

#ref: https://discuss.pytorch.org/t/loading-tensorflow-grucell-weights-into-pytorch/113216/4

def convert_input_kernel(kernel):
    kernel_r, kernel_z, kernel_h = np.hsplit(kernel, 3)
    return np.concatenate((kernel_z.T, kernel_r.T, kernel_h.T))
    
def convert_recurrent_kernel(kernel):
    kernel_r, kernel_z, kernel_h = np.hsplit(kernel, 3)
    return np.concatenate((kernel_z.T, kernel_r.T, kernel_h.T))

def convert_bias(bias):
    bias = bias.reshape(2, 3, -1) 
    return bias[:, [1, 0, 2], :].reshape((2, -1))

def copy_tf_gru_weights(p_rnn_layer, tf_kernel, tf_recurrent, tf_bias, layer_idx=0):
    suffix = f"_l{layer_idx}"
    with torch.no_grad():
        getattr(p_rnn_layer, f"weight_ih{suffix}").copy_(torch.from_numpy(convert_input_kernel(tf_kernel)))
        getattr(p_rnn_layer, f"weight_hh{suffix}").copy_(torch.from_numpy(convert_recurrent_kernel(tf_recurrent)))
        getattr(p_rnn_layer, f"bias_ih{suffix}").copy_(torch.from_numpy(convert_bias(tf_bias)[0]))
        getattr(p_rnn_layer, f"bias_hh{suffix}").copy_(torch.from_numpy(convert_bias(tf_bias)[1]))

def load_tf_weights_npz_to_gru_model(p_model, npz_path="tf_weights.npz"):
    data = np.load(npz_path)

    weights = [data[f'arr_{i}'] for i in range(len(data.files))]

    with torch.no_grad():
        p_model.embed.weight.copy_(torch.from_numpy(weights[0]))
        copy_tf_gru_weights(p_model.rnn, weights[1], weights[2], weights[3], layer_idx=0)
        copy_tf_gru_weights(p_model.rnn, weights[4], weights[5], weights[6], layer_idx=1)
        p_model.fc.weight.copy_(torch.from_numpy(weights[7].T))
        p_model.fc.bias.copy_(torch.from_numpy(weights[8]))
        
model = RNNLanguageModel(vocab_size=64, embed_size=64, hidden_size=256, num_layers=2, rnn_type="GRU", dropout=0.2, pad_id=0)
load_tf_weights_npz_to_gru_model(model, "tf_weights.npz")

In [None]:
import os
from rdkit import RDLogger
from language import SMILES
from node import MolSentenceNode
from transition import RNNTransition
from filter import ValidityFilter
from generator import RandomGenerator
from utils import add_sep
RDLogger.DisableLog('rdApp.*')

tokens = ['<EOS>', '#', '<BOS>', '(', ')', '-', '/', '1', '2', '3', '4', '5', '6', '7', '8', '=', 'Br', 'C', 'Cl', 'F', 'I', 'N', 'O', 'P', 'S', '[C@@H]', '[C@@]', '[C@H]', '[C@]', '[CH-]', '[CH2-]', '[N+]', '[N-]', '[NH+]', '[NH-]', '[NH2+]', '[NH3+]', '[O+]', '[O-]', '[OH+]', '[P+]', '[P@@H]', '[P@@]', '[P@]', '[PH+]', '[PH2]', '[PH]', '[S+]', '[S-]', '[S@@+]', '[S@@]', '[S@]', '[SH+]', '[n+]', '[n-]', '[nH+]', '[nH]', '[o+]', '[s+]', '\\', 'c', 'n', 'o', 's', "<PAD>", "<UNKNOWN>"]
lang = SMILES.load_tokens_list(tokens)
transition = RNNTransition(lang=lang, model=model, device="cpu", top_p=1)
transition.device = "cpu"
root = MolSentenceNode.node_from_string(lang=lang, string="CCC")
filters = [ValidityFilter()]
generator = RandomGenerator(root=root, transition=transition, filters=filters, info_interval=1)

for a, n, p in transition.transitions(root):
    print(str(n), p)

In [None]:
generator.generate(max_generations=1000)
generator.analyze()

In [None]:
lang.save("ported.lang")
model.save("ported")