In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pickle
import os
from sklearn.metrics import mean_squared_error

In [None]:
plt.rcParams["figure.figsize"] = [16, 9]
plt.rcParams["font.size"] = 20
plt.rcParams["axes.labelsize"] = 20
plt.rcParams["axes.titlesize"] = 24
plt.rcParams["xtick.labelsize"] = 16
plt.rcParams["ytick.labelsize"] = 16
plt.rcParams["font.family"] = "serif"

In [None]:
PROJECT_FOLDER = "PycharmProjects/thesis-gan"

In [None]:
FONT_SIZE_TITLE_PLOT = 48  # 40
FONT_SIZE_TITLE_AX = 36  # 30
FONT_SIZE_LABEL = 30  # 24
FONT_SIZE_TICKS = 24  # 20
FONT_SIZE_LEGEND = 32  # 28

In [None]:
stock_names = ["KO", "PEP", "NVDA", "KSU"]

In [None]:
PATH_PICKLE_REAL_PRICE = PROJECT_FOLDER + "/storage/thesis-gan/13v3dpxg/reals.pickle"
with open(PATH_PICKLE_REAL_PRICE, "rb") as handle:
    real_price_dict = pickle.load(handle)
real_prices = real_price_dict["prices"]
real_prices.shape

In [None]:
RUN_ID_PRICES = os.listdir(PROJECT_FOLDER + "/storage/thesis-gan/diversity_val")
FILE_NAMES = list()
for RUN_ID_PRICE in RUN_ID_PRICES:
    pred_file_name = [
        f_name
        for f_name in os.listdir(f"{PROJECT_FOLDER}/storage/thesis-gan/diversity_val/{RUN_ID_PRICE}")
        if f_name.startswith("preds")
    ]
    if len(pred_file_name) > 0:
        pred_file_name = pred_file_name[0]
    FILE_NAMES.append(pred_file_name)

In [None]:
l = list()
for RUN_ID_PRICE, FILE_NAME in zip(RUN_ID_PRICES, FILE_NAMES):
    file_path = f"{PROJECT_FOLDER}/storage/thesis-gan/diversity_val/{RUN_ID_PRICE}/{FILE_NAME}"
    with open(file_path, "rb") as handle:
        pred_price_dict = pickle.load(handle)
    pred_prices = pred_price_dict["pred_prices"]
    if pred_prices.shape == (4, 9360):
        l.append(pred_prices)

In [None]:
all_pred_prices = np.stack(l)
all_pred_prices.shape

In [None]:
corrcoef_real = np.corrcoef(real_prices)
mses = list()
for pred_prices in all_pred_prices:
    corrcoef_pred = np.corrcoef(pred_prices)
    mse_err = mean_squared_error(corrcoef_real, corrcoef_pred)
    mses.append(mse_err)
mses = np.asarray(mses)
good_indexes = np.where(mses <= 0.23)
good_pred_prices = all_pred_prices[good_indexes]
good_mses = mses[good_indexes]
good_mses

In [None]:
price_real = real_prices / 10000
good_pred_prices = good_pred_prices / 10000
good_pred_prices = np.transpose(good_pred_prices, axes=[1, 0, 2])
price_real.shape, good_pred_prices.shape

In [None]:
history_indexes = np.arange(390)
continuation_indexes = np.arange(390, price_real.shape[1])
history_indexes.shape, continuation_indexes.shape

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(16, 9))
axes = axes.ravel()

add_label = True
for ax, stock_name, real, good_pred_price in zip(axes, stock_names, price_real, good_pred_prices):
    ax.plot(history_indexes, real[:390], color="C0", label="Observed" if add_label else None)
    ax.plot(continuation_indexes, real[390:], color="C1", label="Real" if add_label else None)
    for i, (synthetic, mse) in enumerate(zip(good_pred_price, good_mses)):
        ax.plot(
            continuation_indexes,
            synthetic[390:],
            color=f"C{i+3}",  # label=round(mse, 2) if add_label else None
        )
    ax.axvline(x=390, color="r")

    ax.set_title(f"{stock_name}", fontsize=FONT_SIZE_TITLE_AX)
    ax.set_xlabel("Steps", fontsize=FONT_SIZE_LABEL)
    ax.set_ylabel("Price ($)", fontsize=FONT_SIZE_LABEL, rotation=90)
    ax.xaxis.set_tick_params(labelsize=FONT_SIZE_TICKS)
    ax.yaxis.set_tick_params(labelsize=FONT_SIZE_TICKS)
    ax.set_xticks(list(ax.get_xticks()[2:-1]) + [390])

    add_label = False

# fig.suptitle("Prices - Diversity", fontsize=FONT_SIZE_TITLE_PLOT, y=1)
fig.legend(
    loc="upper center",
    ncol=2,
    fontsize=FONT_SIZE_LEGEND,
    frameon=False,
    # bbox_to_anchor=(0.5, 1.06)
)
fig.tight_layout(rect=[0, 0, 1, 0.95])
plt.savefig(PROJECT_FOLDER + "/plot_finali/multistock/diversity_prices.pdf")
# plt.show()
plt.close(fig)