# Full Pipeline for Hyperparameter Optimization or Single Runs

In [1]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import sys
import os
import jax
from jax import numpy as jnp
import jraph
from typing import Any
import warnings
import numpy as np
import haiku as hk
import pandas as pd
from hyperoptimization_utils import run_grid_hyperparam_pipeline, train_single_model

##############################################
###### HYPERPARAMETERS SrTiO3 ################
OVERWRITE = True
#
E_DIM = 24
R_SWITCH = 0.1
R_CUT = 4.0
DISTANCE_ENCODING_TYPE = "none"
FEATURES = [128,64,32,8,1]
NUM_PASSES = 3
ACTIVATION = "relu"
N_EPOCHS = 2000
FORMULA = "SrTiO3"
LR = 1e-3
WEIGHT_DECAY = 1e-4
##########################################
###### HYPERPARAMETERS IL ################
# OVERWRITE = False
# #
# E_DIM = 16
# R_SWITCH = 0.1
# R_CUT = 5.0
# DISTANCE_ENCODING_TYPE = "root"
# FEATURES = [256,64,32,8,1]
# NUM_PASSES = 3
# ACTIVATION = "relu"
# N_EPOCHS = 5000
# LR = 1e-3
# WEIGHT_DECAY = 1e-5
# FORMULA = "C30H120N30O45"

DEFAULT_DICT = {
    "E_DIM" : E_DIM,
    "R_SWITCH" : R_SWITCH,
    "R_CUT" : R_CUT,
    "DISTANCE_ENCODING_TYPE" : DISTANCE_ENCODING_TYPE,
    "FEATURES" : FEATURES,
    "NUM_PASSES" : NUM_PASSES,
    "ACTIVATION" : ACTIVATION,
    "N_EPOCHS" : N_EPOCHS,
    "LR": LR,
    "WEIGHT_DECAY": WEIGHT_DECAY
}

## Train a single model

In [2]:
CURRENT_INDEX = (E_DIM, R_SWITCH, R_CUT, DISTANCE_ENCODING_TYPE, str(FEATURES), NUM_PASSES, ACTIVATION, N_EPOCHS, LR, WEIGHT_DECAY)

model_results, batches = train_single_model(E_DIM = E_DIM,
                                            R_SWITCH=R_SWITCH,
                                            R_CUT=R_CUT,
                                            DISTANCE_ENCODING_TYPE=DISTANCE_ENCODING_TYPE,
                                            FEATURES=FEATURES,
                                            NUM_PASSES=NUM_PASSES,
                                            ACTIVATION=ACTIVATION,
                                            N_EPOCHS=N_EPOCHS,
                                            OVERWRITE = True,
                                            FORMULA = FORMULA,
                                            SAVE_MODEL = True,
                                            SAMPLE_SIZE = None,
                                            LR = LR,
                                            WEIGHT_DECAY=WEIGHT_DECAY)

  0%|          | 0/500 [00:00<?, ?it/s]

Start training.
Current Parameter-Setting: (24, 0.1, 4.0, 'none', '[128, 64, 32, 8, 1]', 3, 'relu', 2000, 0.001, 1e-05)
Epoch: 0 -  Train-RMSE: 0.23939  -  Val-RMSE: 0.23699
Epoch: 1 -  Train-RMSE: 0.16515  -  Val-RMSE: 0.16053
Epoch: 2 -  Train-RMSE: 0.127  -  Val-RMSE: 0.12429
Epoch: 3 -  Train-RMSE: 0.09205  -  Val-RMSE: 0.08703
Epoch: 4 -  Train-RMSE: 0.07615  -  Val-RMSE: 0.07231
Epoch: 5 -  Train-RMSE: 0.06762  -  Val-RMSE: 0.06446
Epoch: 6 -  Train-RMSE: 0.05733  -  Val-RMSE: 0.05462
Epoch: 7 -  Train-RMSE: 0.0534  -  Val-RMSE: 0.05098
Epoch: 8 -  Train-RMSE: 0.05917  -  Val-RMSE: 0.05782
Epoch: 9 -  Train-RMSE: 0.04382  -  Val-RMSE: 0.04209
Epoch: 10 -  Train-RMSE: 0.04062  -  Val-RMSE: 0.03907
Epoch: 20 -  Train-RMSE: 0.03296  -  Val-RMSE: 0.03236
Epoch: 30 -  Train-RMSE: 0.03077  -  Val-RMSE: 0.03083
Epoch: 40 -  Train-RMSE: 0.02393  -  Val-RMSE: 0.02476
Epoch: 50 -  Train-RMSE: 0.02047  -  Val-RMSE: 0.02147
Epoch: 60 -  Train-RMSE: 0.02117  -  Val-RMSE: 0.02213
Epoch: 70 -  

## Run a complete hyperparameter optimization

In [3]:
# SrTiO3:
    # 1. 
    # "FEATURES":[[8,2,1],[64,32,1],[256,128,1],[256,64,32,8,1],],
    # "NUM_PASSES": [2,3]

OPTIM_DICT = {
    "FEATURES":[[64, 32, 32, 1]],
    "NUM_PASSES": [2],
    # "R_SWITCH": [0.1],
    "R_CUT": [6.0, 7.0],
    # "LR": [1e-2,1e-3],
    # "WEIGHT_DECAY": [1e-5],
}

run_grid_hyperparam_pipeline(DEFAULT_DICT,OPTIM_DICT,OVERWRITE,FORMULA = FORMULA)

{'E_DIM': [24], 'R_SWITCH': [0.1], 'R_CUT': [6.0, 7.0], 'DISTANCE_ENCODING_TYPE': ['none'], 'FEATURES': [[64, 32, 32, 1]], 'NUM_PASSES': [2], 'ACTIVATION': ['relu'], 'N_EPOCHS': [800], 'LR': [0.001], 'WEIGHT_DECAY': [0]}


  0%|          | 0/500 [00:00<?, ?it/s]

Start training.
Current Parameter-Setting: (24, 0.1, 6.0, 'none', '[64, 32, 32, 1]', 2, 'relu', 800, 0.001, 0)


2022-10-26 22:46:36.615322: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:479] Allocator (GPU_0_bfc) ran out of memory trying to allocate 3.77GiB (rounded to 4050969088)requested by op 
2022-10-26 22:46:36.615614: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:491] ****___********************************x**********************______________________________________
2022-10-26 22:46:36.616081: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2124] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 4050969056 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:  218.26MiB
              constant allocation:         4B
        maybe_live_out allocation:    1.05MiB
     preallocated temp allocation:    3.77GiB
  preallocated temp fragmentation:       684B (0.00%)
                 total allocation:    3.99GiB
Peak buffe

RuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 4050969056 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:  218.26MiB
              constant allocation:         4B
        maybe_live_out allocation:    1.05MiB
     preallocated temp allocation:    3.77GiB
  preallocated temp fragmentation:       684B (0.00%)
                 total allocation:    3.99GiB
Peak buffers:
	Buffer 1:
		Size: 620.37MiB
		Operator: op_name="jit(update)/jit(main)/jit(jvp(rmse_loss))/jit(jit_jvp(rmse_loss))/concatenate[dimension=1]" source_file="/home/magnuswagner/madsen/epnn/pipeline_utils.py" source_line=162
		XLA Label: fusion
		Shape: f32[204048,797]
		==========================

	Buffer 2:
		Size: 620.37MiB
		Operator: op_name="jit(update)/jit(main)/jit(jvp(rmse_loss))/jit(jit_jvp(rmse_loss))/concatenate[dimension=1]" source_file="/home/magnuswagner/madsen/epnn/pipeline_utils.py" source_line=162
		XLA Label: fusion
		Shape: f32[204048,797]
		==========================

	Buffer 3:
		Size: 620.37MiB
		Operator: op_name="jit(update)/jit(main)/jit(transpose(jvp(rmse_loss)))/jit(jit_transpose(jvp(rmse_loss)))/dot_general[dimension_numbers=(((1,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/magnuswagner/miniconda3/lib/python3.8/site-packages/haiku/_src/basic.py" source_line=178
		XLA Label: custom-call
		Shape: f32[204048,797]
		==========================

	Buffer 4:
		Size: 417.99MiB
		Operator: op_name="jit(update)/jit(main)/jit(jvp(rmse_loss))/jit(jit_jvp(rmse_loss))/concatenate[dimension=1]" source_file="/home/magnuswagner/madsen/epnn/pipeline_utils.py" source_line=161
		XLA Label: fusion
		Shape: f32[204048,537]
		==========================

	Buffer 5:
		Size: 417.99MiB
		Operator: op_name="jit(update)/jit(main)/jit(jvp(rmse_loss))/jit(jit_jvp(rmse_loss))/concatenate[dimension=1]" source_file="/home/magnuswagner/madsen/epnn/pipeline_utils.py" source_line=161
		XLA Label: fusion
		Shape: f32[204048,537]
		==========================

	Buffer 6:
		Size: 417.99MiB
		Operator: op_name="jit(update)/jit(main)/jit(transpose(jvp(rmse_loss)))/jit(jit_transpose(jvp(rmse_loss)))/slice[start_indices=(0, 260) limit_indices=(204048, 797) strides=None]" source_file="/home/magnuswagner/madsen/epnn/pipeline_utils.py" source_line=162
		XLA Label: slice
		Shape: f32[204048,537]
		==========================

	Buffer 7:
		Size: 215.61MiB
		Entry Parameter Subshape: f32[204048,277]
		==========================

	Buffer 8:
		Size: 49.82MiB
		Operator: op_name="jit(update)/jit(main)/jit(jvp(rmse_loss))/jit(jit_jvp(rmse_loss))/max" source_file="/home/magnuswagner/miniconda3/lib/python3.8/site-packages/haiku/_src/basic.py" source_line=125
		XLA Label: fusion
		Shape: f32[204048,64]
		==========================

	Buffer 9:
		Size: 49.82MiB
		Operator: op_name="jit(update)/jit(main)/jit(jvp(rmse_loss))/jit(jit_jvp(rmse_loss))/max" source_file="/home/magnuswagner/miniconda3/lib/python3.8/site-packages/haiku/_src/basic.py" source_line=125
		XLA Label: fusion
		Shape: f32[204048,64]
		==========================

	Buffer 10:
		Size: 49.82MiB
		Operator: op_name="jit(update)/jit(main)/jit(jvp(rmse_loss))/jit(jit_jvp(rmse_loss))/max" source_file="/home/magnuswagner/miniconda3/lib/python3.8/site-packages/haiku/_src/basic.py" source_line=125
		XLA Label: fusion
		Shape: f32[204048,64]
		==========================

	Buffer 11:
		Size: 49.82MiB
		Operator: op_name="jit(update)/jit(main)/jit(jvp(rmse_loss))/jit(jit_jvp(rmse_loss))/max" source_file="/home/magnuswagner/miniconda3/lib/python3.8/site-packages/haiku/_src/basic.py" source_line=125
		XLA Label: fusion
		Shape: f32[204048,64]
		==========================

	Buffer 12:
		Size: 49.82MiB
		Operator: op_name="jit(update)/jit(main)/jit(transpose(jvp(rmse_loss)))/jit(jit_transpose(jvp(rmse_loss)))/select_n" source_file="/home/magnuswagner/miniconda3/lib/python3.8/site-packages/haiku/_src/basic.py" source_line=125
		XLA Label: fusion
		Shape: f32[204048,64]
		==========================

	Buffer 13:
		Size: 49.82MiB
		Operator: op_name="jit(update)/jit(main)/jit(transpose(jvp(rmse_loss)))/jit(jit_transpose(jvp(rmse_loss)))/select_n" source_file="/home/magnuswagner/miniconda3/lib/python3.8/site-packages/haiku/_src/basic.py" source_line=125
		XLA Label: fusion
		Shape: f32[204048,64]
		==========================

	Buffer 14:
		Size: 49.82MiB
		Operator: op_name="jit(update)/jit(main)/jit(jvp(rmse_loss))/jit(jit_jvp(rmse_loss))/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/magnuswagner/miniconda3/lib/python3.8/site-packages/haiku/_src/basic.py" source_line=178
		XLA Label: custom-call
		Shape: f32[204048,64]
		==========================

	Buffer 15:
		Size: 49.82MiB
		Operator: op_name="jit(update)/jit(main)/jit(jvp(rmse_loss))/jit(jit_jvp(rmse_loss))/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/magnuswagner/miniconda3/lib/python3.8/site-packages/haiku/_src/basic.py" source_line=178
		XLA Label: custom-call
		Shape: f32[204048,64]
		==========================



# Visualization of results

In [4]:
from infervis import visualize_results, infer

# MODEL_PATH = "models/SrTiO3/(24, 0.1, 10.0, 'none', '[256, 64, 32, 8, 1]', 2, 'relu', 5000, 0.001, 1e-05).pkl"
MODEL_PATH = "models/SrTiO3/(24, 0.1, 4.0, 'none', '[128, 64, 32, 8, 1]', 3, 'relu', 2000, 0.001, 1e-05).pkl"
DB_PATH = "data/SrTiO3_500.db"
FORMULA = "SrTiO3"

# MODEL_PATH = "models/C30H120N30O45/(16, 0.1, 9.0, 'root', '[64, 32, 16, 8, 4, 2, 1]', 2, 'relu', 5000, 0.001, 1e-05).pkl"
# DB_PATH = "data/IL_charges.db"
# FORMULA = "C30H120N30O45"


SAVE_NAME = MODEL_PATH.split("/")[-1][:-4]
visualize_results(MODEL_PATH, FORMULA = FORMULA, db_path = DB_PATH, xrange = [-1.5,2.1], save_name = SAVE_NAME)
# infer(MODEL_PATH, formula = FORMULA, db_path = DB_PATH)

Model loaded. Index = (24, 0.1, 4.0, 'none', '[128, 64, 32, 8, 1]', 3, 'relu', 2000, 0.001, 1e-05)


  0%|          | 0/500 [00:00<?, ?it/s]

MAE: 0.051878028


In [17]:
visualize_results(model_results, batches, [-1.5,-0.25], save_name = f"{FORMULA}_{CURRENT_INDEX}")

0.010264749
0.010264747256324406
[-1.1592762 -1.1595217 -1.2132499 ...  1.9233024  1.8249884  1.824996 ]


In [18]:
visualize_results(model_results, batches, [1.13,2.1], save_name = f"{FORMULA}_{CURRENT_INDEX}")

0.010264749
0.010264748709542412
[-1.1592762 -1.1595217 -1.2132499 ...  1.9233024  1.8249884  1.824996 ]


In [5]:
visualize_results(model_results, batches, [0.5,0.85], save_name = f"{FORMULA}_{CURRENT_INDEX}")

0.008704517
0.0087045268139808
[-0.4611627   0.31432292  0.3330996  ... -0.44709992 -0.4833104
 -0.43786442]


# Result table sorted by best performance

In [9]:
# FORMULA = "C30H120N30O45"
result_table = pd.read_csv(f"results/result_table_{FORMULA}.csv").set_index(["e_dim","r_switch","r_cut","distance_encoding_type","features","num_passes","activation_fn","n_epochs","lr","wd"]).drop("Unnamed: 0", axis=1)
CURRENT_INDEX = (E_DIM, R_SWITCH, R_CUT, DISTANCE_ENCODING_TYPE, str(FEATURES), NUM_PASSES, ACTIVATION, N_EPOCHS)
result_table.sort_values(by="best_val_rmse")
# result_table["lr"]=1e-3
# result_table["wd"]=0.0
print(FORMULA)
result_table.sort_values(by="test_mae")
# result_table.reset_index().to_csv(f"results/result_table_{FORMULA}.csv")
# indices = result_table.sort_values(by="best_val_rmse").index[:3]



SrTiO3


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0,Unnamed: 9_level_0,time_needed,test_rmse,test_mae,steps,val_rmses,best_val_rmse
e_dim,r_switch,r_cut,distance_encoding_type,features,num_passes,activation_fn,n_epochs,lr,wd,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
24,0.1,10.0,none,"[256, 64, 32, 8, 1]",2,relu,5000,0.001,1e-05,30.54,0.016987,0.010265,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 30, 40,...","[0.30529138445854187, 0.15349964797496796, 0.0...",0.012787
24,0.1,10.0,none,"[64, 32, 32, 1]",2,relu,800,0.001,0.0,2.67,0.016739,0.010688,"[0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 1...","[0.43773484230041504, 0.04919735714793205, 0.0...",0.014628
24,0.1,10.0,none,"[64, 32, 1]",2,relu,800,0.001,1e-05,4.37,0.018193,0.012453,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 30, 40,...","[0.5964868664741516, 0.2387649118900299, 0.160...",0.014506
12,0.1,10.0,none,"[64, 32, 32, 1]",2,relu,300,0.001,0.0,2.7,0.018734,0.012665,"[0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 1...","[0.5771218538284302, 0.056612446904182434, 0.0...",0.018031
12,0.1,10.0,none,"[64, 32, 32, 1]",2,relu,800,0.001,0.0,2.72,0.01929,0.012973,"[0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 1...","[0.5771218538284302, 0.05629841610789299, 0.04...",0.016316
36,0.1,10.0,none,"[64, 32, 32, 1]",2,relu,800,0.001,0.0,28.34,0.019869,0.013053,"[0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 1...","[1.033949613571167, 0.04289126396179199, 0.034...",0.016279
24,0.1,10.0,none,"[8, 2, 1]",2,relu,800,0.001,1e-05,3.8,0.021575,0.013381,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 30, 40,...","[1.2241326570510864, 0.7975797057151794, 0.479...",0.019551
24,0.1,10.0,none,"[256, 128, 1]",2,relu,800,0.001,1e-05,7.01,0.018532,0.01405,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 30, 40,...","[0.8956247568130493, 0.3030022978782654, 0.159...",0.014124
36,0.1,10.0,none,"[64, 32, 16, 8, 1]",3,relu,300,0.001,0.0,1.77,0.022455,0.014212,"[0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 1...","[0.6282362341880798, 0.10450755804777145, 0.06...",0.021596
24,0.1,10.0,none,"[256, 64, 32, 8, 1]",2,relu,800,0.001,1e-05,5.21,0.020122,0.014521,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 30, 40,...","[0.30527162551879883, 0.15270952880382538, 0.0...",0.014059
