In [None]:
import warnings
warnings.filterwarnings('ignore')
import numpy as np
from keras import backend as K
from keras.models import load_model
import tensorflow as tf
import pandas as pd

# Load Models

In [None]:
n_z = 32
def vae_loss(y_true, y_pred):
    
    # Reconstruction Loss
    y_true_reshaped=tf.reshape(y_true,[-1,89,11])
    recon = 10 * K.mean(K.sum(K.categorical_crossentropy(y_true_reshaped, y_pred, axis=2),axis=1))

    # KL Divergence Loss
    kl = 0.05 * K.sum(K.exp(log_sigma) + K.square(mu) - 1. - log_sigma, axis=1)
    
    #Total loss
    return recon + kl

In [None]:
decoder = load_model('../CVAE/ding_decoder_best.h5',custom_objects={'n_z': n_z,'vae_loss':vae_loss})

In [None]:
def sample_z_batch(batch_size):
    return K.random_normal(shape=(batch_size, n_z), mean=0., stddev=1.)

In [None]:
pred_delta_e = load_model('../Predictors/delta_e_best_model.h5')
pred_energy_pa = load_model('../Predictors/energy_pa_best_model.h5')
pred_volume_pa =load_model('../Predictors/volume_pa_best_model.h5')

In [None]:
element_list = ['H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na', 'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar', 'K', 'Ca', 'Sc', 'Ti', 'V', 'Cr','Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr', 'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe', 'Cs', 'Ba', 'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er', 'Tm', 'Yb', 'Lu', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Pb', 'Bi', 'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu']

# Check continuity of latent space

In [None]:
property_desired = [-2, 12, -6.5]

In [None]:
def decode_mol_encoding(z):
    mol_enc = decoder.predict([z.reshape(1,-1),np.array(property_desired).reshape(1,-1)]).reshape(89,11)
    mol_enc[mol_enc<0.5]=0
    delta_e=pred_delta_e.predict(mol_enc.reshape(1,-1)).flatten()[0]
    
    volume_pa=pred_volume_pa.predict(mol_enc.reshape(1,-1)).flatten()[0]
    energy_pa=pred_energy_pa.predict(mol_enc.reshape(1,-1)).flatten()[0]
    properties=np.array([delta_e,volume_pa,energy_pa])

    output_gen=np.zeros((89,))
    for j,atom in enumerate(mol_enc):
        if atom.max()!=0:
            output_gen[j] = atom.argmax()
    element_dict = {}
    for i in range(output_gen.size):
        if output_gen[i]:
            element_dict[element_list[i]] = output_gen[i]
    mol_string = ""
    for elem in sorted(element_dict):
            mol_string += str(elem) + str(int(element_dict[elem]))  
    return mol_string, properties

## Generate materials for the initial and final point

In [None]:
z = sample_z_batch(2)
z = K.eval(z)
for idx in range(2):
    print(decode_mol_encoding(z[idx,:]))

In [None]:
z_initial = z[0,:]
z_final = z[1,:]

## Walk between the points

In [None]:
z_diff = z_final - z_initial
molecule_names=[]
molecules_delta_e=[]
molecules_volume_pa=[]
molecules_energy_pa=[]
step_length = 0.01
for i in range(100):    
    z_intermediate = z_initial + z_diff*step_length*i
    mol_string, current_properties = decode_mol_encoding(z_intermediate)
    
    if mol_string not in molecule_names:
        molecule_names.append(mol_string)
        molecules_delta_e.append(current_properties[0])
        molecules_volume_pa.append(current_properties[1])
        molecules_energy_pa.append(current_properties[2])

mol_data={'delta_e':molecules_delta_e,'volume_pa':molecules_volume_pa,'energy_pa':molecules_energy_pa}
mol_trans=pd.DataFrame.from_dict(data=mol_data,orient='index',columns=molecule_names)

In [None]:
mol_trans