In [1]:
import pandas as pd 
import numpy as np 
from pathlib import Path
from pycaret.regression import *

In [2]:
data = np.load(Path().resolve().parents[1] / f"data/9chrome_data.npy", allow_pickle=True)[()]
df = pd.DataFrame(data['X'].astype('float64'), columns=data['features'])
df['CT_RT'] = data['y'].astype('float64')
del data

In [4]:
df.to_csv(Path().resolve().parents[1] / "data/chrome.csv")

In [3]:
exp = setup(data = df, 
            target = 'CT_RT', 
            session_id=123,
            normalize = True, 
            transformation = True, 
            #transform_target = True, 
            combine_rare_levels = True, 
            rare_level_threshold = 0.05,
            remove_multicollinearity = True, 
            multicollinearity_threshold = 0.95, 
            #train_size=0.8,
            log_experiment = True, 
            fold=5,
            experiment_name = '9chrome')

Unnamed: 0,Description,Value
0,session_id,123
1,Target,CT_RT
2,Original Data,"(836, 28)"
3,Missing Values,False
4,Numeric Features,26
5,Categorical Features,1
6,Ordinal Features,False
7,High Cardinality Features,False
8,High Cardinality Method,
9,Transformed Train Set,"(585, 26)"


In [4]:
compare_models()


Unnamed: 0,Model,MAE,MSE,RMSE,R2,RMSLE,MAPE,TT (Sec)
et,Extra Trees Regressor,4397.0502,207634566.9152,13859.5225,0.7667,0.417,0.4357,0.088
catboost,CatBoost Regressor,4265.7744,227242286.5036,14274.1206,0.7474,0.7689,1.4435,1.23
xgboost,Extreme Gradient Boosting,4737.4378,235264348.8,14958.3264,0.7261,0.5444,0.5488,0.408
lightgbm,Light Gradient Boosting Machine,5353.4755,264628294.3765,15475.5796,0.7038,0.9775,2.7952,0.028
gbr,Gradient Boosting Regressor,5351.8081,288858830.2392,16247.6144,0.6774,0.9435,1.6707,0.034
rf,Random Forest Regressor,5384.1689,291775139.0844,16185.4009,0.6733,0.3498,0.2845,0.096
ada,AdaBoost Regressor,8900.1147,305683952.1038,16539.9283,0.6666,1.8515,24.7884,0.04
br,Bayesian Ridge,11633.3772,394721908.3691,19396.8073,0.5436,2.0224,77.123,0.3
llar,Lasso Least Angle Regression,11918.0274,403550588.7743,19654.8498,0.5311,2.0395,78.2941,0.466
ridge,Ridge Regression,11910.1145,403930441.6,19666.0893,0.5306,2.041,77.937,0.31


ExtraTreesRegressor(bootstrap=False, ccp_alpha=0.0, criterion='mse',
                    max_depth=None, max_features='auto', max_leaf_nodes=None,
                    max_samples=None, min_impurity_decrease=0.0,
                    min_impurity_split=None, min_samples_leaf=1,
                    min_samples_split=2, min_weight_fraction_leaf=0.0,
                    n_estimators=100, n_jobs=-1, oob_score=False,
                    random_state=123, verbose=0, warm_start=False)

In [5]:
ct = create_model('catboost', fold = 5)


Unnamed: 0,MAE,MSE,RMSE,R2,RMSLE,MAPE
0,3233.988,41069900.73,6408.5802,0.9256,0.7684,2.0107
1,3997.6386,344714331.4893,18566.4841,0.6419,0.6549,0.7485
2,5238.3447,400975862.7067,20024.3817,0.6497,0.6082,0.7035
3,4944.6232,199269066.7767,14116.2696,0.7081,0.8097,1.3641
4,3914.2772,150182270.8153,12254.8876,0.8117,1.0031,2.3905
Mean,4265.7744,227242286.5036,14274.1206,0.7474,0.7689,1.4435
SD,730.3094,130662880.0523,4846.8305,0.1078,0.1381,0.6716


In [6]:
tuned_ct = tune_model(ct)


Unnamed: 0,MAE,MSE,RMSE,R2,RMSLE,MAPE
0,5960.7794,87392533.2405,9348.3974,0.8417,1.5553,17.4692
1,6158.0917,259915005.8571,16121.8797,0.73,1.389,6.6837
2,7531.3639,467304626.3603,21617.2298,0.5918,1.0458,2.2162
3,6868.6003,247338812.1297,15727.009,0.6377,1.4843,14.2838
4,7572.9283,305795639.7008,17487.0135,0.6166,1.752,19.0572
Mean,6818.3527,273549323.4577,16060.3059,0.6835,1.4453,11.942
SD,671.0694,121744140.211,3951.6957,0.0918,0.2326,6.4638


In [7]:
predict_model(tuned_ct)

Unnamed: 0,Model,MAE,MSE,RMSE,R2,RMSLE,MAPE
0,CatBoost Regressor,6024.1202,147857724.7944,12159.6762,0.7826,1.3215,5.7368


Unnamed: 0,Fe,C,Cr,Mn,Si,Ni,Co,Mo,W,Nb,...,Normal,Temper1,AGS No.,CT_Temp,CT_EL,CT_RA,log_CT_CS,log_CT_MCR,CT_RT,Label
0,-1.029508,0.411686,0.560337,1.015361,-1.445352,0.652960,-0.403757,-0.797836,1.273721,-0.022017,...,0.088150,0.177758,-0.581429,1.163347,-0.734136,-0.854391,-1.026196,0.244343,4107.500000,2711.717033
1,1.053935,-0.504832,-0.985890,-1.529638,1.073945,-0.853185,-0.403757,0.943752,-1.011419,1.379526,...,0.088150,1.469816,1.398017,-0.577624,-0.209444,0.746437,0.521789,-0.394251,15761.799805,14075.025808
2,1.053935,-0.504832,-0.985890,-1.529638,1.073945,-0.853185,-0.403757,0.943752,-1.011419,1.379526,...,0.088150,1.469816,1.398017,0.098771,-0.006467,0.005299,-0.441432,-0.568245,18502.400391,22726.490625
3,-1.029508,0.411686,0.560337,1.015361,-1.445352,0.652960,-0.403757,-0.797836,1.273721,-0.022017,...,0.088150,0.177758,-0.581429,2.267060,1.750803,0.561645,-1.956473,0.885817,1043.199951,1028.788807
4,-0.906366,0.705470,0.512300,0.816244,0.201748,0.652960,-0.403757,-0.936363,1.174950,-0.274663,...,0.088150,0.177758,-0.897795,0.098771,-1.278540,-1.532960,-0.317396,-1.227077,33459.898438,42910.579490
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
246,1.044871,-0.504832,-0.867924,-0.638874,-1.445352,-1.607561,-0.403757,0.910537,-1.011419,1.103408,...,0.503261,-0.272719,0.536270,0.098771,0.481186,1.151637,0.131664,0.961350,895.299988,4825.661297
247,1.227708,0.012406,-1.149014,-0.879179,-1.157421,-1.407515,-0.403757,0.811343,-1.011419,1.379526,...,-0.100465,0.753496,1.331409,0.098771,1.129440,1.151637,0.131664,1.147347,529.000000,4825.661297
248,0.119547,-0.623167,-0.263660,-0.879179,-0.179302,-0.767557,-0.403757,-0.853441,1.044797,0.270344,...,2.582183,0.753496,-0.202655,0.447959,1.129440,0.652575,0.131664,1.077242,490.500000,4825.661297
249,-1.368662,0.012406,1.309595,0.746492,-0.633148,0.523780,-0.403757,-0.908785,1.125134,0.168521,...,0.088150,1.469816,0.480781,0.803178,0.191770,0.306221,0.025315,1.711943,84.599998,1028.788807
