In [1]:
import numpy as np
import pandas as pd
import torch
from torch import nn

import selfies as sf
device = "cpu"
from chemistry_vae import VAEDecoder, VAEEncoder


In [2]:
def get_smiles_encodings_for_dataset(file_path):
    
    df = pd.read_csv(file_path)

    smiles_list = np.asanyarray(df.smiles)

    smiles_alphabet = list(set(''.join(smiles_list)))
    smiles_alphabet.append(' ')  # for padding
    largest_smiles_len = len(max(smiles_list, key=len))

    return smiles_list, smiles_alphabet, largest_smiles_len

def get_selfies_encodings_for_dataset(smiles_list, largest_smiles_len):

    print("Largest smiles len", largest_smiles_len)
    print('--> Translating SMILES to SELFIES...')
    selfies_list = list(map(sf.encoder, smiles_list))
    f = open("datasets/0SelectedSMILES_QM9.sf.txt", "w")
    f.write("\n".join(selfies_list))
    all_selfies_symbols = sf.get_alphabet_from_selfies(selfies_list)
    all_selfies_symbols.add('[nop]')
    selfies_alphabet = list(all_selfies_symbols)

    largest_selfies_len = max(sf.len_selfies(s) for s in selfies_list)
    print("Largest selfies len", largest_selfies_len)

    print('Finished translating SMILES to SELFIES.')
    return selfies_list, selfies_alphabet, largest_selfies_len

In [3]:


def selfies_to_hot(selfie, largest_selfie_len, alphabet):
    """Go from a single selfies string to a one-hot encoding.
    """

    symbol_to_int = dict((c, i) for i, c in enumerate(alphabet))

    # pad with [nop]
    selfie += '[nop]' * (largest_selfie_len - sf.len_selfies(selfie))

    # integer encode
    symbol_list = sf.split_selfies(selfie)
    integer_encoded = [symbol_to_int[symbol] for symbol in symbol_list]

    # one hot-encode the integer encoded selfie
    onehot_encoded = list()
    for index in integer_encoded:
        letter = [0] * len(alphabet)
        letter[index] = 1
        onehot_encoded.append(letter)

    return integer_encoded, np.array(onehot_encoded)


def multiple_selfies_to_hot(selfies_list, largest_molecule_len, alphabet):
    """Convert a list of selfies strings to a one-hot encoding
    """

    hot_list = []
    for s in selfies_list:
        _, onehot_encoded = selfies_to_hot(s, largest_molecule_len, alphabet)
        hot_list.append(onehot_encoded)
    return np.array(hot_list)


In [4]:

def load_models(epoch):
    print("loading models")
    out_dir = './saved_models/{}'.format(epoch)
    encoder = torch.load('{}/E'.format(out_dir), map_location=torch.device('cpu'))
    encoder.eval()
    decoder = torch.load('{}/D'.format(out_dir), map_location=torch.device('cpu'))
    decoder.eval()
    return encoder, decoder
  


In [5]:
# get all the inputs
smiles_list, smiles_alphabet, largest_smiles_len = get_smiles_encodings_for_dataset("datasets/0SelectedSMILES_QM9.txt")
selfies_list, selfies_alphabet, largest_selfies_len = get_selfies_encodings_for_dataset(smiles_list, largest_smiles_len)
# load the modedl
vae_encoder, vae_decoder = load_models(4999)


Largest smiles len 22
--> Translating SMILES to SELFIES...
Largest selfies len 21
Finished translating SMILES to SELFIES.
loading models


In [6]:
# Convert all the selfies strings to one-hot encoding
batch_size = 1
data = multiple_selfies_to_hot(selfies_list, largest_selfies_len,
                                       selfies_alphabet)
data = torch.tensor(data, dtype=torch.float).to(device)
# Pick just the first molecule [C]
batch = data[:batch_size]
inp_flat_one_hot = batch.flatten(start_dim=1)
# Encode it to a vector in latent space
latent_points, mus, log_vars = vae_encoder(inp_flat_one_hot)
latent_points = latent_points.unsqueeze(0)


torch.Size([1, 1, 50])

In [7]:
gathered_atoms = []
hidden = vae_decoder.init_hidden(batch_size=batch_size)
# Decode the vector one position at a time
for seq_index in range(batch.shape[1]):
    out_one_hot, hidden = vae_decoder(latent_points, hidden)
    out_one_hot = out_one_hot.flatten().detach()
    soft = nn.Softmax(0)
    out_one_hot = soft(out_one_hot)
    out_index = out_one_hot.argmax(0)
    gathered_atoms.append(out_index.data.cpu().tolist())
print(gathered_atoms)
print("".join([selfies_alphabet[idx] for idx in gathered_atoms]))


[1, 8, 11, 11, 11, 1, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 14, 11, 14]
[Ring1][Branch2][nop][nop][nop][Ring1][nop][nop][nop][nop][nop][nop][nop][nop][nop][nop][nop][nop][=O][nop][=O]
