In [None]:
import numpy as np
import matplotlib.pyplot as plt

from src.load import load_data

X, y, X_test = load_data(normalize=True, filter=True)
features = [
    "venue",
    "order_id",
    "action",
    "side",
    "price",
    "bid",
    "ask",
    "bid_size",
    "ask_size",
    "trade",
    "flux",
]

In [None]:
print(X.shape, y.shape)

In [None]:
def plot_sample(features, label, name="train"):
    fig, ax = plt.subplots(ncols=2, figsize=(8, 4))
    ax[0].plot(features[:, 5], label="bid")
    ax[0].plot(features[:, 6], label="ask")
    ax[0].scatter(np.arange(100), features[:, 4], label="price", s=1, color="black")
    ax[0].set_ylim(np.min(features[:, 5]) - 0.01, np.max(features[:, 6]) + 0.01)
    ax[0].legend()

    ax[1].plot(features[:, 7], label="bid_size")
    ax[1].plot(features[:, 8], label="ask_size")
    cumulated_flux = np.cumsum(features[:, 10])
    ax[1].plot(cumulated_flux, label="cumulated_flux")
    ax[1].legend()

    plt.suptitle(f"label : {label} ({name})")
    fig.tight_layout()

In [None]:
for i in range(10):
    plot_sample(X[i], y[i])

In [None]:
# Distributions between train and test

nb_features = len(features)
nb_cols = 3
nb_rows = nb_features // nb_cols + 1

figs, axs = plt.subplots(nb_rows, nb_cols, figsize=(20, 4 * nb_rows))

for i, col in enumerate(features):
    axs[i // nb_cols, i % nb_cols].boxplot(
        [X[:, :, i].reshape(-1), X_test[:, :, i].reshape(-1)],
        labels=["Train", "Test"],
        meanline=True,
        showmeans=True,
    )
    axs[i // nb_cols, i % nb_cols].set_title(col)

plt.show()

In [None]:
# Distributions between labels

nb_features = len(features)
nb_cols = 2
nb_rows = nb_features // nb_cols + 1

figs, axs = plt.subplots(nb_rows, nb_cols, figsize=(20, 4 * nb_rows))
for i, col in enumerate(features):
    data = [X[y == j][:, :, i].reshape(-1) for j in range(24)]

    axs[i // nb_cols, i % nb_cols].boxplot(
        data,
        labels=range(24),
        meanline=True,
        showmeans=True,
    )
    axs[i // nb_cols, i % nb_cols].set_title(col)

plt.show()