In [1]:
# change working directory
%cd ..

/home/OSDA_Generator


In [2]:
# imports
from ddc_pub import ddc_v3 as ddc
import pickle
import numpy as np
import os
import tensorflow as tf
from rdkit import Chem
import pandas as pd

In [3]:
# load in zeolite normalizer and data
zeo_norm = pickle.load(open('models/zeolite_normalizer.pkl', 'rb'))
data = pd.read_excel('data/Jensen_et_al_CentralScience_OSDA_Zeolite_data.xlsx', engine='openpyxl')

In [8]:
# Get zeolite feature vector
zeolite = 'CHA'
zeolite_vector = []
for i, row in data.iterrows():
    if row['Code'] == zeolite:
        features = [row['FD'], row['max_ring_size'], row['channel_dim'], row['inc_vol'], row['accvol'], row['maxarea'], row['minarea']]
        if any([pd.isnull(x) for x in features]):
            continue
        else:
            zeolite_vector = features
if len(zeolite_vector) != 7:
    print('Problem', zeolite, 'not in data, Please select a zeolite in the data')
else:
    zeolite_vector = list(zeo_norm.transform([zeolite_vector])[0])
    print(zeolite_vector)

[0.5597014925373133, 0.08333333333333331, 1.0, 0.09512669725261097, 0.3793103448275862, 0.08952817905635811, 0.08952817905635811]


In [9]:
# Get synthesis vector
def featurize_synthesis(syns):
    seeds = ['seed', 'SAPO-56 seeds', 'SSZ-57', 'FAU', 'seeded with magadiite', 'seeds']
    solvents = ['ethylene glycol', 'hexanol', '2-propanol', 'triethylene glycol', 'triglycol',
                'polyethylene glycol', 'n-hexanol', 'glycol', 'propane-1,3-diol', 'butanol', 
                'glycerol', 'isobutylamine', 'tetraethylene glycol', '1-hexanol', 
               'sec-butanol', 'iso-butanol', 'ethylene glycol monomethyl ether', 'ethanol']
    acids = ['H2SO4', 'acetic acid', 'oxalic acid', 'succinic acid', 'arsenic acid', 'HNO3', 'HCl',
            'SO4']
    frameworks = ['Co', 'Mn', 'Cu', 'Zn', 'Cd', 'Cr', 'V', 'Ce', 'Nd', 'Sn', 'Zr', 'Ni',
                  'S', 'Sm', 'Dy', 'Y', 'La', 'Gd', 'In', 'Nb', 'Te', 'As', 'Hf', 'W',
                 'Se']
    common_frameworks = ['Si', 'Al', 'P', 'Ge', 'B', 'Ti', 'Ga', 'Fe']
    cations = ['Mg', 'Rb', 'Li', 'Cs', 'Sr', 'Ba', 'Be', 'Ca']
    common_cations = ['Na', 'K']
    bad = ['pictures', 'need access', 'also called azepane', 'SMILES code']
    if not syns:
        return None
    syn_vector = []
    for c in common_frameworks:
        if c in syns:
            syn_vector.append(1)
        else:
            syn_vector.append(0)
    for c in common_cations:
        if c in syns:
            syn_vector.append(1)
        else:
            syn_vector.append(0)
    if 'F' in syns:
        syn_vector.append(1)
    else:
        syn_vector.append(0)
    frame, cat, seed, solv, acid, oth = 0,0,0,0,0,0
    for s in syns:
        if s in frameworks:
            frame = 1
        elif s in cations:
            cat = 1
        elif s in seeds:
            seed = 1
        elif s in solvents:
            solv = 1
        elif s in acids:
            acid = 1
        elif s.count(' ') < 2 and s not in bad and len(s) > 2:
            oth = 1
    syn_vector.extend([frame, cat, seed, solv, acid, oth])
    return syn_vector

In [10]:
# User defined synthesis components
# examples include ['Si', 'Al', 'Na'] (zeolite with sodium),['Si', 'Al', 'P'](SAPO)
# ['Si', 'Ge', 'F'] (common germanium zeotype)
synthesis = ['Si', 'Al', 'Na']
synthesis_vector = featurize_synthesis(synthesis)
if synthesis_vector is None:
    print('Problem', synthesis, 'is not valid, Please select synthesis components from the data set')
else:
    print(len(synthesis_vector))
    print(synthesis_vector)

17
[1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]


In [11]:
# combine zeolite and synthesis vectors
model_input = zeolite_vector+synthesis_vector
model_input = np.reshape(np.array(model_input), (1,24))
print(model_input.shape)

(1, 24)


In [13]:
# load in the pre-trained model
model = ddc.DDC(model_name='models/OSDA_model_full')

Initializing model in test mode.
Loading model.
'mol_to_latent_model' not found, setting to None.
Loading finished in 2 seconds.
Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
Latent_Input (InputLayer)       [(None, 24)]         0                                            
__________________________________________________________________________________________________
Decoder_Inputs (InputLayer)     [(None, 132, 31)]    0                                            
__________________________________________________________________________________________________
latent_to_states_model (Model)  [(None, 256), (None, 44544       Latent_Input[0][0]               
__________________________________________________________________________________________________
batch_model (Model)             (None, 132, 31)      1359647  

In [14]:
# Generate the OSDAs
number_to_generate = 100
generated_smiles, neg_log_like = [],[]
for i in range(number_to_generate):
    if i%(round(number_to_generate)*0.1)==0: 
        print(i)
    smiles, nll = model.predict(latent=model_input, temp=1)
    generated_smiles.append(smiles)
    neg_log_like.append(nll)
print(len(generated_smiles), len(neg_log_like))

0
10
20
30
40
50
60
70
80
90
100 100


In [17]:
# save all generated
with open('data/generated_osdas.pickle', 'wb') as f:
    pickle.dump(generated_smiles, f)
with open('data/generated_osdas_nlls.pickle', 'wb') as f:
    pickle.dump(neg_log_like, f)

In [20]:
# Get unique smiles
unique_generated_smiles, unique_neg_log_like = [],[]
for smile, nll in zip(generated_smiles, neg_log_like):
    canonical = Chem.MolToSmiles(Chem.MolFromSmiles(smile), canonical=True)
    if canonical not in unique_generated_smiles:
        unique_generated_smiles.append(canonical)
        unique_neg_log_like.append(canonical)
print(len(unique_generated_smiles), len(unique_neg_log_like))

21 21


In [21]:
# save all unique_generated
with open('data/generated_osdas_unique.pickle', 'wb') as f:
    pickle.dump(unique_generated_smiles, f)
with open('data/generated_osdas_nlls_unique.pickle', 'wb') as f:
    pickle.dump(unique_neg_log_like, f)