In [31]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [32]:
import mlflow
from lib.config import AppConfig
from experiment_analysis.experiment_data_utils import get_full_runs_df
from IPython.display import display, HTML
import pandas as pd

config = AppConfig()
mlflow.set_tracking_uri(config.mlflow_tracking_uri)

In [33]:
runs = get_full_runs_df(config.optimization_experiment_results_path)

In [34]:
len(runs)

15936

In [35]:
from lib.reproduction import major_oxides

analysis_target = "SiO2"
n_splits=4
assert analysis_target in major_oxides, f"{analysis_target} is not a valid oxide. Please choose from {major_oxides}"

In [36]:
from experiment_analysis.experiment_data_utils import clean_experiment_data

filtered_runs = clean_experiment_data(runs)

filtered_runs = filtered_runs[filtered_runs['metrics.rmse_cv'] <= 50]

In [37]:
filtered_runs["params.oxide"].value_counts()

params.oxide
Al2O3    1943
FeOT     1928
SiO2     1926
CaO      1903
TiO2     1892
Na2O     1891
K2O      1883
MgO      1796
Name: count, dtype: int64

In [38]:
def display_table_with_options(df, max_columns=10, max_rows=100, display_func=lambda x: display(x)):
    original_max_columns = pd.get_option('display.max_columns')
    original_max_rows = pd.get_option('display.max_rows')

    pd.set_option('display.max_columns', max_columns)
    pd.set_option('display.max_rows', max_rows)

    display_func(df)

    pd.set_option('display.max_columns', original_max_columns)
    pd.set_option('display.max_rows', original_max_rows)

# Multivariate

In [39]:
overview_list = []
for oxide in major_oxides:
    overview_df = filtered_runs[["params.oxide", "params.model_type", "params.transformer_type", "params.pca_type", "params.scaler_type", "metrics.rmse_cv", "metrics.std_dev_cv"]]

    overview_df = overview_df[overview_df['params.oxide'] == oxide].sort_values(by='metrics.rmse_cv')
    unique_model_types_df = overview_df.drop_duplicates(subset=['params.model_type'])
    overview_list.append(unique_model_types_df)

for oxide, df in zip(major_oxides, overview_list):
    display(HTML(f"<h2>{oxide}</h2>"))
    display_table_with_options(df, max_columns=10, max_rows=100)


Unnamed: 0,params.oxide,params.model_type,params.transformer_type,params.pca_type,params.scaler_type,metrics.rmse_cv,metrics.std_dev_cv
1213,SiO2,pls,none,kernel_pca,min_max_scaler,4.552021,4.551431
1687,SiO2,svr,none,none,min_max_scaler,4.591604,4.588081
1817,SiO2,gbr,none,none,norm3_scaler,4.651792,4.646464
628,SiO2,lasso,power_transformer,pca,norm3_scaler,4.736667,4.738305
1034,SiO2,xgboost,quantile_transformer,none,norm3_scaler,4.791,4.781044
239,SiO2,elasticnet,quantile_transformer,none,norm3_scaler,4.841359,4.843702
808,SiO2,ngboost,power_transformer,none,norm3_scaler,4.859746,4.850966
489,SiO2,ridge,power_transformer,none,norm3_scaler,4.94048,4.938138
1415,SiO2,extra_trees,power_transformer,none,norm3_scaler,5.141448,5.118042
73,SiO2,random_forest,none,none,norm3_scaler,5.203819,5.191888


Unnamed: 0,params.oxide,params.model_type,params.transformer_type,params.pca_type,params.scaler_type,metrics.rmse_cv,metrics.std_dev_cv
3762,TiO2,svr,power_transformer,none,norm3_scaler,0.408701,0.406385
3869,TiO2,gbr,power_transformer,none,norm3_scaler,0.40973,0.408748
3145,TiO2,xgboost,none,none,robust_scaler,0.410616,0.41037
2184,TiO2,random_forest,quantile_transformer,none,norm3_scaler,0.422384,0.420629
2470,TiO2,elasticnet,none,none,robust_scaler,0.423295,0.423306
3417,TiO2,extra_trees,power_transformer,none,standard_scaler,0.425908,0.426013
2668,TiO2,ridge,none,none,min_max_scaler,0.427579,0.426859
2922,TiO2,lasso,power_transformer,none,standard_scaler,0.430815,0.429989
2320,TiO2,ngboost,none,none,robust_scaler,0.431336,0.430507
3364,TiO2,pls,power_transformer,kernel_pca,robust_scaler,0.440842,0.440607


Unnamed: 0,params.oxide,params.model_type,params.transformer_type,params.pca_type,params.scaler_type,metrics.rmse_cv,metrics.std_dev_cv
5062,Al2O3,xgboost,power_transformer,none,norm3_scaler,2.075363,2.067163
5843,Al2O3,gbr,power_transformer,none,robust_scaler,2.092107,2.089475
4063,Al2O3,ngboost,power_transformer,none,robust_scaler,2.121021,2.11273
5646,Al2O3,svr,quantile_transformer,none,min_max_scaler,2.178905,2.1764
4707,Al2O3,ridge,quantile_transformer,none,norm3_scaler,2.217908,2.21135
4506,Al2O3,elasticnet,quantile_transformer,none,norm3_scaler,2.224726,2.218543
5351,Al2O3,pls,quantile_transformer,none,robust_scaler,2.247312,2.243696
4852,Al2O3,lasso,quantile_transformer,none,norm3_scaler,2.249181,2.242429
5475,Al2O3,extra_trees,power_transformer,none,min_max_scaler,2.28846,2.261357
4218,Al2O3,random_forest,power_transformer,none,max_abs_scaler,2.302139,2.295194


Unnamed: 0,params.oxide,params.model_type,params.transformer_type,params.pca_type,params.scaler_type,metrics.rmse_cv,metrics.std_dev_cv
7779,FeOT,svr,quantile_transformer,none,norm3_scaler,2.242387,2.243303
7388,FeOT,pls,power_transformer,none,standard_scaler,2.701247,2.669157
6520,FeOT,ridge,quantile_transformer,none,norm3_scaler,2.707415,2.687497
7964,FeOT,gbr,power_transformer,none,max_abs_scaler,2.748689,2.749705
7142,FeOT,xgboost,none,none,max_abs_scaler,2.749488,2.743086
6248,FeOT,elasticnet,power_transformer,none,max_abs_scaler,2.861748,2.831049
6640,FeOT,lasso,quantile_transformer,none,norm3_scaler,2.874875,2.862145
7573,FeOT,extra_trees,none,none,max_abs_scaler,2.899789,2.902789
6941,FeOT,ngboost,none,none,robust_scaler,2.97996,2.952596
6154,FeOT,random_forest,quantile_transformer,none,norm3_scaler,3.079249,3.043626


Unnamed: 0,params.oxide,params.model_type,params.transformer_type,params.pca_type,params.scaler_type,metrics.rmse_cv,metrics.std_dev_cv
9527,MgO,svr,power_transformer,none,robust_scaler,1.321517,1.320551
9245,MgO,pls,none,kernel_pca,norm3_scaler,1.327285,1.321169
8517,MgO,ridge,power_transformer,none,robust_scaler,1.447727,1.443408
8300,MgO,elasticnet,power_transformer,none,robust_scaler,1.466054,1.462077
9766,MgO,gbr,quantile_transformer,none,norm3_scaler,1.468274,1.464279
9311,MgO,extra_trees,power_transformer,none,norm3_scaler,1.53322,1.522105
8726,MgO,lasso,none,kernel_pca,min_max_scaler,1.60443,1.595999
8945,MgO,xgboost,none,none,norm3_scaler,1.618271,1.610429
8216,MgO,random_forest,quantile_transformer,none,norm3_scaler,1.640273,1.630457
8071,MgO,ngboost,none,none,robust_scaler,1.939806,1.910566


Unnamed: 0,params.oxide,params.model_type,params.transformer_type,params.pca_type,params.scaler_type,metrics.rmse_cv,metrics.std_dev_cv
11589,CaO,svr,quantile_transformer,none,min_max_scaler,1.193312,1.191626
11158,CaO,pls,quantile_transformer,none,max_abs_scaler,1.269539,1.262997
11777,CaO,gbr,quantile_transformer,none,norm3_scaler,1.281066,1.280441
11359,CaO,extra_trees,none,none,norm3_scaler,1.308235,1.308956
10998,CaO,xgboost,power_transformer,none,norm3_scaler,1.362519,1.360737
10124,CaO,elasticnet,quantile_transformer,none,norm3_scaler,1.383963,1.376509
10341,CaO,ridge,quantile_transformer,none,norm3_scaler,1.405815,1.400113
9974,CaO,random_forest,none,none,norm3_scaler,1.43888,1.435206
10859,CaO,ngboost,none,none,robust_scaler,1.488171,1.481055
10534,CaO,lasso,power_transformer,none,min_max_scaler,1.528803,1.514428


Unnamed: 0,params.oxide,params.model_type,params.transformer_type,params.pca_type,params.scaler_type,metrics.rmse_cv,metrics.std_dev_cv
13665,Na2O,svr,power_transformer,none,norm3_scaler,0.777123,0.775457
13136,Na2O,pls,power_transformer,none,norm3_scaler,0.844636,0.842301
13830,Na2O,gbr,quantile_transformer,none,norm3_scaler,0.904053,0.894815
12937,Na2O,xgboost,quantile_transformer,none,max_abs_scaler,0.952315,0.942629
13374,Na2O,extra_trees,quantile_transformer,none,norm3_scaler,0.965194,0.95308
12133,Na2O,elasticnet,quantile_transformer,none,standard_scaler,0.993755,0.989845
12531,Na2O,lasso,quantile_transformer,none,max_abs_scaler,0.994894,0.991154
12801,Na2O,ngboost,quantile_transformer,none,norm3_scaler,0.999736,0.992898
11964,Na2O,random_forest,quantile_transformer,none,norm3_scaler,1.002221,0.995072
12374,Na2O,ridge,quantile_transformer,none,norm3_scaler,1.010849,1.000998


Unnamed: 0,params.oxide,params.model_type,params.transformer_type,params.pca_type,params.scaler_type,metrics.rmse_cv,metrics.std_dev_cv
15190,K2O,pls,none,none,norm3_scaler,0.586905,0.585668
15823,K2O,gbr,quantile_transformer,none,min_max_scaler,0.59025,0.587453
15576,K2O,svr,quantile_transformer,none,norm3_scaler,0.593129,0.593192
15078,K2O,xgboost,power_transformer,none,standard_scaler,0.599887,0.598787
14247,K2O,elasticnet,power_transformer,none,robust_scaler,0.602289,0.601509
14731,K2O,ngboost,quantile_transformer,none,max_abs_scaler,0.602461,0.600255
14538,K2O,lasso,power_transformer,none,norm3_scaler,0.606846,0.606405
14406,K2O,ridge,power_transformer,none,norm3_scaler,0.610802,0.611258
14000,K2O,random_forest,power_transformer,none,norm3_scaler,0.674821,0.668918
15420,K2O,extra_trees,power_transformer,none,robust_scaler,0.713629,0.709484


In [69]:
cols = ""
for col in filtered_runs.columns:
    cols += f"{col}, "

cols

'Unnamed: 0, run_id, experiment_id, status, artifact_uri, start_time, end_time, metrics.std_dev_cv_4, metrics.std_dev_cv_3, metrics.rmse_cv_4, metrics.std_dev_cv, metrics.mse, metrics.rmse_cv_1, metrics.std_dev_cv_2, metrics.std_dev_cv_1, metrics.rmse, metrics.std_dev, metrics.rmse_cv_3, metrics.rmse_cv_2, metrics.rmse_cv, params.fold_4_val_size, params.random_forest_min_samples_split, params.pca_type, params.fold_1_val_size, params.fold_3_train_size, params.model_type, params.fold_2_train_size, params.random_forest_n_estimators, params.fold_1_train_size, params.test_size, params.fold_3_val_size, params.fold_2_val_size, params.trial_number, params.transformer_type, params.train_size, params.scaler_type, params.fold_4_train_size, params.random_forest_max_depth, params.random_forest_min_samples_leaf, params.random_forest_max_features, params.robust_scaler_with_centering, params.robust_scaler_quantile_range, params.method, params.standardize, params.whiten, params.n_components, params.sta

In [78]:
from experiment_analysis.experiment_data_utils import pretty_format_params

overview_list = []
for oxide in major_oxides:
    overview_df = filtered_runs[["params.oxide", "params.model_type", "params.transformer_type", "params.pca_type", "params.scaler_type", "metrics.rmse_cv", "metrics.std_dev_cv"]]

    overview_df = overview_df[overview_df['params.oxide'] == oxide].sort_values(by='metrics.rmse_cv')
    unique_model_types_df = overview_df.drop_duplicates(subset=['params.model_type'])
    overview_list.append(unique_model_types_df)

for oxide, df in zip(major_oxides, overview_list):
    display(HTML(f"<h2>{oxide}</h2>"))
    display(HTML("<h3>Top 3 Configurations</h3>"))

    for i, row in df.head(3).iterrows():
        # display(display_table_with_options(filtered_runs.loc[row.name][filtered_runs.loc[row.name].notna()], max_columns=10, max_rows=100, display_func=lambda x: display(x)))
        data_row = filtered_runs.loc[row.name]
        print(pretty_format_params(data_row))
        print(f"RMSEP: {data_row['metrics.rmse']}")
        print(f"Std.Dev: {data_row['metrics.std_dev']}")
        print(f"RMSE CV: {data_row['metrics.rmse_cv']}")
        print(f"STD Dev CV: {data_row['metrics.std_dev_cv']}")
        print("\n")
    print("\n")


Model: pls
Model Parameters:
  pls_n_components: 1.0

Scaler: min_max_scaler
Scaler Parameters:
  min_max_scaler_feature_range: (0, 1)

PCA: kernel_pca
PCA Parameters:
  n_components: 100.0
  gamma: 0.0025184866615475
  kernel: cosine
  degree: 2.0
RMSEP: 4.084361878000858
Std.Dev: 4.087395463105326
RMSE CV: 4.552020500638808
STD Dev CV: 4.551431354622924


Model: svr
Model Parameters:
  svr_C: 0.101270914859041
  svr_kernel: poly
  svr_degree: 5.0
  svr_gamma: auto
  svr_coef0: 5.982869617857073
  svr_epsilon: 0.1075038351628729
  svr_max_iter: 20000000.0

Scaler: min_max_scaler
Scaler Parameters:
  min_max_scaler_feature_range: (-1, 1)
RMSEP: 3.5328962583557093
Std.Dev: 3.53731678643022
RMSE CV: 4.591603852578772
STD Dev CV: 4.588081003183366


Model: gbr
Model Parameters:
  gbr_learning_rate: 0.0195372695066511
  gbr_subsample: 0.6333895654431646
  gbr_max_depth: 3.0
  gbr_max_features: sqrt
  gbr_n_estimators: 932.0

Scaler: norm3_scaler
Scaler Parameters:
RMSEP: 3.719811372910595


Model: svr
Model Parameters:
  svr_C: 0.0092802848242038
  svr_kernel: poly
  svr_degree: 3.0
  svr_gamma: scale
  svr_coef0: 8.63601100525176
  svr_epsilon: 0.0028037787477313
  svr_max_iter: 20000000.0

Scaler: norm3_scaler
Scaler Parameters:

Transformer: power_transformer
Transformer Parameters:
  method: yeo-johnson
  standardize: True
RMSEP: 0.3970526971649357
Std.Dev: 0.3852958402975567
RMSE CV: 0.4087012424132404
STD Dev CV: 0.406384652663879


Model: gbr
Model Parameters:
  gbr_learning_rate: 0.0285922209309325
  gbr_subsample: 0.5585632955924456
  gbr_max_depth: 5.0
  gbr_max_features: sqrt
  gbr_n_estimators: 898.0

Scaler: norm3_scaler
Scaler Parameters:

Transformer: power_transformer
Transformer Parameters:
  method: yeo-johnson
  standardize: True
RMSEP: 0.3324705144508141
Std.Dev: 0.3238442570736441
RMSE CV: 0.4097301363390033
STD Dev CV: 0.408748159212592


Model: xgboost
Model Parameters:
  xgboost_learning_rate: 0.2118592173320599
  xgboost_reg_lambda: 73.91710980996

Model: xgboost
Model Parameters:
  xgboost_learning_rate: 0.0264099388873174
  xgboost_reg_lambda: 0.0034397452764153
  xgboost_colsample_bytree: 0.5558889734016561
  xgboost_reg_alpha: 0.5011206377632488
  xgboost_n_estimators: 761.0
  xgboost_max_depth: 5.0
  xgboost_subsample: 0.7374993842567144
  xgboost_gamma: 0.157189926185756

Scaler: norm3_scaler
Scaler Parameters:

Transformer: power_transformer
Transformer Parameters:
  method: yeo-johnson
  standardize: True
RMSEP: 1.740205182028688
Std.Dev: 1.7402550985875918
RMSE CV: 2.075363296387038
STD Dev CV: 2.067162586836948


Model: gbr
Model Parameters:
  gbr_learning_rate: 0.0085500994881957
  gbr_subsample: 0.6017741121087732
  gbr_max_depth: 4.0
  gbr_max_features: sqrt
  gbr_n_estimators: 821.0

Scaler: robust_scaler
Scaler Parameters:
  robust_scaler_with_centering: False
  robust_scaler_quantile_range: (35.0, 65.0)

Transformer: power_transformer
Transformer Parameters:
  method: yeo-johnson
  standardize: False
RMSEP: 1.9865

Model: svr
Model Parameters:
  svr_C: 16.477142954470164
  svr_kernel: rbf
  svr_degree: 5.0
  svr_gamma: scale
  svr_coef0: 6.252397271092422
  svr_epsilon: 0.0149272838875138
  svr_max_iter: 20000000.0

Scaler: norm3_scaler
Scaler Parameters:

Transformer: quantile_transformer
Transformer Parameters:
  subsample: 27549.0
  n_quantiles: 665.0
  random_state: 42.0
  output_distribution: uniform
RMSEP: 1.8031705312361492
Std.Dev: 1.7779132683335346
RMSE CV: 2.2423872222329617
STD Dev CV: 2.2433027369935683


Model: pls
Model Parameters:
  pls_n_components: 30.0

Scaler: standard_scaler
Scaler Parameters:
  standard_scaler_with_std: True
  standard_scaler_with_mean: True

Transformer: power_transformer
Transformer Parameters:
  method: yeo-johnson
  standardize: False
RMSEP: 2.0629778885569223
Std.Dev: 2.059337627335091
RMSE CV: 2.70124691165507
STD Dev CV: 2.6691567198201827


Model: ridge
Model Parameters:
  ridge_alpha: 55.16248016653623

Scaler: norm3_scaler
Scaler Parameters:

Trans

Model: svr
Model Parameters:
  svr_C: 0.0892694115469055
  svr_kernel: poly
  svr_degree: 3.0
  svr_gamma: auto
  svr_coef0: 9.35850494507051
  svr_epsilon: 0.0124423713270124
  svr_max_iter: 20000000.0

Scaler: robust_scaler
Scaler Parameters:
  robust_scaler_with_centering: False
  robust_scaler_quantile_range: (10.0, 90.0)

Transformer: power_transformer
Transformer Parameters:
  method: yeo-johnson
  standardize: True
RMSEP: 0.7911495308140506
Std.Dev: 0.7919051923377272
RMSE CV: 1.3215169628193475
STD Dev CV: 1.320551279086084


Model: pls
Model Parameters:
  pls_n_components: 3.0

Scaler: norm3_scaler
Scaler Parameters:

PCA: kernel_pca
PCA Parameters:
  n_components: 80.0
  gamma: 0.0224292044706105
  kernel: rbf
  degree: 2.0
RMSEP: 0.993173264098897
Std.Dev: 0.994444846436946
RMSE CV: 1.327285053654815
STD Dev CV: 1.3211689975417684


Model: ridge
Model Parameters:
  ridge_alpha: 49.00255709869574

Scaler: robust_scaler
Scaler Parameters:
  robust_scaler_with_centering: True
 

Model: svr
Model Parameters:
  svr_C: 0.0834973126025444
  svr_kernel: linear
  svr_degree: 1.0
  svr_gamma: auto
  svr_coef0: 3.035393902393877
  svr_epsilon: 0.0858748187215623
  svr_max_iter: 20000000.0

Scaler: min_max_scaler
Scaler Parameters:
  min_max_scaler_feature_range: (0, 1)

Transformer: quantile_transformer
Transformer Parameters:
  subsample: 80879.0
  n_quantiles: 692.0
  random_state: 42.0
  output_distribution: uniform
RMSEP: 1.5997081619249127
Std.Dev: 1.5900013197250518
RMSE CV: 1.1933122331886648
STD Dev CV: 1.1916256780194558


Model: pls
Model Parameters:
  pls_n_components: 22.0

Scaler: max_abs_scaler
Scaler Parameters:

Transformer: quantile_transformer
Transformer Parameters:
  subsample: 64610.0
  n_quantiles: 714.0
  random_state: 42.0
  output_distribution: uniform
RMSEP: 1.7677424075223047
Std.Dev: 1.7673108101209158
RMSE CV: 1.2695386865040792
STD Dev CV: 1.2629967130447104


Model: gbr
Model Parameters:
  gbr_learning_rate: 0.0170321215492168
  gbr_subs

Model: svr
Model Parameters:
  svr_C: 0.0075575590330982
  svr_kernel: poly
  svr_degree: 4.0
  svr_gamma: scale
  svr_coef0: 7.578932896276198
  svr_epsilon: 0.0026366262560171
  svr_max_iter: 20000000.0

Scaler: norm3_scaler
Scaler Parameters:

Transformer: power_transformer
Transformer Parameters:
  method: yeo-johnson
  standardize: True
RMSEP: 0.39298616902088
Std.Dev: 0.3886397253571709
RMSE CV: 0.7771225118516447
STD Dev CV: 0.7754565846890182


Model: pls
Model Parameters:
  pls_n_components: 30.0

Scaler: norm3_scaler
Scaler Parameters:

Transformer: power_transformer
Transformer Parameters:
  method: yeo-johnson
  standardize: True
RMSEP: 0.5607339537742186
Std.Dev: 0.5592673086762922
RMSE CV: 0.8446357265634863
STD Dev CV: 0.8423007488974287


Model: gbr
Model Parameters:
  gbr_learning_rate: 0.0110063295655558
  gbr_subsample: 0.9527088410971416
  gbr_max_depth: 5.0
  gbr_max_features: sqrt
  gbr_n_estimators: 957.0

Scaler: norm3_scaler
Scaler Parameters:

Transformer: qua

Model: pls
Model Parameters:
  pls_n_components: 30.0

Scaler: norm3_scaler
Scaler Parameters:
RMSEP: 0.7238371861539327
Std.Dev: 0.7172772606716364
RMSE CV: 0.5869051882900327
STD Dev CV: 0.5856683301775469


Model: gbr
Model Parameters:
  gbr_learning_rate: 0.0358163597255004
  gbr_subsample: 0.6765783752228338
  gbr_max_depth: 4.0
  gbr_max_features: sqrt
  gbr_n_estimators: 712.0

Scaler: min_max_scaler
Scaler Parameters:
  min_max_scaler_feature_range: (-1, 1)

Transformer: quantile_transformer
Transformer Parameters:
  subsample: 51417.0
  n_quantiles: 139.0
  random_state: 42.0
  output_distribution: uniform
RMSEP: 0.4226257050202812
Std.Dev: 0.4131298711863923
RMSE CV: 0.5902502338938991
STD Dev CV: 0.587452616172343


Model: svr
Model Parameters:
  svr_C: 446.9244924838029
  svr_kernel: rbf
  svr_degree: 1.0
  svr_gamma: auto
  svr_coef0: 1.5472933504186566
  svr_epsilon: 0.0094011000249323
  svr_max_iter: 20000000.0

Scaler: norm3_scaler
Scaler Parameters:

Transformer: quant

In [41]:
filtered_runs[filtered_runs['params.oxide'] == analysis_target]["metrics.rmse_cv"].describe()

count    1926.000000
mean        7.964783
std         4.262402
min         4.552021
25%         5.320587
50%         6.033474
75%         8.831299
max        46.779825
Name: metrics.rmse_cv, dtype: float64

In [42]:
# Group filtered_runs by the specified parameters and sort by metrics.rmsecv
grouped_runs = filtered_runs.groupby(
    ['params.model_type', 'params.transformer_type', 'params.pca_type', 'params.scaler_type', 'params.oxide']
).apply(lambda x: x.sort_values(by='metrics.rmse_cv').head(1)).reset_index(drop=True)

# Create a pivot table to show the best configurations for each oxide
pivot_table = grouped_runs.pivot_table(
    index=['params.model_type', 'params.transformer_type', 'params.scaler_type', 'params.pca_type'],
    columns='params.oxide',
    values='metrics.rmse_cv',
    aggfunc='first'
)

# pd.set_option('display.max_columns', None)
# pd.set_option('display.max_rows', None)

# Display the pivot table
display(pivot_table)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,params.oxide,Al2O3,CaO,FeOT,K2O,MgO,Na2O,SiO2,TiO2
params.model_type,params.transformer_type,params.scaler_type,params.pca_type,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
elasticnet,none,max_abs_scaler,kernel_pca,5.316901,3.080646,3.715408,1.243959,2.863895,1.688812,7.597098,0.494604
elasticnet,none,max_abs_scaler,none,3.690606,,3.193771,0.867601,1.909102,2.063873,5.262049,0.454056
elasticnet,none,max_abs_scaler,pca,5.929297,5.271999,3.135924,1.784689,5.261134,2.063225,11.566500,0.471508
elasticnet,none,min_max_scaler,kernel_pca,4.752615,3.153506,4.565382,1.457239,3.895911,1.775487,8.709502,0.582199
elasticnet,none,min_max_scaler,none,,1.953705,3.283608,,1.772372,1.377255,5.386139,0.440074
...,...,...,...,...,...,...,...,...,...,...,...
xgboost,quantile_transformer,robust_scaler,none,2.155769,1.492869,3.918840,0.665765,2.100242,1.007344,5.230273,0.415299
xgboost,quantile_transformer,robust_scaler,pca,3.019718,2.156729,3.433276,1.353425,2.327160,1.200463,6.380162,0.476932
xgboost,quantile_transformer,standard_scaler,kernel_pca,3.094951,5.324183,3.516543,0.823987,2.140626,1.093817,6.409288,0.528229
xgboost,quantile_transformer,standard_scaler,none,2.302543,1.565003,2.824020,0.638709,1.962214,1.028113,5.437699,0.447938


# Univariate

In [None]:
filtered_runs = filtered_runs[filtered_runs['params.oxide'] == analysis_target]
len(filtered_runs)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

total_runs = len(runs)

sns.set_style('whitegrid')
plt.figure(figsize=(10, 6))
sns.boxplot(x='params.model_type', y='metrics.rmse', data=filtered_runs)
plt.title(f"{analysis_target}: RMSE for each model type - {len(filtered_runs)} runs out of {total_runs} total runs")
plt.xlabel("Model Type")
plt.ylabel("RMSEP")
plt.show()



In [None]:
# Find the runs that minimize rmse, rmse_cv, std_dev, and std_dev_cv
optimal_runs = filtered_runs.loc[filtered_runs[['metrics.rmse', 'metrics.rmse_cv', 'metrics.std_dev', 'metrics.std_dev_cv']].idxmin()]

# Display the optimal runs
optimal_runs[['metrics.rmse', 'metrics.rmse_cv', 'metrics.std_dev', 'metrics.std_dev_cv', 'params.model_type']]


In [None]:
# Setting up visualization style
sns.set(style="whitegrid")

# Plotting RMSE CV
plt.figure(figsize=(12, 7))
sns.boxplot(x='params.model_type', y='metrics.rmse_cv', data=filtered_runs)
plt.title(f'{analysis_target}: Average Cross-Validation RMSE by Model Type')
plt.ylabel('Average RMSE (Cross-Validation)')
plt.show()

# Plotting Standard Deviation of RMSE CV
plt.figure(figsize=(12, 7))
sns.boxplot(x='params.model_type', y='metrics.std_dev_cv', data=filtered_runs)
plt.title(f'{analysis_target}: Standard Deviation of Errors (Cross-Validation) by Model Type')
plt.ylabel('Standard Deviation of Errors (CV)')
plt.show()


In [None]:
# Prepare a melted DataFrame for seaborn plotting
melted_df = filtered_runs.melt(id_vars=['params.model_type'], value_vars=[f'metrics.rmse_cv_{i+1}' for i in range(n_splits)],
                               var_name='CV Fold', value_name='Fold RMSE')

# Plotting without outliers
plt.figure(figsize=(14, 8))
sns.boxplot(x='params.model_type', y='Fold RMSE', hue='CV Fold', data=melted_df, showfliers=False)
plt.title(f'{analysis_target}: Distribution of RMSE Across CV Folds by Model Type')
plt.show()


In [None]:
cv_columns = [
    'metrics.rmse_cv', 'params.model_type', 'params.scaler_type',
    'params.transformer_type', 'params.pca_type'
]
filtered_runs_new = runs[cv_columns]
filtered_runs_new = filtered_runs_new[filtered_runs_new['metrics.rmse_cv'] <= 50]


# Rename columns for clarity
rename_dict = {col: col.split('.')[-1] for col in cv_columns}
filtered_runs_new = filtered_runs_new.rename(columns=rename_dict)

In [None]:
sns.set(style="whitegrid")

# Individual Parameters
for parameter in ['model_type', 'scaler_type', 'transformer_type', 'pca_type']:
    plt.figure(figsize=(10, 6))
    chart = sns.barplot(x=parameter, y='rmse_cv', data=filtered_runs_new)
    chart.set_xticks(range(len(filtered_runs_new[parameter].unique())))
    chart.set_xticklabels(chart.get_xticklabels(), rotation=45, horizontalalignment='right')
    plt.title(f'{analysis_target}: Average RMSE (CV) by {parameter.capitalize()}')
    plt.ylabel('Average RMSE (CV)')
    plt.show()

# Combinations of Parameters
# Considering combinations might result in a lot of categories, focus on the top few based on average RMSE
combination_data = filtered_runs_new.groupby(['model_type', 'scaler_type', 'transformer_type', 'pca_type']).mean()['rmse_cv']
combination_data = combination_data.reset_index().sort_values(by='rmse_cv', ascending=True)

# Display top 10 combinations
print(combination_data.head(10))

# Optionally, visualize these top combinations
plt.figure(figsize=(14, 8))
combination_data_top10 = combination_data[:10]
combination_labels = combination_data_top10.apply(lambda row: ', '.join([str(row[param]) for param in ['model_type', 'scaler_type', 'transformer_type', 'pca_type'] if row[param] != 'none']), axis=1)
sns.barplot(x='rmse_cv', y=combination_labels, data=combination_data_top10, orient='h')
plt.title(f'{analysis_target}: Top 10 Combinations for RMSE Performance')
plt.xlabel('Average RMSE (Cross-Validation)')
plt.ylabel('Combinations')
plt.show()

In [None]:
# Aggregate the data to compute mean and standard deviation of RMSE for each configuration
# Lower RMSE (lower is better) and lower STD RMSE (lower is better for consistency)
aggregated_data = filtered_runs_new.groupby(['model_type', 'scaler_type', 'transformer_type', 'pca_type']).agg({
    'rmse_cv': ['mean', 'std']
}).reset_index()

# Flatten the columns (multi-level index after aggregation)
aggregated_data.columns = ['Model Type', 'Scaler Type', 'Transformer Type', 'PCA Type', 'Mean RMSECV', 'STD RMSECV']

# Sort configurations first by mean RMSE (ascending, lower is better) and then by STD RMSE (ascending, lower is better for consistency)
sorted_data = aggregated_data.sort_values(by=['Mean RMSECV', 'STD RMSECV'], ascending=[True, True])

# Display the top 10 consistently good configurations
print(sorted_data.head(10))

In [None]:
sns.set(style="whitegrid")

# Plotting the top configurations based on Mean RMSE
plt.figure(figsize=(12, 8))
top_n = 50
for parameter in ['Model Type', 'Scaler Type', 'Transformer Type', 'PCA Type']:
    top_configurations = sns.barplot(x='Mean RMSECV', y=parameter, hue=parameter, data=sorted_data.head(top_n), dodge=False)
    plt.title(f'{analysis_target}: Top {top_n} Configurations by Mean RMSECV and Their Consistency')
    plt.xlabel('Mean RMSECV')
    plt.ylabel(parameter)
    # Annotate each bar with the value of Mean RMSE
    for p in top_configurations.patches:
        width = p.get_width()
        plt.text(width + 0.01, p.get_y()+0.2 + p.get_height() / 2, f'{width:.2f}', ha='left', va='center')
    plt.show()


In [None]:
first_row = filtered_runs.sort_values(by="metrics.rmse_cv").iloc[0]
non_none_columns = first_row[first_row.notna()].index.tolist()
first_row[non_none_columns]
