In [None]:
import pandas as pd
import seaborn as sb
import matplotlib.pyplot as plt
import pickle
from sklearn.metrics import mean_squared_error

In [None]:
plt.rcParams["figure.figsize"] = [16, 9]
plt.rcParams["figure.dpi"] = 300
plt.rcParams["font.size"] = 20
plt.rcParams["axes.labelsize"] = 20
plt.rcParams["axes.titlesize"] = 24
plt.rcParams["xtick.labelsize"] = 16
plt.rcParams["ytick.labelsize"] = 16
plt.rcParams["font.family"] = "serif"

In [None]:
DATA_REALS = "../data/ohlc_DowJones"
!ls $DATA_REALS/

In [None]:
DATA_PREDS = "../storage/thesis-gan/abllixzn"
!ls $DATA_PREDS/

In [None]:
reals = pd.read_csv(f"{DATA_REALS}/ohlc_DowJones_val.csv").filter(regex="mid_price")

In [None]:
cols = list(reals.columns.values)
d = dict()
for col in cols:
    d[col] = col.split("_")[2]
reals = reals.rename(columns=d)
reals = reals.iloc[:8490, :]
reals_corr = reals.corr(numeric_only=True)

In [None]:
EPOCHS = [
    280,
    279,
    268,
    258,
    252,
    241,
    235,
    230,
    229,
    223,
    217,
    216,
    211,
    206,
    196,
    194,
    189,
    185,
    182,
    181,
    171,
    167,
    160,
    159,
    146,
    129,
    128,
    127,
    126,
    125,
]

In [None]:
epoch2corrs = dict()
for epoch in EPOCHS:
    epoch2corrs[epoch] = dict()

    PATH_PICKLE_PRICE = DATA_PREDS + f"/preds_epoch={epoch}-target_price=mid_price-target_volume=None.pickle"
    with open(PATH_PICKLE_PRICE, "rb") as handle:
        pred_prices_dict = pickle.load(handle)
    preds = pd.DataFrame(pred_prices_dict["pred_prices"].squeeze().numpy().T)
    preds.columns = reals.columns

    preds_corr = preds.corr(numeric_only=True)
    mse = round(mean_squared_error(reals_corr, preds_corr), 2)
    diffs_corr = (reals_corr - preds_corr) / reals_corr

    epoch2corrs[epoch]["preds_corr"] = preds_corr
    epoch2corrs[epoch]["diffs_corr"] = diffs_corr
    epoch2corrs[epoch]["mse"] = mse

In [None]:
mses = list()
for epoch, v in epoch2corrs.items():
    mses.append(v["mse"])
min(zip(EPOCHS, mses), key=lambda x: x[1])

# Correlations

In [None]:
plt.rcParams["font.size"] = 8

In [None]:
EPOCH = 241
preds_corr = epoch2corrs[EPOCH]["preds_corr"]
diffs_corr = epoch2corrs[EPOCH]["diffs_corr"]
mse = epoch2corrs[EPOCH]["mse"]
mse

In [None]:
sb.heatmap(reals_corr, cmap="Blues", annot=True, fmt=".2f")
plt.title("Real correlations")
plt.show()
plt.close()

In [None]:
sb.heatmap(preds_corr, cmap="Blues", annot=True, fmt=".2f")
plt.title("Pred correlations")
plt.show()
plt.close()

In [None]:
sb.heatmap(diffs_corr, cmap="Blues", annot=True, fmt=".2f")
plt.title("Difference correlations")
plt.show()
plt.close()