In [None]:
import numpy as np
from tqdm import tqdm
from itertools import product
from select_deepc.deepc_utils import DeePCDims, load_data_from_folder
from select_deepc.deepc_controller import IsoMapEmbeddedDistances
import matplotlib.pyplot as plt

In [None]:
plt.rcParams.update(
    {
        "text.usetex": True,
        "font.family": "serif",
        "font.sans-serif": "Times",
        "font.size": 10,
    }
)

In [None]:
iid_trajectories = load_data_from_folder("data/rocket/dataset_iid", "iid")
random_walk_trajectories = load_data_from_folder(
    "data/rocket/dataset_random_walk", "random_walk"
)
random_walk_trajectories_2 = load_data_from_folder(
    "data/rocket/dataset_random_walk_2", "random_walk_2"
)

def isomap_test(trajectory_set):
    isomap_settings = []
    reconstruction_errors = []
    n_components = np.logspace(4, 7, 4, base=2, dtype=int)
    for n_neighbors, n_comp in tqdm(
        product(np.logspace(1, 6.6, 10, base=2, dtype=int), n_components), total=4 * 10
    ):
        embedder = IsoMapEmbeddedDistances(
            trajectory_set,
            DeePCDims(1, 30, 6, 3),
            n_components=n_comp,
            n_neighbors=n_neighbors,
        )
        isomap_settings.append((n_comp, n_neighbors))
        reconstruction_errors.append(
            embedder._embedding.steps[-1][-1].reconstruction_error()
        )

    return isomap_settings, reconstruction_errors


def plot_isomap(fig, ax, settings, error, dataset):
    n_components = sorted(set([components for components, _ in isomap_settings]))
    for curr_n_components, color in zip(
        n_components,
        ["navy", "steelblue", "royalblue", "lightsteelblue"],
    ):
        ax.loglog(
            [
                neighbors
                for components, neighbors in isomap_settings
                if components == curr_n_components
            ],
            [
                error
                for error, (components, _) in zip(reconstruction_error, isomap_settings)
                if components == curr_n_components
            ],
            # reconstruction_error[n_components.size*i:n_components.size*(i+1)],
            label=f"Dim {int(curr_n_components)}",
            color=color,
        )
    ax.set_title(dataset.dataset_name)


# for dataset in [random_walk_trajectories, random_walk_trajectories_2, iid_trajectories]:
#     isomap_settings, reconstruction_error = isomap_test(dataset)
#     np.savetxt(f"data/isomap_dataset_test/{dataset.dataset_name}_normalized_settings.csv", np.array(isomap_settings), delimiter=",")
#     np.savetxt(f"data/isomap_dataset_test/{dataset.dataset_name}_normalized_error.csv", np.array(reconstruction_error), delimiter=",")
fig, axs = plt.subplots(1,2, figsize=(5.6,3.5), sharex=True, sharey=True)
for idx, dataset in enumerate([random_walk_trajectories, iid_trajectories]):
    isomap_settings = np.loadtxt(
        f"data/rocket/isomap_dataset_test/{dataset.dataset_name}_normalized_settings.csv",
        delimiter=",",
    )
    reconstruction_error = np.loadtxt(
        f"data/rocket/isomap_dataset_test/{dataset.dataset_name}_normalized_error.csv",
        delimiter=",",
    )
    if dataset.dataset_name == "iid":
        dataset.dataset_name = "IID"
    if dataset.dataset_name == "random_walk":
        dataset.dataset_name = "Random Walk"

    plot_isomap(fig, axs[idx], isomap_settings, reconstruction_error, dataset)

fig.suptitle(f"Isomap Embedding Error")
axs[0].set_ylabel("Reconstruction Error")
axs[0].set_xlabel("\# Graph Neighbors")
axs[1].set_xlabel("\# Graph Neighbors")

axs[0].legend(loc="upper left")
fig.tight_layout()
plt.savefig(
    f"figures/isomap_reconstruction_error/isomap_error_combined.pdf",
    bbox_inches="tight",
)
plt.show()