In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
import itertools
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pathlib
from functools import partial

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['text.usetex'] = True
mpl.rcParams.update({'font.size': 10 * 2.54})
mpl.rcParams['text.latex.preamble']=r"\usepackage{bm}\usepackage{amsmath}"
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax.numpy as jnp

In [None]:
import exciting_environments as excenvs

from dmpe.utils.density_estimation import select_bandwidth
from dmpe.utils.signals import aprbs
from dmpe.models import NeuralEulerODE
from dmpe.algorithms import excite_with_dmpe
from dmpe.related_work.algorithms import excite_with_sGOATS
from dmpe.utils.density_estimation import DensityEstimate
from dmpe.evaluation.experiment_utils import (
    get_experiment_ids, load_experiment_results, quick_eval, evaluate_experiment_metrics, evaluate_algorithm_metrics, evaluate_metrics
)

In [None]:
results_path = pathlib.Path("/home/hvater@uni-paderborn.de/projects/forks/DMPE/eval/results/igoats/pmsm/")

identifier = get_experiment_ids(results_path)[-1]
params, igoats_observations, igoats_actions, model = load_experiment_results(
    exp_id=identifier, results_path=results_path, model_class=None
)

In [None]:
results_path = pathlib.Path("/home/hvater@uni-paderborn.de/projects/forks/DMPE/eval/results/dmpe/pmsm/")

identifier = get_experiment_ids(results_path)[-1]
params, dmpe_observations, dmpe_actions, model = load_experiment_results(
    exp_id=identifier, results_path=results_path, model_class=None
)


In [None]:
# what should the qualitative plots look like?

In [None]:
tau = 1e-4
labels=[r"$\tilde{i_d}$", r"$\tilde{i_q}$", r"$\tilde{u_d}$", r"$\tilde{u_q}$"]

full_column_width = 18.2
half_colmun_width = 8.89

In [None]:
igoats_data = jnp.concatenate([igoats_observations, igoats_actions], axis=-1)
dmpe_data = jnp.concatenate([dmpe_observations[0:-1], dmpe_actions], axis=-1)

all_data = [dmpe_data, igoats_data]
algo_names = ["$\mathrm{DMPE}$", "$\mathrm{iGOATS}$"]

In [None]:
# def plot_feature_combinations(all_data, labels, algo_names):

#     n_features = all_data[0].shape[-1]
#     for data in all_data:
#         assert data.shape[-1] == n_features

#     n_algos = len(all_data)    

#     all_combinations = list(itertools.combinations(np.arange(n_features), 2))
#     n_combinations = len(all_combinations)
    
#     # fig, axs = plt.subplots(nrows=n_algos, ncols=n_combinations, figsize=(full_column_width, full_column_width/3), sharex=True)
#     fig, axs = plt.subplots(nrows=n_combinations, ncols=n_algos, figsize=(full_column_width/3, full_column_width), sharey=True)
    
#     for comb_idx, (i, j) in enumerate(all_combinations):      

#         density_estimates = []
        
#         for algo_idx, (algo_name, data) in enumerate(zip(algo_names, all_data)):
#             density_estimate = DensityEstimate.from_dataset(
#                 jnp.concatenate([data[..., i][..., None], data[..., j][..., None]], axis=-1)[None],
#                 points_per_dim=100,
#                 bandwidth=0.05,
#             )
#             density_estimates.append(density_estimate)

#         stacked_p = jnp.concatenate([density_estimate.p for density_estimate in density_estimates], axis=0)
#         maximum_p_value = jnp.max(stacked_p)
#         levels = np.linspace(0, maximum_p_value, 50).round(4)

#         print(maximum_p_value)
        
#         for algo_idx, (algo_name, density_estimate) in enumerate(zip(algo_names, density_estimates)):
            
#             p_est = density_estimate.p
#             x = density_estimate.x_g

#             grid_len_per_dim = int(np.sqrt(x.shape[0]))
#             x_plot = x.reshape((grid_len_per_dim, grid_len_per_dim, 2))

#             cax = axs[comb_idx, algo_idx].contourf(
#             #cax = axs[algo_idx, comb_idx].contourf(
#                 x_plot[..., 0],
#                 x_plot[..., 1],
#                 p_est.reshape(x_plot.shape[:-1]),
#                 antialiased=False,
#                 levels=levels,
#                 alpha=1.0,
#                 cmap=plt.cm.coolwarm,
#             )
        
#             axs[comb_idx, 0].set_ylabel(labels[j])
#             axs[comb_idx, algo_idx].set_xlabel(labels[i])
#         #fig.colorbar(cax)
#     plt.subplots_adjust(hspace=0.02)
    
#     plt.tight_layout(pad=0.05)

In [None]:
# fig = plot_feature_combinations(all_data, labels, algo_names)
# #mpl.rcParams.update({'figure.autolayout': True})

# plt.savefig("pmsm_qualitative_comparison.pdf")

In [None]:
# from dmpe.evaluation.plotting_utils import plot_feature_combinations

In [None]:
def get_contour_levels(all_data):
    n_levels = 50
    n_features = all_data[0].shape[-1]
    all_levels = np.zeros((n_features, n_features, n_levels)) 
    
    for i in range(n_features):
        for j in range(n_features):

            density_estimates = []
            
            for algo_idx, data in enumerate(all_data):
                density_estimate = DensityEstimate.from_dataset(
                    jnp.concatenate([data[..., i][..., None], data[..., j][..., None]], axis=-1)[None],
                    points_per_dim=100,
                    bandwidth=0.05,
                )
                density_estimates.append(density_estimate)
    
            stacked_p = jnp.concatenate([density_estimate.p for density_estimate in density_estimates], axis=0)
            maximum_p_value = jnp.max(stacked_p) + 0.03
            all_levels[i, j] =  np.linspace(0, maximum_p_value, 50)

    return all_levels

In [None]:
all_levels = get_contour_levels([dmpe_data, igoats_data])

In [None]:
def plot_feature_combinations(data, labels, mode="plot", all_levels=None):
    """Plot all combinations of the data set."""
    assert data.shape[-1] == len(labels)
    assert data.ndim == 2

    n_features = data.shape[-1]

    fig, axs = plt.subplots(nrows=n_features, ncols=n_features, figsize=(half_colmun_width, half_colmun_width), sharex=True, sharey=True)
    
    for i in range(n_features):
        for j in range(n_features):
            if mode == "plot":
                axs[j, i].scatter(data[..., i], data[..., j], s=0.1)
            elif mode == "contourf":
                density_estimate = DensityEstimate.from_dataset(
                    jnp.concatenate([data[..., i][..., None], data[..., j][..., None]], axis=-1)[None],
                    points_per_dim=100,
                    bandwidth=0.05,
                )

                p_est = density_estimate.p
                x = density_estimate.x_g                
                grid_len_per_dim = int(np.sqrt(x.shape[0]))
                x_plot = np.array(x.reshape((grid_len_per_dim, grid_len_per_dim, 2)))
                cax = axs[j, i].contourf(
                    x_plot[..., 0],
                    x_plot[..., 1],
                    p_est.reshape(x_plot.shape[:-1]),
                    antialiased=False,
                    levels=50 if all_levels is None else all_levels[i, j, :],
                    alpha=0.9,
                    cmap=plt.cm.coolwarm,
                )

            axs[j, 0].set_ylabel(labels[j])

            axs[j, i].grid(True)
            axs[j, i].set_xlim(-1.02, 1.02)
            axs[j, i].set_ylim(-1.02, 1.02)                
        
        axs[-1, i].set_xlabel(labels[i])
    fig.tight_layout(pad=0.01)

    return fig

In [None]:
fig = plot_feature_combinations(
    dmpe_data,
    labels=["$\\tilde{i}_d$", "$\\tilde{i}_q$", "$\\tilde{v}_d$", "$\\tilde{v}_q$"],
    mode="contourf",
    all_levels = None #get_contour_levels([dmpe_data, igoats_data])
);
plt.savefig("results/qualitative/pmsm_dmpe_qualitative.pdf")
# plt.savefig("results/qualitative/pmsm_dmpe_qualitative.png", dpi=200)

In [None]:
import dmpe

In [None]:
dmpe.utils.density_estimation.select_bandwidth(
    delta_x=2,
    dim=4,
    n_g=100,
    percentage=0.5,
)

In [None]:
fig = plot_feature_combinations(
    igoats_data,
    labels=["$\\tilde{i}_d$", "$\\tilde{i}_q$", "$\\tilde{v}_d$", "$\\tilde{v}_q$"],
    mode="contourf",
    all_levels = get_contour_levels([dmpe_data, igoats_data])
);
plt.savefig("results/qualitative/pmsm_dmpe_qualitative.pdf")
# plt.savefig("results/qualitative/pmsm_igoats_qualitative.png", dpi=200)