In [136]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import os
import re
import sys
sys.path.append("./utils")

from generate_features import *
from get_targets import *
from dnn_utils import *

from functools import reduce

from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import confusion_matrix
from sklearn.metrics import average_precision_score
from sklearn.metrics import f1_score
from catboost import CatBoostClassifier

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [137]:
DATA_PATH = "./data/prices/technology_prices_15minute_train.csv"
TICKERS = ["AAPL", "MSFT", "GOOG", "TSLA", "NVDA"]
TARGET_TICKER = "AAPL"
ADDITIONAL_TICKERS = ["MSFT", "GOOG", "TSLA", "NVDA"]

In [138]:
data = pd.read_csv(DATA_PATH, index_col=0)
data.dropna(inplace=True)
data = data[[ticker in TICKERS for ticker in data["ticker"]]]
data

Unnamed: 0,v,vw,o,c,h,l,t,n,ticker
0,25452.0,61.7090,61.750,61.3800,61.750,61.3800,2020-03-30 07:50:00,36,AAPL
1,28416.0,61.5333,61.750,61.5875,61.750,61.4225,2020-03-30 08:05:00,61,AAPL
2,17296.0,61.6269,61.540,61.7500,61.750,61.5250,2020-03-30 08:20:00,43,AAPL
3,6776.0,61.8962,61.675,61.9100,61.930,61.6750,2020-03-30 08:35:00,23,AAPL
4,17188.0,61.9149,62.000,61.7875,62.025,61.7875,2020-03-30 08:50:00,63,AAPL
...,...,...,...,...,...,...,...,...,...
131912,1630.0,196.5607,196.500,196.6000,196.600,196.5000,2021-08-11 22:30:00,18,NVDA
131913,1609.0,196.4820,196.500,196.3200,196.500,196.2100,2021-08-11 22:45:00,93,NVDA
131914,200.0,196.3200,196.320,196.3200,196.320,196.3200,2021-08-11 23:00:00,1,NVDA
131915,4302.0,196.3767,196.540,196.6000,196.600,196.1200,2021-08-11 23:30:00,52,NVDA


In [139]:
target_data = data[data["ticker"] == TARGET_TICKER]
additional_data = data[[ticker in ADDITIONAL_TICKERS for ticker in data["ticker"]]]
display(target_data)
display(additional_data)

Unnamed: 0,v,vw,o,c,h,l,t,n,ticker
0,25452.0,61.7090,61.750,61.3800,61.750,61.3800,2020-03-30 07:50:00,36,AAPL
1,28416.0,61.5333,61.750,61.5875,61.750,61.4225,2020-03-30 08:05:00,61,AAPL
2,17296.0,61.6269,61.540,61.7500,61.750,61.5250,2020-03-30 08:20:00,43,AAPL
3,6776.0,61.8962,61.675,61.9100,61.930,61.6750,2020-03-30 08:35:00,23,AAPL
4,17188.0,61.9149,62.000,61.7875,62.025,61.7875,2020-03-30 08:50:00,63,AAPL
...,...,...,...,...,...,...,...,...,...
21910,5261.0,145.6801,145.700,145.6800,145.700,145.6700,2021-08-11 22:45:00,164,AAPL
21911,1284.0,145.6924,145.700,145.6800,145.700,145.6800,2021-08-11 23:00:00,19,AAPL
21912,3649.0,145.6700,145.670,145.6700,145.710,145.6600,2021-08-11 23:15:00,28,AAPL
21913,9026.0,145.6643,145.670,145.6500,145.690,145.6500,2021-08-11 23:30:00,64,AAPL


Unnamed: 0,v,vw,o,c,h,l,t,n,ticker
31822,184.0,1119.9946,1120.0000,1120.00,1120.0000,1120.00,2020-03-30 11:21:00,4,GOOG
31823,997.0,1115.8603,1116.6600,1115.94,1116.6600,1115.94,2020-03-30 11:51:00,34,GOOG
31824,5765.0,1110.9206,1114.2000,1110.71,1114.2000,1110.71,2020-03-30 12:21:00,6,GOOG
31825,610.0,1124.6381,1122.3500,1127.00,1127.0000,1122.35,2020-03-30 12:36:00,49,GOOG
31826,673.0,1127.0979,1127.0002,1127.00,1127.0002,1127.00,2020-03-30 12:51:00,31,GOOG
...,...,...,...,...,...,...,...,...,...
131912,1630.0,196.5607,196.5000,196.60,196.6000,196.50,2021-08-11 22:30:00,18,NVDA
131913,1609.0,196.4820,196.5000,196.32,196.5000,196.21,2021-08-11 22:45:00,93,NVDA
131914,200.0,196.3200,196.3200,196.32,196.3200,196.32,2021-08-11 23:00:00,1,NVDA
131915,4302.0,196.3767,196.5400,196.60,196.6000,196.12,2021-08-11 23:30:00,52,NVDA


In [140]:
additional_data["ticker"].unique()

array(['GOOG', 'MSFT', 'TSLA', 'NVDA'], dtype=object)

In [141]:
all_data = simple_indicators(target_data).drop(["vw", "n", "ticker"], axis=1).set_index("t")
for ticker in additional_data["ticker"].unique():
    ticker_data = additional_data[additional_data["ticker"] == ticker]
    ticker_indicators = simple_indicators(ticker_data).drop(["vw", "ticker", "n"], axis=1).set_index("t")
    all_data = all_data.join(ticker_indicators, rsuffix=f"_{ticker}", how="left")

all_data.reset_index(drop=True, inplace=True)
all_data["target"] = get_direction_with_std_threshold_target(all_data, 10, 5, 5)
all_data.dropna(subset=["target"], inplace=True)

all_data

Unnamed: 0,volume,open,close,high,low,macd,macds,macdh,macd_xu_macds,macd_xd_macds,...,macds_NVDA,macdh_NVDA,macd_xu_macds_NVDA,macd_xd_macds_NVDA,boll_NVDA,boll_ub_NVDA,boll_lb_NVDA,high_x_boll_ub_NVDA,low_x_boll_lb_NVDA,target
0,25452.0,61.7500,61.3800,61.7500,61.3800,0.000000,0.000000,0.000000,False,False,...,,,,,,,,,,0.0
1,28416.0,61.7500,61.5875,61.7500,61.4225,0.004655,0.002586,0.002069,True,False,...,,,,,,,,,,0.0
2,17296.0,61.5400,61.7500,61.7500,61.5250,0.010971,0.006023,0.004948,False,False,...,,,,,,,,,,0.0
3,6776.0,61.6750,61.9100,61.9300,61.6750,0.019351,0.010538,0.008814,False,False,...,,,,,,,,,,0.0
4,17188.0,62.0000,61.7875,62.0250,61.7875,0.018623,0.012943,0.005680,False,False,...,,,,,,,,,,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
21900,17731.0,145.7100,145.6800,145.7500,145.6600,-0.008033,0.051588,-0.059621,False,False,...,-0.318244,0.168126,False,False,196.31558,197.631764,194.999396,False,False,0.0
21901,10326.0,145.6800,145.7800,145.8000,145.6800,-0.017429,0.037784,-0.055214,False,False,...,-0.278440,0.159218,False,False,196.42158,197.571797,195.271363,False,False,0.0
21902,5767.0,145.7301,145.7000,145.7301,145.7000,-0.030974,0.024033,-0.055007,False,False,...,-0.242440,0.143998,False,False,196.51025,197.506493,195.514007,False,False,0.0
21903,38578.0,145.7100,145.7200,145.8600,145.7000,-0.039638,0.011299,-0.050936,False,False,...,-0.210799,0.126567,False,False,196.57975,197.460632,195.698868,False,False,0.0


In [142]:
all_data["target"].value_counts()

0.0    15026
1.0     3846
2.0     3033
Name: target, dtype: int64

In [143]:
TEST_RATIO = 0.2

X = all_data.drop(["target"], axis=1).to_numpy()
y = all_data["target"].to_numpy()

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=TEST_RATIO, shuffle=False)
X_train.shape, X_test.shape

((17524, 75), (4381, 75))

In [144]:
class_weight = compute_class_weight(class_weight="balanced", classes=np.unique(y_train), y=y_train) * np.array([2, 1, 1])

cb_classifier = CatBoostClassifier(class_weights=class_weight)
cb_classifier.fit(X_train, y_train)

Learning rate set to 0.091569
0:	learn: 1.0770121	total: 29.8ms	remaining: 29.7s
1:	learn: 1.0584038	total: 53.7ms	remaining: 26.8s
2:	learn: 1.0433872	total: 78.6ms	remaining: 26.1s
3:	learn: 1.0299652	total: 103ms	remaining: 25.7s
4:	learn: 1.0200339	total: 128ms	remaining: 25.4s
5:	learn: 1.0123764	total: 151ms	remaining: 25s
6:	learn: 1.0047761	total: 177ms	remaining: 25.1s
7:	learn: 0.9976585	total: 209ms	remaining: 26s
8:	learn: 0.9931574	total: 232ms	remaining: 25.5s
9:	learn: 0.9879755	total: 255ms	remaining: 25.3s
10:	learn: 0.9828330	total: 279ms	remaining: 25.1s
11:	learn: 0.9789774	total: 304ms	remaining: 25s
12:	learn: 0.9753150	total: 332ms	remaining: 25.2s
13:	learn: 0.9712033	total: 356ms	remaining: 25.1s
14:	learn: 0.9664689	total: 381ms	remaining: 25s
15:	learn: 0.9639771	total: 404ms	remaining: 24.9s
16:	learn: 0.9605984	total: 432ms	remaining: 25s
17:	learn: 0.9570530	total: 458ms	remaining: 25s
18:	learn: 0.9544889	total: 486ms	remaining: 25.1s
19:	learn: 0.9524670

164:	learn: 0.7758791	total: 4.38s	remaining: 22.2s
165:	learn: 0.7747200	total: 4.42s	remaining: 22.2s
166:	learn: 0.7735424	total: 4.45s	remaining: 22.2s
167:	learn: 0.7721987	total: 4.48s	remaining: 22.2s
168:	learn: 0.7711158	total: 4.51s	remaining: 22.2s
169:	learn: 0.7705075	total: 4.54s	remaining: 22.2s
170:	learn: 0.7694538	total: 4.57s	remaining: 22.2s
171:	learn: 0.7689290	total: 4.6s	remaining: 22.2s
172:	learn: 0.7676630	total: 4.63s	remaining: 22.1s
173:	learn: 0.7666502	total: 4.65s	remaining: 22.1s
174:	learn: 0.7658229	total: 4.68s	remaining: 22.1s
175:	learn: 0.7644634	total: 4.71s	remaining: 22.1s
176:	learn: 0.7634568	total: 4.74s	remaining: 22s
177:	learn: 0.7626332	total: 4.76s	remaining: 22s
178:	learn: 0.7620938	total: 4.79s	remaining: 22s
179:	learn: 0.7612988	total: 4.81s	remaining: 21.9s
180:	learn: 0.7600412	total: 4.84s	remaining: 21.9s
181:	learn: 0.7594472	total: 4.86s	remaining: 21.9s
182:	learn: 0.7587429	total: 4.89s	remaining: 21.8s
183:	learn: 0.75732

327:	learn: 0.6557324	total: 8.74s	remaining: 17.9s
328:	learn: 0.6553184	total: 8.77s	remaining: 17.9s
329:	learn: 0.6549060	total: 8.8s	remaining: 17.9s
330:	learn: 0.6542245	total: 8.83s	remaining: 17.8s
331:	learn: 0.6534488	total: 8.86s	remaining: 17.8s
332:	learn: 0.6526488	total: 8.89s	remaining: 17.8s
333:	learn: 0.6521022	total: 8.91s	remaining: 17.8s
334:	learn: 0.6513885	total: 8.94s	remaining: 17.8s
335:	learn: 0.6505825	total: 8.97s	remaining: 17.7s
336:	learn: 0.6496601	total: 9s	remaining: 17.7s
337:	learn: 0.6492996	total: 9.03s	remaining: 17.7s
338:	learn: 0.6487214	total: 9.05s	remaining: 17.7s
339:	learn: 0.6481636	total: 9.08s	remaining: 17.6s
340:	learn: 0.6475790	total: 9.11s	remaining: 17.6s
341:	learn: 0.6469931	total: 9.14s	remaining: 17.6s
342:	learn: 0.6464088	total: 9.16s	remaining: 17.6s
343:	learn: 0.6457414	total: 9.19s	remaining: 17.5s
344:	learn: 0.6453516	total: 9.22s	remaining: 17.5s
345:	learn: 0.6444775	total: 9.25s	remaining: 17.5s
346:	learn: 0.64

488:	learn: 0.5755479	total: 13.1s	remaining: 13.7s
489:	learn: 0.5753131	total: 13.2s	remaining: 13.7s
490:	learn: 0.5749511	total: 13.2s	remaining: 13.7s
491:	learn: 0.5745971	total: 13.2s	remaining: 13.7s
492:	learn: 0.5741563	total: 13.3s	remaining: 13.6s
493:	learn: 0.5736455	total: 13.3s	remaining: 13.6s
494:	learn: 0.5732531	total: 13.3s	remaining: 13.6s
495:	learn: 0.5729489	total: 13.3s	remaining: 13.6s
496:	learn: 0.5726522	total: 13.4s	remaining: 13.5s
497:	learn: 0.5718551	total: 13.4s	remaining: 13.5s
498:	learn: 0.5712054	total: 13.4s	remaining: 13.5s
499:	learn: 0.5706568	total: 13.5s	remaining: 13.5s
500:	learn: 0.5703955	total: 13.5s	remaining: 13.4s
501:	learn: 0.5698691	total: 13.5s	remaining: 13.4s
502:	learn: 0.5694974	total: 13.5s	remaining: 13.4s
503:	learn: 0.5690512	total: 13.6s	remaining: 13.4s
504:	learn: 0.5686997	total: 13.6s	remaining: 13.3s
505:	learn: 0.5682935	total: 13.6s	remaining: 13.3s
506:	learn: 0.5679099	total: 13.7s	remaining: 13.3s
507:	learn: 

648:	learn: 0.5167084	total: 18s	remaining: 9.75s
649:	learn: 0.5164473	total: 18.1s	remaining: 9.73s
650:	learn: 0.5161650	total: 18.1s	remaining: 9.71s
651:	learn: 0.5156350	total: 18.1s	remaining: 9.68s
652:	learn: 0.5152858	total: 18.2s	remaining: 9.65s
653:	learn: 0.5148927	total: 18.2s	remaining: 9.62s
654:	learn: 0.5144110	total: 18.2s	remaining: 9.6s
655:	learn: 0.5141652	total: 18.2s	remaining: 9.57s
656:	learn: 0.5139511	total: 18.3s	remaining: 9.54s
657:	learn: 0.5136681	total: 18.3s	remaining: 9.51s
658:	learn: 0.5132507	total: 18.3s	remaining: 9.48s
659:	learn: 0.5126916	total: 18.3s	remaining: 9.45s
660:	learn: 0.5125006	total: 18.4s	remaining: 9.42s
661:	learn: 0.5122002	total: 18.4s	remaining: 9.39s
662:	learn: 0.5116691	total: 18.4s	remaining: 9.37s
663:	learn: 0.5113748	total: 18.5s	remaining: 9.34s
664:	learn: 0.5110871	total: 18.5s	remaining: 9.32s
665:	learn: 0.5108194	total: 18.5s	remaining: 9.29s
666:	learn: 0.5105166	total: 18.6s	remaining: 9.27s
667:	learn: 0.5

809:	learn: 0.4706430	total: 23.5s	remaining: 5.51s
810:	learn: 0.4703991	total: 23.5s	remaining: 5.48s
811:	learn: 0.4701129	total: 23.5s	remaining: 5.45s
812:	learn: 0.4698631	total: 23.6s	remaining: 5.42s
813:	learn: 0.4696220	total: 23.6s	remaining: 5.39s
814:	learn: 0.4691751	total: 23.6s	remaining: 5.36s
815:	learn: 0.4687257	total: 23.6s	remaining: 5.33s
816:	learn: 0.4686347	total: 23.7s	remaining: 5.3s
817:	learn: 0.4684303	total: 23.7s	remaining: 5.27s
818:	learn: 0.4682620	total: 23.7s	remaining: 5.24s
819:	learn: 0.4680059	total: 23.8s	remaining: 5.21s
820:	learn: 0.4678430	total: 23.8s	remaining: 5.18s
821:	learn: 0.4675830	total: 23.8s	remaining: 5.16s
822:	learn: 0.4674411	total: 23.8s	remaining: 5.13s
823:	learn: 0.4672893	total: 23.9s	remaining: 5.1s
824:	learn: 0.4670924	total: 23.9s	remaining: 5.07s
825:	learn: 0.4668470	total: 23.9s	remaining: 5.04s
826:	learn: 0.4664243	total: 23.9s	remaining: 5.01s
827:	learn: 0.4660807	total: 24s	remaining: 4.98s
828:	learn: 0.46

974:	learn: 0.4324806	total: 29.3s	remaining: 752ms
975:	learn: 0.4323275	total: 29.3s	remaining: 722ms
976:	learn: 0.4321239	total: 29.4s	remaining: 692ms
977:	learn: 0.4319150	total: 29.4s	remaining: 661ms
978:	learn: 0.4316529	total: 29.4s	remaining: 631ms
979:	learn: 0.4314597	total: 29.5s	remaining: 601ms
980:	learn: 0.4313173	total: 29.5s	remaining: 571ms
981:	learn: 0.4311972	total: 29.5s	remaining: 541ms
982:	learn: 0.4309876	total: 29.6s	remaining: 511ms
983:	learn: 0.4308178	total: 29.6s	remaining: 481ms
984:	learn: 0.4306332	total: 29.6s	remaining: 451ms
985:	learn: 0.4304878	total: 29.7s	remaining: 421ms
986:	learn: 0.4303096	total: 29.7s	remaining: 391ms
987:	learn: 0.4299529	total: 29.8s	remaining: 361ms
988:	learn: 0.4295879	total: 29.8s	remaining: 331ms
989:	learn: 0.4294430	total: 29.8s	remaining: 301ms
990:	learn: 0.4291339	total: 29.9s	remaining: 271ms
991:	learn: 0.4290416	total: 29.9s	remaining: 241ms
992:	learn: 0.4287776	total: 30s	remaining: 211ms
993:	learn: 0.

<catboost.core.CatBoostClassifier at 0x1db0c25c2b0>

In [145]:
predictions = cb_classifier.predict(X_test)
probas = cb_classifier.predict_proba(X_test)
confusion_matrix(y_test, predictions)

array([[2366,  205,  493],
       [ 367,  134,  225],
       [ 301,  118,  172]], dtype=int64)

In [146]:
f1_score(y_test, predictions, average="weighted")

0.6115920309952221

Посмотрим на важности признаков

In [147]:
a = [(score, col) for score, col in zip(cb_classifier.get_feature_importance(), all_data.drop(["target"], axis=1).columns)]
a.sort(reverse=True)
a

[(8.227388720240922, 'macdh'),
 (8.18537333329686, 'volume'),
 (6.494151646826219, 'macds'),
 (6.353197242691733, 'macd'),
 (3.5990901622731544, 'macds_TSLA'),
 (3.390193858786731, 'macdh_TSLA'),
 (3.315007688658888, 'macdh_MSFT'),
 (3.251034706261007, 'boll_ub'),
 (3.1598793895471293, 'macd_TSLA'),
 (3.0836244156734867, 'macds_MSFT'),
 (2.7537426246297727, 'macdh_NVDA'),
 (2.6528709438669367, 'boll_lb'),
 (2.5274857618545377, 'volume_MSFT'),
 (2.476629442428361, 'volume_TSLA'),
 (2.300796781074994, 'macds_NVDA'),
 (2.0704533991592147, 'macd_MSFT'),
 (2.061747898696431, 'boll'),
 (1.8994181189335038, 'boll_ub_MSFT'),
 (1.7001651041207497, 'macd_NVDA'),
 (1.5777768478475078, 'high'),
 (1.5215942755850975, 'boll_lb_NVDA'),
 (1.4675451496962721, 'volume_NVDA'),
 (1.4335322665779129, 'boll_lb_TSLA'),
 (1.4141856932218415, 'boll_lb_MSFT'),
 (1.2905304639416417, 'low'),
 (1.224073239015579, 'boll_ub_TSLA'),
 (1.178239262456766, 'boll_ub_NVDA'),
 (1.0769506803984659, 'close'),
 (1.04267287579

# Сравним результаты без использования коррелированных акций

In [152]:
one_stock_data = simple_indicators(target_data.reset_index(drop=True)).drop(["vw", "t", "ticker", "n"], axis=1)
one_stock_data["target"] = get_direction_with_std_threshold_target(one_stock_data, 10, 5, 5)
one_stock_data.dropna(subset=["target"], inplace=True)

one_stock_data

Unnamed: 0,volume,open,close,high,low,macd,macds,macdh,macd_xu_macds,macd_xd_macds,boll,boll_ub,boll_lb,high_x_boll_ub,low_x_boll_lb,target
0,25452.0,61.7500,61.3800,61.7500,61.3800,0.000000,0.000000,0.000000,False,False,61.380000,,,False,False,0.0
1,28416.0,61.7500,61.5875,61.7500,61.4225,0.004655,0.002586,0.002069,True,False,61.483750,61.777199,61.190301,False,True,0.0
2,17296.0,61.5400,61.7500,61.7500,61.5250,0.010971,0.006023,0.004948,False,False,61.572500,61.943411,61.201589,False,False,0.0
3,6776.0,61.6750,61.9100,61.9300,61.6750,0.019351,0.010538,0.008814,False,False,61.656875,62.110332,61.203418,False,False,0.0
4,17188.0,62.0000,61.7875,62.0250,61.7875,0.018623,0.012943,0.005680,False,False,61.683000,62.092716,61.273284,False,False,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
21900,17731.0,145.7100,145.6800,145.7500,145.6600,-0.008033,0.051588,-0.059621,False,False,146.030925,146.464263,145.597587,False,False,0.0
21901,10326.0,145.6800,145.7800,145.8000,145.6800,-0.017429,0.037784,-0.055214,False,False,146.028425,146.467187,145.589663,False,False,0.0
21902,5767.0,145.7301,145.7000,145.7301,145.7000,-0.030974,0.024033,-0.055007,False,False,146.018925,146.478061,145.559789,False,False,0.0
21903,38578.0,145.7100,145.7200,145.8600,145.7000,-0.039638,0.011299,-0.050936,False,False,145.993175,146.459004,145.527346,False,False,0.0


In [153]:
TEST_RATIO = 0.2

X = one_stock_data.drop(["target"], axis=1).to_numpy()
y = one_stock_data["target"].to_numpy()

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=TEST_RATIO, shuffle=False)
X_train.shape, X_test.shape

((17524, 15), (4381, 15))

In [154]:
class_weight = compute_class_weight(class_weight="balanced", classes=np.unique(y_train), y=y_train) * np.array([2, 1, 1])

cb_classifier = CatBoostClassifier(class_weights=class_weight)
cb_classifier.fit(X_train, y_train)

Learning rate set to 0.091569
0:	learn: 1.0768078	total: 10.8ms	remaining: 10.8s
1:	learn: 1.0573511	total: 21ms	remaining: 10.5s
2:	learn: 1.0423452	total: 29.9ms	remaining: 9.93s
3:	learn: 1.0300260	total: 37.7ms	remaining: 9.38s
4:	learn: 1.0186388	total: 45ms	remaining: 8.96s
5:	learn: 1.0097189	total: 53ms	remaining: 8.78s
6:	learn: 1.0022105	total: 61.3ms	remaining: 8.69s
7:	learn: 0.9960799	total: 69ms	remaining: 8.56s
8:	learn: 0.9893197	total: 76.7ms	remaining: 8.44s
9:	learn: 0.9836079	total: 85.4ms	remaining: 8.45s
10:	learn: 0.9787778	total: 95.5ms	remaining: 8.58s
11:	learn: 0.9746735	total: 105ms	remaining: 8.66s
12:	learn: 0.9699945	total: 115ms	remaining: 8.7s
13:	learn: 0.9665636	total: 123ms	remaining: 8.69s
14:	learn: 0.9632594	total: 132ms	remaining: 8.69s
15:	learn: 0.9601537	total: 141ms	remaining: 8.68s
16:	learn: 0.9571205	total: 149ms	remaining: 8.64s
17:	learn: 0.9543257	total: 158ms	remaining: 8.59s
18:	learn: 0.9517313	total: 166ms	remaining: 8.56s
19:	learn

169:	learn: 0.8309173	total: 1.47s	remaining: 7.17s
170:	learn: 0.8304171	total: 1.48s	remaining: 7.16s
171:	learn: 0.8297807	total: 1.49s	remaining: 7.15s
172:	learn: 0.8290897	total: 1.49s	remaining: 7.14s
173:	learn: 0.8287033	total: 1.5s	remaining: 7.14s
174:	learn: 0.8282716	total: 1.51s	remaining: 7.13s
175:	learn: 0.8277677	total: 1.52s	remaining: 7.13s
176:	learn: 0.8266429	total: 1.53s	remaining: 7.13s
177:	learn: 0.8257925	total: 1.54s	remaining: 7.12s
178:	learn: 0.8253793	total: 1.55s	remaining: 7.11s
179:	learn: 0.8246512	total: 1.56s	remaining: 7.1s
180:	learn: 0.8236866	total: 1.57s	remaining: 7.1s
181:	learn: 0.8229185	total: 1.58s	remaining: 7.09s
182:	learn: 0.8222056	total: 1.59s	remaining: 7.08s
183:	learn: 0.8217810	total: 1.59s	remaining: 7.08s
184:	learn: 0.8205300	total: 1.6s	remaining: 7.07s
185:	learn: 0.8200988	total: 1.61s	remaining: 7.06s
186:	learn: 0.8192401	total: 1.62s	remaining: 7.05s
187:	learn: 0.8183266	total: 1.63s	remaining: 7.05s
188:	learn: 0.81

335:	learn: 0.7377927	total: 2.97s	remaining: 5.86s
336:	learn: 0.7375809	total: 2.97s	remaining: 5.85s
337:	learn: 0.7371842	total: 2.98s	remaining: 5.84s
338:	learn: 0.7365559	total: 2.99s	remaining: 5.83s
339:	learn: 0.7362026	total: 3s	remaining: 5.82s
340:	learn: 0.7355945	total: 3.01s	remaining: 5.81s
341:	learn: 0.7352870	total: 3.02s	remaining: 5.8s
342:	learn: 0.7349075	total: 3.03s	remaining: 5.8s
343:	learn: 0.7343253	total: 3.04s	remaining: 5.79s
344:	learn: 0.7338834	total: 3.05s	remaining: 5.78s
345:	learn: 0.7335542	total: 3.06s	remaining: 5.78s
346:	learn: 0.7328651	total: 3.06s	remaining: 5.77s
347:	learn: 0.7323498	total: 3.08s	remaining: 5.76s
348:	learn: 0.7319513	total: 3.08s	remaining: 5.75s
349:	learn: 0.7316238	total: 3.09s	remaining: 5.75s
350:	learn: 0.7311862	total: 3.1s	remaining: 5.74s
351:	learn: 0.7306496	total: 3.11s	remaining: 5.73s
352:	learn: 0.7300971	total: 3.12s	remaining: 5.72s
353:	learn: 0.7296416	total: 3.13s	remaining: 5.71s
354:	learn: 0.7294

507:	learn: 0.6726287	total: 4.47s	remaining: 4.33s
508:	learn: 0.6724199	total: 4.48s	remaining: 4.32s
509:	learn: 0.6722120	total: 4.49s	remaining: 4.32s
510:	learn: 0.6718470	total: 4.5s	remaining: 4.31s
511:	learn: 0.6712808	total: 4.51s	remaining: 4.3s
512:	learn: 0.6709077	total: 4.53s	remaining: 4.3s
513:	learn: 0.6706091	total: 4.54s	remaining: 4.29s
514:	learn: 0.6702329	total: 4.55s	remaining: 4.29s
515:	learn: 0.6698935	total: 4.56s	remaining: 4.28s
516:	learn: 0.6697666	total: 4.58s	remaining: 4.27s
517:	learn: 0.6694766	total: 4.58s	remaining: 4.27s
518:	learn: 0.6692185	total: 4.59s	remaining: 4.26s
519:	learn: 0.6687540	total: 4.61s	remaining: 4.25s
520:	learn: 0.6684317	total: 4.62s	remaining: 4.25s
521:	learn: 0.6679766	total: 4.63s	remaining: 4.24s
522:	learn: 0.6676680	total: 4.64s	remaining: 4.23s
523:	learn: 0.6672278	total: 4.66s	remaining: 4.23s
524:	learn: 0.6667756	total: 4.67s	remaining: 4.22s
525:	learn: 0.6664368	total: 4.68s	remaining: 4.22s
526:	learn: 0.6

677:	learn: 0.6190018	total: 7.1s	remaining: 3.37s
678:	learn: 0.6188389	total: 7.12s	remaining: 3.36s
679:	learn: 0.6183278	total: 7.13s	remaining: 3.35s
680:	learn: 0.6180062	total: 7.14s	remaining: 3.35s
681:	learn: 0.6176359	total: 7.15s	remaining: 3.33s
682:	learn: 0.6173314	total: 7.16s	remaining: 3.33s
683:	learn: 0.6170876	total: 7.17s	remaining: 3.31s
684:	learn: 0.6169208	total: 7.18s	remaining: 3.3s
685:	learn: 0.6168230	total: 7.2s	remaining: 3.29s
686:	learn: 0.6166156	total: 7.21s	remaining: 3.28s
687:	learn: 0.6164179	total: 7.22s	remaining: 3.27s
688:	learn: 0.6160773	total: 7.23s	remaining: 3.26s
689:	learn: 0.6157920	total: 7.24s	remaining: 3.25s
690:	learn: 0.6155728	total: 7.25s	remaining: 3.24s
691:	learn: 0.6152350	total: 7.26s	remaining: 3.23s
692:	learn: 0.6149318	total: 7.27s	remaining: 3.22s
693:	learn: 0.6146675	total: 7.29s	remaining: 3.21s
694:	learn: 0.6144605	total: 7.3s	remaining: 3.2s
695:	learn: 0.6142170	total: 7.31s	remaining: 3.19s
696:	learn: 0.614

845:	learn: 0.5786218	total: 8.58s	remaining: 1.56s
846:	learn: 0.5784403	total: 8.59s	remaining: 1.55s
847:	learn: 0.5782632	total: 8.6s	remaining: 1.54s
848:	learn: 0.5780414	total: 8.61s	remaining: 1.53s
849:	learn: 0.5779291	total: 8.61s	remaining: 1.52s
850:	learn: 0.5777859	total: 8.62s	remaining: 1.51s
851:	learn: 0.5775732	total: 8.63s	remaining: 1.5s
852:	learn: 0.5774523	total: 8.64s	remaining: 1.49s
853:	learn: 0.5770415	total: 8.65s	remaining: 1.48s
854:	learn: 0.5768194	total: 8.66s	remaining: 1.47s
855:	learn: 0.5765253	total: 8.67s	remaining: 1.46s
856:	learn: 0.5762916	total: 8.68s	remaining: 1.45s
857:	learn: 0.5760563	total: 8.68s	remaining: 1.44s
858:	learn: 0.5757181	total: 8.69s	remaining: 1.43s
859:	learn: 0.5754368	total: 8.7s	remaining: 1.42s
860:	learn: 0.5752452	total: 8.71s	remaining: 1.41s
861:	learn: 0.5749826	total: 8.72s	remaining: 1.4s
862:	learn: 0.5748287	total: 8.73s	remaining: 1.39s
863:	learn: 0.5746905	total: 8.74s	remaining: 1.38s
864:	learn: 0.57

<catboost.core.CatBoostClassifier at 0x1db0b5db460>

In [155]:
predictions = cb_classifier.predict(X_test)
probas = cb_classifier.predict_proba(X_test)
confusion_matrix(y_test, predictions)

array([[2230,  286,  548],
       [ 403,  153,  170],
       [ 318,  107,  166]], dtype=int64)

In [156]:
f1_score(y_test, predictions, average="weighted")

0.588808406180549