# **Load model**

In [None]:
import sys
sys.path.insert(0, '../')

import molecule_vae

character_weights = "../weights/CharVAE_L128.hdf5"
character_model = molecule_vae.CharacterModel(character_weights, two_tower=False)

# two_tower_character_weights = "../weights/Two_tower_CharVAE_L128.hdf5"
# two_tower_character_model = molecule_vae.CharacterModel(two_tower_character_weights, two_tower=True)

# grammar_weights = "../weights/GrammarVAE_L128.hdf5"
# grammar_model = molecule_vae.GrammarModel(grammar_weights, two_tower=False)

# two_tower_grammar_weights = "../weights/Two_tower_GrammarVAE_L128.hdf5"
# two_tower_grammar_model = molecule_vae.GrammarModel(two_tower_grammar_weights, two_tower=True)

# **Load data**

In [None]:
import h5py
import numpy as np

grammar_data_path = '../grammar_dataset.h5'
char_data_path = '../char_dataset.h5'
features_data_path = '../features_dataset.h5'

h5f = h5py.File(grammar_data_path, 'r')
grammar_data = h5f['data'][:]
h5f.close()

h5f = h5py.File(char_data_path, 'r')
char_data = h5f['data'][:]
h5f.close()

h5f = h5py.File(features_data_path, 'r')
rdkit_features = h5f['data'][:]
h5f.close()


# Delete NaN values
nan_indices = np.unique(np.argwhere(np.isnan(rdkit_features))[:,0])
rdkit_features = np.delete(rdkit_features, nan_indices, axis=0)
char_data=np.delete(char_data, nan_indices, axis=0) 
grammar_data=np.delete(grammar_data, nan_indices, axis=0) 

# Select testing data
char_data=char_data[:5000]
grammar_data=grammar_data[:5000]
rdkit_features=rdkit_features[:5000]

# **Calculate reconstruction accuracy**

In [None]:
from tqdm.notebook import tqdm

charlist = ['C', '(', ')', 'c', '1', '2', 'o', '=', 'O', 'N', '3', 'F', '[',
            '@', 'H', ']', 'n', '-', '#', 'S', 'l', '+', 's', 'B', 'r', '/',
            '4', '\\', '5', '6', '7', 'I', 'P', '8', ' ']
_char_index = {}
for ix, char in enumerate(charlist):
    _char_index[char] = ix

def encode(smiles):
    """ Encode a list of smiles strings into the latent space """
    indices = [np.array([_char_index[c] for c in entry], dtype=int) for entry in smiles]
    one_hot = np.zeros((len(indices), 120, len(charlist)), dtype=np.float32)
    for i in range(len(indices)):
        num_productions = len(indices[i])
        one_hot[i][np.arange(num_productions),indices[i]] = 1.
        one_hot[i][np.arange(num_productions, 120),-1] = 1.
    return one_hot

def reconstruction(model, model_type, smiles_data, features_data=None):
    avg_reconstructions = []
    for i in tqdm(range(smiles_data.shape[0])):
        num_correct = 0
        structure_one_hot = smiles_data[i]
        repeated_one_hot=np.array([structure_one_hot for i in range(10)])

        if model._two_tower:
            features = features_data[i]
            repeated_features = np.array([features for i in range(10)])
            encodings = model.vae.encoder.predict([repeated_one_hot, repeated_features])
            repeat_encodings = np.tile(encodings, (10,1))
            out, features = model.vae.decoder.predict(repeat_encodings)
        else:
            encodings = model.vae.encoder.predict(repeated_one_hot)
            repeat_encodings = np.tile(encodings, (10,1))
            out = model.vae.decoder.predict(repeat_encodings)

        if model_type == 'Grammar':
            X_hat = model._sample_using_masks(out)
        elif model_type == 'Character':
            noise = np.random.gumbel(size=out.shape)
            sampled_chars = np.argmax(np.log(out) + noise, axis=-1)
            char_matrix = np.array(charlist)[np.array(sampled_chars, dtype=int)]
            s=[''.join(ch).strip() for ch in char_matrix]
            X_hat=encode(s)
        else:
            raise ValueError('Invalid model_type. Must be either \'Grammar\' or \'Character\'')

        for structure in X_hat:
            if np.array_equal(structure, structure_one_hot):
                num_correct+=1

        avg_reconstructions.append(num_correct/100)
        print('Average accuracy: {} +/- {}'.format(np.average(avg_reconstructions), np.std(avg_reconstructions)))

    return np.average(avg_reconstructions), np.std(avg_reconstructions)

In [None]:
reconstruction (character_model, 'Character', char_data)
# reconstruction (two_tower_character_model, 'Character', char_data, rdkit_features)

# reconstruction (grammar_model, 'Grammar', grammar_data)
# reconstruction (two_tower_grammar_model, 'Grammar', grammar_data, rdkit_features)