In [52]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import os
import re
import sys
sys.path.append("./utils")

from generate_features import *
from get_targets import *

from sklearn.model_selection import train_test_split
from sklearn.model_selection import TimeSeriesSplit
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import confusion_matrix
from sklearn.metrics import average_precision_score
from sklearn.metrics import f1_score
from sklearn.metrics import balanced_accuracy_score

from catboost import CatBoostClassifier

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [53]:
DATA_PATH = "./data/prices/tech_companies_15minute_train.csv"

TICKERS = ["AAPL", "MSFT", "GOOG", "TSLA", "NVDA", "BABA", "AMD", "ATVI", "ZG"]
TARGET_TICKER = "AAPL"
ADDITIONAL_TICKERS = ["MSFT", "GOOG", "TSLA", "NVDA", "BABA", "AMD", "ATVI", "ZG"]

In [55]:
data = pd.read_csv(DATA_PATH, index_col=0)
data.dropna(inplace=True)
data = data[[ticker in TICKERS for ticker in data["ticker"]]]
data

Unnamed: 0,volume,open,close,high,low,date,ticker
0,67612.0,62.9025,62.6025,63.0275,62.6025,2020-04-06 07:59:00,AAPL
1,36348.0,62.6100,62.7500,62.7750,62.5125,2020-04-06 08:14:00,AAPL
2,27440.0,62.8500,62.7875,62.8500,62.4875,2020-04-06 08:29:00,AAPL
3,46384.0,62.8000,62.8675,62.8875,62.6375,2020-04-06 08:44:00,AAPL
4,24700.0,62.8000,62.7175,62.8475,62.5525,2020-04-06 08:59:00,AAPL
...,...,...,...,...,...,...,...
427458,105229.0,66.1100,66.4000,66.4700,66.1010,2021-11-09 20:30:00,ZG
427459,175685.0,66.4400,66.6300,66.7300,66.3100,2021-11-09 20:45:00,ZG
427460,8144.0,66.6200,66.6200,66.6200,66.6200,2021-11-09 21:00:00,ZG
427461,100.0,66.1900,66.1900,66.1900,66.1900,2021-11-09 21:30:00,ZG


In [56]:
target_data = data[data["ticker"] == TARGET_TICKER]
additional_data = data[[ticker in ADDITIONAL_TICKERS for ticker in data["ticker"]]]
display(target_data)
display(additional_data)

Unnamed: 0,volume,open,close,high,low,date,ticker
0,67612.0,62.9025,62.6025,63.0275,62.6025,2020-04-06 07:59:00,AAPL
1,36348.0,62.6100,62.7500,62.7750,62.5125,2020-04-06 08:14:00,AAPL
2,27440.0,62.8500,62.7875,62.8500,62.4875,2020-04-06 08:29:00,AAPL
3,46384.0,62.8000,62.8675,62.8875,62.6375,2020-04-06 08:44:00,AAPL
4,24700.0,62.8000,62.7175,62.8475,62.5525,2020-04-06 08:59:00,AAPL
...,...,...,...,...,...,...,...
25569,3613.0,150.5200,150.5000,150.5500,150.5000,2021-11-09 23:45:00,AAPL
25570,8250.0,150.4900,150.6300,150.6300,150.4900,2021-11-10 00:00:00,AAPL
25571,10936.0,150.6000,150.5500,150.6000,150.4600,2021-11-10 00:15:00,AAPL
25572,9524.0,150.5500,150.6000,150.6700,150.5000,2021-11-10 00:30:00,AAPL


Unnamed: 0,volume,open,close,high,low,date,ticker
31863,1703.0,1138.36,1138.00,1138.36,1138.000,2020-04-06 11:00:00,GOOG
31864,397.0,1138.00,1138.10,1138.10,1138.000,2020-04-06 11:15:00,GOOG
31865,1671.0,1138.00,1138.50,1138.50,1135.880,2020-04-06 11:30:00,GOOG
31866,1591.0,1138.22,1138.22,1138.22,1138.220,2020-04-06 12:00:00,GOOG
31867,1267.0,1135.00,1136.00,1138.00,1135.000,2020-04-06 12:15:00,GOOG
...,...,...,...,...,...,...,...
427458,105229.0,66.11,66.40,66.47,66.101,2021-11-09 20:30:00,ZG
427459,175685.0,66.44,66.63,66.73,66.310,2021-11-09 20:45:00,ZG
427460,8144.0,66.62,66.62,66.62,66.620,2021-11-09 21:00:00,ZG
427461,100.0,66.19,66.19,66.19,66.190,2021-11-09 21:30:00,ZG


In [57]:
additional_data["ticker"].unique()

array(['GOOG', 'MSFT', 'TSLA', 'NVDA', 'BABA', 'AMD', 'ATVI', 'ZG'],
      dtype=object)

In [58]:
def get_sample_weight(y: np.ndarray):
    num_observations = len(y)
    sample_weight = np.zeros(num_observations)
    for class_idx in np.unique(y):
        idx = (y == class_idx)
        sample_weight[idx] = idx.sum() / num_observations

    return sample_weight

In [59]:
def compute_balanced_accuracy(y_true: np.ndarray, y_pred: np.ndarray):
    assert y_true.shape == y_pred.shape, f"{y_true.shape}, {y_pred.shape}"
    sample_weight = get_sample_weight(y_true)
    assert y_true.shape == sample_weight.shape, f"{y_true.shape}, {sample_weight.shape}"
    return balanced_accuracy_score(y_true, y_pred, sample_weight=sample_weight)

In [60]:
def cross_validate_catboost(X: np.ndarray, y: np.ndarray, metric):
    tscv = TimeSeriesSplit()
    
    metrics = []
    for fold_idx, (train_idx, test_idx) in enumerate(tscv.split(X)):
        X_train = X[train_idx]
        X_test = X[test_idx]
        
        y_train = y[train_idx]
        y_test = y[test_idx]
        
        assert X_train.shape[0] == y_train.shape[0]
        assert X_test.shape[0] == y_test.shape[0]
        
        class_weight = compute_class_weight(class_weight="balanced", classes=np.unique(y_train), y=y_train)
        model = CatBoostClassifier(class_weights=class_weight, verbose=False)
        model.fit(X_train, y_train)
        predictions = model.predict(X_test).flatten()
        assert y_test.shape == predictions.shape, f"{y_test.shape}, {predictions.shape}"
        
        fold_metric = metric(y_test, predictions)
        metrics.append(fold_metric)
        print(f"Fold {fold_idx + 1}: {fold_metric}")
    
    return np.mean(metrics)

In [92]:
THRESHOLD = 1.0
MAX_TIME_LAG = 30

all_data = simple_indicators(target_data).drop(["ticker"], axis=1).set_index("date")
for ticker in additional_data["ticker"].unique():
    ticker_data = additional_data[additional_data["ticker"] == ticker]
    ticker_indicators = simple_indicators(ticker_data).drop(["ticker"], axis=1).set_index("date")
    all_data = all_data.join(ticker_indicators, rsuffix=f"_{ticker}", how="left")

dates = pd.Series(all_data.index).apply(lambda val: pd.Timestamp(val))
all_data.reset_index(drop=True, inplace=True)
all_data["weekday"] = dates.dt.weekday
all_data["day"] = dates.dt.day
all_data["minute"] = dates.dt.minute
all_data["hour"] = dates.dt.hour
all_data["target"] = get_first_threshold_bump(all_data, threshold=THRESHOLD, max_time_lag=MAX_TIME_LAG)
all_data.dropna(subset=["target"], inplace=True)

print(all_data.columns)
all_data

Index(['volume', 'open', 'close', 'high', 'low', 'macd', 'macds', 'macdh',
       'macd_xu_macds', 'macd_xd_macds',
       ...
       'low_x_boll_lb_ZG', 'rs_14_ZG', 'rsi_ZG', 'chop_ZG', 'mfi_ZG',
       'weekday', 'day', 'minute', 'hour', 'target'],
      dtype='object', length=176)


Unnamed: 0,volume,open,close,high,low,macd,macds,macdh,macd_xu_macds,macd_xd_macds,...,low_x_boll_lb_ZG,rs_14_ZG,rsi_ZG,chop_ZG,mfi_ZG,weekday,day,minute,hour,target
0,67612.0,62.9025,62.6025,63.0275,62.6025,0.000000,0.000000,0.000000,False,False,...,,,,,,0,6,59,7,2.0
1,36348.0,62.6100,62.7500,62.7750,62.5125,0.003309,0.001838,0.001471,True,False,...,,,,,,0,6,14,8,2.0
2,27440.0,62.8500,62.7875,62.8500,62.4875,0.005380,0.003290,0.002090,False,False,...,,,,,,0,6,29,8,2.0
3,46384.0,62.8000,62.8675,62.8875,62.6375,0.009078,0.005250,0.003827,False,False,...,,,,,,0,6,44,8,2.0
4,24700.0,62.8000,62.7175,62.8475,62.5525,0.004927,0.005154,-0.000227,False,True,...,,,,,,0,6,59,8,2.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
25569,3613.0,150.5200,150.5000,150.5500,150.5000,-0.046170,-0.013098,-0.033072,False,False,...,,,,,,1,9,45,23,2.0
25570,8250.0,150.4900,150.6300,150.6300,150.4900,-0.041052,-0.018689,-0.022363,False,False,...,,,,,,2,10,0,0,2.0
25571,10936.0,150.6000,150.5500,150.6000,150.4600,-0.042956,-0.023542,-0.019414,False,False,...,,,,,,2,10,15,0,2.0
25572,9524.0,150.5500,150.6000,150.6700,150.5000,-0.039970,-0.026828,-0.013142,False,False,...,False,1.055489,51.349786,52.388492,0.73503,2,10,30,0,2.0


In [87]:
all_data["target"].value_counts()

2.0    9827
1.0    8583
0.0    7164
Name: target, dtype: int64

In [88]:
TEST_RATIO = 0.2

X = all_data.drop(["target"], axis=1).to_numpy()
y = all_data["target"].to_numpy()

print(X.shape, y.shape)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=TEST_RATIO, shuffle=False)
X_train.shape, X_test.shape

(25574, 174) (25574,)


((20459, 174), (5115, 174))

In [89]:
cross_validate_catboost(X, y, compute_balanced_accuracy)

Fold 1: 0.3550900447036714
Fold 2: 0.36236024687913754
Fold 3: 0.34745971771178435
Fold 4: 0.45116659487487026
Fold 5: 0.46361496392059065


0.39593831361801085

In [90]:
cb_classifier = CatBoostClassifier()
cb_classifier.fit(X_train, y_train)

Learning rate set to 0.092295
0:	learn: 1.0818823	total: 56ms	remaining: 55.9s
1:	learn: 1.0656405	total: 111ms	remaining: 55.5s
2:	learn: 1.0541788	total: 162ms	remaining: 53.8s
3:	learn: 1.0439211	total: 216ms	remaining: 53.9s
4:	learn: 1.0342338	total: 271ms	remaining: 53.9s
5:	learn: 1.0249767	total: 326ms	remaining: 54.1s
6:	learn: 1.0167362	total: 382ms	remaining: 54.2s
7:	learn: 1.0076992	total: 439ms	remaining: 54.5s
8:	learn: 1.0012368	total: 494ms	remaining: 54.4s
9:	learn: 0.9938058	total: 551ms	remaining: 54.5s
10:	learn: 0.9889032	total: 610ms	remaining: 54.8s
11:	learn: 0.9838304	total: 666ms	remaining: 54.8s
12:	learn: 0.9776126	total: 723ms	remaining: 54.9s
13:	learn: 0.9737854	total: 775ms	remaining: 54.6s
14:	learn: 0.9700742	total: 823ms	remaining: 54.1s
15:	learn: 0.9657280	total: 877ms	remaining: 54s
16:	learn: 0.9626338	total: 936ms	remaining: 54.1s
17:	learn: 0.9581198	total: 992ms	remaining: 54.1s
18:	learn: 0.9546340	total: 1.05s	remaining: 54s
19:	learn: 0.952

161:	learn: 0.7161801	total: 10s	remaining: 51.9s
162:	learn: 0.7153204	total: 10.1s	remaining: 51.9s
163:	learn: 0.7140763	total: 10.2s	remaining: 52s
164:	learn: 0.7126325	total: 10.3s	remaining: 51.9s
165:	learn: 0.7114385	total: 10.3s	remaining: 52s
166:	learn: 0.7101989	total: 10.4s	remaining: 52s
167:	learn: 0.7091304	total: 10.5s	remaining: 52s
168:	learn: 0.7081756	total: 10.6s	remaining: 52s
169:	learn: 0.7074822	total: 10.6s	remaining: 51.9s
170:	learn: 0.7063004	total: 10.7s	remaining: 51.8s
171:	learn: 0.7040606	total: 10.8s	remaining: 51.9s
172:	learn: 0.7029293	total: 10.8s	remaining: 51.8s
173:	learn: 0.7017962	total: 10.9s	remaining: 51.7s
174:	learn: 0.7006004	total: 11s	remaining: 51.7s
175:	learn: 0.6999948	total: 11s	remaining: 51.7s
176:	learn: 0.6984957	total: 11.1s	remaining: 51.6s
177:	learn: 0.6976488	total: 11.2s	remaining: 51.6s
178:	learn: 0.6965231	total: 11.2s	remaining: 51.5s
179:	learn: 0.6954484	total: 11.3s	remaining: 51.5s
180:	learn: 0.6942299	total:

323:	learn: 0.5683782	total: 21s	remaining: 43.8s
324:	learn: 0.5679552	total: 21.1s	remaining: 43.7s
325:	learn: 0.5668476	total: 21.1s	remaining: 43.6s
326:	learn: 0.5661588	total: 21.2s	remaining: 43.6s
327:	learn: 0.5656750	total: 21.2s	remaining: 43.5s
328:	learn: 0.5648683	total: 21.3s	remaining: 43.4s
329:	learn: 0.5643982	total: 21.3s	remaining: 43.3s
330:	learn: 0.5634072	total: 21.4s	remaining: 43.2s
331:	learn: 0.5622451	total: 21.5s	remaining: 43.2s
332:	learn: 0.5615180	total: 21.5s	remaining: 43.1s
333:	learn: 0.5609809	total: 21.6s	remaining: 43s
334:	learn: 0.5604162	total: 21.6s	remaining: 43s
335:	learn: 0.5600943	total: 21.7s	remaining: 42.9s
336:	learn: 0.5590961	total: 21.8s	remaining: 42.8s
337:	learn: 0.5583962	total: 21.8s	remaining: 42.7s
338:	learn: 0.5578506	total: 21.9s	remaining: 42.7s
339:	learn: 0.5572352	total: 22s	remaining: 42.6s
340:	learn: 0.5567577	total: 22s	remaining: 42.6s
341:	learn: 0.5560430	total: 22.1s	remaining: 42.6s
342:	learn: 0.5557209	

484:	learn: 0.4763305	total: 32.3s	remaining: 34.3s
485:	learn: 0.4760498	total: 32.4s	remaining: 34.2s
486:	learn: 0.4756399	total: 32.4s	remaining: 34.2s
487:	learn: 0.4753749	total: 32.5s	remaining: 34.1s
488:	learn: 0.4750526	total: 32.5s	remaining: 34s
489:	learn: 0.4748853	total: 32.6s	remaining: 33.9s
490:	learn: 0.4744858	total: 32.6s	remaining: 33.8s
491:	learn: 0.4741218	total: 32.7s	remaining: 33.8s
492:	learn: 0.4733921	total: 32.8s	remaining: 33.7s
493:	learn: 0.4728968	total: 32.8s	remaining: 33.6s
494:	learn: 0.4722793	total: 32.9s	remaining: 33.5s
495:	learn: 0.4717492	total: 32.9s	remaining: 33.4s
496:	learn: 0.4713072	total: 33s	remaining: 33.4s
497:	learn: 0.4708676	total: 33s	remaining: 33.3s
498:	learn: 0.4703153	total: 33.1s	remaining: 33.2s
499:	learn: 0.4700108	total: 33.1s	remaining: 33.1s
500:	learn: 0.4696314	total: 33.2s	remaining: 33.1s
501:	learn: 0.4692182	total: 33.3s	remaining: 33s
502:	learn: 0.4689328	total: 33.3s	remaining: 32.9s
503:	learn: 0.468706

645:	learn: 0.4099572	total: 43.4s	remaining: 23.8s
646:	learn: 0.4096328	total: 43.5s	remaining: 23.7s
647:	learn: 0.4094057	total: 43.6s	remaining: 23.7s
648:	learn: 0.4089190	total: 43.7s	remaining: 23.6s
649:	learn: 0.4083617	total: 43.7s	remaining: 23.5s
650:	learn: 0.4081911	total: 43.8s	remaining: 23.5s
651:	learn: 0.4079311	total: 43.9s	remaining: 23.4s
652:	learn: 0.4074758	total: 44s	remaining: 23.4s
653:	learn: 0.4071937	total: 44s	remaining: 23.3s
654:	learn: 0.4068909	total: 44.1s	remaining: 23.2s
655:	learn: 0.4067723	total: 44.2s	remaining: 23.2s
656:	learn: 0.4065307	total: 44.2s	remaining: 23.1s
657:	learn: 0.4061040	total: 44.3s	remaining: 23s
658:	learn: 0.4058200	total: 44.4s	remaining: 23s
659:	learn: 0.4054120	total: 44.4s	remaining: 22.9s
660:	learn: 0.4052039	total: 44.5s	remaining: 22.8s
661:	learn: 0.4046399	total: 44.6s	remaining: 22.8s
662:	learn: 0.4042793	total: 44.7s	remaining: 22.7s
663:	learn: 0.4039694	total: 44.7s	remaining: 22.6s
664:	learn: 0.403718

805:	learn: 0.3618129	total: 54.2s	remaining: 13s
806:	learn: 0.3616273	total: 54.3s	remaining: 13s
807:	learn: 0.3614365	total: 54.3s	remaining: 12.9s
808:	learn: 0.3611159	total: 54.4s	remaining: 12.8s
809:	learn: 0.3609530	total: 54.5s	remaining: 12.8s
810:	learn: 0.3607358	total: 54.5s	remaining: 12.7s
811:	learn: 0.3603654	total: 54.6s	remaining: 12.6s
812:	learn: 0.3598214	total: 54.6s	remaining: 12.6s
813:	learn: 0.3594213	total: 54.7s	remaining: 12.5s
814:	learn: 0.3592705	total: 54.8s	remaining: 12.4s
815:	learn: 0.3588172	total: 54.9s	remaining: 12.4s
816:	learn: 0.3585768	total: 54.9s	remaining: 12.3s
817:	learn: 0.3583893	total: 55s	remaining: 12.2s
818:	learn: 0.3579279	total: 55.1s	remaining: 12.2s
819:	learn: 0.3574963	total: 55.1s	remaining: 12.1s
820:	learn: 0.3570317	total: 55.2s	remaining: 12s
821:	learn: 0.3567604	total: 55.3s	remaining: 12s
822:	learn: 0.3564077	total: 55.4s	remaining: 11.9s
823:	learn: 0.3562273	total: 55.4s	remaining: 11.8s
824:	learn: 0.3556166	

968:	learn: 0.3218001	total: 1m 5s	remaining: 2.08s
969:	learn: 0.3215244	total: 1m 5s	remaining: 2.02s
970:	learn: 0.3213545	total: 1m 5s	remaining: 1.95s
971:	learn: 0.3211144	total: 1m 5s	remaining: 1.88s
972:	learn: 0.3209154	total: 1m 5s	remaining: 1.81s
973:	learn: 0.3205781	total: 1m 5s	remaining: 1.75s
974:	learn: 0.3201797	total: 1m 5s	remaining: 1.68s
975:	learn: 0.3200013	total: 1m 5s	remaining: 1.61s
976:	learn: 0.3197373	total: 1m 5s	remaining: 1.54s
977:	learn: 0.3195636	total: 1m 5s	remaining: 1.48s
978:	learn: 0.3193328	total: 1m 5s	remaining: 1.41s
979:	learn: 0.3192507	total: 1m 5s	remaining: 1.34s
980:	learn: 0.3191533	total: 1m 5s	remaining: 1.28s
981:	learn: 0.3189967	total: 1m 5s	remaining: 1.21s
982:	learn: 0.3188665	total: 1m 6s	remaining: 1.14s
983:	learn: 0.3185917	total: 1m 6s	remaining: 1.07s
984:	learn: 0.3184664	total: 1m 6s	remaining: 1.01s
985:	learn: 0.3182989	total: 1m 6s	remaining: 940ms
986:	learn: 0.3180876	total: 1m 6s	remaining: 873ms
987:	learn: 

<catboost.core.CatBoostClassifier at 0x1f8dab66b20>

In [91]:
features = [(imp, feat) for imp, feat in zip(cb_classifier.get_feature_importance(), all_data.columns)]
sorted(features, reverse=True)

[(11.254020220496924, 'day'),
 (6.437007139977336, 'hour'),
 (3.608612134557145, 'macds'),
 (2.791229677798349, 'boll_ub'),
 (2.479532150697117, 'rs_14'),
 (2.1693611005628446, 'macd'),
 (2.0454134715910466, 'close'),
 (1.8239336243928923, 'boll_ub_BABA'),
 (1.804436757674946, 'macdh'),
 (1.7766282689874664, 'macds_BABA'),
 (1.6672565292919161, 'boll_lb'),
 (1.653092886452245, 'volume'),
 (1.46960506216008, 'boll'),
 (1.3515574779911055, 'boll_ub_TSLA'),
 (1.3229364179553937, 'boll_lb_AMD'),
 (1.3129387824392738, 'macds_TSLA'),
 (1.302955959290274, 'boll_ub_MSFT'),
 (1.2448206209995265, 'boll_ub_AMD'),
 (1.2390330146047146, 'low'),
 (1.2312414596038066, 'high'),
 (1.2290119312844094, 'macds_MSFT'),
 (1.2241049388039733, 'mfi'),
 (1.1837079082063946, 'boll_AMD'),
 (1.1816707220483127, 'chop'),
 (1.0960542107806461, 'boll_MSFT'),
 (1.0919371318653803, 'low_BABA'),
 (1.0792955106618074, 'boll_lb_TSLA'),
 (1.0616228937121805, 'boll_lb_BABA'),
 (1.0573742298600117, 'macds_AMD'),
 (0.9958698