In [121]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import os
import re
import sys
sys.path.append("./utils")

from generate_features import *
from get_targets import *
from dnn_utils import *

from functools import reduce

from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import confusion_matrix
from sklearn.metrics import average_precision_score
from sklearn.metrics import f1_score
from catboost import CatBoostClassifier

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [74]:
DATA_PATH = "technology_1m_prices.csv"
TICKERS = ["AAPL", "MSFT", "GOOG", "TSLA", "NVDA"]
TARGET_TICKER = "AAPL"
ADDITIONAL_TICKERS = ["MSFT", "GOOG", "TSLA", "NVDA"]

In [75]:
data = pd.read_csv(DATA_PATH, index_col=0)
data.dropna(inplace=True)
data = data[[ticker in TICKERS for ticker in data["ticker"]]]
data

Unnamed: 0,v,vw,o,c,h,l,t,n,ticker
0,1338056.0,58.2800,57.8425,58.6775,58.875,57.7000,2020-03-23 12:32:00,4458.0,AAPL
1,13767216.0,56.8116,58.6875,55.7975,58.875,55.5875,2020-03-23 13:02:00,24455.0,AAPL
2,41790168.0,55.7332,55.8009,55.2800,56.750,54.5000,2020-03-23 13:32:00,119429.0,AAPL
3,24935492.0,55.5366,55.2850,55.5364,56.170,55.0000,2020-03-23 14:02:00,69243.0,AAPL
4,20382612.0,55.3185,55.5625,55.0238,55.890,54.6500,2020-03-23 14:32:00,58723.0,AAPL
...,...,...,...,...,...,...,...,...,...
1850,61601.0,233.6102,234.6100,233.0100,234.620,232.9100,2022-02-18 22:30:00,1656.0,NVDA
1851,55422.0,232.5017,233.0000,232.2201,233.190,232.0100,2022-02-18 23:00:00,1049.0,NVDA
1852,52479.0,231.9978,232.2100,232.0700,232.310,231.7600,2022-02-18 23:30:00,1036.0,NVDA
1853,27356.0,232.2012,232.1000,231.9300,232.450,231.9000,2022-02-19 00:00:00,671.0,NVDA


In [76]:
target_data = data[data["ticker"] == TARGET_TICKER].set_index("t").drop(["ticker"], axis=1)
additional_data = data[[ticker in ADDITIONAL_TICKERS for ticker in data["ticker"]]].drop(["vw"], axis=1)
display(target_data)
display(additional_data)

Unnamed: 0_level_0,v,vw,o,c,h,l,n
t,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
2020-03-23 12:32:00,1338056.0,58.2800,57.8425,58.6775,58.875,57.7000,4458.0
2020-03-23 13:02:00,13767216.0,56.8116,58.6875,55.7975,58.875,55.5875,24455.0
2020-03-23 13:32:00,41790168.0,55.7332,55.8009,55.2800,56.750,54.5000,119429.0
2020-03-23 14:02:00,24935492.0,55.5366,55.2850,55.5364,56.170,55.0000,69243.0
2020-03-23 14:32:00,20382612.0,55.3185,55.5625,55.0238,55.890,54.6500,58723.0
...,...,...,...,...,...,...,...
2022-02-18 22:30:00,42693.0,166.9412,167.0400,166.8000,167.040,166.8000,961.0
2022-02-18 23:00:00,29043.0,166.6105,166.8100,166.4200,166.880,166.4200,555.0
2022-02-18 23:30:00,23891.0,166.3759,166.4100,166.4700,166.500,166.3200,375.0
2022-02-19 00:00:00,19223.0,166.4545,166.4900,166.4400,166.530,166.3800,301.0


Unnamed: 0,v,o,c,h,l,t,n,ticker
0,204075.0,139.25,141.7500,142.4500,138.30,2020-03-23 12:33:00,2320.0,MSFT
1,4234855.0,141.61,137.4600,142.4000,134.88,2020-03-23 13:03:00,32231.0,MSFT
2,8973373.0,137.45,138.6600,139.1957,134.52,2020-03-23 13:33:00,90313.0,MSFT
3,5848586.0,138.70,137.7100,140.5700,137.18,2020-03-23 14:03:00,55368.0,MSFT
4,3906426.0,137.71,135.6349,138.6700,135.56,2020-03-23 14:33:00,42775.0,MSFT
...,...,...,...,...,...,...,...,...
1850,61601.0,234.61,233.0100,234.6200,232.91,2022-02-18 22:30:00,1656.0,NVDA
1851,55422.0,233.00,232.2201,233.1900,232.01,2022-02-18 23:00:00,1049.0,NVDA
1852,52479.0,232.21,232.0700,232.3100,231.76,2022-02-18 23:30:00,1036.0,NVDA
1853,27356.0,232.10,231.9300,232.4500,231.90,2022-02-19 00:00:00,671.0,NVDA


In [97]:
all_data = target_data
for ticker in additional_data["ticker"].unique():
    ticker_data = additional_data[additional_data["ticker"] == ticker].set_index("t").drop(["ticker"], axis=1)
    all_data = all_data.join(ticker_data, rsuffix=f"_{ticker}", how="left")

all_data = simple_indicators(all_data.reset_index())
all_data["target"] = get_direction_with_std_threshold_target(all_data, 10, 5, 5)
all_data.dropna(subset=["target"], inplace=True)

all_data

Unnamed: 0,volume,open,close,high,low,v_MSFT,o_MSFT,c_MSFT,h_MSFT,l_MSFT,...,macds,macdh,macd_xu_macds,macd_xd_macds,boll,boll_ub,boll_lb,high_x_boll_ub,low_x_boll_lb,target
0,1338056.0,57.8425,58.6775,58.875,57.7000,,,,,,...,0.000000,0.000000,False,False,58.677500,,,False,False,0.0
1,13767216.0,58.6875,55.7975,58.875,55.5875,,,,,,...,-0.035897,-0.028718,False,False,57.237500,61.310435,53.164565,False,True,0.0
2,41790168.0,55.8009,55.2800,56.750,54.5000,,,,,,...,-0.061505,-0.036875,False,False,56.585000,60.246076,52.923924,False,False,0.0
3,24935492.0,55.2850,55.5364,56.170,55.0000,,,,,,...,-0.075246,-0.026823,False,False,56.322850,59.490690,53.155010,False,False,0.0
4,20382612.0,55.5625,55.0238,55.890,54.6500,,,,,,...,-0.089250,-0.033072,False,False,56.063040,59.042374,53.083706,False,False,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
15373,5434352.0,167.6900,166.6700,167.810,166.6300,2200209.0,289.2150,287.4701,289.6200,287.440,...,-0.467370,-0.224737,False,False,169.019525,171.020273,167.018777,False,False,0.0
15374,4383547.0,166.6700,166.3500,166.780,166.1900,1596211.0,287.5066,286.8199,287.7634,286.305,...,-0.533024,-0.262617,False,False,168.888525,171.218870,166.558180,False,False,0.0
15375,3898394.0,166.3400,167.0700,167.300,166.2995,1393481.0,286.7900,287.6300,287.9850,286.660,...,-0.588470,-0.221785,False,False,168.800025,171.268476,166.331574,False,False,0.0
15376,3976796.0,167.0600,168.0400,168.110,167.0300,1491604.0,287.6200,289.4250,289.5500,287.570,...,-0.617795,-0.117298,False,False,168.714525,171.162808,166.266242,False,True,0.0


In [98]:
all_data["target"].value_counts()

0.0    10423
1.0     2824
2.0     2131
Name: target, dtype: int64

In [140]:
TEST_RATIO = 0.2

X = all_data.drop(["target"], axis=1).to_numpy()
y = all_data["target"].to_numpy()

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=TEST_RATIO, shuffle=False)
X_train.shape, X_test.shape

((12302, 39), (3076, 39))

In [130]:
class_weight = compute_class_weight(class_weight="balanced", classes=np.unique(y_train), y=y_train) * np.array([2, 1, 1])

cb_classifier = CatBoostClassifier(class_weights=class_weight)
cb_classifier.fit(X_train, y_train)

Learning rate set to 0.089932
0:	learn: 1.0755774	total: 17ms	remaining: 17s
1:	learn: 1.0572714	total: 33ms	remaining: 16.5s
2:	learn: 1.0430949	total: 48ms	remaining: 15.9s
3:	learn: 1.0293306	total: 62.9ms	remaining: 15.7s
4:	learn: 1.0186597	total: 77.8ms	remaining: 15.5s
5:	learn: 1.0093664	total: 93.9ms	remaining: 15.6s
6:	learn: 1.0019055	total: 109ms	remaining: 15.4s
7:	learn: 0.9940935	total: 124ms	remaining: 15.4s
8:	learn: 0.9878541	total: 139ms	remaining: 15.3s
9:	learn: 0.9823344	total: 156ms	remaining: 15.4s
10:	learn: 0.9757266	total: 171ms	remaining: 15.4s
11:	learn: 0.9718653	total: 188ms	remaining: 15.5s
12:	learn: 0.9675535	total: 206ms	remaining: 15.6s
13:	learn: 0.9639099	total: 223ms	remaining: 15.7s
14:	learn: 0.9604773	total: 238ms	remaining: 15.6s
15:	learn: 0.9570834	total: 253ms	remaining: 15.5s
16:	learn: 0.9546924	total: 268ms	remaining: 15.5s
17:	learn: 0.9513751	total: 284ms	remaining: 15.5s
18:	learn: 0.9488283	total: 300ms	remaining: 15.5s
19:	learn: 0.

161:	learn: 0.7871765	total: 2.67s	remaining: 13.8s
162:	learn: 0.7865666	total: 2.69s	remaining: 13.8s
163:	learn: 0.7856865	total: 2.7s	remaining: 13.8s
164:	learn: 0.7848582	total: 2.72s	remaining: 13.8s
165:	learn: 0.7837496	total: 2.74s	remaining: 13.8s
166:	learn: 0.7825548	total: 2.76s	remaining: 13.8s
167:	learn: 0.7812720	total: 2.78s	remaining: 13.8s
168:	learn: 0.7808622	total: 2.8s	remaining: 13.8s
169:	learn: 0.7796472	total: 2.82s	remaining: 13.7s
170:	learn: 0.7785988	total: 2.83s	remaining: 13.7s
171:	learn: 0.7773232	total: 2.85s	remaining: 13.7s
172:	learn: 0.7768209	total: 2.87s	remaining: 13.7s
173:	learn: 0.7760795	total: 2.89s	remaining: 13.7s
174:	learn: 0.7753085	total: 2.91s	remaining: 13.7s
175:	learn: 0.7749184	total: 2.93s	remaining: 13.7s
176:	learn: 0.7741882	total: 2.94s	remaining: 13.7s
177:	learn: 0.7729316	total: 2.96s	remaining: 13.7s
178:	learn: 0.7722909	total: 2.98s	remaining: 13.7s
179:	learn: 0.7710506	total: 2.99s	remaining: 13.6s
180:	learn: 0.

324:	learn: 0.6683767	total: 5.51s	remaining: 11.5s
325:	learn: 0.6678579	total: 5.53s	remaining: 11.4s
326:	learn: 0.6672388	total: 5.55s	remaining: 11.4s
327:	learn: 0.6666996	total: 5.57s	remaining: 11.4s
328:	learn: 0.6658427	total: 5.59s	remaining: 11.4s
329:	learn: 0.6651842	total: 5.61s	remaining: 11.4s
330:	learn: 0.6648213	total: 5.62s	remaining: 11.4s
331:	learn: 0.6645375	total: 5.64s	remaining: 11.3s
332:	learn: 0.6637926	total: 5.66s	remaining: 11.3s
333:	learn: 0.6634679	total: 5.67s	remaining: 11.3s
334:	learn: 0.6626141	total: 5.69s	remaining: 11.3s
335:	learn: 0.6623273	total: 5.71s	remaining: 11.3s
336:	learn: 0.6617732	total: 5.72s	remaining: 11.3s
337:	learn: 0.6610356	total: 5.74s	remaining: 11.2s
338:	learn: 0.6606590	total: 5.76s	remaining: 11.2s
339:	learn: 0.6603492	total: 5.78s	remaining: 11.2s
340:	learn: 0.6596822	total: 5.79s	remaining: 11.2s
341:	learn: 0.6589623	total: 5.81s	remaining: 11.2s
342:	learn: 0.6585324	total: 5.83s	remaining: 11.2s
343:	learn: 

490:	learn: 0.5866924	total: 8.4s	remaining: 8.71s
491:	learn: 0.5863016	total: 8.42s	remaining: 8.69s
492:	learn: 0.5859331	total: 8.44s	remaining: 8.68s
493:	learn: 0.5855015	total: 8.45s	remaining: 8.66s
494:	learn: 0.5853124	total: 8.47s	remaining: 8.64s
495:	learn: 0.5848540	total: 8.49s	remaining: 8.62s
496:	learn: 0.5847008	total: 8.51s	remaining: 8.61s
497:	learn: 0.5840777	total: 8.52s	remaining: 8.59s
498:	learn: 0.5838314	total: 8.54s	remaining: 8.57s
499:	learn: 0.5831968	total: 8.56s	remaining: 8.56s
500:	learn: 0.5829347	total: 8.57s	remaining: 8.54s
501:	learn: 0.5824563	total: 8.59s	remaining: 8.52s
502:	learn: 0.5820050	total: 8.61s	remaining: 8.51s
503:	learn: 0.5815486	total: 8.63s	remaining: 8.49s
504:	learn: 0.5811815	total: 8.65s	remaining: 8.47s
505:	learn: 0.5810110	total: 8.66s	remaining: 8.46s
506:	learn: 0.5806037	total: 8.68s	remaining: 8.44s
507:	learn: 0.5804698	total: 8.7s	remaining: 8.43s
508:	learn: 0.5803384	total: 8.72s	remaining: 8.41s
509:	learn: 0.

655:	learn: 0.5261564	total: 11.3s	remaining: 5.9s
656:	learn: 0.5258163	total: 11.3s	remaining: 5.88s
657:	learn: 0.5254660	total: 11.3s	remaining: 5.87s
658:	learn: 0.5250272	total: 11.3s	remaining: 5.85s
659:	learn: 0.5247933	total: 11.3s	remaining: 5.84s
660:	learn: 0.5243303	total: 11.3s	remaining: 5.82s
661:	learn: 0.5237612	total: 11.4s	remaining: 5.8s
662:	learn: 0.5235284	total: 11.4s	remaining: 5.79s
663:	learn: 0.5229565	total: 11.4s	remaining: 5.77s
664:	learn: 0.5223629	total: 11.4s	remaining: 5.75s
665:	learn: 0.5219762	total: 11.4s	remaining: 5.73s
666:	learn: 0.5216723	total: 11.4s	remaining: 5.72s
667:	learn: 0.5212218	total: 11.5s	remaining: 5.7s
668:	learn: 0.5208982	total: 11.5s	remaining: 5.68s
669:	learn: 0.5205320	total: 11.5s	remaining: 5.67s
670:	learn: 0.5203383	total: 11.5s	remaining: 5.65s
671:	learn: 0.5200625	total: 11.5s	remaining: 5.63s
672:	learn: 0.5195933	total: 11.6s	remaining: 5.61s
673:	learn: 0.5192216	total: 11.6s	remaining: 5.6s
674:	learn: 0.51

818:	learn: 0.4758590	total: 14.1s	remaining: 3.12s
819:	learn: 0.4756017	total: 14.1s	remaining: 3.11s
820:	learn: 0.4754242	total: 14.2s	remaining: 3.09s
821:	learn: 0.4751839	total: 14.2s	remaining: 3.07s
822:	learn: 0.4748145	total: 14.2s	remaining: 3.06s
823:	learn: 0.4745108	total: 14.2s	remaining: 3.04s
824:	learn: 0.4741756	total: 14.2s	remaining: 3.02s
825:	learn: 0.4739093	total: 14.3s	remaining: 3s
826:	learn: 0.4737107	total: 14.3s	remaining: 2.99s
827:	learn: 0.4734837	total: 14.3s	remaining: 2.97s
828:	learn: 0.4731014	total: 14.3s	remaining: 2.95s
829:	learn: 0.4729332	total: 14.3s	remaining: 2.94s
830:	learn: 0.4723877	total: 14.4s	remaining: 2.92s
831:	learn: 0.4721787	total: 14.4s	remaining: 2.9s
832:	learn: 0.4720563	total: 14.4s	remaining: 2.88s
833:	learn: 0.4716831	total: 14.4s	remaining: 2.87s
834:	learn: 0.4714633	total: 14.4s	remaining: 2.85s
835:	learn: 0.4711454	total: 14.5s	remaining: 2.83s
836:	learn: 0.4709205	total: 14.5s	remaining: 2.82s
837:	learn: 0.47

982:	learn: 0.4327771	total: 17.4s	remaining: 301ms
983:	learn: 0.4326375	total: 17.4s	remaining: 284ms
984:	learn: 0.4323855	total: 17.5s	remaining: 266ms
985:	learn: 0.4322629	total: 17.5s	remaining: 248ms
986:	learn: 0.4320528	total: 17.5s	remaining: 231ms
987:	learn: 0.4317351	total: 17.5s	remaining: 213ms
988:	learn: 0.4314073	total: 17.6s	remaining: 195ms
989:	learn: 0.4312494	total: 17.6s	remaining: 178ms
990:	learn: 0.4310625	total: 17.6s	remaining: 160ms
991:	learn: 0.4307829	total: 17.6s	remaining: 142ms
992:	learn: 0.4306062	total: 17.6s	remaining: 124ms
993:	learn: 0.4304912	total: 17.6s	remaining: 107ms
994:	learn: 0.4302169	total: 17.7s	remaining: 88.8ms
995:	learn: 0.4298752	total: 17.7s	remaining: 71ms
996:	learn: 0.4295630	total: 17.7s	remaining: 53.2ms
997:	learn: 0.4294640	total: 17.7s	remaining: 35.5ms
998:	learn: 0.4293942	total: 17.7s	remaining: 17.7ms
999:	learn: 0.4291356	total: 17.7s	remaining: 0us


<catboost.core.CatBoostClassifier at 0x21975985ac0>

In [132]:
predictions = cb_classifier.predict(X_test)
probas = cb_classifier.predict_proba(X_test)
confusion_matrix(y_test, predictions)

array([[1432,   19,  668],
       [ 277,    7,  242],
       [ 213,    6,  212]], dtype=int64)

In [133]:
f1_score(y_test, predictions, average="weighted")

0.5307799671406075

Посмотрим на важности признаков

In [114]:
a = [(score, col) for score, col in zip(cb_classifier.get_feature_importance(), all_data.drop(["target"], axis=1).columns)]
a.sort(reverse=True)
a

[(11.577048358887064, 'macdh'),
 (8.580339348037821, 'macds'),
 (7.472517188945133, 'macd'),
 (7.374608886840692, 'volume'),
 (4.434299543111685, 'boll_ub'),
 (4.253077332090552, 'boll_lb'),
 (2.8283612634619977, 'o_TSLA'),
 (2.776636260089023, 'n_TSLA'),
 (2.7459282536240948, 'n_MSFT'),
 (2.6187528851024475, 'l_TSLA'),
 (2.5885612895780032, 'v_NVDA'),
 (2.460306569344442, 'l_MSFT'),
 (2.3913946499232366, 'l_NVDA'),
 (2.324718029157461, 'h_MSFT'),
 (2.318229073955199, 'v_MSFT'),
 (2.307749887879362, 'high'),
 (2.2491796189498707, 'boll'),
 (2.241470855292928, 'h_TSLA'),
 (2.233349941776797, 'c_TSLA'),
 (2.201390187035577, 'v_TSLA'),
 (2.091147578645298, 'low'),
 (1.9808850272936775, 'h_NVDA'),
 (1.959679251284039, 'n_NVDA'),
 (1.9313025116556735, 'c_NVDA'),
 (1.8280186886977439, 'open'),
 (1.7010462678050602, 'o_NVDA'),
 (1.5307905278652734, 'close'),
 (1.5114100212401849, 'o_MSFT'),
 (1.4312765919326522, 'c_MSFT'),
 (0.9716375933996533, 'n_GOOG'),
 (0.9380131426371401, 'l_GOOG'),
 (0.

# Сравним результаты без использования коррелированных акций

In [134]:
one_stock_data = simple_indicators(target_data.reset_index())
one_stock_data["target"] = get_direction_with_std_threshold_target(one_stock_data, 10, 5, 5)
one_stock_data.dropna(subset=["target"], inplace=True)

one_stock_data

Unnamed: 0,volume,open,close,high,low,macd,macds,macdh,macd_xu_macds,macd_xd_macds,boll,boll_ub,boll_lb,high_x_boll_ub,low_x_boll_lb,target
0,1338056.0,57.8425,58.6775,58.875,57.7000,0.000000,0.000000,0.000000,False,False,58.677500,,,False,False,0.0
1,13767216.0,58.6875,55.7975,58.875,55.5875,-0.064615,-0.035897,-0.028718,False,False,57.237500,61.310435,53.164565,False,True,0.0
2,41790168.0,55.8009,55.2800,56.750,54.5000,-0.098380,-0.061505,-0.036875,False,False,56.585000,60.246076,52.923924,False,False,0.0
3,24935492.0,55.2850,55.5364,56.170,55.0000,-0.102069,-0.075246,-0.026823,False,False,56.322850,59.490690,53.155010,False,False,0.0
4,20382612.0,55.5625,55.0238,55.890,54.6500,-0.122322,-0.089250,-0.033072,False,False,56.063040,59.042374,53.083706,False,False,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
15373,5434352.0,167.6900,166.6700,167.810,166.6300,-0.692106,-0.467370,-0.224737,False,False,169.019525,171.020273,167.018777,False,False,0.0
15374,4383547.0,166.6700,166.3500,166.780,166.1900,-0.795642,-0.533024,-0.262617,False,False,168.888525,171.218870,166.558180,False,False,0.0
15375,3898394.0,166.3400,167.0700,167.300,166.2995,-0.810256,-0.588470,-0.221785,False,False,168.800025,171.268476,166.331574,False,False,0.0
15376,3976796.0,167.0600,168.0400,168.110,167.0300,-0.735093,-0.617795,-0.117298,False,False,168.714525,171.162808,166.266242,False,True,0.0


In [139]:
TEST_RATIO = 0.2

X = one_stock_data.drop(["target"], axis=1).to_numpy()
y = one_stock_data["target"].to_numpy()

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=TEST_RATIO, shuffle=False)
X_train.shape, X_test.shape

((12302, 15), (3076, 15))

In [136]:
class_weight = compute_class_weight(class_weight="balanced", classes=np.unique(y_train), y=y_train) * np.array([2, 1, 1])

cb_classifier = CatBoostClassifier(class_weights=class_weight)
cb_classifier.fit(X_train, y_train)

Learning rate set to 0.089932
0:	learn: 1.0751600	total: 24.5ms	remaining: 24.4s
1:	learn: 1.0567385	total: 31.8ms	remaining: 15.9s
2:	learn: 1.0413418	total: 38.9ms	remaining: 12.9s
3:	learn: 1.0275017	total: 46.1ms	remaining: 11.5s
4:	learn: 1.0175303	total: 53.3ms	remaining: 10.6s
5:	learn: 1.0086857	total: 60.6ms	remaining: 10s
6:	learn: 1.0002078	total: 67.9ms	remaining: 9.64s
7:	learn: 0.9927828	total: 74.7ms	remaining: 9.27s
8:	learn: 0.9874099	total: 81.5ms	remaining: 8.97s
9:	learn: 0.9820773	total: 89.2ms	remaining: 8.83s
10:	learn: 0.9772254	total: 96.3ms	remaining: 8.66s
11:	learn: 0.9731126	total: 103ms	remaining: 8.49s
12:	learn: 0.9695222	total: 110ms	remaining: 8.38s
13:	learn: 0.9660739	total: 117ms	remaining: 8.27s
14:	learn: 0.9625891	total: 124ms	remaining: 8.16s
15:	learn: 0.9597934	total: 131ms	remaining: 8.08s
16:	learn: 0.9571550	total: 139ms	remaining: 8.02s
17:	learn: 0.9546453	total: 145ms	remaining: 7.93s
18:	learn: 0.9531809	total: 152ms	remaining: 7.87s
19

178:	learn: 0.8137225	total: 1.28s	remaining: 5.86s
179:	learn: 0.8126602	total: 1.28s	remaining: 5.85s
180:	learn: 0.8123812	total: 1.29s	remaining: 5.84s
181:	learn: 0.8117255	total: 1.3s	remaining: 5.84s
182:	learn: 0.8111441	total: 1.31s	remaining: 5.84s
183:	learn: 0.8104958	total: 1.31s	remaining: 5.83s
184:	learn: 0.8100468	total: 1.32s	remaining: 5.82s
185:	learn: 0.8096105	total: 1.33s	remaining: 5.82s
186:	learn: 0.8090664	total: 1.34s	remaining: 5.81s
187:	learn: 0.8083581	total: 1.34s	remaining: 5.8s
188:	learn: 0.8075448	total: 1.35s	remaining: 5.79s
189:	learn: 0.8068395	total: 1.36s	remaining: 5.79s
190:	learn: 0.8062940	total: 1.36s	remaining: 5.78s
191:	learn: 0.8056039	total: 1.37s	remaining: 5.77s
192:	learn: 0.8049359	total: 1.38s	remaining: 5.76s
193:	learn: 0.8045312	total: 1.38s	remaining: 5.75s
194:	learn: 0.8039672	total: 1.39s	remaining: 5.74s
195:	learn: 0.8034492	total: 1.4s	remaining: 5.73s
196:	learn: 0.8027283	total: 1.4s	remaining: 5.72s
197:	learn: 0.80

353:	learn: 0.7160176	total: 2.56s	remaining: 4.67s
354:	learn: 0.7156283	total: 2.56s	remaining: 4.66s
355:	learn: 0.7150239	total: 2.57s	remaining: 4.65s
356:	learn: 0.7146566	total: 2.58s	remaining: 4.65s
357:	learn: 0.7143121	total: 2.59s	remaining: 4.64s
358:	learn: 0.7136906	total: 2.6s	remaining: 4.63s
359:	learn: 0.7134003	total: 2.6s	remaining: 4.63s
360:	learn: 0.7129434	total: 2.61s	remaining: 4.62s
361:	learn: 0.7125706	total: 2.62s	remaining: 4.61s
362:	learn: 0.7120397	total: 2.63s	remaining: 4.61s
363:	learn: 0.7116537	total: 2.63s	remaining: 4.6s
364:	learn: 0.7111208	total: 2.64s	remaining: 4.59s
365:	learn: 0.7105836	total: 2.65s	remaining: 4.59s
366:	learn: 0.7100095	total: 2.66s	remaining: 4.58s
367:	learn: 0.7096865	total: 2.66s	remaining: 4.58s
368:	learn: 0.7093520	total: 2.67s	remaining: 4.57s
369:	learn: 0.7091029	total: 2.68s	remaining: 4.56s
370:	learn: 0.7082706	total: 2.69s	remaining: 4.55s
371:	learn: 0.7078870	total: 2.69s	remaining: 4.55s
372:	learn: 0.7

526:	learn: 0.6434709	total: 3.83s	remaining: 3.43s
527:	learn: 0.6431691	total: 3.83s	remaining: 3.43s
528:	learn: 0.6427802	total: 3.84s	remaining: 3.42s
529:	learn: 0.6424415	total: 3.85s	remaining: 3.41s
530:	learn: 0.6421185	total: 3.85s	remaining: 3.4s
531:	learn: 0.6418527	total: 3.86s	remaining: 3.4s
532:	learn: 0.6413348	total: 3.87s	remaining: 3.39s
533:	learn: 0.6408774	total: 3.88s	remaining: 3.38s
534:	learn: 0.6405591	total: 3.88s	remaining: 3.38s
535:	learn: 0.6403206	total: 3.89s	remaining: 3.37s
536:	learn: 0.6400281	total: 3.9s	remaining: 3.36s
537:	learn: 0.6394853	total: 3.91s	remaining: 3.35s
538:	learn: 0.6393343	total: 3.91s	remaining: 3.35s
539:	learn: 0.6391558	total: 3.92s	remaining: 3.34s
540:	learn: 0.6387905	total: 3.93s	remaining: 3.33s
541:	learn: 0.6382586	total: 3.93s	remaining: 3.32s
542:	learn: 0.6380040	total: 3.94s	remaining: 3.32s
543:	learn: 0.6377492	total: 3.95s	remaining: 3.31s
544:	learn: 0.6374616	total: 3.96s	remaining: 3.3s
545:	learn: 0.63

698:	learn: 0.5887072	total: 5.3s	remaining: 2.28s
699:	learn: 0.5885079	total: 5.32s	remaining: 2.28s
700:	learn: 0.5882359	total: 5.33s	remaining: 2.27s
701:	learn: 0.5879250	total: 5.33s	remaining: 2.26s
702:	learn: 0.5875028	total: 5.34s	remaining: 2.26s
703:	learn: 0.5870847	total: 5.35s	remaining: 2.25s
704:	learn: 0.5868942	total: 5.36s	remaining: 2.24s
705:	learn: 0.5867886	total: 5.37s	remaining: 2.24s
706:	learn: 0.5863792	total: 5.38s	remaining: 2.23s
707:	learn: 0.5861050	total: 5.39s	remaining: 2.22s
708:	learn: 0.5857881	total: 5.4s	remaining: 2.22s
709:	learn: 0.5856083	total: 5.41s	remaining: 2.21s
710:	learn: 0.5853658	total: 5.42s	remaining: 2.2s
711:	learn: 0.5851660	total: 5.43s	remaining: 2.2s
712:	learn: 0.5849504	total: 5.44s	remaining: 2.19s
713:	learn: 0.5846465	total: 5.45s	remaining: 2.18s
714:	learn: 0.5844083	total: 5.46s	remaining: 2.17s
715:	learn: 0.5840026	total: 5.47s	remaining: 2.17s
716:	learn: 0.5838580	total: 5.48s	remaining: 2.16s
717:	learn: 0.58

866:	learn: 0.5449048	total: 6.79s	remaining: 1.04s
867:	learn: 0.5446519	total: 6.8s	remaining: 1.03s
868:	learn: 0.5444300	total: 6.81s	remaining: 1.03s
869:	learn: 0.5442394	total: 6.82s	remaining: 1.02s
870:	learn: 0.5441272	total: 6.82s	remaining: 1.01s
871:	learn: 0.5438666	total: 6.83s	remaining: 1s
872:	learn: 0.5436078	total: 6.84s	remaining: 995ms
873:	learn: 0.5434227	total: 6.84s	remaining: 987ms
874:	learn: 0.5431758	total: 6.85s	remaining: 979ms
875:	learn: 0.5429375	total: 6.86s	remaining: 971ms
876:	learn: 0.5427022	total: 6.87s	remaining: 963ms
877:	learn: 0.5424944	total: 6.88s	remaining: 955ms
878:	learn: 0.5422657	total: 6.88s	remaining: 947ms
879:	learn: 0.5419504	total: 6.89s	remaining: 939ms
880:	learn: 0.5417632	total: 6.9s	remaining: 932ms
881:	learn: 0.5415316	total: 6.9s	remaining: 924ms
882:	learn: 0.5411933	total: 6.91s	remaining: 916ms
883:	learn: 0.5409802	total: 6.92s	remaining: 908ms
884:	learn: 0.5406423	total: 6.92s	remaining: 900ms
885:	learn: 0.5404

<catboost.core.CatBoostClassifier at 0x21975a3e310>

In [137]:
predictions = cb_classifier.predict(X_test)
probas = cb_classifier.predict_proba(X_test)
confusion_matrix(y_test, predictions)

array([[1620,  178,  321],
       [ 324,   74,  128],
       [ 263,   66,  102]], dtype=int64)

In [138]:
f1_score(y_test, predictions, average="weighted")

0.5750384844397667