In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pickle

In [None]:
plt.rcParams["figure.figsize"] = [16, 9]
# plt.rcParams["figure.dpi"] = 300
plt.rcParams["font.size"] = 20
plt.rcParams["axes.labelsize"] = 20
plt.rcParams["axes.titlesize"] = 24
plt.rcParams["xtick.labelsize"] = 16
plt.rcParams["ytick.labelsize"] = 16
plt.rcParams["font.family"] = "serif"

In [None]:
FONT_SIZE_TITLE_PLOT = 48  # 40
FONT_SIZE_TITLE_AX = 36  # 30
FONT_SIZE_LABEL = 30  # 24
FONT_SIZE_TICKS = 24  # 20
FONT_SIZE_LEGEND = 32  # 28

In [None]:
PROJECT_FOLDER = "PycharmProjects/thesis-gan"

In [None]:
stock_names = ["PEP", "KO", "NVDA", "KSU"]
stock_names = ["KO"]
n_stocks = len(stock_names)

In [None]:
RUN_ID_PRICE = "3m9c18s6"

In [None]:
PATH_PICKLE_PRICE_REALS = f"{PROJECT_FOLDER}/storage/thesis-gan/{RUN_ID_PRICE}/reals.pickle"
PATH_PICKLE_PRICE_NOPER = f"{PROJECT_FOLDER}/storage/thesis-gan/{RUN_ID_PRICE}/perturbations/no_perturbation.pickle"
PATH_PICKLE_PRICE_PER = f"{PROJECT_FOLDER}/storage/thesis-gan/{RUN_ID_PRICE}/perturbations/perturbation.pickle"

In [None]:
with open(PATH_PICKLE_PRICE_REALS, "rb") as handle:
    dict_reals = pickle.load(handle)
with open(PATH_PICKLE_PRICE_NOPER, "rb") as handle:
    dict_no_per = pickle.load(handle)
with open(PATH_PICKLE_PRICE_PER, "rb") as handle:
    dict_per = pickle.load(handle)

In [None]:
price_no_per = dict_no_per["pred_prices"].numpy().squeeze()
price_per = dict_per["pred_prices"].numpy().squeeze()
price_reals = dict_reals["prices"].squeeze()[: len(price_per)]

In [None]:
history_indexes = np.arange(390)
continuation_indexes = np.arange(390, len(price_per))
history_indexes.shape, continuation_indexes.shape

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(16, 9))

ax.plot(history_indexes, price_no_per[:390], color="C0", label="Observed")
# ax.plot(continuation_indexes, price_reals[390:], color="C1", label="Real")
ax.plot(continuation_indexes, price_no_per[390:], color="C2", label="W/o perturbation")
ax.plot(continuation_indexes, price_per[390:], color="C3", label="W/ perturbation")
ax.axvline(x=1140, color="r")
ax.axvline(x=1290, color="r")

ax.set_title(stock_names[0], fontsize=FONT_SIZE_TITLE_AX)
ax.set_xlabel("Steps", fontsize=FONT_SIZE_LABEL)
ax.set_ylabel("Price ($)", fontsize=FONT_SIZE_LABEL, rotation=90)
ax.xaxis.set_tick_params(labelsize=FONT_SIZE_TICKS)
ax.yaxis.set_tick_params(labelsize=FONT_SIZE_TICKS)
ax.set_xticks(list(ax.get_xticks()[2:-1]) + [390])


fig.suptitle("Prices", fontsize=FONT_SIZE_TITLE_PLOT, y=1)
fig.legend(loc="upper center", ncol=4, fontsize=FONT_SIZE_LEGEND, frameon=False, bbox_to_anchor=(0.5, 0.97))
fig.tight_layout()
plt.savefig(f"{PROJECT_FOLDER}/plot_finali/multistock/perturbations/perturbed_KO.png")
# plt.show()
plt.close(fig)

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(16, 9))
axes = axes.ravel()

add_label = True
for (
    ax,
    stock_name,
    price_no_per,
    price_per,
) in zip(axes, stock_names, prices_no_per, prices_per):
    ax.plot(history_indexes, price_no_per[:390], color="C0", label="Observed" if add_label else None)
    ax.plot(continuation_indexes, price_no_per[390:], color="C1", label="W/o perturbation" if add_label else None)
    ax.plot(continuation_indexes, price_per[390:], color="C2", label="W/ perturbation" if add_label else None)
    ax.axvline(x=1140, color="r")

    ax.set_title(stock_name, fontsize=FONT_SIZE_TITLE_AX)
    ax.set_xlabel("Steps", fontsize=FONT_SIZE_LABEL)
    ax.set_ylabel("Price ($)", fontsize=FONT_SIZE_LABEL, rotation=90)
    ax.xaxis.set_tick_params(labelsize=FONT_SIZE_TICKS)
    ax.yaxis.set_tick_params(labelsize=FONT_SIZE_TICKS)
    ax.set_xticks(list(ax.get_xticks()[2:-1]) + [390])

    add_label = False

fig.suptitle("Prices", fontsize=FONT_SIZE_TITLE_PLOT, y=1)
fig.legend(loc="upper center", ncol=3, fontsize=FONT_SIZE_LEGEND, frameon=False, bbox_to_anchor=(0.5, 0.97))
fig.tight_layout()
# plt.savefig(f"{PROJECT_FOLDER}/plot_finali/multistock/perturbations/perturbed_KO.pdf")
plt.show()
plt.close(fig)