In [1]:
# Auto reload local files
%load_ext autoreload
%reload_ext autoreload
%autoreload 2
# Make files in src/ available to notebook
import sys
if '../src' not in sys.path:
    sys.path.insert(0, '../src')

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random
from datetime import datetime, timedelta

# Suppress ta warnings
import warnings
warnings.filterwarnings("ignore")

In [3]:
# Read SPY csv, define config
watchlist = list(pd.read_csv('../data/watchlist.csv', header=0)['symbol'])
spy_constituents = list(pd.read_csv('../../data/spy_constituents.csv', header=0)['Symbol'])
random.shuffle(spy_constituents)

# Current tickers we have in database
tickers = ["A","AAL","AAP","AAPL","ABBV","ABC","ACN","ADBE","ADI","ADP","ADSK","ADX","AEE","AEP","AES","AFL","AG","AIV","AIZ","AJG","AJRD","AJX","AJXA","AKAM","AL","ALB","ALGN","ALK","ALL","ALLE","AMAT","AMCR","AMD","AME","AMGN","AMN","AMP","AMT","ANET","ANSS","AON","AOS","APA","APD","APH","APT","APTV","ARA","ARE","ASX","ATGE","ATO","ATUS","ATVI","AU","AUY","AVGO","AVY","AWK","AXP","AZO","AZZ","BA","BABA","BB","BBWI","BBY","BEN","BIIB","BK","BKI","BKN","BKNG","BKR","BLK","BMY","BR","BRO","BSIG","BSL","BSX","BWA","BYM","BZH","C","CAAP","CAG","CARR","CAT","CBOE","CBRE","CBT","CCI","CCL","CDNS","CDW","CE","CERN","CF","CFG","CHD","CHRW","CHT","CHTR","CI","CIA","CINF","CL","CLPR","CLX","CMA","CMCSA","CME","CMG","CMI","CMO","CMS","CNC","CNR","COF","COP","COST","COTY","CP","CPB","CPRT","CRL","CRM","CSCO","CSX","CTAS","CTLT","CTRA","CTSH","CTVA","CUK","CULP","CURO","CUZ","CVA","CVEO","CVS","CVX","CWH","CYD","CZR","DAL","DAVA","DCI","DD","DDS","DEI","DFS","DG","DGX","DHI","DHR","DHY","DIS","DISCA","DISCK","DISH","DLTR","DNB","DNOW","DOV","DOW","DPW","DPZ","DRE","DRI","DRQ","DS","DTE","DUK","DVA","DVN","DXC","DXCM","DXF","ECL","ED","EEX","EFL","EFX","EIX","EL","ELC","ELLO","EMN","EMR","ENPH","EOG","EPAM","EQR","ES","ESS","ETB","ETN","ETR","ETSY","EVA","EVM","EVRG","EW","EXC","EXPD","EXR","F","FANG","FAST","FBC","FBHS","FCX","FDX","FE","FFC","FFIV","FISV","FLC","FLT","FMC","FOX","FPAC","FRC","FRT","FTV","GE","GEF","GEO","GILD","GIS","GL","GLU","GLW","GM","GME","GNRC","GOOG","GOOGL","GPC","GPM","GPN","GPS","GRMN","GS","GWW","HAL","HAS","HASI","HBAN","HCA","HD","HES","HESM","HIG","HOLX","HON","HOV","HPE","HPP","HPQ","HQL","HRI","HRL","HST","HSY","HTD","HUBB","HUM","HUN","HUYA","HWM","IBM","ICE","IDXX","IEX","IFF","IGA","IIM","INFO","INTC","INTU","IP","IPG","IPGP","IQV","IR","IRM","ISD","ISRG","IT","ITW","IVZ","J","JAX","JBHT","JBT","JCI","JDD","JHB","JKHY","JLL","JMM","JMP","JNJ","JPM","JT","K","KEY","KEYS","KHC","KIM","KLAC","KMB","KMI","KMPR","KMX","KO","KRO","KSU","LDOS","LEN","LH","LHX","LIN","LLY","LNC","LND","LNT","LOW","LRCX","LUMN","LUV","LVS","LYV","M","MA","MAA","MAC","MAN","MANU","MAR","MAS","MAXR","MC","MCD","MCHP","MCR","MDLZ","MDT","MEG","MET","MFA","MGM","MHK","MKC","MLM","MMI","MMM","MNR","MO","MPC","MPWR","MRK","MRO","MS","MSC","MSCI","MSFT","MSGE","MSGS","MSI","MSM","MTB","MTCH","MTD","MTOR","MVT","MYN","NAD","NAN","NAVB","NBB","NCV","NEE","NEM","NFJ","NFLX","NHS","NI","NJV","NLOK","NLSN","NML","NNA","NOC","NOW","NRG","NRT","NSC","NTAP","NTIP","NTP","NUE","NVDA","NVR","NWL","NWS","NWSA","NXPI","NXRT","NYV","O","OCFT","ODFL","OGN","OKE","OMC","OOMA","ORA","ORLY","OTIS","OXY","PAA","PAYC","PAYX","PBCT","PBI","PBY","PCAR","PE","PEAK","PEG","PEN","PENN","PFE","PFG","PG","PGR","PGTI","PH","PHM","PHX","PKG","PKI","PLYM","PNC","PNR","PNW","POOL","POR","PPG","PPL","PRU","PSX","PUMP","PVH","PWR","PXD","PYPL","PYS","QCOM","QRVO","RCA","RCL","RCUS","REG","REGN","RELX","REXR","RF","RHI","RIO","RJF","RL","RMD","ROK","ROL","ROP","ROST","RSG","SBAC","SCHW","SEE","SHW","SIVB","SLB","SLCA","SMLP","SMTS","SNA","SNPS","SO","SOR","SPG","SPGI","SPH","SPY","SRE","SRL","STE","STG","STL","STT","STZ","SUP","SWK","SWKS","SWZ","SYY","T","TAP","TBC","TDG","TDS","TDY","TECH","TEL","TEN","TER","TFC","TFX","TGB","THC","THO","TJX","TMO","TMQ","TNK","TPR","TPVG","TPX","TRMB","TROW","TRV","TSCO","TSLA","TSN","TT","TTM","TTWO","TVC","TWTR","TX","TXN","TXT","TY","TYL","UA","UAA","UAL","UBP","UBS","UDR","ULTA","UNF","UNH","UNP","UPS","URI","USB","UTL","UZC","V","VCIF","VCV","VFC","VLO","VMC","VMM","VNO","VPV","VRSK","VRSN","VRTX","VTI","VTR","VTRS","VZ","WAB","WAT","WBA","WDC","WEA","WEC","WELL","WFC","WHR","WIT","WM","WMB","WMK","WMT","WRB","WRK","WST","WU","WY","WYNN","XLNX","XOM","XPO","XRAY","XTNT","XYL","YCBD","YEXT","YUM","YUMC","ZBH","ZBRA","ZION","ZNH","ZTR","ZTS"]
random.shuffle(tickers)
tickers = tickers[:100]
train_start = pd.to_datetime("2022-06-01")
test_start = pd.to_datetime('2022-09-01')
end_time = datetime.now()
predict_window = 288  # = 24 hours

In [4]:
from model import Model, SignalExpr, TorchBackend, SklearnBackend
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestRegressor
import torch
import torch.nn as nn

feature_exprs = [
    # RSI
    SignalExpr('rsi', {'window': 14, 'base': SignalExpr('candles_5min', {}, select='close')}),
    SignalExpr('rsi', {'window': 30, 'base': SignalExpr('candles_5min', {}, select='close')}),
    SignalExpr('rsi', {'window': 60, 'base': SignalExpr('candles_5min', {}, select='close')}),
    SignalExpr('rsi', {'window': 120, 'base': SignalExpr('candles_5min', {}, select='close')}),
    SignalExpr('rsi', {'window': 240, 'base': SignalExpr('candles_5min', {}, select='close')}),
    #SignalExpr('rsi', {'window': 480, 'base': SignalExpr('candles_5min', {}, select='close')}),
    #SignalExpr('rsi', {'window': 960, 'base': SignalExpr('candles_5min', {}, select='close')}),
    #SignalExpr('rsi', {'window': 1920, 'base': SignalExpr('candles_5min', {}, select='close')}),
    # KAMA
    SignalExpr('kama', {'window': 14, 'base': SignalExpr('candles_5min', {}, select='close')}),
    SignalExpr('kama', {'window': 30, 'base': SignalExpr('candles_5min', {}, select='close')}),
    SignalExpr('kama', {'window': 60, 'base': SignalExpr('candles_5min', {}, select='close')}),
    SignalExpr('kama', {'window': 120, 'base': SignalExpr('candles_5min', {}, select='close')}),
    SignalExpr('kama', {'window': 240, 'base': SignalExpr('candles_5min', {}, select='close')}),
    SignalExpr('kama', {'window': 480, 'base': SignalExpr('candles_5min', {}, select='close')}),
    SignalExpr('kama', {'window': 960, 'base': SignalExpr('candles_5min', {}, select='close')}),
    #SignalExpr('kama', {'window': 1920, 'base': SignalExpr('candles_5min', {}, select='close')}),
    # MACD
    SignalExpr('macd', {'window': 14, 'base': SignalExpr('candles_5min', {}, select='close')}),
    SignalExpr('macd', {'window': 30, 'base': SignalExpr('candles_5min', {}, select='close')}),
    SignalExpr('macd', {'window': 60, 'base': SignalExpr('candles_5min', {}, select='close')}),
    SignalExpr('macd', {'window': 120, 'base': SignalExpr('candles_5min', {}, select='close')}),
    SignalExpr('macd', {'window': 240, 'base': SignalExpr('candles_5min', {}, select='close')}),
    SignalExpr('macd', {'window': 480, 'base': SignalExpr('candles_5min', {}, select='close')}),
    SignalExpr('macd', {'window': 960, 'base': SignalExpr('candles_5min', {}, select='close')}),
    #SignalExpr('macd', {'window': 1920, 'base': SignalExpr('candles_5min', {}, select='close')}),
    # % Change
    SignalExpr('percent_change', {'window': -14, 'base': SignalExpr('candles_5min', {}, select='close')}),
    SignalExpr('percent_change', {'window': -30, 'base': SignalExpr('candles_5min', {}, select='close')}),
    SignalExpr('percent_change', {'window': -60, 'base': SignalExpr('candles_5min', {}, select='close')}),
    SignalExpr('percent_change', {'window': -120, 'base': SignalExpr('candles_5min', {}, select='close')}),
    SignalExpr('percent_change', {'window': -240, 'base': SignalExpr('candles_5min', {}, select='close')}),
    SignalExpr('percent_change', {'window': -480, 'base': SignalExpr('candles_5min', {}, select='close')}),
    SignalExpr('percent_change', {'window': -960, 'base': SignalExpr('candles_5min', {}, select='close')}),
    #SignalExpr('percent_change', {'window': -1920, 'base': SignalExpr('candles_5min', {}, select='close')}),
]

label_exprs = [
    SignalExpr('percent_change', {'window': predict_window, 'base': SignalExpr('candles_5min', {}, select='close')})
]

"""
n_outputs = 1

net = nn.Sequential(
    nn.LazyLinear(256),
    nn.ReLU(),
    nn.Linear(256, 128),
    nn.ReLU(),
    nn.Linear(128, 32),
    nn.ReLU(),
    nn.Linear(32, 16),
    nn.ReLU(),
    nn.Linear(16, 8),
    nn.ReLU(),
    nn.Linear(8, n_outputs),
)

model = Model(
    'TorchTI-SignalExpr-1', 'Test',
    TorchBackend(net, '''
        import torch.nn as nn
        net = nn.Sequential(
            nn.LazyLinear(256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 32),
            nn.ReLU(),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, 8),
            nn.ReLU(),
            nn.Linear(8, n_outputs),
        )
    '''),
    feature_exprs, label_exprs,
    StandardScaler(), StandardScaler()
)
"""

model = Model(
    'TIRF-test', 'TIRF Test',
    SklearnBackend(RandomForestRegressor(warm_start=True)),
    feature_exprs, label_exprs,
    StandardScaler(), StandardScaler()
)

In [None]:
from signal_library import FetchOptions
from model import TrainOptions

model.train(FetchOptions(tickers, (train_start, test_start)), TrainOptions(epochs=1))

In [None]:
model.eval(FetchOptions(tickers, (pd.to_datetime('2022-06-06'), datetime.now())))

In [None]:
importances = model.backend.model.feature_importances_
for name, imp in sorted(zip(list(model.features), importances), key=lambda x: x[1]):
    print(name.qualified_id(), '=', imp)

In [None]:
from strategy import ModelStrategy

strategy = ModelStrategy(model, 0.4, 0., share_count=lambda x: x*100)

In [None]:
import backtest as bt
from importlib import reload
reload(bt)

random.shuffle(tickers)
bt.comprehensive_backtest(strategy, tickers[:5], (test_start, datetime.now()), processes=1)