In [1]:
import pandas as pd
import numpy as np
from aeon.transformations.collection.dictionary_based import SAX
from utils import vae_encoding, Params
from model import VAE
import torch

In [2]:
prices_df = pd.read_csv("../datasets/stocks/nasdaq_prices.csv", index_col=0)
prices_df.index = pd.to_datetime(prices_df.index)

# todo: split data to in-sample and out-of-sample

# monthly returns
returns_df = (prices_df
              .pct_change()
              .dropna()
              .resample('MS')
              .agg(lambda x: (x + 1).prod() - 1)
              )
returns_np = np.expand_dims(returns_df.T.values, axis=1)

returns_df.head()

Unnamed: 0_level_0,AAPL,ADBE,ADI,ADP,ADSK,AEP,ALGN,AMAT,AMD,AMGN,...,SIRI,SNPS,TMUS,TSLA,TXN,VRSK,VRTX,WBA,WBD,XEL
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2011-01-01,0.029584,0.056248,0.023189,0.018066,0.035905,-0.015724,0.048314,0.115931,-0.075561,-0.008461,...,-0.04142,0.005187,-0.010711,-0.094666,0.039992,-0.018566,0.080578,0.028484,-0.07165,0.000425
2011-02-01,0.040935,0.043873,0.027041,0.043842,0.033677,0.015768,0.00096,0.051731,0.176245,-0.068083,...,0.117284,0.021747,0.113689,-0.008714,0.050133,-0.043749,0.200051,0.07616,0.105385,0.015697
2011-03-01,-0.013307,-0.038841,-0.00696,0.033513,0.048989,-0.017887,-0.017746,-0.0493,-0.066232,0.041301,...,-0.088398,-0.002525,0.127778,0.161574,-0.029486,0.012674,0.026998,-0.073835,-0.074461,0.008684
2011-04-01,0.004649,0.011761,0.023617,0.059443,0.019723,0.038133,0.178711,0.004482,0.05814,0.063611,...,0.206061,-0.009403,0.03633,-0.005405,0.031803,0.004274,0.14855,0.064275,0.109273,0.018418
2011-05-01,-0.006569,0.032191,0.027686,0.013796,-0.044464,0.060456,0.014913,-0.116552,-0.046154,0.064908,...,0.180904,-0.001825,0.063577,0.092029,-0.006473,0.034954,-0.019255,0.025352,-0.015816,0.016851


#### SAX encoding

In [3]:
n_segments_sax = 32
alphabet_size_sax = 16

sax = SAX(n_segments=n_segments_sax, alphabet_size=alphabet_size_sax)
returns_sax = sax.fit_transform(returns_np).squeeze()


#### VAE encoding

In [4]:
model_path = "../baseline_models/fc/model.pt"
params_path = "../baseline_models/fc/params.json"

params = Params(params_path)

# Get VAE hyperparams
patch_len, alphabet_size, n_latent, arch = params.patch_len, params.alphabet_size, params.n_latent, params.arch

# Init VAE, load parameters
vae = VAE(patch_len, alphabet_size, n_latent, arch)
vae.load_state_dict(torch.load(model_path))
vae.eval()

returns_vae = vae_encoding(vae, returns_np, patch_len)