## Constants

In [1]:
import sys, os
import pandas as pd
import polars as pl
import numpy as np
import subprocess
import gc
import optuna
from datetime import datetime, timezone
import warnings
import xgboost as xgb
import joblib as jl
from sklearn.model_selection import train_test_split
import warnings
from sklearn.metrics import matthews_corrcoef
from mlflow.models import infer_signature
import mlflow
import random
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import OneHotEncoder, StandardScaler, LabelBinarizer
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.model_selection import StratifiedKFold

today = datetime.now(timezone.utc).strftime("%Y_%m_%d")
warnings.filterwarnings("ignore")

from hyper_params import (
    mushroom_tuning_2024_08_06_1722934727_params,
)

SEED = 108
random.seed(SEED)
N_FOLDS = 12
# model
is_tunning = True
try:
    rs = subprocess.check_output("nvidia-smi")
    device = "cuda" if rs is not None else "cpu"
except (
    Exception
):  # this command not being found can raise quite a few different errors depending on the configuration
    print("No Nvidia GPU in system!")
    device = "cpu"

best_params = {
    "device": device,
    "verbosity": 0,
    "objective": "binary:logistic",
}
best_params.update(mushroom_tuning_2024_08_06_1722934727_params)
best_params

{'device': 'cuda',
 'verbosity': 0,
 'objective': 'binary:logistic',
 'tree_method': 'hist',
 'eta': 0.0696294726051571,
 'max_depth': 0,
 'min_child_weight': 1,
 'gamma': 0.044230646284796976,
 'subsample': 0.9405269471473167,
 'colsample_bytree': 0.2999355523666192,
 'lambda': 0.9746051811186938,
 'alpha': 4.210861941737071}

## Prepare data

In [2]:
y_train_pkl = jl.load("../y_train.pkl")
X_train_pkl = jl.load("../X_train.pkl")

print(f"train size: {X_train_pkl.shape}")

train size: (3116945, 294)


## CV

In [3]:
from sklearn.model_selection import KFold
from sklearn.model_selection import cross_val_score


clf: xgb.XGBClassifier = xgb.XGBClassifier(
    **best_params,
    n_estimators=4000,
    early_stopping_rounds=50,
    enable_categorical=True,
)

In [4]:
from tqdm import tqdm

gc.collect()
skf = StratifiedKFold(n_splits=N_FOLDS)

y_preds = []
y_trues = []
for train_index, test_index in tqdm(skf.split(X_train_pkl, y_train_pkl)):
    X_train, X_test = X_train_pkl[train_index], X_train_pkl[test_index]
    y_train, y_test = y_train_pkl[train_index], y_train_pkl[test_index]

    clf.fit(X=X_train, y=y_train, eval_set=[(X_test, y_test)])

    y_pred = clf.predict(X_test)
    y_preds.append(y_pred)
    y_trues.append(y_test)

    del X_train, X_test, y_train, y_test, y_pred
    gc.collect()
# Concatenate the predictions and true labels
y_preds_concat = np.concatenate(y_preds)
y_trues_concat = np.concatenate(y_trues)
mcc = matthews_corrcoef(y_trues_concat, y_preds_concat)
print(f"Validation mcc score: {mcc}")
jl.dump(clf, "../clf.pkl")

0it [00:00, ?it/s]

[0]	validation_0-logloss:0.62808
[1]	validation_0-logloss:0.57444
[2]	validation_0-logloss:0.52920
[3]	validation_0-logloss:0.49783
[4]	validation_0-logloss:0.45835
[5]	validation_0-logloss:0.42422
[6]	validation_0-logloss:0.40108
[7]	validation_0-logloss:0.37320
[8]	validation_0-logloss:0.34676
[9]	validation_0-logloss:0.32602
[10]	validation_0-logloss:0.30922
[11]	validation_0-logloss:0.28825
[12]	validation_0-logloss:0.27167
[13]	validation_0-logloss:0.25550
[14]	validation_0-logloss:0.24140
[15]	validation_0-logloss:0.22823
[16]	validation_0-logloss:0.21667
[17]	validation_0-logloss:0.20659
[18]	validation_0-logloss:0.19646
[19]	validation_0-logloss:0.18624
[20]	validation_0-logloss:0.17659
[21]	validation_0-logloss:0.16629
[22]	validation_0-logloss:0.15773
[23]	validation_0-logloss:0.15169
[24]	validation_0-logloss:0.14326
[25]	validation_0-logloss:0.13687
[26]	validation_0-logloss:0.13249
[27]	validation_0-logloss:0.12580
[28]	validation_0-logloss:0.11958
[29]	validation_0-loglos

1it [01:06, 66.38s/it]

[0]	validation_0-logloss:0.62812
[1]	validation_0-logloss:0.57448
[2]	validation_0-logloss:0.52941
[3]	validation_0-logloss:0.49803
[4]	validation_0-logloss:0.45862
[5]	validation_0-logloss:0.42448
[6]	validation_0-logloss:0.40133
[7]	validation_0-logloss:0.37342
[8]	validation_0-logloss:0.34697
[9]	validation_0-logloss:0.32625
[10]	validation_0-logloss:0.30942
[11]	validation_0-logloss:0.28851
[12]	validation_0-logloss:0.27145
[13]	validation_0-logloss:0.25540
[14]	validation_0-logloss:0.24132
[15]	validation_0-logloss:0.22817
[16]	validation_0-logloss:0.21665
[17]	validation_0-logloss:0.20659
[18]	validation_0-logloss:0.19653
[19]	validation_0-logloss:0.18632
[20]	validation_0-logloss:0.17670
[21]	validation_0-logloss:0.16642
[22]	validation_0-logloss:0.15790
[23]	validation_0-logloss:0.15192
[24]	validation_0-logloss:0.14352
[25]	validation_0-logloss:0.13716
[26]	validation_0-logloss:0.13280
[27]	validation_0-logloss:0.12618
[28]	validation_0-logloss:0.11997
[29]	validation_0-loglos

2it [02:10, 65.07s/it]

[0]	validation_0-logloss:0.62815
[1]	validation_0-logloss:0.57454
[2]	validation_0-logloss:0.52929
[3]	validation_0-logloss:0.49788
[4]	validation_0-logloss:0.45844
[5]	validation_0-logloss:0.42430
[6]	validation_0-logloss:0.40123
[7]	validation_0-logloss:0.37337
[8]	validation_0-logloss:0.34697
[9]	validation_0-logloss:0.32625
[10]	validation_0-logloss:0.30945
[11]	validation_0-logloss:0.28852
[12]	validation_0-logloss:0.27357
[13]	validation_0-logloss:0.25731
[14]	validation_0-logloss:0.24305
[15]	validation_0-logloss:0.22989
[16]	validation_0-logloss:0.21823
[17]	validation_0-logloss:0.20804
[18]	validation_0-logloss:0.19787
[19]	validation_0-logloss:0.18753
[20]	validation_0-logloss:0.17781
[21]	validation_0-logloss:0.16745
[22]	validation_0-logloss:0.15885
[23]	validation_0-logloss:0.15277
[24]	validation_0-logloss:0.14432
[25]	validation_0-logloss:0.13791
[26]	validation_0-logloss:0.13349
[27]	validation_0-logloss:0.12680
[28]	validation_0-logloss:0.12054
[29]	validation_0-loglos

3it [03:12, 63.68s/it]

[0]	validation_0-logloss:0.62813
[1]	validation_0-logloss:0.57452
[2]	validation_0-logloss:0.52939
[3]	validation_0-logloss:0.49799
[4]	validation_0-logloss:0.45856
[5]	validation_0-logloss:0.42437
[6]	validation_0-logloss:0.40120
[7]	validation_0-logloss:0.37331
[8]	validation_0-logloss:0.34690
[9]	validation_0-logloss:0.32614
[10]	validation_0-logloss:0.30932
[11]	validation_0-logloss:0.28836
[12]	validation_0-logloss:0.27132
[13]	validation_0-logloss:0.25525
[14]	validation_0-logloss:0.24117
[15]	validation_0-logloss:0.22802
[16]	validation_0-logloss:0.21649
[17]	validation_0-logloss:0.20642
[18]	validation_0-logloss:0.19634
[19]	validation_0-logloss:0.18611
[20]	validation_0-logloss:0.17649
[21]	validation_0-logloss:0.16624
[22]	validation_0-logloss:0.15770
[23]	validation_0-logloss:0.15169
[24]	validation_0-logloss:0.14331
[25]	validation_0-logloss:0.13694
[26]	validation_0-logloss:0.13256
[27]	validation_0-logloss:0.12592
[28]	validation_0-logloss:0.11971
[29]	validation_0-loglos

4it [04:16, 63.70s/it]

[0]	validation_0-logloss:0.62815
[1]	validation_0-logloss:0.57452
[2]	validation_0-logloss:0.52934
[3]	validation_0-logloss:0.49793
[4]	validation_0-logloss:0.45846
[5]	validation_0-logloss:0.42431
[6]	validation_0-logloss:0.40114
[7]	validation_0-logloss:0.37325
[8]	validation_0-logloss:0.34683
[9]	validation_0-logloss:0.32617
[10]	validation_0-logloss:0.30935
[11]	validation_0-logloss:0.28839
[12]	validation_0-logloss:0.27133
[13]	validation_0-logloss:0.25523
[14]	validation_0-logloss:0.24115
[15]	validation_0-logloss:0.22799
[16]	validation_0-logloss:0.21645
[17]	validation_0-logloss:0.20637
[18]	validation_0-logloss:0.19626
[19]	validation_0-logloss:0.18603
[20]	validation_0-logloss:0.17649
[21]	validation_0-logloss:0.16621
[22]	validation_0-logloss:0.15768
[23]	validation_0-logloss:0.15166
[24]	validation_0-logloss:0.14327
[25]	validation_0-logloss:0.13689
[26]	validation_0-logloss:0.13250
[27]	validation_0-logloss:0.12587
[28]	validation_0-logloss:0.11966
[29]	validation_0-loglos

5it [05:20, 64.06s/it]

[0]	validation_0-logloss:0.62814
[1]	validation_0-logloss:0.57452
[2]	validation_0-logloss:0.52934
[3]	validation_0-logloss:0.49788
[4]	validation_0-logloss:0.45837
[5]	validation_0-logloss:0.42420
[6]	validation_0-logloss:0.40102
[7]	validation_0-logloss:0.37313
[8]	validation_0-logloss:0.34670
[9]	validation_0-logloss:0.32598
[10]	validation_0-logloss:0.30913
[11]	validation_0-logloss:0.28813
[12]	validation_0-logloss:0.27108
[13]	validation_0-logloss:0.25497
[14]	validation_0-logloss:0.24087
[15]	validation_0-logloss:0.22771
[16]	validation_0-logloss:0.21615
[17]	validation_0-logloss:0.20609
[18]	validation_0-logloss:0.19601
[19]	validation_0-logloss:0.18580
[20]	validation_0-logloss:0.17615
[21]	validation_0-logloss:0.16589
[22]	validation_0-logloss:0.15734
[23]	validation_0-logloss:0.15131
[24]	validation_0-logloss:0.14292
[25]	validation_0-logloss:0.13653
[26]	validation_0-logloss:0.13216
[27]	validation_0-logloss:0.12551
[28]	validation_0-logloss:0.11929
[29]	validation_0-loglos

6it [06:20, 62.59s/it]

[0]	validation_0-logloss:0.62814
[1]	validation_0-logloss:0.57457
[2]	validation_0-logloss:0.52940
[3]	validation_0-logloss:0.49809
[4]	validation_0-logloss:0.45860
[5]	validation_0-logloss:0.42446
[6]	validation_0-logloss:0.40142
[7]	validation_0-logloss:0.37351
[8]	validation_0-logloss:0.34707
[9]	validation_0-logloss:0.32635
[10]	validation_0-logloss:0.30952
[11]	validation_0-logloss:0.28856
[12]	validation_0-logloss:0.27148
[13]	validation_0-logloss:0.25541
[14]	validation_0-logloss:0.24133
[15]	validation_0-logloss:0.22820
[16]	validation_0-logloss:0.21667
[17]	validation_0-logloss:0.20661
[18]	validation_0-logloss:0.19654
[19]	validation_0-logloss:0.18633
[20]	validation_0-logloss:0.17670
[21]	validation_0-logloss:0.16642
[22]	validation_0-logloss:0.15789
[23]	validation_0-logloss:0.15189
[24]	validation_0-logloss:0.14350
[25]	validation_0-logloss:0.13714
[26]	validation_0-logloss:0.13278
[27]	validation_0-logloss:0.12613
[28]	validation_0-logloss:0.11993
[29]	validation_0-loglos

7it [07:20, 61.82s/it]

[0]	validation_0-logloss:0.62820
[1]	validation_0-logloss:0.57476
[2]	validation_0-logloss:0.52956
[3]	validation_0-logloss:0.49823
[4]	validation_0-logloss:0.45878
[5]	validation_0-logloss:0.42463
[6]	validation_0-logloss:0.40155
[7]	validation_0-logloss:0.37367
[8]	validation_0-logloss:0.34729
[9]	validation_0-logloss:0.32656
[10]	validation_0-logloss:0.30973
[11]	validation_0-logloss:0.28874
[12]	validation_0-logloss:0.27334
[13]	validation_0-logloss:0.25712
[14]	validation_0-logloss:0.24294
[15]	validation_0-logloss:0.22965
[16]	validation_0-logloss:0.21804
[17]	validation_0-logloss:0.20790
[18]	validation_0-logloss:0.19774
[19]	validation_0-logloss:0.18742
[20]	validation_0-logloss:0.17771
[21]	validation_0-logloss:0.16738
[22]	validation_0-logloss:0.15878
[23]	validation_0-logloss:0.15274
[24]	validation_0-logloss:0.14428
[25]	validation_0-logloss:0.13789
[26]	validation_0-logloss:0.13348
[27]	validation_0-logloss:0.12677
[28]	validation_0-logloss:0.12052
[29]	validation_0-loglos

8it [08:29, 64.12s/it]

[0]	validation_0-logloss:0.62829
[1]	validation_0-logloss:0.57476
[2]	validation_0-logloss:0.53037
[3]	validation_0-logloss:0.49904
[4]	validation_0-logloss:0.45948
[5]	validation_0-logloss:0.42525
[6]	validation_0-logloss:0.40204
[7]	validation_0-logloss:0.37407
[8]	validation_0-logloss:0.34761
[9]	validation_0-logloss:0.32688
[10]	validation_0-logloss:0.31002
[11]	validation_0-logloss:0.28901
[12]	validation_0-logloss:0.27193
[13]	validation_0-logloss:0.25583
[14]	validation_0-logloss:0.24170
[15]	validation_0-logloss:0.22855
[16]	validation_0-logloss:0.21709
[17]	validation_0-logloss:0.20699
[18]	validation_0-logloss:0.19689
[19]	validation_0-logloss:0.18667
[20]	validation_0-logloss:0.17701
[21]	validation_0-logloss:0.16673
[22]	validation_0-logloss:0.15817
[23]	validation_0-logloss:0.15214
[24]	validation_0-logloss:0.14372
[25]	validation_0-logloss:0.13735
[26]	validation_0-logloss:0.13296
[27]	validation_0-logloss:0.12633
[28]	validation_0-logloss:0.12011
[29]	validation_0-loglos

9it [09:34, 64.29s/it]

[0]	validation_0-logloss:0.62817
[1]	validation_0-logloss:0.57457
[2]	validation_0-logloss:0.52938
[3]	validation_0-logloss:0.49810
[4]	validation_0-logloss:0.45864
[5]	validation_0-logloss:0.42444
[6]	validation_0-logloss:0.40129
[7]	validation_0-logloss:0.37336
[8]	validation_0-logloss:0.34695
[9]	validation_0-logloss:0.32619
[10]	validation_0-logloss:0.30934
[11]	validation_0-logloss:0.28835
[12]	validation_0-logloss:0.27127
[13]	validation_0-logloss:0.25517
[14]	validation_0-logloss:0.24105
[15]	validation_0-logloss:0.22790
[16]	validation_0-logloss:0.21634
[17]	validation_0-logloss:0.20625
[18]	validation_0-logloss:0.19615
[19]	validation_0-logloss:0.18593
[20]	validation_0-logloss:0.17628
[21]	validation_0-logloss:0.17122
[22]	validation_0-logloss:0.16236
[23]	validation_0-logloss:0.15609
[24]	validation_0-logloss:0.14732
[25]	validation_0-logloss:0.14066
[26]	validation_0-logloss:0.13608
[27]	validation_0-logloss:0.12916
[28]	validation_0-logloss:0.12266
[29]	validation_0-loglos

10it [10:40, 64.70s/it]

[0]	validation_0-logloss:0.62815
[1]	validation_0-logloss:0.57457
[2]	validation_0-logloss:0.52933
[3]	validation_0-logloss:0.49796
[4]	validation_0-logloss:0.45848
[5]	validation_0-logloss:0.42438
[6]	validation_0-logloss:0.40127
[7]	validation_0-logloss:0.37340
[8]	validation_0-logloss:0.34699
[9]	validation_0-logloss:0.32630
[10]	validation_0-logloss:0.30948
[11]	validation_0-logloss:0.28852
[12]	validation_0-logloss:0.27147
[13]	validation_0-logloss:0.25535
[14]	validation_0-logloss:0.24128
[15]	validation_0-logloss:0.22811
[16]	validation_0-logloss:0.21657
[17]	validation_0-logloss:0.20645
[18]	validation_0-logloss:0.19636
[19]	validation_0-logloss:0.18613
[20]	validation_0-logloss:0.17650
[21]	validation_0-logloss:0.16621
[22]	validation_0-logloss:0.15766
[23]	validation_0-logloss:0.15164
[24]	validation_0-logloss:0.14324
[25]	validation_0-logloss:0.13733
[26]	validation_0-logloss:0.13294
[27]	validation_0-logloss:0.12625
[28]	validation_0-logloss:0.12002
[29]	validation_0-loglos

11it [11:42, 64.06s/it]

[0]	validation_0-logloss:0.62817
[1]	validation_0-logloss:0.57455
[2]	validation_0-logloss:0.52933
[3]	validation_0-logloss:0.49795
[4]	validation_0-logloss:0.45851
[5]	validation_0-logloss:0.42436
[6]	validation_0-logloss:0.40119
[7]	validation_0-logloss:0.37328
[8]	validation_0-logloss:0.34687
[9]	validation_0-logloss:0.32615
[10]	validation_0-logloss:0.30932
[11]	validation_0-logloss:0.28836
[12]	validation_0-logloss:0.27128
[13]	validation_0-logloss:0.25517
[14]	validation_0-logloss:0.24105
[15]	validation_0-logloss:0.22791
[16]	validation_0-logloss:0.21637
[17]	validation_0-logloss:0.20631
[18]	validation_0-logloss:0.19624
[19]	validation_0-logloss:0.18601
[20]	validation_0-logloss:0.17637
[21]	validation_0-logloss:0.16608
[22]	validation_0-logloss:0.15754
[23]	validation_0-logloss:0.15151
[24]	validation_0-logloss:0.14307
[25]	validation_0-logloss:0.13670
[26]	validation_0-logloss:0.13233
[27]	validation_0-logloss:0.12567
[28]	validation_0-logloss:0.11945
[29]	validation_0-loglos

12it [12:47, 63.93s/it]


Validation mcc score: 0.9848005518849403


['../clf.pkl']

In [5]:
submit_df = jl.load("../submit_df.pkl")
X_test_pkl = jl.load("../X_test.pkl")
lb = jl.load("../lb.pkl")
y_preds = clf.predict(X_test_pkl)

In [6]:
pred_classes = lb.inverse_transform(y_preds)
submit_df["class"] = pred_classes
submit_df.to_csv("submission.csv", index=False)