In [2]:
import pandas as pd
from pycaret.regression import setup, RegressionExperiment, compare_models, predict_model
import numpy as np

In [1]:
def prepare_data(df):
    expanded_df = pd.DataFrame(df['flux'].tolist(), index=df.index)
    df = df.drop(columns=['source_id','spectraltype_esphs','flux'])
    df = pd.concat([df, expanded_df], axis=1)
    return df

In [5]:
train_data = pd.read_parquet('../../../data/Gaia DR3/LM_mass.parquet')

train_df = prepare_data(train_data)

In [6]:
X = train_df.drop(columns='mass_flame')
y_train = train_df['mass_flame']

num_fluxes = X.shape[1]

print('Each spectrum contains ' + str(num_fluxes) + ' wavelength bins')
print('Training set includes ' + str(X.shape[0]) + ' spectra.')

Each spectrum contains 343 wavelength bins
Training set includes 10000 spectra.


In [7]:
s = setup(data=train_df, target='mass_flame', session_id=123, n_jobs=1)

Unnamed: 0,Description,Value
0,Session id,123
1,Target,mass_flame
2,Target type,Regression
3,Original data shape,"(10000, 344)"
4,Transformed data shape,"(10000, 344)"
5,Transformed train set shape,"(7000, 344)"
6,Transformed test set shape,"(3000, 344)"
7,Numeric features,343
8,Preprocess,True
9,Imputation type,simple


In [10]:
exp = RegressionExperiment()
exp.setup(data=train_df, target='mass_flame', session_id=123)

Unnamed: 0,Description,Value
0,Session id,123
1,Target,mass_flame
2,Target type,Regression
3,Original data shape,"(10000, 344)"
4,Transformed data shape,"(10000, 344)"
5,Transformed train set shape,"(7000, 344)"
6,Transformed test set shape,"(3000, 344)"
7,Numeric features,343
8,Preprocess,True
9,Imputation type,simple


<pycaret.regression.oop.RegressionExperiment at 0x1572189d8d0>

In [11]:
best = compare_models()

Unnamed: 0,Model,MAE,MSE,RMSE,R2,RMSLE,MAPE,TT (Sec)
lightgbm,Light Gradient Boosting Machine,0.2687,0.1956,0.442,0.6722,0.1483,0.1843,4.43
xgboost,Extreme Gradient Boosting,0.2671,0.1967,0.4429,0.6711,0.1497,0.1821,9.922
knn,K Neighbors Regressor,0.2972,0.2602,0.5099,0.5629,0.1721,0.212,0.259
br,Bayesian Ridge,0.5524,0.5992,0.7734,-0.0012,0.261,0.3791,0.461
rf,Random Forest Regressor,0.5522,0.5993,0.7734,-0.0012,0.261,0.3788,12.841
dt,Decision Tree Regressor,0.5524,0.5992,0.7734,-0.0012,0.261,0.3791,0.355
gbr,Gradient Boosting Regressor,0.5524,0.5992,0.7734,-0.0012,0.261,0.3791,21.171
et,Extra Trees Regressor,0.5524,0.5992,0.7734,-0.0012,0.261,0.3791,0.69
lasso,Lasso Regression,0.5524,0.5992,0.7734,-0.0012,0.261,0.3791,0.166
omp,Orthogonal Matching Pursuit,0.5524,0.5992,0.7734,-0.0012,0.261,0.3791,0.159


## Predict

In [12]:
predict_model(best)

Unnamed: 0,Model,MAE,MSE,RMSE,R2,RMSLE,MAPE
0,Light Gradient Boosting Machine,0.261,0.1782,0.4222,0.6887,0.1428,0.1748


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,335,336,337,338,339,340,341,342,mass_flame,prediction_label
2656,3.261346e-17,3.208211e-17,3.191596e-17,2.864753e-17,2.501117e-17,2.517226e-17,2.959747e-17,3.243251e-17,2.946856e-17,2.415390e-17,...,2.563199e-17,2.376246e-17,2.215072e-17,2.081729e-17,1.955281e-17,1.950488e-17,1.971968e-17,2.147762e-17,0.942859,1.030746
445,4.664591e-17,3.998625e-17,3.829824e-17,3.990560e-17,4.054429e-17,3.882924e-17,3.822021e-17,4.066460e-17,4.467936e-17,4.578236e-17,...,4.501409e-17,4.444265e-17,4.426366e-17,4.420858e-17,4.337071e-17,4.383337e-17,4.314738e-17,4.396473e-17,1.116834,1.186844
9505,2.762769e-17,2.581370e-17,2.633392e-17,2.684024e-17,2.660487e-17,2.553572e-17,2.414175e-17,2.256132e-17,2.228626e-17,2.444628e-17,...,1.295339e-17,1.243369e-17,1.223368e-17,1.231704e-17,1.245083e-17,1.322531e-17,1.387203e-17,1.513953e-17,1.208841,1.128289
332,3.750044e-17,4.455583e-17,4.321733e-17,3.779516e-17,3.547266e-17,3.573379e-17,3.523281e-17,3.408556e-17,3.664467e-17,4.142156e-17,...,5.372923e-17,5.281144e-17,5.235093e-17,5.201430e-17,5.073027e-17,5.093726e-17,4.979145e-17,5.038919e-17,1.034442,1.076405
4168,5.639429e-17,5.652484e-17,5.287723e-17,4.953052e-17,4.829877e-17,4.716395e-17,4.633942e-17,4.722725e-17,5.065964e-17,5.303881e-17,...,3.880129e-17,3.672569e-17,3.531618e-17,3.455277e-17,3.393839e-17,3.526509e-17,3.663526e-17,4.018681e-17,1.120123,1.056340
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2708,1.997118e-16,1.820768e-16,1.739837e-16,1.773602e-16,1.838042e-16,1.814691e-16,1.728953e-16,1.674313e-16,1.767075e-16,1.948580e-16,...,1.251182e-16,1.204107e-16,1.177117e-16,1.166139e-16,1.150141e-16,1.185977e-16,1.207362e-16,1.284813e-16,1.409340,1.140305
8232,2.131098e-17,1.820812e-17,1.860339e-17,2.001373e-17,2.015218e-17,1.916435e-17,1.906422e-17,1.998830e-17,2.029067e-17,1.870770e-17,...,1.351004e-17,1.325842e-17,1.318060e-17,1.322025e-17,1.312368e-17,1.352947e-17,1.368159e-17,1.438905e-17,1.228488,1.137038
5835,4.966837e-17,5.391177e-17,4.145197e-17,3.274213e-17,3.495223e-17,3.715022e-17,3.425730e-17,3.478972e-17,4.318306e-17,4.509815e-17,...,2.590357e-16,2.558042e-16,2.560348e-16,2.582245e-16,2.569167e-16,2.641338e-16,2.647739e-16,2.744048e-16,2.501397,2.324480
6689,4.546188e-18,3.105398e-18,4.935675e-18,5.685854e-18,4.539743e-18,3.863436e-18,4.757187e-18,5.135200e-18,3.864532e-18,3.403119e-18,...,4.524077e-17,4.474668e-17,4.501498e-17,4.580153e-17,4.612697e-17,4.811413e-17,4.897031e-17,5.147380e-17,2.784038,2.404197
