In [None]:
import pandas as pd
import numpy as np
import seaborn as sb
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import pickle
import tqdm
from sklearn.metrics import mean_squared_error

In [None]:
plt.rcParams["figure.figsize"] = [16, 9]
# plt.rcParams["figure.dpi"] = 300
plt.rcParams["font.size"] = 20
plt.rcParams["axes.labelsize"] = 20
plt.rcParams["axes.titlesize"] = 24
plt.rcParams["xtick.labelsize"] = 16
plt.rcParams["ytick.labelsize"] = 16
plt.rcParams["font.family"] = "serif"

In [None]:
FONT_SIZE_TITLE_PLOT = 48  # 40
FONT_SIZE_TITLE_AX = 36  # 30
FONT_SIZE_LABEL = 30  # 24
FONT_SIZE_TICKS = 24  # 20
FONT_SIZE_LEGEND = 32  # 28

In [None]:
PROJECT_FOLDER = "PycharmProjects/thesis-gan"

In [None]:
df_cross_corr_prices = pd.read_csv(f"{PROJECT_FOLDER}/data/wandb_export_2023-04-17T15_51_58.484+02_00.csv")
df_cross_corr_prices = df_cross_corr_prices[
    [
        "New data, multistock, prices, conv - val_corr_dist/KO_mid_price-KSU_mid_price",
        "New data, multistock, prices, conv - val_corr_dist/KO_mid_price-NVDA_mid_price",
        "New data, multistock, prices, conv - val_corr_dist/KO_mid_price-NVDA_mid_price",
        "New data, multistock, prices, conv - val_corr_dist/KO_mid_price-PEP_mid_price",
        "New data, multistock, prices, conv - val_corr_dist/NVDA_mid_price-KSU_mid_price",
        "New data, multistock, prices, conv - val_corr_dist/PEP_mid_price-KSU_mid_price",
        "New data, multistock, prices, conv - val_corr_dist/PEP_mid_price-NVDA_mid_price",
    ]
]
cross_corr_distance_price = df_cross_corr_prices.mean(axis=1).values

In [None]:
fig = plt.figure(1, figsize=(16, 9))
plt.plot(cross_corr_distance_price, color="C8", label="Price")

plt.xlabel("Epoch", fontsize=FONT_SIZE_LABEL)
plt.ylabel(r"$MSE(\rho(\cdot, \cdot), \rho(\cdot, \cdot))$", fontsize=FONT_SIZE_LABEL, rotation=90)
plt.xticks(fontsize=FONT_SIZE_TICKS)
plt.yticks(fontsize=FONT_SIZE_TICKS)

fig.suptitle("Average Cross-Correlation Distance", fontsize=FONT_SIZE_TITLE_AX, y=1)
fig.legend(loc="upper center", ncol=2, fontsize=FONT_SIZE_LEGEND, frameon=False, bbox_to_anchor=(0.5, 0.97))
plt.tight_layout()
plt.show()
plt.close(fig)

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(16, 9))
axes = axes.ravel()

add_label = True
for (
    ax,
    stock_name,
    real,
    synthetic,
) in zip(axes, stock_names, price_real_, price_pred_):
    ax.plot(history_indexes, real[:390], color="C0", label="Observed" if add_label else None)
    ax.plot(continuation_indexes, real[390:], color="C1", label="Real continuation" if add_label else None)
    ax.plot(continuation_indexes, synthetic[390:], color="C2", label="Synthetic continuation" if add_label else None)
    ax.axvline(x=390, color="r")

    ax.set_title(stock_name, fontsize=FONT_SIZE_TITLE_AX)
    ax.set_xlabel("Steps", fontsize=FONT_SIZE_LABEL)
    ax.set_ylabel("Price ($)", fontsize=FONT_SIZE_LABEL, rotation=90)
    ax.xaxis.set_tick_params(labelsize=FONT_SIZE_TICKS)
    ax.yaxis.set_tick_params(labelsize=FONT_SIZE_TICKS)
    ax.set_xticks(list(ax.get_xticks()[1:-1]) + [390])

    add_label = False

fig.suptitle("Prices", fontsize=FONT_SIZE_TITLE_PLOT, y=1)
fig.legend(loc="upper center", ncol=3, fontsize=FONT_SIZE_LEGEND, frameon=False, bbox_to_anchor=(0.5, 0.97))
fig.tight_layout()
# plt.savefig(f"{PROJECT_FOLDER}/plot_finali/multistock/prices.pdf")
plt.show()
plt.close(fig)