In [33]:
import pandas as pd
from pycaret.regression import setup, RegressionExperiment, compare_models, predict_model
from sklearn.decomposition import PCA
import numpy as np

In [5]:
def prepare_data(df):
    expanded_df = pd.DataFrame(df['flux'].tolist(), index=df.index)
    df = df.drop(columns=['source_id','spectraltype_esphs','Cat','flux'])
    df = pd.concat([df, expanded_df], axis=1)
    return df

In [6]:
train_data = pd.read_parquet('../../../data/Gaia DR3/train.parquet')
test_data = pd.read_parquet('../../../data/Gaia DR3/test.parquet')

train_df = prepare_data(train_data)
test_df = prepare_data(test_data)

pca = PCA(n_components=1)

In [30]:
x_train = train_df.iloc[:,3:]
y_train = np.vstack(train_df[['teff_gspphot','logg_gspphot','mh_gspphot']].values)

x_test = test_df.iloc[:,3:]
y_test = np.vstack(test_df[['teff_gspphot','logg_gspphot','mh_gspphot']].values)

# standardize the labels
mean_train = np.mean(y_train, axis=0)
std_train = np.std(y_train, axis=0)

mean_test = np.mean(y_test, axis=0)
std_test = np.std(y_test, axis=0)

y_train = (y_train - mean_train) / std_train
y_test = (y_test - mean_test) / std_test

# Transform into a single value
y_train = pca.fit_transform(y_train)
y_test = pca.fit_transform(y_test)


# Define the number of output labels
num_labels = y_train.shape[1]
num_fluxes = x_train.shape[1]

print('Each spectrum contains ' + str(num_fluxes) + ' wavelength bins')
print('Training set includes ' + str(x_train.shape[0]) + ' spectra.')
print('Test set includes ' + str(x_test.shape[0]) + ' spectra.')

Each spectrum contains 343 wavelength bins
Training set includes 14101 spectra.
Test set includes 3526 spectra.


In [31]:
df_train = pd.concat([pd.DataFrame(x_train), pd.DataFrame(y_train, columns=['target'])], axis=1)
df_test = pd.concat([pd.DataFrame(x_test), pd.DataFrame(y_test, columns=['target'])], axis=1)

In [25]:
s = setup(data=df_train, target='target', session_id=123, n_jobs=1)

Unnamed: 0,Description,Value
0,Session id,123
1,Target,target
2,Target type,Regression
3,Original data shape,"(14101, 344)"
4,Transformed data shape,"(14101, 344)"
5,Transformed train set shape,"(9870, 344)"
6,Transformed test set shape,"(4231, 344)"
7,Numeric features,343
8,Preprocess,True
9,Imputation type,simple


In [26]:
exp = RegressionExperiment()
exp.setup(data=df_train, target='target', session_id=123)

Unnamed: 0,Description,Value
0,Session id,123
1,Target,target
2,Target type,Regression
3,Original data shape,"(14101, 344)"
4,Transformed data shape,"(14101, 344)"
5,Transformed train set shape,"(9870, 344)"
6,Transformed test set shape,"(4231, 344)"
7,Numeric features,343
8,Preprocess,True
9,Imputation type,simple


<pycaret.regression.oop.RegressionExperiment at 0x2531483e200>

In [27]:
best = compare_models()

Unnamed: 0,Model,MAE,MSE,RMSE,R2,RMSLE,MAPE,TT (Sec)
lightgbm,Light Gradient Boosting Machine,0.552,0.8096,0.8988,0.3525,0.3515,2.6454,4.971
xgboost,Extreme Gradient Boosting,0.5676,0.8725,0.9334,0.3018,0.3412,3.6637,9.493
knn,K Neighbors Regressor,0.5849,0.9374,0.9677,0.2495,0.3398,3.8082,0.271
br,Bayesian Ridge,0.8074,1.2516,1.1184,-0.0014,0.6256,1.0522,0.519
rf,Random Forest Regressor,0.8071,1.2516,1.1184,-0.0014,0.6236,1.0752,18.847
dt,Decision Tree Regressor,0.8074,1.2516,1.1184,-0.0014,0.6256,1.0522,0.485
gbr,Gradient Boosting Regressor,0.8074,1.2516,1.1184,-0.0014,0.6256,1.0522,30.746
et,Extra Trees Regressor,0.8074,1.2516,1.1184,-0.0014,0.6256,1.0522,0.909
lasso,Lasso Regression,0.8074,1.2516,1.1184,-0.0014,0.6256,1.0522,0.195
omp,Orthogonal Matching Pursuit,0.8074,1.2516,1.1184,-0.0014,0.6256,1.0522,0.198


## Predict

In [34]:
predict_model(best, data=df_test)

Unnamed: 0,Model,MAE,MSE,RMSE,R2,RMSLE,MAPE
0,Light Gradient Boosting Machine,0.5489,0.8018,0.8955,0.358,0.3468,2.1137


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,335,336,337,338,339,340,341,342,target,prediction_label
0,2.140608e-17,1.365276e-17,8.998306e-18,9.277905e-18,1.133764e-17,1.236834e-17,1.210512e-17,1.104840e-17,1.005849e-17,1.006457e-17,...,1.875544e-17,1.888428e-17,1.936134e-17,2.002518e-17,2.037434e-17,2.127749e-17,2.145453e-17,2.210674e-17,-0.211531,0.043237
1,5.743483e-16,5.637122e-16,5.316359e-16,5.080654e-16,5.123901e-16,5.203310e-16,5.116850e-16,4.899089e-16,4.942565e-16,5.246064e-16,...,1.788671e-16,1.732059e-16,1.707564e-16,1.708296e-16,1.701658e-16,1.769846e-16,1.812655e-16,1.934178e-16,-0.277604,-0.362509
2,9.540365e-16,9.370028e-16,8.942922e-16,8.675982e-16,8.753908e-16,8.860598e-16,8.863876e-16,8.799443e-16,9.082285e-16,9.587669e-16,...,8.802934e-16,8.855333e-16,9.017898e-16,9.214689e-16,9.221377e-16,9.441346e-16,9.314883e-16,9.383808e-16,-1.705701,-1.786159
3,4.322730e-17,3.599028e-17,3.116138e-17,3.147355e-17,3.380242e-17,3.458265e-17,3.315480e-17,3.060560e-17,3.091396e-17,3.615521e-17,...,1.779722e-17,1.638334e-17,1.512725e-17,1.408430e-17,1.316753e-17,1.320554e-17,1.359484e-17,1.523197e-17,0.362757,0.293711
4,1.134171e-15,1.104958e-15,1.022256e-15,9.641852e-16,9.621754e-16,9.662636e-16,9.512777e-16,9.316988e-16,9.557092e-16,9.853166e-16,...,1.517024e-16,1.455515e-16,1.424176e-16,1.418621e-16,1.413263e-16,1.477437e-16,1.527918e-16,1.651610e-16,-0.864140,-0.900459
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3521,2.481654e-17,2.224579e-17,2.078612e-17,1.823804e-17,1.631545e-17,1.758840e-17,2.190646e-17,2.471268e-17,2.304838e-17,1.938278e-17,...,5.743371e-17,5.699121e-17,5.724090e-17,5.781109e-17,5.744140e-17,5.878628e-17,5.845928e-17,5.990392e-17,-0.270653,0.652693
3522,2.100507e-14,2.111412e-14,1.965781e-14,1.800614e-14,1.759250e-14,1.766060e-14,1.680955e-14,1.511526e-14,1.526317e-14,1.799432e-14,...,3.149228e-15,3.032727e-15,2.977191e-15,2.971463e-15,2.959201e-15,3.082453e-15,3.164294e-15,3.382208e-15,-0.988735,-0.673666
3523,4.950771e-17,4.735013e-17,4.774364e-17,4.771116e-17,4.583606e-17,4.318559e-17,4.302986e-17,4.524361e-17,4.818025e-17,5.006902e-17,...,4.230342e-17,4.132579e-17,4.100531e-17,4.113372e-17,4.087202e-17,4.214290e-17,4.251873e-17,4.444370e-17,-0.240389,-0.049454
3524,4.047548e-14,4.102307e-14,3.667214e-14,3.226985e-14,3.214448e-14,3.417008e-14,3.371515e-14,3.013685e-14,2.969429e-14,3.469891e-14,...,3.287420e-15,3.079038e-15,2.927607e-15,2.829557e-15,2.743425e-15,2.813259e-15,2.885701e-15,3.129216e-15,-1.497687,-0.809178
