In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
os.chdir('..')

In [13]:
import pygad
import torch
import pickle
import json
import pandas as pd
import numpy as np
from modules.functions import data_generator_vec, check_cuda
from sklearn.neighbors import KernelDensity


In [15]:
model_path = 'saved_models'
if os.path.exists(os.path.join(model_path, 'GAN_generator.pt')):
    generator = torch.jit.load(os.path.join(model_path, 'GAN_generator.pt'), map_location='cpu')
else:
    print('train model first!')

el_list_loc = 'misc/element_order_uts.pkl'
with open(el_list_loc,'rb') as fid:
    el_list = pickle.load(fid)

with open('misc/scaler_y.pkl','rb') as fid:
    uts_scaler = pickle.load(fid)

with open('misc/starting_comp.json','r') as fid:
    json_dict = json.load(fid)

In [7]:
dataset = pd.read_csv('dataset/synthetic_dataset.csv', index_col = 0)
to_train_df = dataset[dataset['uts1200C']>0].copy()
comp_dset = data_generator_vec(to_train_df['Composition'], el_list=el_list)
vec_comps = comp_dset.real_data
y = to_train_df['uts1200C'].values.reshape(-1,1).astype('float32')
y_scaled = uts_scaler.transform(y)
kde = KernelDensity(kernel='gaussian',bandwidth=0.5)
v = kde.fit(y_scaled)

def prop_sampler(n_samples):
    return kde.sample(n_samples).astype('float32')

def noise_sampler(N, z_dim):
    return np.random.normal(size=[N, z_dim]).astype('float32')

latent_dim = 4
cuda = check_cuda()

In [None]:
def ga_inputs(N, z_dim = latent_dim):
    prop = torch.from_numpy(prop_sampler(N))
    noise = torch.from_numpy(noise_sampler(N,z_dim))
    return(torch.cat([noise,prop], dim=-1))

functional_inputs = ga_inputs(1)
desired_output = json_dict['start_comp']

In [79]:
def fitness_func(solution, solution_idx):
    functional_input = solution
    output = generator(functional_inputs[:,:-1].reshape(-1,latent_dim),functional_inputs[:,-1].reshape(-1,1))
    output = output.to('cpu').detach().numpy()
    fitness = 1/np.linalg.norm(output - desired_output)
    return fitness

In [107]:
fitness_function = fitness_func

num_generations = 100
num_parents_mating = 2

sol_per_pop = 20
num_genes = len(functional_inputs)

initial_population = ga_inputs(10)

init_range_low = -2
init_range_high = 5

parent_selection_type = "sss"
keep_parents = 1

crossover_type = "single_point"

mutation_type = "adaptive"
mutation_probability = (0.25,0.1)
mutation_percent_genes = (10,1)

In [108]:
ga_instance = pygad.GA(num_generations=num_generations,
                       num_parents_mating=num_parents_mating,
                       fitness_func=fitness_function,
                       sol_per_pop=sol_per_pop,
                       num_genes=num_genes,
                       init_range_low=init_range_low,
                       init_range_high=init_range_high,
                       parent_selection_type=parent_selection_type,
                       keep_parents=keep_parents,
                       crossover_type=crossover_type,
                       mutation_type=mutation_type,
                       mutation_percent_genes=mutation_percent_genes,
                       mutation_probability=mutation_probability,
                       initial_population=initial_population)
ga_instance.run()

If you do not want to mutate any gene, please set mutation_type=None.


In [109]:
solution, solution_fitness, solution_idx = ga_instance.best_solution()

In [110]:
out = generator(torch.from_numpy(solution[:-1].reshape(-1,latent_dim).astype('float32')),torch.from_numpy(solution[-1].reshape(-1,1).astype('float32')))
out

tensor([[6.2857e-19, 3.4425e-18, 1.4116e-17, 1.5785e-01, 1.0405e-17, 4.5400e-18,
         1.3517e-01, 1.3711e-18, 5.3952e-02, 1.0567e-01, 1.9060e-18, 2.8429e-17,
         2.0204e-18, 2.3981e-02, 9.2355e-02, 1.8019e-02, 1.0077e-01, 3.1223e-01]],
       grad_fn=<DifferentiableGraphBackward>)

In [67]:
np.array(json_dict['start_comp'])

array([0.        , 0.        , 0.        , 0.19628505, 0.        ,
       0.        , 0.17855167, 0.        , 0.0589374 , 0.06548531,
       0.        , 0.        , 0.        , 0.        , 0.07380948,
       0.        , 0.07025729, 0.35667381])