In [3]:
import pandas as pd
import numpy as np
import catboost
import optuna

In [4]:
from catboost import CatBoostRegressor, Pool
from optuna.integration import CatBoostPruningCallback
from sklearn.metrics import mean_squared_error

from sklearn.model_selection import train_test_split

In [5]:
import os
import matplotlib.pyplot as plt#visualization
%matplotlib inline
import seaborn as sns#visualization

In [6]:
df_train = pd.read_csv('data/train.csv')
df_val = pd.read_csv('data/test.csv')

print(df_train.shape)
print(df_val.shape) 

(300000, 16)
(200000, 15)


In [7]:
X_train, X_test, y_train, y_test = train_test_split(
    df_train.drop(labels=['target', 'id'], axis=1),  # drop the target
    df_train['target'],  # just the target
    test_size=0.2,
    random_state=0)

X_train.shape, X_test.shape

((240000, 14), (60000, 14))

In [8]:
def objective(trial):
    global cbr
    X_train_cbr = X_train
    X_test_cbr = X_test
    y_train_cbr = y_train
    y_test_cbr = y_test
    
    param = {
        'learning_rate': trial.suggest_loguniform('learning_rate', 0.005, 0.1),
        'n_estimators': trial.suggest_int('n_estimators', 1000, 25000, step=1),
        'max_depth': trial.suggest_int('max_depth', 1, 20, step=1),
        'subsample': trial.suggest_float('subsample', 0.4, 1),
        'colsample_bylevel': trial.suggest_float('colsample_bylevel', 0.5, 0.99),
        'l2_leaf_reg': trial.suggest_float('l2_leaf_reg', 0.01, 15.0),
        'max_bin': trial.suggest_int('max_bin', 200, 400),
        'eval_metric': 'RMSE'
    }
    
    cbr = CatBoostRegressor(**param)

    pruning_callback = CatBoostPruningCallback(trial, "RMSE")
    
    cbr.fit(
        X_train_cbr,
        y_train_cbr,
        eval_set=[(X_test_cbr, y_test_cbr)],
        verbose=1,
        early_stopping_rounds=100,
        callbacks=[pruning_callback],
    )
    trial.set_user_attr(key="best_booster", value=cbr)
    # evoke pruning manually.
    pruning_callback.check_pruned()
    
    y_pred = cbr.predict(X_test_cbr)
    rmse = mean_squared_error(y_test, y_pred, squared=False)
    return rmse

In [9]:
def callback(study, trial):
    if study.best_trial.number == trial.number:
        study.set_user_attr(key="best_booster", value=trial.user_attrs["best_booster"])

In [10]:
study = optuna.create_study(pruner=optuna.pruners.MedianPruner(n_warmup_steps=5), direction="minimize")
study.optimize(objective, n_trials=5, timeout=600, callbacks=[callback])

[32m[I 2022-10-03 14:30:30,019][0m A new study created in memory with name: no-name-d7eb8a4e-c664-4044-b161-54d5eb8a005c[0m
  'learning_rate': trial.suggest_loguniform('learning_rate', 0.005, 0.1),
  pruning_callback = CatBoostPruningCallback(trial, "RMSE")


0:	learn: 0.7317711	test: 0.7315576	best: 0.7315576 (0)	total: 2.29s	remaining: 7h 8m 28s
1:	learn: 0.7303805	test: 0.7307905	best: 0.7307905 (1)	total: 4.21s	remaining: 6h 34m 4s
2:	learn: 0.7291040	test: 0.7300658	best: 0.7300658 (2)	total: 6.27s	remaining: 6h 30m 51s
3:	learn: 0.7278128	test: 0.7293094	best: 0.7293094 (3)	total: 8.15s	remaining: 6h 21m 7s
4:	learn: 0.7265283	test: 0.7285586	best: 0.7285586 (4)	total: 9.97s	remaining: 6h 13m 13s
5:	learn: 0.7253544	test: 0.7279286	best: 0.7279286 (5)	total: 11.8s	remaining: 6h 8m 54s
6:	learn: 0.7241643	test: 0.7273765	best: 0.7273765 (6)	total: 13.9s	remaining: 6h 10m 41s
7:	learn: 0.7230484	test: 0.7268440	best: 0.7268440 (7)	total: 16s	remaining: 6h 14m 15s
8:	learn: 0.7219393	test: 0.7262326	best: 0.7262326 (8)	total: 17.8s	remaining: 6h 8m 58s
9:	learn: 0.7209080	test: 0.7257275	best: 0.7257275 (9)	total: 19.8s	remaining: 6h 10m 53s
10:	learn: 0.7198722	test: 0.7252181	best: 0.7252181 (10)	total: 21.8s	remaining: 6h 10m 15s
11:	

89:	learn: 0.6784915	test: 0.7103735	best: 0.7103735 (89)	total: 3m 6s	remaining: 6h 23m 59s
90:	learn: 0.6781480	test: 0.7103163	best: 0.7103163 (90)	total: 3m 8s	remaining: 6h 23m 37s
91:	learn: 0.6777496	test: 0.7102515	best: 0.7102515 (91)	total: 3m 10s	remaining: 6h 23m 23s
92:	learn: 0.6774992	test: 0.7101770	best: 0.7101770 (92)	total: 3m 12s	remaining: 6h 23m 48s
93:	learn: 0.6771979	test: 0.7101029	best: 0.7101029 (93)	total: 3m 14s	remaining: 6h 24m 6s
94:	learn: 0.6768076	test: 0.7100399	best: 0.7100399 (94)	total: 3m 16s	remaining: 6h 23m 55s
95:	learn: 0.6763902	test: 0.7099797	best: 0.7099797 (95)	total: 3m 18s	remaining: 6h 23m 30s
96:	learn: 0.6759777	test: 0.7099072	best: 0.7099072 (96)	total: 3m 20s	remaining: 6h 23m 13s
97:	learn: 0.6755364	test: 0.7098402	best: 0.7098402 (97)	total: 3m 22s	remaining: 6h 22m 39s
98:	learn: 0.6752288	test: 0.7097819	best: 0.7097819 (98)	total: 3m 24s	remaining: 6h 22m 23s
99:	learn: 0.6749199	test: 0.7097257	best: 0.7097257 (99)	total

175:	learn: 0.6506441	test: 0.7067798	best: 0.7067798 (175)	total: 5m 57s	remaining: 6h 13m 59s
176:	learn: 0.6503851	test: 0.7067586	best: 0.7067586 (176)	total: 5m 59s	remaining: 6h 13m 51s
177:	learn: 0.6502160	test: 0.7067246	best: 0.7067246 (177)	total: 6m 1s	remaining: 6h 13m 45s
178:	learn: 0.6499413	test: 0.7066846	best: 0.7066846 (178)	total: 6m 3s	remaining: 6h 13m 51s
179:	learn: 0.6497281	test: 0.7066613	best: 0.7066613 (179)	total: 6m 4s	remaining: 6h 13m 20s
180:	learn: 0.6493677	test: 0.7066456	best: 0.7066456 (180)	total: 6m 6s	remaining: 6h 13m 22s
181:	learn: 0.6491442	test: 0.7066224	best: 0.7066224 (181)	total: 6m 9s	remaining: 6h 13m 21s
182:	learn: 0.6488260	test: 0.7066044	best: 0.7066044 (182)	total: 6m 11s	remaining: 6h 13m 30s
183:	learn: 0.6485829	test: 0.7065611	best: 0.7065611 (183)	total: 6m 13s	remaining: 6h 13m 21s
184:	learn: 0.6483052	test: 0.7065446	best: 0.7065446 (184)	total: 6m 14s	remaining: 6h 13m 6s
185:	learn: 0.6479796	test: 0.7065226	best: 0.

261:	learn: 0.6281959	test: 0.7051990	best: 0.7051990 (261)	total: 8m 56s	remaining: 6h 14m 17s
262:	learn: 0.6278290	test: 0.7051806	best: 0.7051806 (262)	total: 8m 58s	remaining: 6h 14m 16s
263:	learn: 0.6275229	test: 0.7051780	best: 0.7051780 (263)	total: 9m	remaining: 6h 14m 23s
264:	learn: 0.6271632	test: 0.7051504	best: 0.7051504 (264)	total: 9m 3s	remaining: 6h 14m 31s
265:	learn: 0.6269241	test: 0.7051397	best: 0.7051397 (265)	total: 9m 5s	remaining: 6h 14m 40s
266:	learn: 0.6265825	test: 0.7051348	best: 0.7051348 (266)	total: 9m 7s	remaining: 6h 14m 46s
267:	learn: 0.6262853	test: 0.7051430	best: 0.7051348 (266)	total: 9m 9s	remaining: 6h 14m 46s
268:	learn: 0.6260127	test: 0.7051251	best: 0.7051251 (268)	total: 9m 11s	remaining: 6h 14m 40s
269:	learn: 0.6258058	test: 0.7051203	best: 0.7051203 (269)	total: 9m 13s	remaining: 6h 14m 27s
270:	learn: 0.6255763	test: 0.7050942	best: 0.7050942 (270)	total: 9m 15s	remaining: 6h 14m 19s
271:	learn: 0.6253338	test: 0.7050767	best: 0.70

347:	learn: 0.6062618	test: 0.7042822	best: 0.7042722 (346)	total: 11m 47s	remaining: 6h 8m 36s
348:	learn: 0.6059707	test: 0.7042792	best: 0.7042722 (346)	total: 11m 49s	remaining: 6h 8m 32s
349:	learn: 0.6057383	test: 0.7042765	best: 0.7042722 (346)	total: 11m 51s	remaining: 6h 8m 27s
350:	learn: 0.6054871	test: 0.7042715	best: 0.7042715 (350)	total: 11m 53s	remaining: 6h 8m 21s
351:	learn: 0.6052708	test: 0.7042509	best: 0.7042509 (351)	total: 11m 55s	remaining: 6h 8m 22s
352:	learn: 0.6048409	test: 0.7042572	best: 0.7042509 (351)	total: 11m 57s	remaining: 6h 8m 14s
353:	learn: 0.6046975	test: 0.7042481	best: 0.7042481 (353)	total: 11m 58s	remaining: 6h 8m 9s
354:	learn: 0.6044826	test: 0.7042444	best: 0.7042444 (354)	total: 12m 1s	remaining: 6h 8m 7s
355:	learn: 0.6042591	test: 0.7042471	best: 0.7042444 (354)	total: 12m 2s	remaining: 6h 8m 2s
356:	learn: 0.6040821	test: 0.7042423	best: 0.7042423 (356)	total: 12m 5s	remaining: 6h 8m 1s
357:	learn: 0.6039288	test: 0.7042317	best: 0.7

433:	learn: 0.5864811	test: 0.7037424	best: 0.7037382 (432)	total: 14m 43s	remaining: 6h 6m 12s
434:	learn: 0.5863260	test: 0.7037323	best: 0.7037323 (434)	total: 14m 44s	remaining: 6h 5m 59s
435:	learn: 0.5861228	test: 0.7037245	best: 0.7037245 (435)	total: 14m 46s	remaining: 6h 5m 55s
436:	learn: 0.5859373	test: 0.7037233	best: 0.7037233 (436)	total: 14m 48s	remaining: 6h 5m 53s
437:	learn: 0.5856350	test: 0.7037132	best: 0.7037132 (437)	total: 14m 50s	remaining: 6h 5m 50s
438:	learn: 0.5854567	test: 0.7037088	best: 0.7037088 (438)	total: 14m 52s	remaining: 6h 5m 44s
439:	learn: 0.5852292	test: 0.7036983	best: 0.7036983 (439)	total: 14m 54s	remaining: 6h 5m 45s
440:	learn: 0.5850188	test: 0.7037039	best: 0.7036983 (439)	total: 14m 57s	remaining: 6h 5m 52s
441:	learn: 0.5847589	test: 0.7036966	best: 0.7036966 (441)	total: 14m 59s	remaining: 6h 5m 55s
442:	learn: 0.5845157	test: 0.7036883	best: 0.7036883 (442)	total: 15m 1s	remaining: 6h 5m 52s
443:	learn: 0.5841590	test: 0.7036800	bes

519:	learn: 0.5671530	test: 0.7035163	best: 0.7034665 (482)	total: 17m 41s	remaining: 6h 4m 23s
520:	learn: 0.5669285	test: 0.7035127	best: 0.7034665 (482)	total: 17m 44s	remaining: 6h 4m 33s
521:	learn: 0.5667372	test: 0.7035086	best: 0.7034665 (482)	total: 17m 46s	remaining: 6h 4m 43s
522:	learn: 0.5665738	test: 0.7035156	best: 0.7034665 (482)	total: 17m 49s	remaining: 6h 4m 45s
523:	learn: 0.5664285	test: 0.7035130	best: 0.7034665 (482)	total: 17m 51s	remaining: 6h 4m 52s
524:	learn: 0.5661761	test: 0.7035158	best: 0.7034665 (482)	total: 17m 54s	remaining: 6h 5m 2s
525:	learn: 0.5660153	test: 0.7035133	best: 0.7034665 (482)	total: 17m 56s	remaining: 6h 5m 5s
526:	learn: 0.5657549	test: 0.7035080	best: 0.7034665 (482)	total: 17m 58s	remaining: 6h 5m 6s
527:	learn: 0.5654885	test: 0.7035222	best: 0.7034665 (482)	total: 18m 1s	remaining: 6h 5m 14s
528:	learn: 0.5652865	test: 0.7035264	best: 0.7034665 (482)	total: 18m 3s	remaining: 6h 5m 15s
529:	learn: 0.5650804	test: 0.7035262	best: 0

605:	learn: 0.5492005	test: 0.7034663	best: 0.7034526 (596)	total: 20m 39s	remaining: 6h 2m 7s
606:	learn: 0.5490425	test: 0.7034674	best: 0.7034526 (596)	total: 20m 41s	remaining: 6h 2m
607:	learn: 0.5488691	test: 0.7034687	best: 0.7034526 (596)	total: 20m 43s	remaining: 6h 1m 56s
608:	learn: 0.5486722	test: 0.7034648	best: 0.7034526 (596)	total: 20m 44s	remaining: 6h 1m 49s
609:	learn: 0.5484857	test: 0.7034652	best: 0.7034526 (596)	total: 20m 46s	remaining: 6h 1m 43s
610:	learn: 0.5483015	test: 0.7034653	best: 0.7034526 (596)	total: 20m 48s	remaining: 6h 1m 33s
611:	learn: 0.5480670	test: 0.7034701	best: 0.7034526 (596)	total: 20m 50s	remaining: 6h 1m 34s
612:	learn: 0.5477846	test: 0.7034847	best: 0.7034526 (596)	total: 20m 52s	remaining: 6h 1m 29s
613:	learn: 0.5475448	test: 0.7034840	best: 0.7034526 (596)	total: 20m 54s	remaining: 6h 1m 27s
614:	learn: 0.5473498	test: 0.7034837	best: 0.7034526 (596)	total: 20m 56s	remaining: 6h 1m 23s
615:	learn: 0.5471149	test: 0.7034805	best: 0

692:	learn: 0.5315609	test: 0.7034545	best: 0.7034345 (685)	total: 23m 49s	remaining: 6h 2m 22s
693:	learn: 0.5313596	test: 0.7034686	best: 0.7034345 (685)	total: 23m 52s	remaining: 6h 2m 26s
694:	learn: 0.5310869	test: 0.7034827	best: 0.7034345 (685)	total: 23m 54s	remaining: 6h 2m 29s
695:	learn: 0.5309722	test: 0.7034846	best: 0.7034345 (685)	total: 23m 57s	remaining: 6h 2m 29s
696:	learn: 0.5307867	test: 0.7034865	best: 0.7034345 (685)	total: 23m 59s	remaining: 6h 2m 30s
697:	learn: 0.5305550	test: 0.7034865	best: 0.7034345 (685)	total: 24m 1s	remaining: 6h 2m 34s
698:	learn: 0.5304211	test: 0.7034802	best: 0.7034345 (685)	total: 24m 4s	remaining: 6h 2m 46s
699:	learn: 0.5301973	test: 0.7034801	best: 0.7034345 (685)	total: 24m 6s	remaining: 6h 2m 40s
700:	learn: 0.5300468	test: 0.7034744	best: 0.7034345 (685)	total: 24m 8s	remaining: 6h 2m 43s
701:	learn: 0.5298231	test: 0.7034733	best: 0.7034345 (685)	total: 24m 11s	remaining: 6h 2m 42s
702:	learn: 0.5296578	test: 0.7034751	best: 

778:	learn: 0.5144791	test: 0.7035088	best: 0.7034345 (685)	total: 27m	remaining: 6h 2m 23s
779:	learn: 0.5142318	test: 0.7035042	best: 0.7034345 (685)	total: 27m 2s	remaining: 6h 2m 22s
780:	learn: 0.5140512	test: 0.7034964	best: 0.7034345 (685)	total: 27m 4s	remaining: 6h 2m 19s
781:	learn: 0.5138224	test: 0.7035060	best: 0.7034345 (685)	total: 27m 6s	remaining: 6h 2m 13s
782:	learn: 0.5136531	test: 0.7035038	best: 0.7034345 (685)	total: 27m 8s	remaining: 6h 2m 9s
783:	learn: 0.5135489	test: 0.7034987	best: 0.7034345 (685)	total: 27m 10s	remaining: 6h 2m 8s
784:	learn: 0.5133602	test: 0.7035023	best: 0.7034345 (685)	total: 27m 12s	remaining: 6h 2m 6s
785:	learn: 0.5131023	test: 0.7035172	best: 0.7034345 (685)	total: 27m 14s	remaining: 6h 2m 2s
Stopped by overfitting detector  (100 iterations wait)

bestTest = 0.7034344991
bestIteration = 685

Shrink model to first 686 iterations.


[32m[I 2022-10-03 14:58:01,514][0m Trial 0 finished with value: 0.7034344983038484 and parameters: {'learning_rate': 0.03523702825185894, 'n_estimators': 11230, 'max_depth': 16, 'subsample': 0.5113544318455522, 'colsample_bylevel': 0.7875155120093605, 'l2_leaf_reg': 6.692328199156744, 'max_bin': 223}. Best is trial 0 with value: 0.7034344983038484.[0m


In [14]:
best_model=study.user_attrs["best_booster"]
print(best_model.get_params())

{'loss_function': 'RMSE', 'learning_rate': 0.03523702825185894, 'l2_leaf_reg': 6.692328199156744, 'eval_metric': 'RMSE', 'subsample': 0.5113544318455522, 'max_depth': 16, 'n_estimators': 11230, 'colsample_bylevel': 0.7875155120093605, 'max_bin': 223}


In [15]:
best_model.save_model('catboost')

In [16]:
rmse = mean_squared_error(y_test, best_model.predict(X_test), squared=False)
print('RMSE: ', rmse)

RMSE:  0.7034344983038484


In [17]:
df_features_importance = pd.DataFrame(
    {'feature_names': best_model.feature_names_,
     'feature_importances': best_model.feature_importances_
    })

In [18]:
df_features_importance.sort_values('feature_importances', ascending=False)

Unnamed: 0,feature_names,feature_importances
13,cont14,8.751266
1,cont2,8.5271
3,cont4,8.242854
4,cont5,8.017134
7,cont8,7.61118
2,cont3,7.590197
12,cont13,7.412545
6,cont7,6.881245
0,cont1,6.743386
9,cont10,6.604544


In [23]:
y_val = best_model.predict(df_val)

In [24]:
df_t = df_val[['id']]
df_val.drop('id', axis=1, inplace=True)
df_t.insert(1, "target", y_val.tolist(), True)
df_t.to_csv("catboost_reg.csv", index=False)
### Score: 0.70447
### Public score: 0.70522

In [26]:
df_t.describe()

Unnamed: 0,id,target
count,200000.0,200000.0
mean,250261.031215,7.903851
std,144128.894365,0.210196
min,0.0,7.032549
25%,125538.25,7.760895
50%,250389.5,7.893861
75%,375240.25,8.035628
max,499990.0,9.36153
