In [None]:
import pandas as pd
import numpy as np
import seaborn as sb
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import pickle
import tqdm
from sklearn.metrics import mean_squared_error

In [None]:
plt.rcParams["figure.figsize"] = [16, 9]
# plt.rcParams["figure.dpi"] = 300
plt.rcParams["font.size"] = 20
plt.rcParams["axes.labelsize"] = 20
plt.rcParams["axes.titlesize"] = 24
plt.rcParams["xtick.labelsize"] = 16
plt.rcParams["ytick.labelsize"] = 16
plt.rcParams["font.family"] = "serif"

In [None]:
DATA_REALS = "../data/ohlc_DowJones"

In [None]:
DATA_PREDS = "../storage/thesis-gan/abllixzn"

In [None]:
reals = pd.read_csv(f"{DATA_REALS}/ohlc_DowJones_val.csv").filter(regex="mid_price")

In [None]:
cols = list(reals.columns.values)
d = dict()
for col in cols:
    d[col] = col.split("_")[2]
reals = reals.rename(columns=d)
reals = reals.iloc[:8490, :]
reals_corr = reals.corr(numeric_only=True)

In [None]:
mses = list()
for epoch in tqdm.tqdm(range(300)):
    PATH_PICKLE_PRICE = DATA_PREDS + f"/preds_epoch={epoch}-target_price=mid_price-target_volume=None.pickle"
    with open(PATH_PICKLE_PRICE, "rb") as handle:
        pred_prices_dict = pickle.load(handle)
    preds = pd.DataFrame(pred_prices_dict["pred_prices"].squeeze().numpy().T)
    preds.columns = reals.columns
    preds_corr = preds.corr(numeric_only=True)
    mse = mean_squared_error(reals_corr, preds_corr)
    mses.append(mse)

In [None]:
min_epoch, min_mse = min(zip(range(300), mses), key=lambda x: x[1])
min_epoch, min_mse

In [None]:
plt.plot(mses)
plt.scatter(min_epoch, min_mse, c="r", label=f"Minimum: Epoch={min_epoch} MSE={round(min_mse, 2)}")
plt.legend(loc="upper center", fontsize=15)
plt.title("Average Correlation Distance - Prices")
plt.xlabel("Epoch")
plt.tight_layout()
# plt.savefig(f"/Users/giuseppemasi/PycharmProjects/thesis-gan/storage/thesis-gan/cross_corr_dist_DowJones.png")
plt.show()
plt.close()

In [None]:
stock_names = [
    "AAPL",
    "AMGN",
    "AXP",
    "BA",
    "CAT",
    "CRM",
    "CSCO",
    "CVX",
    "DIS",
    "GS",
    "HD",
    "HON",
    "IBM",
    "INTC",
    "JNJ",
    "JPM",
    "KO",
    "MCD",
    "MMM",
    "MRK",
    "MSFT",
    "NKE",
    "PG",
    "TRV",
    "UNH",
    "V",
    "VZ",
    "WBA",
    "WMT",
]

In [None]:
chosen_epoch = 197

In [None]:
mses[chosen_epoch]

In [None]:
PATH_PICKLE_PRICE = DATA_PREDS + f"/preds_epoch={chosen_epoch}-target_price=mid_price-target_volume=None.pickle"
with open(PATH_PICKLE_PRICE, "rb") as handle:
    pred_prices_dict = pickle.load(handle)
preds = pd.DataFrame(pred_prices_dict["pred_prices"].squeeze().numpy().T)
preds.columns = reals.columns

In [None]:
reals = reals.to_numpy()
preds = preds.to_numpy()

In [None]:
reals.shape, preds.shape

In [None]:
history_indexes = np.arange(390)
continuation_indexes = np.arange(390, reals.shape[0])
history_indexes.shape, continuation_indexes.shape

In [None]:
history = reals[:390, :].T
reals_continuation = reals[390:, :].T
preds_continuation = preds[390:, :].T
history.shape, reals_continuation.shape, preds_continuation.shape

In [None]:
fig, ax = plt.subplots(5, 6)
legend_elements = [
    Line2D([0], [0], color="C0", lw=2, label="Observed"),
    Line2D([0], [0], color="C1", lw=2, label="Real continuation"),
    Line2D([0], [0], color="C2", lw=2, label="Predicted continuation"),
]

for i in range(5):
    for j in range(6):
        index = i * 6 + j
        ax[i, j].set_xticklabels([])
        ax[i, j].set_yticklabels([])
        if index == 29:
            continue
        ax[i, j].set_title(f"{stock_names[index]}", fontsize=20)
        ax[i, j].plot(
            history_indexes,
            history[index],
            color="C0",
        )
        ax[i, j].plot(
            continuation_indexes,
            reals_continuation[index],
            color="C1",
        )
        ax[i, j].plot(
            continuation_indexes,
            preds_continuation[index],
            color="C2",
        )

fig.suptitle("Prices", fontsize=24, y=1.04)
fig.legend(handles=legend_elements, loc="upper center", ncol=3, fontsize=15, bbox_to_anchor=(0.5, 1))
fig.tight_layout()
# plt.savefig("../storage/thesis-gan/prices_DowJones.png")
plt.show()
plt.close(fig)

# Correlations

In [None]:
plt.rcParams["font.size"] = 8

In [None]:
sb.heatmap(reals_corr, cmap="Blues", annot=True, fmt=".2f")
plt.title("Real correlations")
plt.show()
plt.close()

In [None]:
sb.heatmap(preds_corr, cmap="Blues", annot=True, fmt=".2f")
plt.title("Pred correlations")
plt.show()
plt.close()

In [None]:
diffs_corr = (reals_corr - preds_corr) / reals_corr
sb.heatmap(diffs_corr, cmap="Blues", annot=True, fmt=".2f")
plt.title("Difference correlations")
plt.show()
plt.close()