# Full Pipeline for Hyperparameter Optimization or Single Runs

In [2]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import sys
import os
import jax
from jax import numpy as jnp
import jraph
from typing import Any
import warnings
import numpy as np
import haiku as hk
import pandas as pd
from hyperoptimization_utils import run_grid_hyperparam_pipeline, train_single_model

##############################################
###### HYPERPARAMETERS SrTiO3 ################
OVERWRITE = True
#
E_DIM = 24
R_SWITCH = 0.1
R_CUT = 10.0
DISTANCE_ENCODING_TYPE = "none"
FEATURES = [256,64,32,8,1]
NUM_PASSES = 2
ACTIVATION = "relu"
N_EPOCHS = 800
FORMULA = "SrTiO3"
LR = 1e-3
WEIGHT_DECAY = 1e-5
##########################################
###### HYPERPARAMETERS IL ################
# OVERWRITE = False
# #
# E_DIM = 16
# R_SWITCH = 0.1
# R_CUT = 9.0
# DISTANCE_ENCODING_TYPE = "root"
# FEATURES = [256,64,32,8,1]
# NUM_PASSES = 3
# ACTIVATION = "relu"
# N_EPOCHS = 5000
# LR = 1e-3
# WEIGHT_DECAY = 1e-5
# FORMULA = "C30H120N30O45"

DEFAULT_DICT = {
    "E_DIM" : E_DIM,
    "R_SWITCH" : R_SWITCH,
    "R_CUT" : R_CUT,
    "DISTANCE_ENCODING_TYPE" : DISTANCE_ENCODING_TYPE,
    "FEATURES" : FEATURES,
    "NUM_PASSES" : NUM_PASSES,
    "ACTIVATION" : ACTIVATION,
    "N_EPOCHS" : N_EPOCHS,
    "LR": LR,
    "WEIGHT_DECAY": WEIGHT_DECAY
}

## Train a single model

In [None]:
CURRENT_INDEX = (E_DIM, R_SWITCH, R_CUT, DISTANCE_ENCODING_TYPE, str(FEATURES), NUM_PASSES, ACTIVATION, N_EPOCHS, LR, WEIGHT_DECAY)

model_results, batches = train_single_model(E_DIM = E_DIM,
                                            R_SWITCH=R_SWITCH,
                                            R_CUT=R_CUT,
                                            DISTANCE_ENCODING_TYPE=DISTANCE_ENCODING_TYPE,
                                            FEATURES=FEATURES,
                                            NUM_PASSES=NUM_PASSES,
                                            ACTIVATION=ACTIVATION,
                                            N_EPOCHS=N_EPOCHS,
                                            OVERWRITE = True,
                                            FORMULA = FORMULA,
                                            SAVE_MODEL = True,
                                            SAMPLE_SIZE = None,
                                            LR = LR,
                                            WEIGHT_DECAY=WEIGHT_DECAY)

## Run a complete hyperparameter optimization

In [3]:
# SrTiO3:
    # 1. 
    # "FEATURES":[[8,2,1],[64,32,1],[256,128,1],[256,64,32,8,1],],
    # "NUM_PASSES": [2,3]

OPTIM_DICT = {
    "FEATURES":[[8,2,1],[64,32,1],[256,128,1],[256,64,32,8,1],],
    "NUM_PASSES": [2,3]
    # "R_SWITCH": [4.0],
    # "R_CUT": [10.0,12.0]
    # "LR": [1e-2,1e-3],
    # "WEIGHT_DECAY": [1e-5],
}

run_grid_hyperparam_pipeline(DEFAULT_DICT,OPTIM_DICT,OVERWRITE,FORMULA = FORMULA)

{'E_DIM': [24], 'R_SWITCH': [0.1], 'R_CUT': [10.0], 'DISTANCE_ENCODING_TYPE': ['none'], 'FEATURES': [[8, 2, 1], [64, 32, 1], [256, 128, 1], [256, 64, 32, 8, 1]], 'NUM_PASSES': [2, 3], 'ACTIVATION': ['relu'], 'N_EPOCHS': [800], 'LR': [0.001], 'WEIGHT_DECAY': [1e-05]}


  0%|          | 0/500 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [3]:
### Load model
FORMULA = "C30H120N30O45"
FORMULA = "SrTiO3"

model_results, index = load_model("test", FORMULA)

NameError: name 'load_model' is not defined

# Visualization of results

In [5]:
from infervis import visualize_results, infer
# MODEL_PATH = "models/SrTiO3/(24, 0.1, 10.0, 'none', '[256, 64, 32, 8, 1]', 2, 'relu', 5000, 0.001, 1e-05).pkl"
# DB_PATH = "data/SrTiO3_500.db"
# FORMULA = "SrTiO3"

MODEL_PATH = "models/C30H120N30O45/(16, 0.1, 9.0, 'root', '[64, 32, 16, 8, 4, 2, 1]', 2, 'relu', 5000, 0.001, 1e-05).pkl"
SAVE_NAME = MODEL_PATH.split("/")[-1][:-4]
DB_PATH = "data/IL_charges.db"
FORMULA = "C30H120N30O45"
visualize_results(MODEL_PATH, FORMULA = FORMULA, db_path = DB_PATH, xrange = [-1.5,2.1], save_name = SAVE_NAME)
# infer(MODEL_PATH, formula = FORMULA, db_path = DB_PATH)

Model loaded. Index = (16, 0.1, 9.0, 'root', '[64, 32, 16, 8, 4, 2, 1]', 2, 'relu', 5000, 0.001, 1e-05)


  0%|          | 0/368 [00:00<?, ?it/s]

49/368 were infered.
99/368 were infered.
149/368 were infered.
199/368 were infered.
249/368 were infered.
299/368 were infered.
349/368 were infered.
367/368 were infered.
(368, 225)


[array([-0.4749943 ,  0.3414318 ,  0.32547483,  0.30840367,  0.03941292,
         0.08963099,  0.09464467, -0.45628655,  0.15105863,  0.14928281,
         0.15295546, -0.47029215,  0.3303365 ,  0.3217677 ,  0.32624826,
         0.04861046,  0.08602774,  0.0930823 , -0.49913156,  0.14130868,
         0.14438221,  0.1433258 , -0.45934966,  0.32140973,  0.30920297,
         0.3272708 ,  0.03144807,  0.09045399,  0.08026497, -0.4472653 ,
         0.14373253,  0.14088489,  0.1375393 , -0.4669689 ,  0.34097067,
         0.30874792,  0.33444986,  0.03615157,  0.08549786,  0.10194817,
        -0.45412233,  0.14609201,  0.15460938,  0.13836603, -0.5015221 ,
         0.3356153 ,  0.33574414,  0.33607027,  0.03349363,  0.08699757,
         0.08938645, -0.44077086,  0.14534451,  0.13680524,  0.16587725,
        -0.48924997,  0.322082  ,  0.29335603,  0.3321885 ,  0.03623088,
         0.08238042,  0.09734111, -0.47286275,  0.14730176,  0.14697313,
         0.14866932, -0.4915495 ,  0.32495964,  0.3

In [17]:
visualize_results(model_results, batches, [-1.5,-0.25], save_name = f"{FORMULA}_{CURRENT_INDEX}")

0.010264749
0.010264747256324406
[-1.1592762 -1.1595217 -1.2132499 ...  1.9233024  1.8249884  1.824996 ]


In [18]:
visualize_results(model_results, batches, [1.13,2.1], save_name = f"{FORMULA}_{CURRENT_INDEX}")

0.010264749
0.010264748709542412
[-1.1592762 -1.1595217 -1.2132499 ...  1.9233024  1.8249884  1.824996 ]


In [5]:
visualize_results(model_results, batches, [0.5,0.85], save_name = f"{FORMULA}_{CURRENT_INDEX}")

0.008704517
0.0087045268139808
[-0.4611627   0.31432292  0.3330996  ... -0.44709992 -0.4833104
 -0.43786442]


# Result table sorted by best performance

In [12]:
# FORMULA = "C30H120N30O45"
result_table = pd.read_csv(f"results/result_table_{FORMULA}.csv").set_index(["e_dim","r_switch","r_cut","distance_encoding_type","features","num_passes","activation_fn","n_epochs","lr","wd"]).drop("Unnamed: 0", axis=1)
CURRENT_INDEX = (E_DIM, R_SWITCH, R_CUT, DISTANCE_ENCODING_TYPE, str(FEATURES), NUM_PASSES, ACTIVATION, N_EPOCHS)
result_table.sort_values(by="best_val_rmse")
# result_table["lr"]=1e-3
# result_table["wd"]=0.0
print(FORMULA)
result_table.sort_values(by="best_val_rmse")
# result_table.reset_index().to_csv(f"results/result_table_{FORMULA}.csv")
# indices = result_table.sort_values(by="best_val_rmse").index[:3]



SrTiO3


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0,Unnamed: 9_level_0,time_needed,test_rmse,test_mae,steps,val_rmses,best_val_rmse
e_dim,r_switch,r_cut,distance_encoding_type,features,num_passes,activation_fn,n_epochs,lr,wd,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
24,0.1,10.0,none,"[256, 64, 32, 8, 1]",2,relu,5000,0.001,1e-05,30.54,0.016987,0.010265,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 30, 40,...","[0.30529138445854187, 0.15349964797496796, 0.0...",0.012787
24,0.1,10.0,none,"[64, 32, 32, 1]",2,relu,800,0.001,0.0,2.67,0.016739,0.010688,"[0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 1...","[0.43773484230041504, 0.04919735714793205, 0.0...",0.014628
36,0.1,10.0,none,"[64, 32, 32, 1]",2,relu,800,0.001,0.0,28.34,0.019869,0.013053,"[0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 1...","[1.033949613571167, 0.04289126396179199, 0.034...",0.016279
12,0.1,10.0,none,"[64, 32, 32, 1]",2,relu,800,0.001,0.0,2.72,0.01929,0.012973,"[0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 1...","[0.5771218538284302, 0.05629841610789299, 0.04...",0.016316
36,0.1,10.0,none,"[64, 32, 32, 1]",2,relu,300,0.001,0.0,1.12,0.023337,0.01635,"[0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 1...","[1.0342873334884644, 0.043234094977378845, 0.0...",0.017399
36,0.1,10.0,none,"[64, 32, 32, 1]",5,relu,300,0.001,0.0,3.42,0.022672,0.015433,"[0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 1...","[0.6070751547813416, 0.08307366073131561, 0.06...",0.017983
12,0.1,10.0,none,"[64, 32, 32, 1]",2,relu,300,0.001,0.0,2.7,0.018734,0.012665,"[0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 1...","[0.5771218538284302, 0.056612446904182434, 0.0...",0.018031
36,0.1,10.0,none,"[64, 32, 32, 1]",3,relu,300,0.001,0.0,3.43,0.023029,0.016385,"[0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 1...","[0.8580541014671326, 0.07854203879833221, 0.06...",0.018867
36,0.1,10.0,none,"[64, 32, 32, 1]",3,switch,300,0.001,0.0,1.65,0.023369,0.016988,"[0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 1...","[0.8580860495567322, 0.07682593166828156, 0.07...",0.01888
6,0.1,10.0,none,"[64, 32, 32, 1]",2,relu,300,0.001,0.0,2.81,0.022475,0.014857,"[0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 1...","[0.6277487874031067, 0.057092297822237015, 0.0...",0.019193
