## ML models

This notebook trains the following ML models:

1. Logistic Regressor
2. Decision Tree
3. Support-Vector Machine
4. K-Nearest Neighbours
5. Random Forests

as well as two boosting methods:

1. Extreme Gradient Boosting Machine
2. Light Gradient Boosting Machine

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import h5py
%matplotlib inline

from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import (RandomForestClassifier, VotingClassifier)
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from lightgbm import (LGBMClassifier as lgb, plot_importance)
from xgboost import XGBClassifier as xgb

from sklearn.utils import shuffle
from sklearn.metrics import (
    accuracy_score, classification_report, mean_squared_error, roc_auc_score, confusion_matrix)
from sklearn.model_selection import (cross_validate, RepeatedStratifiedKFold, cross_val_score, train_test_split)
import optuna

In [17]:
kfold = RepeatedStratifiedKFold(n_splits=10, n_repeats=2, random_state=42)

In [28]:
def get_data(name:str='', SHUFFLE_FLAG:bool=False, NORM_FLAG:bool=True, random_state:int=42):
    '''
    Function to select data

    Arguments
    ---------
    name: str, (required)
        name of dataset to be returned
    SHUFFLE_FLAG: bool, (optional)
        Flag for if the data should be shuffled
    NORM_FLAG: bool, (optional)
        If the data should be normalized
    random_state: int, (optional)
        random_state
    
    Returns
    -------
    X: numpy.ndarray 
        training set 
    y: numpy.ndarray 
        test set
    '''
    
    if name is None:
        raise ValueError("Required argument 'name' is missing.")
    
    if name == "gaia":
        dir = '../data/Gaia DR3/gaia_lm_m_stars.parquet'
        data = pd.read_parquet(dir)
        if SHUFFLE_FLAG:
            df = shuffle(data)
        else:
            df = data
        X = np.vstack(df['flux'])
        y = np.vstack(df['Cat'])
        
        y = np.where(y == 'M', 1, y)
        y = np.where(y == 'LM', 0, y)

        y = y.astype(int)

        if NORM_FLAG:
            norm = np.linalg.norm(X,keepdims=True)
            X = X/norm
            

    elif name == 'apogee':
        dir = '../data/APOGEE'
        train_dir = dir + '/training_data.h5'
        tets_dir = dir +'/test_data.h5'

        with h5py.File(train_dir, 'r') as f:
            X = f['spectrum'][:]
            y = np.hstack((f['TEFF'],
                        f['LOGG'],
                        f['FE_H']))
        
        #TODO: add shuffle

        if NORM_FLAG:
            norm_dir = dir + '/mean_and_std.npy'
            norm_data = np.load(norm_dir)
            
            mean = norm_data[0]
            std = norm_data[1]
            y = (y-mean)/std

    return X, y

In [29]:
X, y = get_data('gaia', SHUFFLE_FLAG=True)
#X, y = get_data('apogee')

num_samples = X.shape[0]
spectrum_width = X.shape[1]

num_samples_m = np.count_nonzero(y)
num_samples_lm = len(y) - num_samples_m
num_classes = len(np.unique(y))

print("Total number of spectra:", num_samples)
print("Number of bins in each spectra:", spectrum_width)
print("In the dataset, we have", num_samples_lm, "spectra for low mass stars and", num_samples_m, "spectra for high mass stars.")

Total number of spectra: 17627
Number of bins in each spectra: 343
In the dataset, we have 11026 spectra for low mass stars and 6601 spectra for high mass stars.


In [7]:
#naive splitting methods

split = 0.8

train_size = int(split * num_samples)

x_train, x_test = np.split(X, [train_size])
y_train, y_test = np.split(y, [train_size])

#x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.2)


In [8]:
print("The dataset is divided into", len(x_train), "training samples and", len(x_test),"testing samples.")

The dataset is divided into 14101 training samples and 3526 testing samples.


## Logistic Regression

In [35]:
# training with no hyperparamater tuning

accuracy_scores = []
auc_roc_scores = []

for train_idx, test_idx in kfold.split(X, y):
    
    x_train, x_test = X[train_idx], X[test_idx]
    y_train, y_test = y[train_idx], y[test_idx]

    model = LogisticRegression()
    model.fit(x_train, y_train.squeeze(1))

    y_pred = model.predict(x_test)
    y_probs = model.predict_proba(x_test)[:,1]
    accuracy = accuracy_score(y_test.squeeze(1), y_pred)
    ra_score = roc_auc_score(y_test.squeeze(1), y_probs)

    auc_roc_scores.append(ra_score)
    accuracy_scores.append(accuracy)

print(np.mean(auc_roc_scores))
print(np.mean(accuracy_scores))

0.945127931496158
0.625517656095179


## Decision Trees

In [22]:
model = DecisionTreeClassifier()
model.fit(x_train, y_train.squeeze(1))

In [23]:
y_pred = model.predict(x_test)
accuracy = accuracy_score(y_test, y_pred)

report = classification_report(y_test, y_pred)
print(report)

              precision    recall  f1-score   support

           0       0.93      0.95      0.94      2226
           1       0.91      0.88      0.90      1300

    accuracy                           0.93      3526
   macro avg       0.92      0.92      0.92      3526
weighted avg       0.93      0.93      0.93      3526



## Random Forest

In [None]:
def objective(trial):

    n_estimators = trial.suggest_int('n_estimators', 100, 1000)
    max_depth = trial.suggest_int('max_depth', 1 , 50)
    min_samples_split = trial.suggest_int('min_samples_split', 1, 32)
    min_samples_leaf = trial.suggest_int('min_samples_leaf', 1, 32)

    model = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth, min_samples_split=min_samples_split, min_samples_leaf=min_samples_leaf)

    rkf = RepeatedStratifiedKFold(n_splits = 5)
    score = cross_val_score(model, x_train, y_train.squeeze(1), cv=rkf, scoring='accuracy')

    return score


study = optuna.create_study(direction='maximize', study_name='xgb_model_training')
study.optimize(objective, n_trials=100)

In [66]:
model = RandomForestClassifier()
model.fit(x_train, y_train.squeeze(1))

In [67]:
y_pred = model.predict(x_test)
accuracy = accuracy_score(y_test, y_pred)

report = classification_report(y_test, y_pred)
print(report)

              precision    recall  f1-score   support

           0       0.97      0.96      0.97      2212
           1       0.94      0.95      0.94      1314

    accuracy                           0.96      3526
   macro avg       0.95      0.96      0.95      3526
weighted avg       0.96      0.96      0.96      3526



## Support-Vector Machines

In [64]:
model = SVC()
model.fit(x_train, y_train.squeeze(1))

In [65]:
y_pred = model.predict(x_test)
accuracy = accuracy_score(y_test, y_pred)

report = classification_report(y_test, y_pred)
print(report)

              precision    recall  f1-score   support

           0       0.80      0.96      0.87      2212
           1       0.90      0.59      0.71      1314

    accuracy                           0.82      3526
   macro avg       0.85      0.78      0.79      3526
weighted avg       0.84      0.82      0.81      3526



## K-Nearest Neighbours

In [61]:
model = KNeighborsClassifier()
model.fit(x_train, y_train.squeeze(1))

In [62]:
y_pred = model.predict(x_test)
accuracy = accuracy_score(y_test, y_pred)

report = classification_report(y_test, y_pred)
print(report)

              precision    recall  f1-score   support

           0       0.98      0.95      0.97      2212
           1       0.93      0.97      0.95      1314

    accuracy                           0.96      3526
   macro avg       0.95      0.96      0.96      3526
weighted avg       0.96      0.96      0.96      3526



## Light Gradient Boosting Machine

In [103]:
#-------------------initial naive implementation, needs a lot more tuning-------------------------------

model = lgb(n_estimators=1200, random_state=42, learning_rate = 0.01,reg_lambda=50, min_child_samples=2400, num_leaves=95, colsample_bytree=0.19, max_bins=65, device='gpu')
model.fit(x_train, y_train.squeeze(1))

[LightGBM] [Info] Number of positive: 5287, number of negative: 8814
[LightGBM] [Info] This is the GPU trainer!!
[LightGBM] [Info] Total Bins 22295
[LightGBM] [Info] Number of data points in the train set: 14101, number of used features: 343
[LightGBM] [Info] Using GPU Device: NVIDIA GeForce RTX 3050 Laptop GPU, Vendor: NVIDIA Corporation
[LightGBM] [Info] Compiling OpenCL Kernel with 256 bins...
[LightGBM] [Info] GPU programs have been built
[LightGBM] [Info] Size of histogram bin entry: 8
[LightGBM] [Info] 343 dense feature groups (4.63 MB) transferred to GPU in 0.010345 secs. 0 sparse feature groups
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.374938 -> initscore=-0.511090
[LightGBM] [Info] Start training from score -0.511090


In [104]:
y_pred = model.predict(x_test)
accuracy = accuracy_score(y_test, y_pred)
report = classification_report(y_test, y_pred)

print(report)

              precision    recall  f1-score   support

           0       0.98      0.91      0.94      2212
           1       0.86      0.98      0.92      1314

    accuracy                           0.93      3526
   macro avg       0.92      0.94      0.93      3526
weighted avg       0.94      0.93      0.93      3526



In [100]:
#check overfitting
y_pred_train = model.predict(x_train)
acc_train = accuracy_score(y_train, y_pred_train)
print(accuracy, acc_train)

0.9293817356778219 0.9282320402808312


In [None]:
def objective(trial):
    return

## Extreme Gradient Boosting Machine

In [27]:
#-------------------initial naive implementation, needs a lot more tuning-------------------------------

accuracy_scores = []
auc_roc_scores = []
for train_idx, test_idx in kfold.split(X, y):
    
    x_train, x_test = X.iloc[train_idx], X.iloc[test_idx]
    y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]

    model = xgb(n_estimators=1000, learning_rate=0.05, early_stopping_rounds=5, device='gpu')
    model.fit(x_train, y_train, eval_set=[(x_test, y_test)])

    y_preds = model.predict(x_test)
    y_probs = model.predict_proba(x_test)[:, 1]
    accuracy = accuracy_score(y_test, y_preds)
    ra_score = roc_auc_score(y_test, y_probs)

    accuracy_scores.append(accuracy)
    auc_roc_scores.append(ra_score)

print("Mean accuracy:", np.mean(accuracy_scores))
print("Mean ROC-AUC scores:", np.mean(auc_roc_scores))

Parameters: { "verbose" } are not used.



[0]	validation_0-logloss:0.62383
[1]	validation_0-logloss:0.58983
[2]	validation_0-logloss:0.55926
[3]	validation_0-logloss:0.53148
[4]	validation_0-logloss:0.50602
[5]	validation_0-logloss:0.48258
[6]	validation_0-logloss:0.46087
[7]	validation_0-logloss:0.44120
[8]	validation_0-logloss:0.42302
[9]	validation_0-logloss:0.40613
[10]	validation_0-logloss:0.39053
[11]	validation_0-logloss:0.37612
[12]	validation_0-logloss:0.36189
[13]	validation_0-logloss:0.34937
[14]	validation_0-logloss:0.33709
[15]	validation_0-logloss:0.32616
[16]	validation_0-logloss:0.31531
[17]	validation_0-logloss:0.30585
[18]	validation_0-logloss:0.29647
[19]	validation_0-logloss:0.28787
[20]	validation_0-logloss:0.28003
[21]	validation_0-logloss:0.27235
[22]	validation_0-logloss:0.26553
[23]	validation_0-logloss:0.25866
[24]	validation_0-logloss:0.25250
[25]	validation_0-logloss:0.24642
[26]	validation_0-logloss:0.24085
[27]	validation_0-logloss:0.23558
[28]	validation_0-logloss:0.23067
[29]	validation_0-loglos

Parameters: { "verbose" } are not used.



[7]	validation_0-logloss:0.44105
[8]	validation_0-logloss:0.42283
[9]	validation_0-logloss:0.40588
[10]	validation_0-logloss:0.39028
[11]	validation_0-logloss:0.37568
[12]	validation_0-logloss:0.36214
[13]	validation_0-logloss:0.34937
[14]	validation_0-logloss:0.33743
[15]	validation_0-logloss:0.32635
[16]	validation_0-logloss:0.31560
[17]	validation_0-logloss:0.30571
[18]	validation_0-logloss:0.29609
[19]	validation_0-logloss:0.28731
[20]	validation_0-logloss:0.27911
[21]	validation_0-logloss:0.27150
[22]	validation_0-logloss:0.26407
[23]	validation_0-logloss:0.25725
[24]	validation_0-logloss:0.25061
[25]	validation_0-logloss:0.24443
[26]	validation_0-logloss:0.23850
[27]	validation_0-logloss:0.23276
[28]	validation_0-logloss:0.22739
[29]	validation_0-logloss:0.22246
[30]	validation_0-logloss:0.21784
[31]	validation_0-logloss:0.21350
[32]	validation_0-logloss:0.20959
[33]	validation_0-logloss:0.20576
[34]	validation_0-logloss:0.20223
[35]	validation_0-logloss:0.19889
[36]	validation_0

Parameters: { "verbose" } are not used.



[7]	validation_0-logloss:0.44220
[8]	validation_0-logloss:0.42355
[9]	validation_0-logloss:0.40622
[10]	validation_0-logloss:0.39111
[11]	validation_0-logloss:0.37702
[12]	validation_0-logloss:0.36379
[13]	validation_0-logloss:0.35155
[14]	validation_0-logloss:0.34025
[15]	validation_0-logloss:0.32944
[16]	validation_0-logloss:0.31928
[17]	validation_0-logloss:0.30938
[18]	validation_0-logloss:0.29974
[19]	validation_0-logloss:0.29164
[20]	validation_0-logloss:0.28379
[21]	validation_0-logloss:0.27578
[22]	validation_0-logloss:0.26877
[23]	validation_0-logloss:0.26175
[24]	validation_0-logloss:0.25533
[25]	validation_0-logloss:0.24944
[26]	validation_0-logloss:0.24411
[27]	validation_0-logloss:0.23870
[28]	validation_0-logloss:0.23357
[29]	validation_0-logloss:0.22909
[30]	validation_0-logloss:0.22498
[31]	validation_0-logloss:0.22101
[32]	validation_0-logloss:0.21691
[33]	validation_0-logloss:0.21300
[34]	validation_0-logloss:0.20990
[35]	validation_0-logloss:0.20635
[36]	validation_0

Parameters: { "verbose" } are not used.



[7]	validation_0-logloss:0.44154
[8]	validation_0-logloss:0.42296
[9]	validation_0-logloss:0.40625
[10]	validation_0-logloss:0.39078
[11]	validation_0-logloss:0.37649
[12]	validation_0-logloss:0.36250
[13]	validation_0-logloss:0.34955
[14]	validation_0-logloss:0.33758
[15]	validation_0-logloss:0.32669
[16]	validation_0-logloss:0.31597
[17]	validation_0-logloss:0.30613
[18]	validation_0-logloss:0.29742
[19]	validation_0-logloss:0.28873
[20]	validation_0-logloss:0.28067
[21]	validation_0-logloss:0.27313
[22]	validation_0-logloss:0.26610
[23]	validation_0-logloss:0.25920
[24]	validation_0-logloss:0.25279
[25]	validation_0-logloss:0.24667
[26]	validation_0-logloss:0.24128
[27]	validation_0-logloss:0.23587
[28]	validation_0-logloss:0.23066
[29]	validation_0-logloss:0.22585
[30]	validation_0-logloss:0.22115
[31]	validation_0-logloss:0.21687
[32]	validation_0-logloss:0.21276
[33]	validation_0-logloss:0.20901
[34]	validation_0-logloss:0.20538
[35]	validation_0-logloss:0.20248
[36]	validation_0

Parameters: { "verbose" } are not used.



[7]	validation_0-logloss:0.44305
[8]	validation_0-logloss:0.42475
[9]	validation_0-logloss:0.40801
[10]	validation_0-logloss:0.39195
[11]	validation_0-logloss:0.37683
[12]	validation_0-logloss:0.36333
[13]	validation_0-logloss:0.35082
[14]	validation_0-logloss:0.33941
[15]	validation_0-logloss:0.32834
[16]	validation_0-logloss:0.31798
[17]	validation_0-logloss:0.30809
[18]	validation_0-logloss:0.29919
[19]	validation_0-logloss:0.29039
[20]	validation_0-logloss:0.28216
[21]	validation_0-logloss:0.27417
[22]	validation_0-logloss:0.26710
[23]	validation_0-logloss:0.26026
[24]	validation_0-logloss:0.25383
[25]	validation_0-logloss:0.24781
[26]	validation_0-logloss:0.24207
[27]	validation_0-logloss:0.23637
[28]	validation_0-logloss:0.23113
[29]	validation_0-logloss:0.22629
[30]	validation_0-logloss:0.22189
[31]	validation_0-logloss:0.21745
[32]	validation_0-logloss:0.21332
[33]	validation_0-logloss:0.20960
[34]	validation_0-logloss:0.20593
[35]	validation_0-logloss:0.20233
[36]	validation_0

Parameters: { "verbose" } are not used.



[6]	validation_0-logloss:0.45810
[7]	validation_0-logloss:0.43808
[8]	validation_0-logloss:0.41953
[9]	validation_0-logloss:0.40248
[10]	validation_0-logloss:0.38634
[11]	validation_0-logloss:0.37161
[12]	validation_0-logloss:0.35753
[13]	validation_0-logloss:0.34453
[14]	validation_0-logloss:0.33232
[15]	validation_0-logloss:0.32092
[16]	validation_0-logloss:0.30992
[17]	validation_0-logloss:0.29979
[18]	validation_0-logloss:0.29035
[19]	validation_0-logloss:0.28163
[20]	validation_0-logloss:0.27325
[21]	validation_0-logloss:0.26570
[22]	validation_0-logloss:0.25848
[23]	validation_0-logloss:0.25177
[24]	validation_0-logloss:0.24491
[25]	validation_0-logloss:0.23865
[26]	validation_0-logloss:0.23281
[27]	validation_0-logloss:0.22689
[28]	validation_0-logloss:0.22156
[29]	validation_0-logloss:0.21677
[30]	validation_0-logloss:0.21212
[31]	validation_0-logloss:0.20761
[32]	validation_0-logloss:0.20340
[33]	validation_0-logloss:0.19946
[34]	validation_0-logloss:0.19568
[35]	validation_0-

Parameters: { "verbose" } are not used.



[7]	validation_0-logloss:0.44766
[8]	validation_0-logloss:0.42964
[9]	validation_0-logloss:0.41302
[10]	validation_0-logloss:0.39764
[11]	validation_0-logloss:0.38358
[12]	validation_0-logloss:0.37053
[13]	validation_0-logloss:0.35802
[14]	validation_0-logloss:0.34667
[15]	validation_0-logloss:0.33570
[16]	validation_0-logloss:0.32552
[17]	validation_0-logloss:0.31622
[18]	validation_0-logloss:0.30718
[19]	validation_0-logloss:0.29860
[20]	validation_0-logloss:0.29074
[21]	validation_0-logloss:0.28343
[22]	validation_0-logloss:0.27653
[23]	validation_0-logloss:0.26942
[24]	validation_0-logloss:0.26331
[25]	validation_0-logloss:0.25763
[26]	validation_0-logloss:0.25148
[27]	validation_0-logloss:0.24664
[28]	validation_0-logloss:0.24187
[29]	validation_0-logloss:0.23694
[30]	validation_0-logloss:0.23243
[31]	validation_0-logloss:0.22803
[32]	validation_0-logloss:0.22374
[33]	validation_0-logloss:0.22000
[34]	validation_0-logloss:0.21658
[35]	validation_0-logloss:0.21311
[36]	validation_0

Parameters: { "verbose" } are not used.



[7]	validation_0-logloss:0.44153
[8]	validation_0-logloss:0.42355
[9]	validation_0-logloss:0.40663
[10]	validation_0-logloss:0.39113
[11]	validation_0-logloss:0.37637
[12]	validation_0-logloss:0.36232
[13]	validation_0-logloss:0.34905
[14]	validation_0-logloss:0.33711
[15]	validation_0-logloss:0.32611
[16]	validation_0-logloss:0.31529
[17]	validation_0-logloss:0.30549
[18]	validation_0-logloss:0.29604
[19]	validation_0-logloss:0.28746
[20]	validation_0-logloss:0.27928
[21]	validation_0-logloss:0.27192
[22]	validation_0-logloss:0.26475
[23]	validation_0-logloss:0.25831
[24]	validation_0-logloss:0.25205
[25]	validation_0-logloss:0.24611
[26]	validation_0-logloss:0.24070
[27]	validation_0-logloss:0.23551
[28]	validation_0-logloss:0.23048
[29]	validation_0-logloss:0.22599
[30]	validation_0-logloss:0.22179
[31]	validation_0-logloss:0.21783
[32]	validation_0-logloss:0.21408
[33]	validation_0-logloss:0.21022
[34]	validation_0-logloss:0.20696
[35]	validation_0-logloss:0.20379
[36]	validation_0

Parameters: { "verbose" } are not used.



[7]	validation_0-logloss:0.44178
[8]	validation_0-logloss:0.42334
[9]	validation_0-logloss:0.40639
[10]	validation_0-logloss:0.39101
[11]	validation_0-logloss:0.37622
[12]	validation_0-logloss:0.36310
[13]	validation_0-logloss:0.35048
[14]	validation_0-logloss:0.33905
[15]	validation_0-logloss:0.32804
[16]	validation_0-logloss:0.31751
[17]	validation_0-logloss:0.30801
[18]	validation_0-logloss:0.29886
[19]	validation_0-logloss:0.29046
[20]	validation_0-logloss:0.28247
[21]	validation_0-logloss:0.27493
[22]	validation_0-logloss:0.26768
[23]	validation_0-logloss:0.26045
[24]	validation_0-logloss:0.25376
[25]	validation_0-logloss:0.24740
[26]	validation_0-logloss:0.24166
[27]	validation_0-logloss:0.23637
[28]	validation_0-logloss:0.23089
[29]	validation_0-logloss:0.22597
[30]	validation_0-logloss:0.22170
[31]	validation_0-logloss:0.21751
[32]	validation_0-logloss:0.21373
[33]	validation_0-logloss:0.21000
[34]	validation_0-logloss:0.20624
[35]	validation_0-logloss:0.20289
[36]	validation_0

Parameters: { "verbose" } are not used.



[6]	validation_0-logloss:0.46042
[7]	validation_0-logloss:0.44041
[8]	validation_0-logloss:0.42180
[9]	validation_0-logloss:0.40449
[10]	validation_0-logloss:0.38859
[11]	validation_0-logloss:0.37394
[12]	validation_0-logloss:0.35991
[13]	validation_0-logloss:0.34720
[14]	validation_0-logloss:0.33516
[15]	validation_0-logloss:0.32386
[16]	validation_0-logloss:0.31307
[17]	validation_0-logloss:0.30292
[18]	validation_0-logloss:0.29347
[19]	validation_0-logloss:0.28442
[20]	validation_0-logloss:0.27640
[21]	validation_0-logloss:0.26893
[22]	validation_0-logloss:0.26157
[23]	validation_0-logloss:0.25465
[24]	validation_0-logloss:0.24794
[25]	validation_0-logloss:0.24152
[26]	validation_0-logloss:0.23569
[27]	validation_0-logloss:0.23037
[28]	validation_0-logloss:0.22495
[29]	validation_0-logloss:0.22013
[30]	validation_0-logloss:0.21517
[31]	validation_0-logloss:0.21093
[32]	validation_0-logloss:0.20668
[33]	validation_0-logloss:0.20291
[34]	validation_0-logloss:0.19900
[35]	validation_0-

Parameters: { "verbose" } are not used.



[0]	validation_0-logloss:0.62380
[1]	validation_0-logloss:0.58948
[2]	validation_0-logloss:0.55870
[3]	validation_0-logloss:0.53102
[4]	validation_0-logloss:0.50530
[5]	validation_0-logloss:0.48198
[6]	validation_0-logloss:0.46035
[7]	validation_0-logloss:0.44107
[8]	validation_0-logloss:0.42313
[9]	validation_0-logloss:0.40613
[10]	validation_0-logloss:0.39073
[11]	validation_0-logloss:0.37589
[12]	validation_0-logloss:0.36252
[13]	validation_0-logloss:0.35017
[14]	validation_0-logloss:0.33845
[15]	validation_0-logloss:0.32722
[16]	validation_0-logloss:0.31684
[17]	validation_0-logloss:0.30718
[18]	validation_0-logloss:0.29786
[19]	validation_0-logloss:0.28954
[20]	validation_0-logloss:0.28120
[21]	validation_0-logloss:0.27358
[22]	validation_0-logloss:0.26611
[23]	validation_0-logloss:0.25937
[24]	validation_0-logloss:0.25287
[25]	validation_0-logloss:0.24677
[26]	validation_0-logloss:0.24075
[27]	validation_0-logloss:0.23498
[28]	validation_0-logloss:0.22994
[29]	validation_0-loglos

Parameters: { "verbose" } are not used.



[7]	validation_0-logloss:0.43634
[8]	validation_0-logloss:0.41726
[9]	validation_0-logloss:0.40005
[10]	validation_0-logloss:0.38426
[11]	validation_0-logloss:0.36944
[12]	validation_0-logloss:0.35585
[13]	validation_0-logloss:0.34259
[14]	validation_0-logloss:0.33014
[15]	validation_0-logloss:0.31916
[16]	validation_0-logloss:0.30839
[17]	validation_0-logloss:0.29829
[18]	validation_0-logloss:0.28885
[19]	validation_0-logloss:0.27975
[20]	validation_0-logloss:0.27151
[21]	validation_0-logloss:0.26373
[22]	validation_0-logloss:0.25668
[23]	validation_0-logloss:0.24979
[24]	validation_0-logloss:0.24342
[25]	validation_0-logloss:0.23700
[26]	validation_0-logloss:0.23116
[27]	validation_0-logloss:0.22542
[28]	validation_0-logloss:0.22024
[29]	validation_0-logloss:0.21550
[30]	validation_0-logloss:0.21047
[31]	validation_0-logloss:0.20619
[32]	validation_0-logloss:0.20190
[33]	validation_0-logloss:0.19772
[34]	validation_0-logloss:0.19384
[35]	validation_0-logloss:0.19012
[36]	validation_0

Parameters: { "verbose" } are not used.



[7]	validation_0-logloss:0.44249
[8]	validation_0-logloss:0.42456
[9]	validation_0-logloss:0.40790
[10]	validation_0-logloss:0.39279
[11]	validation_0-logloss:0.37834
[12]	validation_0-logloss:0.36528
[13]	validation_0-logloss:0.35281
[14]	validation_0-logloss:0.34109
[15]	validation_0-logloss:0.32984
[16]	validation_0-logloss:0.31987
[17]	validation_0-logloss:0.30999
[18]	validation_0-logloss:0.30102
[19]	validation_0-logloss:0.29265
[20]	validation_0-logloss:0.28432
[21]	validation_0-logloss:0.27654
[22]	validation_0-logloss:0.26930
[23]	validation_0-logloss:0.26242
[24]	validation_0-logloss:0.25611
[25]	validation_0-logloss:0.25007
[26]	validation_0-logloss:0.24433
[27]	validation_0-logloss:0.23890
[28]	validation_0-logloss:0.23370
[29]	validation_0-logloss:0.22889
[30]	validation_0-logloss:0.22456
[31]	validation_0-logloss:0.22079
[32]	validation_0-logloss:0.21711
[33]	validation_0-logloss:0.21336
[34]	validation_0-logloss:0.20974
[35]	validation_0-logloss:0.20631
[36]	validation_0

Parameters: { "verbose" } are not used.



[7]	validation_0-logloss:0.44194
[8]	validation_0-logloss:0.42314
[9]	validation_0-logloss:0.40588
[10]	validation_0-logloss:0.38977
[11]	validation_0-logloss:0.37529
[12]	validation_0-logloss:0.36186
[13]	validation_0-logloss:0.34912
[14]	validation_0-logloss:0.33729
[15]	validation_0-logloss:0.32597
[16]	validation_0-logloss:0.31530
[17]	validation_0-logloss:0.30564
[18]	validation_0-logloss:0.29646
[19]	validation_0-logloss:0.28790
[20]	validation_0-logloss:0.27975
[21]	validation_0-logloss:0.27234
[22]	validation_0-logloss:0.26533
[23]	validation_0-logloss:0.25840
[24]	validation_0-logloss:0.25166
[25]	validation_0-logloss:0.24590
[26]	validation_0-logloss:0.23988
[27]	validation_0-logloss:0.23445
[28]	validation_0-logloss:0.22918
[29]	validation_0-logloss:0.22426
[30]	validation_0-logloss:0.21975
[31]	validation_0-logloss:0.21546
[32]	validation_0-logloss:0.21131
[33]	validation_0-logloss:0.20757
[34]	validation_0-logloss:0.20378
[35]	validation_0-logloss:0.20045
[36]	validation_0

Parameters: { "verbose" } are not used.



[7]	validation_0-logloss:0.44166
[8]	validation_0-logloss:0.42368
[9]	validation_0-logloss:0.40687
[10]	validation_0-logloss:0.39113
[11]	validation_0-logloss:0.37657
[12]	validation_0-logloss:0.36311
[13]	validation_0-logloss:0.35070
[14]	validation_0-logloss:0.33905
[15]	validation_0-logloss:0.32836
[16]	validation_0-logloss:0.31789
[17]	validation_0-logloss:0.30832
[18]	validation_0-logloss:0.29966
[19]	validation_0-logloss:0.29104
[20]	validation_0-logloss:0.28268
[21]	validation_0-logloss:0.27424
[22]	validation_0-logloss:0.26651
[23]	validation_0-logloss:0.25917
[24]	validation_0-logloss:0.25178
[25]	validation_0-logloss:0.24520
[26]	validation_0-logloss:0.23917
[27]	validation_0-logloss:0.23333
[28]	validation_0-logloss:0.22774
[29]	validation_0-logloss:0.22246
[30]	validation_0-logloss:0.21766
[31]	validation_0-logloss:0.21326
[32]	validation_0-logloss:0.20902
[33]	validation_0-logloss:0.20480
[34]	validation_0-logloss:0.20080
[35]	validation_0-logloss:0.19714
[36]	validation_0

Parameters: { "verbose" } are not used.



[7]	validation_0-logloss:0.44067
[8]	validation_0-logloss:0.42178
[9]	validation_0-logloss:0.40490
[10]	validation_0-logloss:0.38906
[11]	validation_0-logloss:0.37386
[12]	validation_0-logloss:0.36007
[13]	validation_0-logloss:0.34743
[14]	validation_0-logloss:0.33580
[15]	validation_0-logloss:0.32484
[16]	validation_0-logloss:0.31435
[17]	validation_0-logloss:0.30475
[18]	validation_0-logloss:0.29534
[19]	validation_0-logloss:0.28683
[20]	validation_0-logloss:0.27901
[21]	validation_0-logloss:0.27141
[22]	validation_0-logloss:0.26415
[23]	validation_0-logloss:0.25721
[24]	validation_0-logloss:0.25093
[25]	validation_0-logloss:0.24536
[26]	validation_0-logloss:0.23959
[27]	validation_0-logloss:0.23426
[28]	validation_0-logloss:0.22975
[29]	validation_0-logloss:0.22528
[30]	validation_0-logloss:0.22077
[31]	validation_0-logloss:0.21634
[32]	validation_0-logloss:0.21268
[33]	validation_0-logloss:0.20926
[34]	validation_0-logloss:0.20559
[35]	validation_0-logloss:0.20222
[36]	validation_0

Parameters: { "verbose" } are not used.



[7]	validation_0-logloss:0.44655
[8]	validation_0-logloss:0.42904
[9]	validation_0-logloss:0.41298
[10]	validation_0-logloss:0.39761
[11]	validation_0-logloss:0.38336
[12]	validation_0-logloss:0.37018
[13]	validation_0-logloss:0.35802
[14]	validation_0-logloss:0.34657
[15]	validation_0-logloss:0.33558
[16]	validation_0-logloss:0.32589
[17]	validation_0-logloss:0.31674
[18]	validation_0-logloss:0.30822
[19]	validation_0-logloss:0.29968
[20]	validation_0-logloss:0.29196
[21]	validation_0-logloss:0.28441
[22]	validation_0-logloss:0.27722
[23]	validation_0-logloss:0.27048
[24]	validation_0-logloss:0.26429
[25]	validation_0-logloss:0.25859
[26]	validation_0-logloss:0.25324
[27]	validation_0-logloss:0.24859
[28]	validation_0-logloss:0.24365
[29]	validation_0-logloss:0.23944
[30]	validation_0-logloss:0.23525
[31]	validation_0-logloss:0.23141
[32]	validation_0-logloss:0.22746
[33]	validation_0-logloss:0.22393
[34]	validation_0-logloss:0.22029
[35]	validation_0-logloss:0.21690
[36]	validation_0

Parameters: { "verbose" } are not used.



[7]	validation_0-logloss:0.44051
[8]	validation_0-logloss:0.42215
[9]	validation_0-logloss:0.40517
[10]	validation_0-logloss:0.38948
[11]	validation_0-logloss:0.37488
[12]	validation_0-logloss:0.36139
[13]	validation_0-logloss:0.34855
[14]	validation_0-logloss:0.33699
[15]	validation_0-logloss:0.32610
[16]	validation_0-logloss:0.31555
[17]	validation_0-logloss:0.30566
[18]	validation_0-logloss:0.29610
[19]	validation_0-logloss:0.28699
[20]	validation_0-logloss:0.27862
[21]	validation_0-logloss:0.27070
[22]	validation_0-logloss:0.26328
[23]	validation_0-logloss:0.25637
[24]	validation_0-logloss:0.25021
[25]	validation_0-logloss:0.24403
[26]	validation_0-logloss:0.23877
[27]	validation_0-logloss:0.23393
[28]	validation_0-logloss:0.22852
[29]	validation_0-logloss:0.22355
[30]	validation_0-logloss:0.21902
[31]	validation_0-logloss:0.21451
[32]	validation_0-logloss:0.21015
[33]	validation_0-logloss:0.20618
[34]	validation_0-logloss:0.20242
[35]	validation_0-logloss:0.19892
[36]	validation_0

Parameters: { "verbose" } are not used.



[7]	validation_0-logloss:0.44349
[8]	validation_0-logloss:0.42549
[9]	validation_0-logloss:0.40920
[10]	validation_0-logloss:0.39336
[11]	validation_0-logloss:0.37897
[12]	validation_0-logloss:0.36547
[13]	validation_0-logloss:0.35336
[14]	validation_0-logloss:0.34130
[15]	validation_0-logloss:0.33062
[16]	validation_0-logloss:0.32070
[17]	validation_0-logloss:0.31123
[18]	validation_0-logloss:0.30194
[19]	validation_0-logloss:0.29335
[20]	validation_0-logloss:0.28518
[21]	validation_0-logloss:0.27784
[22]	validation_0-logloss:0.27029
[23]	validation_0-logloss:0.26322
[24]	validation_0-logloss:0.25671
[25]	validation_0-logloss:0.25079
[26]	validation_0-logloss:0.24506
[27]	validation_0-logloss:0.23947
[28]	validation_0-logloss:0.23438
[29]	validation_0-logloss:0.22958
[30]	validation_0-logloss:0.22496
[31]	validation_0-logloss:0.22078
[32]	validation_0-logloss:0.21642
[33]	validation_0-logloss:0.21218
[34]	validation_0-logloss:0.20853
[35]	validation_0-logloss:0.20495
[36]	validation_0

Parameters: { "verbose" } are not used.



[7]	validation_0-logloss:0.44046
[8]	validation_0-logloss:0.42206
[9]	validation_0-logloss:0.40497
[10]	validation_0-logloss:0.38916
[11]	validation_0-logloss:0.37499
[12]	validation_0-logloss:0.36134
[13]	validation_0-logloss:0.34893
[14]	validation_0-logloss:0.33681
[15]	validation_0-logloss:0.32537
[16]	validation_0-logloss:0.31480
[17]	validation_0-logloss:0.30473
[18]	validation_0-logloss:0.29518
[19]	validation_0-logloss:0.28646
[20]	validation_0-logloss:0.27797
[21]	validation_0-logloss:0.27034
[22]	validation_0-logloss:0.26303
[23]	validation_0-logloss:0.25587
[24]	validation_0-logloss:0.24943
[25]	validation_0-logloss:0.24342
[26]	validation_0-logloss:0.23768
[27]	validation_0-logloss:0.23235
[28]	validation_0-logloss:0.22733
[29]	validation_0-logloss:0.22253
[30]	validation_0-logloss:0.21815
[31]	validation_0-logloss:0.21380
[32]	validation_0-logloss:0.20993
[33]	validation_0-logloss:0.20600
[34]	validation_0-logloss:0.20219
[35]	validation_0-logloss:0.19864
[36]	validation_0

In [88]:
def objective(trial):
    params = {
        'booster': trial.suggest_categorical('booster', ['gbtree','gblinear']),
        'device': 'cuda',
        'grow_policy': trial.suggest_categorical('grow_policy', ['depthwise','lossguide']),
        'learning_rate': trial.suggest_float('learning_rate', 0.005, 0.1),
        'gamma' : trial.suggest_float('gamma', 1e-5, 0.5, log=True),
        'subsample': trial.suggest_float('subsample', 0.3, 1.0),
        'colsample_bytree': trial.suggest_float('colsample_bytree', 0.3, 1.0),
        'max_depth': trial.suggest_int('max_depth', 3, 15),
        'min_child_weight': trial.suggest_int('min_child_weight', 1, 7),
        'lambda': trial.suggest_float('lambda', 1e-3, 10.0, log=True),
        'alpha': trial.suggest_float('alpha', 1e-3, 10.0, log=True),
    }

    params['n_estimators'] = 3000
    params['early_stopping_rounds'] = 50
    params['booster'] = 'gbtree'
    params["verbosity"] = 0
    params['tree_method'] = "hist"
    
    auc_scores = []

    for train_idx, valid_idx in cv.split(X, y):

        X_train_fold, X_valid_fold = pd.DataFrame(X).iloc[train_idx], pd.DataFrame(X).iloc[valid_idx]
        y_train_fold, y_valid_fold = pd.Series(y.squeeze(1)).iloc[train_idx], pd.Series(y.squeeze(1)).iloc[valid_idx]
                
        # Create and fit the model
        model = xgb(**params)
        model.fit(X_train_fold, y_train_fold, eval_set=[(X_valid_fold, y_valid_fold)],verbose=False)

        # Predict class probabilities
        y_prob = model.predict_proba(X_valid_fold)

        # Compute the AUC for each class and take the average
        average_auc = roc_auc_score(pd.Series(y.squeeze(1)).iloc[valid_idx], y_prob.squeeze(1), multi_class="ovr", average="macro")
        auc_scores.append(average_auc)

    # Return the average AUC score across all folds
    return np.mean(auc_scores)

study = optuna.create_study(direction='maximize', study_name='xgb_model_training')
study.optimize(objective, n_trials=200)