In [78]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plot
from fastai.tabular.all import *
from pathlib import Path
from sklearn.ensemble import RandomForestRegressor
from sklearn.tree import DecisionTreeRegressor, export_graphviz
from dtreeviz.trees import *

pd.options.display.max_rows = 60


In [79]:
save_dir = Path("../market/save")
df = pd.read_csv(save_dir / f"rdf_output.csv")


In [80]:
df.shape

(10000, 57)

In [81]:
df[['ns_since_open', 'ts_in_delta', 'ts_eventElapsed']]

Unnamed: 0,ns_since_open,ts_in_delta,ts_eventElapsed
0,1.247944e+12,167064,1.735916e+09
1,3.668683e+11,167870,1.735915e+09
2,4.511887e+11,169914,1.735915e+09
3,9.572814e+11,170682,1.735916e+09
4,2.028339e+12,168396,1.735917e+09
...,...,...,...
9995,1.139725e+12,168145,1.735916e+09
9996,1.930373e+12,167420,1.735917e+09
9997,2.169029e+12,166449,1.735917e+09
9998,1.338696e+12,186805,1.735916e+09


In [82]:
df.columns

Index(['ts_recv', 'ts_event', 'rtype', 'publisher_id', 'instrument_id',
       'action', 'side', 'depth', 'price', 'size', 'flags', 'ts_in_delta',
       'sequence', 'bid_px_00', 'ask_px_00', 'symbol', 'ns_since_open',
       'ts_eventYear', 'ts_eventMonth', 'ts_eventWeek', 'ts_eventDay',
       'ts_eventDayofweek', 'ts_eventDayofyear', 'ts_eventIs_month_end',
       'ts_eventIs_month_start', 'ts_eventIs_quarter_end',
       'ts_eventIs_quarter_start', 'ts_eventIs_year_end',
       'ts_eventIs_year_start', 'ts_eventElapsed', 'spread', 'mid',
       'bid_weight', 'ask_weight', 'bid_weight_log', 'ask_weight_log',
       'traded_bid_size', 'traded_ask_size', 'is_trade_bid', 'is_trade_ask',
       'rolling_30s_bid_size', 'rolling_30s_ask_size', 'rolling_30s_bid_cnt',
       'rolling_30s_ask_cnt', 'rolling_5min_bid_size', 'rolling_5min_ask_size',
       'rolling_5min_bid_cnt', 'rolling_5min_ask_cnt', 'hedge_buy_stop_idx',
       'hedge_sell_stop_idx', 'ideal_sell_price', 'ideal_buy_price_sp

#### Simplify the data further

In [83]:
columns_to_keep = ['instrument_id', 'symbol',
                   'ns_since_open', 'ts_eventYear', 'ts_eventMonth',
       'ts_eventWeek', 'ts_eventDay', 'ts_eventDayofweek', 'ts_eventDayofyear',
       'ts_eventIs_month_end', 'ts_eventIs_month_start',
       'ts_eventIs_quarter_end', 'ts_eventIs_quarter_start',
       'ts_eventIs_year_end', 'ts_eventIs_year_start', 'ts_eventElapsed',
       'spread', 
       #'mid', 
       'bid_px_00', 'ask_px_00',
       'bid_weight', 'ask_weight', 'bid_weight_log',
       'ask_weight_log', 
       #'traded_bid_size', 'traded_ask_size', 'is_trade_bid', 'is_trade_ask', 
       'rolling_30s_bid_size', 'rolling_30s_ask_size',
       'rolling_30s_bid_cnt', 'rolling_30s_ask_cnt', 'rolling_5min_bid_size',
       'rolling_5min_ask_size', 'rolling_5min_bid_cnt', 'rolling_5min_ask_cnt',
       #'hedge_buy_stop_idx', 'hedge_sell_stop_idx', 
       'is_buy'
]
dep_vars = ['ideal_price_spread'] 

#       'ideal_buy_price', 'ideal_sell_price', 
#       'ideal_buy_price_spread', 'ideal_sell_price_spread']

rdf = df[columns_to_keep + dep_vars]

In [84]:
cont_cols,cat_cols = cont_cat_split(rdf, max_card=9000, dep_var=dep_vars)

In [85]:
print("********* Continuous columns")
[print(x) for x in cont_cols] 
print("********* Categorical columns")
_= [print(x) for x in cat_cols] 
#cont_cols

********* Continuous columns
ns_since_open
ts_eventElapsed
spread
bid_px_00
ask_px_00
bid_weight
ask_weight
bid_weight_log
ask_weight_log
rolling_30s_bid_size
rolling_30s_ask_size
rolling_30s_bid_cnt
rolling_30s_ask_cnt
rolling_5min_bid_size
rolling_5min_ask_size
rolling_5min_bid_cnt
rolling_5min_ask_cnt
********* Categorical columns
instrument_id
symbol
ts_eventYear
ts_eventMonth
ts_eventWeek
ts_eventDay
ts_eventDayofweek
ts_eventDayofyear
ts_eventIs_month_end
ts_eventIs_month_start
ts_eventIs_quarter_end
ts_eventIs_quarter_start
ts_eventIs_year_end
ts_eventIs_year_start
is_buy


In [86]:
rdf.columns

Index(['instrument_id', 'symbol', 'ns_since_open', 'ts_eventYear',
       'ts_eventMonth', 'ts_eventWeek', 'ts_eventDay', 'ts_eventDayofweek',
       'ts_eventDayofyear', 'ts_eventIs_month_end', 'ts_eventIs_month_start',
       'ts_eventIs_quarter_end', 'ts_eventIs_quarter_start',
       'ts_eventIs_year_end', 'ts_eventIs_year_start', 'ts_eventElapsed',
       'spread', 'bid_px_00', 'ask_px_00', 'bid_weight', 'ask_weight',
       'bid_weight_log', 'ask_weight_log', 'rolling_30s_bid_size',
       'rolling_30s_ask_size', 'rolling_30s_bid_cnt', 'rolling_30s_ask_cnt',
       'rolling_5min_bid_size', 'rolling_5min_ask_size',
       'rolling_5min_bid_cnt', 'rolling_5min_ask_cnt', 'is_buy',
       'ideal_price_spread'],
      dtype='object')

### FastAI Tabular

In [87]:
procs_nn = [Categorify, FillMissing, Normalize]
test_set = rdf.sample(frac=0.2)
train_set = rdf.drop(test_set.index)
splits = (list(train_set.index), list(test_set.index))


to = TabularPandas(rdf, procs_nn, cat_cols, cont_cols,
                      splits=splits, y_names='ideal_price_spread')

In [None]:
from sklearn.metrics import mean_squared_error

xs,y = to.train.xs,to.train.y
m = DecisionTreeRegressor(max_leaf_nodes=20)
m.fit(xs, y)

accuracy = m.score(xs, y)
test_xs, test_y = to.valid.xs, to.valid.y
predictions = m.predict(test_xs)
mse = mean_squared_error(test_y, predictions)
print(f"Accuracy: {accuracy}")
print(f"Mean Squared Error: {mse}")

Accuracy: 0.9292127118547251
Mean Squared Error: 137.5607095923853


In [89]:
predictions

array([  -9.06443909, -159.48888889,  -56.3824956 , ..., -116.98639456,
        -56.3824956 ,  -22.57431457])

In [90]:
def draw_tree(t, df, size=10, ratio=0.6, precision=0, **kwargs):
    s=export_graphviz(t, out_file=None, feature_names=df.columns, filled=True, rounded=True,
                      special_characters=True, rotate=False, precision=precision, **kwargs)
    import graphviz
    return graphviz.Source(re.sub('Tree {', f'Tree {{ size={size}; ratio={ratio}', s))

#draw_tree(m, xs, size=20, leaves_parallel=True, precision=4, ratio=0.6)

## Scikit Learn

In [95]:
from sklearn.ensemble import HistGradientBoostingRegressor

est = HistGradientBoostingRegressor(learning_rate=0.1, max_depth=6)
est.fit(xs, y)

accuracy = est.score(to.valid.xs, to.valid.y)

predictions = est.predict(to.valid.xs)

mse = mean_squared_error(to.valid.y, predictions)
print(f"Accuracy: {accuracy}")
print(f"Mean Squared Error: {mse}")

results_df = pd.DataFrame({'Predictions': predictions, 'Actual': to.valid.y.values})
results_df


Accuracy: 0.9931711762821236
Mean Squared Error: 12.516253777567503


Unnamed: 0,Predictions,Actual
0,-6.549329,-7.666667
1,-175.108178,-175.000000
2,-48.052847,-42.000000
3,-50.849556,-52.000000
4,-7.617375,-9.500000
...,...,...
1995,-5.591790,-8.000000
1996,-35.606038,-36.000000
1997,-137.446016,-140.000000
1998,-51.582478,-47.000000


In [92]:
type(to)
to.train.xs.ts_eventIs_quarter_start.unique()
to.train.xs.shape, to.valid.xs.shape
to.valid.xs.head(10)

Unnamed: 0,instrument_id,symbol,ts_eventYear,ts_eventMonth,ts_eventWeek,ts_eventDay,ts_eventDayofweek,ts_eventDayofyear,ts_eventIs_month_end,ts_eventIs_month_start,...,bid_weight_log,ask_weight_log,rolling_30s_bid_size,rolling_30s_ask_size,rolling_30s_bid_cnt,rolling_30s_ask_cnt,rolling_5min_bid_size,rolling_5min_ask_size,rolling_5min_bid_cnt,rolling_5min_ask_cnt
419,1,1,1,1,1,1,1,1,1,1,...,0.989316,-0.929643,-0.598467,-0.972831,-0.740311,-0.966506,-1.166131,-0.380371,-1.216715,-0.631056
8058,1,1,1,1,1,1,1,1,1,1,...,0.531724,-0.319442,-0.411706,1.058706,-0.513813,0.643564,0.309216,1.85108,0.591268,1.777178
5277,1,1,1,1,1,1,1,1,1,1,...,-0.277267,1.514075,-0.221419,-1.252122,-0.330996,-1.267519,0.009125,1.299003,0.188605,1.221509
9786,1,1,1,1,1,1,1,1,1,1,...,0.19818,-0.135618,-0.150224,-1.440153,-0.347175,-1.701537,-1.097669,-0.698937,-1.168271,-0.703666
4669,1,1,1,1,1,1,1,1,1,1,...,0.82863,0.239296,-0.388763,-0.693851,-0.447481,-1.141513,0.32692,1.74737,0.596157,1.649102
9372,1,1,1,1,1,1,1,1,1,1,...,0.113656,-0.21251,-0.494273,-0.618461,-0.481456,-0.854501,-0.850118,0.86555,-0.739386,0.52062
8853,1,1,1,1,1,1,1,1,1,1,...,-0.339624,-0.49584,-0.510098,0.265641,-0.597941,0.482557,-0.396322,-1.23197,-0.335835,-0.816615
5874,1,1,1,1,1,1,1,1,1,1,...,-2.113849,-0.468058,2.030004,0.454517,2.843215,1.378596,-0.440604,-1.199216,-0.572277,-1.987452
3138,1,1,1,1,1,1,1,1,1,1,...,-0.958099,-1.153927,0.036587,-0.862679,-0.161123,1.140586,0.724039,0.546168,1.109485,0.583146
1160,1,1,1,1,1,1,1,1,1,1,...,2.629162,-0.349639,-0.486957,0.663796,-0.41836,0.384553,-0.805086,1.155745,-0.706498,0.753578
