In [23]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import os
import sys
sys.path.append("./utils")

from generate_features import simple_indicators
from get_targets import get_direction_with_std_threshold_target
from dnn_utils import StockSequenceDataset, StockLSTMModel, train_model, get_test_logits

from catboost import CatBoostClassifier

from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import confusion_matrix
from sklearn.metrics import balanced_accuracy_score

import torch
from torch import nn
from torch.utils.data import DataLoader

Current device: cuda
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
DATA_PATH = "./data/prices/technology_prices_15minute_train.csv"
TICKERS = ["AAPL", "MSFT", "GOOG", "TSLA", "NVDA"]
TARGET_TICKER = "AAPL"
ADDITIONAL_TICKERS = ["MSFT", "GOOG", "TSLA", "NVDA"]

In [3]:
data = pd.read_csv(DATA_PATH, index_col=0)
data.dropna(inplace=True)
data = data[[ticker in TICKERS for ticker in data["ticker"]]]
data

Unnamed: 0,v,vw,o,c,h,l,t,n,ticker
0,25452.0,61.7090,61.750,61.3800,61.750,61.3800,2020-03-30 07:50:00,36,AAPL
1,28416.0,61.5333,61.750,61.5875,61.750,61.4225,2020-03-30 08:05:00,61,AAPL
2,17296.0,61.6269,61.540,61.7500,61.750,61.5250,2020-03-30 08:20:00,43,AAPL
3,6776.0,61.8962,61.675,61.9100,61.930,61.6750,2020-03-30 08:35:00,23,AAPL
4,17188.0,61.9149,62.000,61.7875,62.025,61.7875,2020-03-30 08:50:00,63,AAPL
...,...,...,...,...,...,...,...,...,...
131912,1630.0,196.5607,196.500,196.6000,196.600,196.5000,2021-08-11 22:30:00,18,NVDA
131913,1609.0,196.4820,196.500,196.3200,196.500,196.2100,2021-08-11 22:45:00,93,NVDA
131914,200.0,196.3200,196.320,196.3200,196.320,196.3200,2021-08-11 23:00:00,1,NVDA
131915,4302.0,196.3767,196.540,196.6000,196.600,196.1200,2021-08-11 23:30:00,52,NVDA


In [4]:
target_data = data[data["ticker"] == TARGET_TICKER]
additional_data = data[[ticker in ADDITIONAL_TICKERS for ticker in data["ticker"]]]
display(target_data)
display(additional_data)

Unnamed: 0,v,vw,o,c,h,l,t,n,ticker
0,25452.0,61.7090,61.750,61.3800,61.750,61.3800,2020-03-30 07:50:00,36,AAPL
1,28416.0,61.5333,61.750,61.5875,61.750,61.4225,2020-03-30 08:05:00,61,AAPL
2,17296.0,61.6269,61.540,61.7500,61.750,61.5250,2020-03-30 08:20:00,43,AAPL
3,6776.0,61.8962,61.675,61.9100,61.930,61.6750,2020-03-30 08:35:00,23,AAPL
4,17188.0,61.9149,62.000,61.7875,62.025,61.7875,2020-03-30 08:50:00,63,AAPL
...,...,...,...,...,...,...,...,...,...
21910,5261.0,145.6801,145.700,145.6800,145.700,145.6700,2021-08-11 22:45:00,164,AAPL
21911,1284.0,145.6924,145.700,145.6800,145.700,145.6800,2021-08-11 23:00:00,19,AAPL
21912,3649.0,145.6700,145.670,145.6700,145.710,145.6600,2021-08-11 23:15:00,28,AAPL
21913,9026.0,145.6643,145.670,145.6500,145.690,145.6500,2021-08-11 23:30:00,64,AAPL


Unnamed: 0,v,vw,o,c,h,l,t,n,ticker
31822,184.0,1119.9946,1120.0000,1120.00,1120.0000,1120.00,2020-03-30 11:21:00,4,GOOG
31823,997.0,1115.8603,1116.6600,1115.94,1116.6600,1115.94,2020-03-30 11:51:00,34,GOOG
31824,5765.0,1110.9206,1114.2000,1110.71,1114.2000,1110.71,2020-03-30 12:21:00,6,GOOG
31825,610.0,1124.6381,1122.3500,1127.00,1127.0000,1122.35,2020-03-30 12:36:00,49,GOOG
31826,673.0,1127.0979,1127.0002,1127.00,1127.0002,1127.00,2020-03-30 12:51:00,31,GOOG
...,...,...,...,...,...,...,...,...,...
131912,1630.0,196.5607,196.5000,196.60,196.6000,196.50,2021-08-11 22:30:00,18,NVDA
131913,1609.0,196.4820,196.5000,196.32,196.5000,196.21,2021-08-11 22:45:00,93,NVDA
131914,200.0,196.3200,196.3200,196.32,196.3200,196.32,2021-08-11 23:00:00,1,NVDA
131915,4302.0,196.3767,196.5400,196.60,196.6000,196.12,2021-08-11 23:30:00,52,NVDA


In [5]:
additional_data["ticker"].unique()

array(['GOOG', 'MSFT', 'TSLA', 'NVDA'], dtype=object)

In [6]:
all_data = simple_indicators(target_data).drop(["vw", "n", "ticker"], axis=1).set_index("t")
for ticker in additional_data["ticker"].unique():
    ticker_data = additional_data[additional_data["ticker"] == ticker]
    ticker_indicators = simple_indicators(ticker_data).drop(["vw", "ticker", "n"], axis=1).set_index("t")
    all_data = all_data.join(ticker_indicators, rsuffix=f"_{ticker}", how="left")

all_data.reset_index(drop=True, inplace=True)
all_data["target"] = get_direction_with_std_threshold_target(all_data, 10, 5, 5)
all_data.dropna(subset=["target"], inplace=True)

all_data

Unnamed: 0,volume,open,close,high,low,macd,macds,macdh,macd_xu_macds,macd_xd_macds,...,macds_NVDA,macdh_NVDA,macd_xu_macds_NVDA,macd_xd_macds_NVDA,boll_NVDA,boll_ub_NVDA,boll_lb_NVDA,high_x_boll_ub_NVDA,low_x_boll_lb_NVDA,target
0,25452.0,61.7500,61.3800,61.7500,61.3800,0.000000,0.000000,0.000000,False,False,...,,,,,,,,,,0.0
1,28416.0,61.7500,61.5875,61.7500,61.4225,0.004655,0.002586,0.002069,True,False,...,,,,,,,,,,0.0
2,17296.0,61.5400,61.7500,61.7500,61.5250,0.010971,0.006023,0.004948,False,False,...,,,,,,,,,,0.0
3,6776.0,61.6750,61.9100,61.9300,61.6750,0.019351,0.010538,0.008814,False,False,...,,,,,,,,,,0.0
4,17188.0,62.0000,61.7875,62.0250,61.7875,0.018623,0.012943,0.005680,False,False,...,,,,,,,,,,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
21900,17731.0,145.7100,145.6800,145.7500,145.6600,-0.008033,0.051588,-0.059621,False,False,...,-0.318244,0.168126,False,False,196.31558,197.631764,194.999396,False,False,0.0
21901,10326.0,145.6800,145.7800,145.8000,145.6800,-0.017429,0.037784,-0.055214,False,False,...,-0.278440,0.159218,False,False,196.42158,197.571797,195.271363,False,False,0.0
21902,5767.0,145.7301,145.7000,145.7301,145.7000,-0.030974,0.024033,-0.055007,False,False,...,-0.242440,0.143998,False,False,196.51025,197.506493,195.514007,False,False,0.0
21903,38578.0,145.7100,145.7200,145.8600,145.7000,-0.039638,0.011299,-0.050936,False,False,...,-0.210799,0.126567,False,False,196.57975,197.460632,195.698868,False,False,0.0


In [7]:
all_data.target.value_counts()

0.0    15026
1.0     3846
2.0     3033
Name: target, dtype: int64

In [8]:
all_data_preprocessed = all_data.ffill().bfill()
all_data_preprocessed.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 21905 entries, 0 to 21904
Data columns (total 76 columns):
 #   Column               Non-Null Count  Dtype  
---  ------               --------------  -----  
 0   volume               21905 non-null  float64
 1   open                 21905 non-null  float64
 2   close                21905 non-null  float64
 3   high                 21905 non-null  float64
 4   low                  21905 non-null  float64
 5   macd                 21905 non-null  float64
 6   macds                21905 non-null  float64
 7   macdh                21905 non-null  float64
 8   macd_xu_macds        21905 non-null  bool   
 9   macd_xd_macds        21905 non-null  bool   
 10  boll                 21905 non-null  float64
 11  boll_ub              21905 non-null  float64
 12  boll_lb              21905 non-null  float64
 13  high_x_boll_ub       21905 non-null  bool   
 14  low_x_boll_lb        21905 non-null  bool   
 15  volume_GOOG          21905 non-null 

In [13]:
CATBOOST_TRAIN_RATIO = 0.4

catboost_train_data, lstm_train_data = train_test_split(
    all_data_preprocessed,
    train_size=CATBOOST_TRAIN_RATIO,
    shuffle=False
)

In [14]:
X_catboost = catboost_train_data.drop(["target"], axis=1).to_numpy()
y_catboost = catboost_train_data["target"].to_numpy()

class_weight = compute_class_weight(class_weight="balanced", classes=np.unique(y_catboost), y=y_catboost)
print(class_weight)

[0.49528009 1.7509992  2.43998886]


In [15]:
catboost_classifier = CatBoostClassifier(class_weights=class_weight)
catboost_classifier.fit(X_catboost, y_catboost)

Learning rate set to 0.088389
0:	learn: 1.0876995	total: 203ms	remaining: 3m 22s
1:	learn: 1.0778606	total: 225ms	remaining: 1m 52s
2:	learn: 1.0707465	total: 244ms	remaining: 1m 21s
3:	learn: 1.0622129	total: 263ms	remaining: 1m 5s
4:	learn: 1.0554423	total: 282ms	remaining: 56.2s
5:	learn: 1.0479660	total: 302ms	remaining: 50.1s
6:	learn: 1.0410370	total: 323ms	remaining: 45.8s
7:	learn: 1.0341046	total: 342ms	remaining: 42.5s
8:	learn: 1.0276252	total: 361ms	remaining: 39.8s
9:	learn: 1.0230672	total: 381ms	remaining: 37.7s
10:	learn: 1.0181523	total: 400ms	remaining: 36s
11:	learn: 1.0133076	total: 421ms	remaining: 34.6s
12:	learn: 1.0089367	total: 443ms	remaining: 33.6s
13:	learn: 1.0045006	total: 462ms	remaining: 32.6s
14:	learn: 1.0019402	total: 481ms	remaining: 31.6s
15:	learn: 0.9971458	total: 500ms	remaining: 30.8s
16:	learn: 0.9926532	total: 520ms	remaining: 30s
17:	learn: 0.9880024	total: 539ms	remaining: 29.4s
18:	learn: 0.9853237	total: 559ms	remaining: 28.9s
19:	learn: 0

162:	learn: 0.7464852	total: 3.53s	remaining: 18.1s
163:	learn: 0.7452708	total: 3.55s	remaining: 18.1s
164:	learn: 0.7437653	total: 3.57s	remaining: 18.1s
165:	learn: 0.7415376	total: 3.59s	remaining: 18s
166:	learn: 0.7404409	total: 3.61s	remaining: 18s
167:	learn: 0.7397110	total: 3.63s	remaining: 18s
168:	learn: 0.7389956	total: 3.65s	remaining: 17.9s
169:	learn: 0.7383544	total: 3.67s	remaining: 17.9s
170:	learn: 0.7372728	total: 3.69s	remaining: 17.9s
171:	learn: 0.7365222	total: 3.71s	remaining: 17.9s
172:	learn: 0.7357391	total: 3.73s	remaining: 17.8s
173:	learn: 0.7341217	total: 3.75s	remaining: 17.8s
174:	learn: 0.7330008	total: 3.77s	remaining: 17.8s
175:	learn: 0.7318086	total: 3.79s	remaining: 17.8s
176:	learn: 0.7303909	total: 3.81s	remaining: 17.7s
177:	learn: 0.7287483	total: 3.83s	remaining: 17.7s
178:	learn: 0.7273446	total: 3.85s	remaining: 17.7s
179:	learn: 0.7256441	total: 3.87s	remaining: 17.6s
180:	learn: 0.7240385	total: 3.9s	remaining: 17.6s
181:	learn: 0.72332

325:	learn: 0.6001605	total: 7.88s	remaining: 16.3s
326:	learn: 0.5994777	total: 7.91s	remaining: 16.3s
327:	learn: 0.5992401	total: 7.93s	remaining: 16.2s
328:	learn: 0.5984365	total: 7.95s	remaining: 16.2s
329:	learn: 0.5978063	total: 7.97s	remaining: 16.2s
330:	learn: 0.5973282	total: 7.99s	remaining: 16.1s
331:	learn: 0.5969936	total: 8.01s	remaining: 16.1s
332:	learn: 0.5963696	total: 8.03s	remaining: 16.1s
333:	learn: 0.5958838	total: 8.05s	remaining: 16s
334:	learn: 0.5955253	total: 8.07s	remaining: 16s
335:	learn: 0.5950193	total: 8.09s	remaining: 16s
336:	learn: 0.5942872	total: 8.11s	remaining: 16s
337:	learn: 0.5938078	total: 8.13s	remaining: 15.9s
338:	learn: 0.5933622	total: 8.15s	remaining: 15.9s
339:	learn: 0.5930635	total: 8.17s	remaining: 15.9s
340:	learn: 0.5925131	total: 8.19s	remaining: 15.8s
341:	learn: 0.5914179	total: 8.21s	remaining: 15.8s
342:	learn: 0.5909751	total: 8.23s	remaining: 15.8s
343:	learn: 0.5902892	total: 8.25s	remaining: 15.7s
344:	learn: 0.589214

489:	learn: 0.5158530	total: 11.7s	remaining: 12.2s
490:	learn: 0.5154660	total: 11.7s	remaining: 12.2s
491:	learn: 0.5150546	total: 11.8s	remaining: 12.1s
492:	learn: 0.5145044	total: 11.8s	remaining: 12.1s
493:	learn: 0.5140188	total: 11.8s	remaining: 12.1s
494:	learn: 0.5135672	total: 11.8s	remaining: 12.1s
495:	learn: 0.5128067	total: 11.8s	remaining: 12s
496:	learn: 0.5125562	total: 11.9s	remaining: 12s
497:	learn: 0.5118528	total: 11.9s	remaining: 12s
498:	learn: 0.5115362	total: 11.9s	remaining: 12s
499:	learn: 0.5111977	total: 11.9s	remaining: 11.9s
500:	learn: 0.5109917	total: 12s	remaining: 11.9s
501:	learn: 0.5101914	total: 12s	remaining: 11.9s
502:	learn: 0.5097740	total: 12s	remaining: 11.9s
503:	learn: 0.5094409	total: 12s	remaining: 11.8s
504:	learn: 0.5086595	total: 12s	remaining: 11.8s
505:	learn: 0.5080256	total: 12.1s	remaining: 11.8s
506:	learn: 0.5076250	total: 12.1s	remaining: 11.8s
507:	learn: 0.5072423	total: 12.1s	remaining: 11.7s
508:	learn: 0.5067837	total: 1

648:	learn: 0.4544986	total: 15.6s	remaining: 8.41s
649:	learn: 0.4542064	total: 15.6s	remaining: 8.39s
650:	learn: 0.4535490	total: 15.6s	remaining: 8.37s
651:	learn: 0.4533353	total: 15.6s	remaining: 8.34s
652:	learn: 0.4529824	total: 15.6s	remaining: 8.31s
653:	learn: 0.4528331	total: 15.7s	remaining: 8.29s
654:	learn: 0.4526465	total: 15.7s	remaining: 8.26s
655:	learn: 0.4524310	total: 15.7s	remaining: 8.24s
656:	learn: 0.4520137	total: 15.7s	remaining: 8.21s
657:	learn: 0.4516577	total: 15.8s	remaining: 8.19s
658:	learn: 0.4512090	total: 15.8s	remaining: 8.16s
659:	learn: 0.4507651	total: 15.8s	remaining: 8.14s
660:	learn: 0.4504032	total: 15.8s	remaining: 8.11s
661:	learn: 0.4501314	total: 15.8s	remaining: 8.09s
662:	learn: 0.4495084	total: 15.9s	remaining: 8.06s
663:	learn: 0.4492037	total: 15.9s	remaining: 8.04s
664:	learn: 0.4489405	total: 15.9s	remaining: 8.01s
665:	learn: 0.4486146	total: 15.9s	remaining: 7.99s
666:	learn: 0.4483717	total: 15.9s	remaining: 7.96s
667:	learn: 

811:	learn: 0.4078865	total: 20s	remaining: 4.63s
812:	learn: 0.4077909	total: 20s	remaining: 4.6s
813:	learn: 0.4076320	total: 20s	remaining: 4.58s
814:	learn: 0.4074925	total: 20.1s	remaining: 4.55s
815:	learn: 0.4073961	total: 20.1s	remaining: 4.53s
816:	learn: 0.4073110	total: 20.1s	remaining: 4.5s
817:	learn: 0.4070434	total: 20.1s	remaining: 4.47s
818:	learn: 0.4068000	total: 20.1s	remaining: 4.45s
819:	learn: 0.4066747	total: 20.2s	remaining: 4.42s
820:	learn: 0.4065206	total: 20.2s	remaining: 4.4s
821:	learn: 0.4061081	total: 20.2s	remaining: 4.37s
822:	learn: 0.4058903	total: 20.2s	remaining: 4.35s
823:	learn: 0.4056701	total: 20.2s	remaining: 4.32s
824:	learn: 0.4054589	total: 20.3s	remaining: 4.3s
825:	learn: 0.4051793	total: 20.3s	remaining: 4.27s
826:	learn: 0.4049158	total: 20.3s	remaining: 4.25s
827:	learn: 0.4045773	total: 20.3s	remaining: 4.22s
828:	learn: 0.4044276	total: 20.3s	remaining: 4.2s
829:	learn: 0.4041247	total: 20.4s	remaining: 4.17s
830:	learn: 0.4040039	t

976:	learn: 0.3679071	total: 24.1s	remaining: 567ms
977:	learn: 0.3678004	total: 24.1s	remaining: 542ms
978:	learn: 0.3673952	total: 24.1s	remaining: 517ms
979:	learn: 0.3672015	total: 24.1s	remaining: 493ms
980:	learn: 0.3669580	total: 24.2s	remaining: 468ms
981:	learn: 0.3668277	total: 24.2s	remaining: 443ms
982:	learn: 0.3667130	total: 24.2s	remaining: 419ms
983:	learn: 0.3664836	total: 24.2s	remaining: 394ms
984:	learn: 0.3663510	total: 24.3s	remaining: 369ms
985:	learn: 0.3661348	total: 24.3s	remaining: 345ms
986:	learn: 0.3658445	total: 24.3s	remaining: 320ms
987:	learn: 0.3654432	total: 24.3s	remaining: 296ms
988:	learn: 0.3653484	total: 24.4s	remaining: 271ms
989:	learn: 0.3651939	total: 24.4s	remaining: 246ms
990:	learn: 0.3649624	total: 24.4s	remaining: 222ms
991:	learn: 0.3648384	total: 24.4s	remaining: 197ms
992:	learn: 0.3646535	total: 24.4s	remaining: 172ms
993:	learn: 0.3645235	total: 24.5s	remaining: 148ms
994:	learn: 0.3643393	total: 24.5s	remaining: 123ms
995:	learn: 

<catboost.core.CatBoostClassifier at 0x1cb43671940>

In [39]:
def get_sample_weight(num_observations, y):
    sample_weight = np.zeros(num_observations)
    for class_idx in np.unique(y_catboost):
        idx = (y == class_idx)
        sample_weight[idx] = idx.sum() / num_observations

    return sample_weight

In [35]:
predictions = catboost_classifier.predict(lstm_train_data.drop(["target"], axis=1))
probas = catboost_classifier.predict_proba(lstm_train_data.drop(["target"], axis=1))

sample_weight = get_sample_weight(lstm_train_data.shape[0], lstm_train_data["target"])

print(balanced_accuracy_score(lstm_train_data["target"], predictions, sample_weight=sample_weight))
confusion_matrix(lstm_train_data["target"], predictions)

0.4296810253924943


array([[5363, 1058, 2708],
       [ 725,  649,  804],
       [ 613,  482,  741]], dtype=int64)

In [17]:
lstm_train_data = pd.concat([lstm_train_data.reset_index(drop=True), pd.DataFrame(probas)], axis=1)
lstm_train_data

Unnamed: 0,volume,open,close,high,low,macd,macds,macdh,macd_xu_macds,macd_xd_macds,...,macd_xd_macds_NVDA,boll_NVDA,boll_ub_NVDA,boll_lb_NVDA,high_x_boll_ub_NVDA,low_x_boll_lb_NVDA,target,0,1,2
0,16862.0,121.4300,121.35,121.4700,121.28,-0.042554,-0.100333,0.057779,False,False,...,False,142.719375,143.406384,142.032366,False,False,0.0,0.701924,0.231374,0.066702
1,43497.0,121.4200,120.70,121.5700,120.60,-0.115375,-0.103341,-0.012034,False,True,...,False,142.719625,143.406267,142.032983,False,False,0.0,0.522010,0.447591,0.030399
2,43496.0,120.6500,121.00,121.2900,120.60,-0.147183,-0.112110,-0.035073,False,False,...,True,142.724875,143.405462,142.044288,False,False,0.0,0.462736,0.423742,0.113522
3,23107.0,121.1000,120.85,121.1000,120.70,-0.182392,-0.126166,-0.056226,False,False,...,False,142.721750,143.407241,142.036259,False,False,0.0,0.456106,0.495087,0.048808
4,18738.0,120.8200,120.90,121.1000,120.70,-0.203910,-0.141715,-0.062195,False,False,...,False,142.700125,143.447836,141.952414,False,False,0.0,0.499755,0.446175,0.054070
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
13138,17731.0,145.7100,145.68,145.7500,145.66,-0.008033,0.051588,-0.059621,False,False,...,False,196.315580,197.631764,194.999396,False,False,0.0,0.434435,0.203958,0.361608
13139,10326.0,145.6800,145.78,145.8000,145.68,-0.017429,0.037784,-0.055214,False,False,...,False,196.421580,197.571797,195.271363,False,False,0.0,0.553766,0.189007,0.257227
13140,5767.0,145.7301,145.70,145.7301,145.70,-0.030974,0.024033,-0.055007,False,False,...,False,196.510250,197.506493,195.514007,False,False,0.0,0.578312,0.172742,0.248946
13141,38578.0,145.7100,145.72,145.8600,145.70,-0.039638,0.011299,-0.050936,False,False,...,False,196.579750,197.460632,195.698868,False,False,0.0,0.362496,0.144970,0.492533


In [18]:
LSTM_TEST_RATIO = 0.2
LSTM_VAL_RATIO = 0.2

X_lstm = lstm_train_data.drop(["target"], axis=1)
y_lstm = lstm_train_data["target"]

X_lstm_train_val, X_lstm_test, y_lstm_train_val, y_lstm_test = train_test_split(
    X_lstm, y_lstm,
    test_size=LSTM_TEST_RATIO,
    shuffle=False
)
X_lstm_train, X_lstm_val, y_lstm_train, y_lstm_val = train_test_split(
    X_lstm_train_val, 
    y_lstm_train_val, 
    test_size=LSTM_VAL_RATIO,
    shuffle=False
)

y_lstm_train.value_counts()

0.0    5835
1.0    1407
2.0    1169
Name: target, dtype: int64

In [19]:
SEQUENCE_LEN = 30
BATCH_SIZE = 50

train_dataset = StockSequenceDataset(
    torch.tensor(X_lstm_train.to_numpy(dtype=float)),
    torch.tensor(y_lstm_train.to_numpy(dtype=float)),
    sequence_length=SEQUENCE_LEN
)
val_dataset = StockSequenceDataset(
    torch.tensor(X_lstm_val.to_numpy(dtype=float)),
    torch.tensor(y_lstm_val.to_numpy(dtype=float)),
    sequence_length=SEQUENCE_LEN
)
test_dataset = StockSequenceDataset(
    torch.tensor(X_lstm_test.to_numpy(dtype=float)),
    torch.tensor(y_lstm_test.to_numpy(dtype=float)),
    sequence_length=SEQUENCE_LEN
)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [24]:
NUM_EPOCH = 10

class_weight = compute_class_weight(class_weight="balanced", classes=np.unique(y_lstm_train), y=y_lstm_train)

model = StockLSTMModel(
    input_size=X_lstm_train.shape[1],
    num_classes=len(np.unique(y_lstm_train)),
    sequence_length=SEQUENCE_LEN
)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss(weight=torch.tensor(class_weight, dtype=torch.float32))

train_model(
    model=model,
    optimizer=optimizer,
    criterion=criterion,
    train_loader=train_loader,
    validation_loader=val_loader,
    num_epoch=NUM_EPOCH
)

////////////////////////////////////////
// Epoch: 1
// Train loss: 1.0955236355463664
// Validation loss: 1.0736486911773682
////////////////////////////////////////
////////////////////////////////////////
// Epoch: 2
// Train loss: 1.0558247168858845
// Validation loss: 1.0486690998077393
////////////////////////////////////////
////////////////////////////////////////
// Epoch: 3
// Train loss: 1.025233229001363
// Validation loss: 1.019029974937439
////////////////////////////////////////
////////////////////////////////////////
// Epoch: 4
// Train loss: 1.0044305125872295
// Validation loss: 1.0119693577289581
////////////////////////////////////////
////////////////////////////////////////
// Epoch: 5
// Train loss: 0.9693718651930491
// Validation loss: 1.0284196734428406
////////////////////////////////////////
////////////////////////////////////////
// Epoch: 6
// Train loss: 0.9624430537223816
// Validation loss: 1.0321577489376068
////////////////////////////////////////


In [31]:
logits, targets = get_test_logits(model, test_loader)
logits_view = logits.reshape(-1, 3)
targets = targets.flatten()

print(logits_view.shape, targets.shape)

torch.Size([2610, 3]) torch.Size([2610])


In [44]:
predictions = logits_view.argmax(axis=1)
sample_weight = get_sample_weight(len(targets), targets)

print(balanced_accuracy_score(targets, predictions, sample_weight=sample_weight))
confusion_matrix(targets, predictions)

0.39363766006407647


array([[845, 462, 577],
       [ 47, 165, 206],
       [ 76, 128, 104]], dtype=int64)