In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
import pathlib

from functools import partial
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pickle
import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import rc

import jax
import jax.numpy as jnp
import jax_dataclasses as jdc
from jax.tree_util import tree_flatten, tree_unflatten

gpus = jax.devices()
jax.config.update("jax_default_device", gpus[0])

import diffrax

In [None]:
import dmpe

from dmpe.evaluation.experiment_utils import (
    load_experiment_results, get_experiment_ids, get_organized_experiment_ids
)
from dmpe.evaluation.metrics_utils import default_ae, default_jsd, default_mcudsa, default_ksfc

from dmpe.utils.env_utils.pmsm_utils import plot_constraints_induced_voltage
from dmpe.utils.density_estimation import build_grid, DensityEstimate

In [None]:
import eval_dmpe
from eval_dmpe import setup_env

In [None]:
full_results_path = "/home/hvater@uni-paderborn.de/projects/forks/DMPE/eval/pmsm/results/dmpe/NODE"
organized_experiment_ids = get_organized_experiment_ids(full_results_path)
organized_experiment_ids

In [None]:
full_results_path = "/home/hvater@uni-paderborn.de/projects/forks/DMPE/eval/pmsm/results/dmpe/NODE"
organized_experiment_ids = get_organized_experiment_ids(full_results_path)

params, observations, actions, _ = load_experiment_results(
    exp_id=organized_experiment_ids[True][5000][-1],
    results_path=full_results_path,
)

env, penalty_function = setup_env(params["rpm"])

_, state = env.reset(env.env_properties)
i_d_normalizer = env.env_properties.physical_normalizations.i_d
i_q_normalizer = env.env_properties.physical_normalizations.i_q

physical_i_d = i_d_normalizer.denormalize(observations[..., 0])
physical_i_q = i_q_normalizer.denormalize(observations[..., 1])

print(state.physical_state.omega_el)

plot_constraints_induced_voltage(
    env,
    physical_i_d,
    physical_i_q,
    w_el=state.physical_state.omega_el,
    saturated=True,
    show_torque=False
)

In [None]:
full_results_path = "/home/hvater@uni-paderborn.de/projects/forks/DMPE/eval/pmsm/results/dmpe/RLS"
organized_experiment_ids = get_organized_experiment_ids(full_results_path)

params, observations, actions, _ = load_experiment_results(
    exp_id=organized_experiment_ids[True][5000][-1],
    results_path=full_results_path,
)

env, penalty_function = setup_env(params["rpm"])

_, state = env.reset(env.env_properties)
i_d_normalizer = env.env_properties.physical_normalizations.i_d
i_q_normalizer = env.env_properties.physical_normalizations.i_q

physical_i_d = i_d_normalizer.denormalize(observations[..., 0])
physical_i_q = i_q_normalizer.denormalize(observations[..., 1])

print(state.physical_state.omega_el)

plot_constraints_induced_voltage(
    env,
    physical_i_d,
    physical_i_q,
    w_el=state.physical_state.omega_el,
    saturated=True,
    show_torque=False
)

In [None]:
def _filter_valid_points(data_points, penalty_function):
    
    valid_points_bool = jax.vmap(penalty_function, in_axes=0)(data_points) == 0
    return data_points[jnp.where(valid_points_bool == True)]

filter_valid_points = lambda observations: _filter_valid_points(observations, penalty_function=lambda x: penalty_function(x, None))

## AE

In [None]:
from dmpe.utils.metrics import audze_eglais

In [None]:
default_ae(observations, actions)

In [None]:
print(audze_eglais(observations, eps=1e-16))
print(audze_eglais(filter_valid_points(observations), eps=1e-16))

In [None]:
full_results_path = "/home/hvater@uni-paderborn.de/projects/forks/DMPE/eval/pmsm/results/dmpe/NODE"
organized_experiment_ids = get_organized_experiment_ids(full_results_path)

params, observations, actions, _ = load_experiment_results(
    exp_id=organized_experiment_ids[True][5000][-1],
    results_path=full_results_path,
)

print(audze_eglais(observations, eps=0))
print(audze_eglais(observations, eps=1e-16))
print(audze_eglais(observations, eps=1e-3))


observations_ = jnp.concatenate([observations, observations[-1][None]], axis=0)

print(audze_eglais(observations_, eps=0))
print(audze_eglais(observations_, eps=1e-16))
print(audze_eglais(observations_, eps=1e-3))

In [None]:
full_results_path = "/home/hvater@uni-paderborn.de/projects/forks/DMPE/eval/pmsm/results/dmpe/RLS"
organized_experiment_ids = get_organized_experiment_ids(full_results_path)

params, observations, actions, _ = load_experiment_results(
    exp_id=organized_experiment_ids[True][5000][-1],
    results_path=full_results_path,
)

print(audze_eglais(observations, eps=1e-16))
print(audze_eglais(observations, eps=1e-3))

- it is simply a poor metric for data comparison?
- it might be a decent optimization metric, I guess?
- **I do not see any point in using it in the eval**

## MCUDSA

In [None]:
from dmpe.utils.metrics import MC_uniform_sampling_distribution_approximation as mcudsa

In [None]:
full_results_path = "/home/hvater@uni-paderborn.de/projects/forks/DMPE/eval/scripts/pmsm/results/dmpe/NODE"
organized_experiment_ids = get_organized_experiment_ids(full_results_path)

params, observations, actions, _ = load_experiment_results(
    exp_id=organized_experiment_ids[False][5000][-1],
    results_path=full_results_path,
)

print(default_mcudsa(observations, actions, ca=True))

for i in range(10, 50):
    print(default_mcudsa(observations, actions, points_per_dim=i, ca=False))

In [None]:
full_results_path = "/home/hvater@uni-paderborn.de/projects/forks/DMPE/eval/scripts/pmsm/results/dmpe/RLS"
organized_experiment_ids = get_organized_experiment_ids(full_results_path)

params, observations, actions, _ = load_experiment_results(
    exp_id=organized_experiment_ids[False][5000][-1],
    results_path=full_results_path,
)

print(default_mcudsa(observations, actions, ca=True))

for i in range(10, 50):
    print(default_mcudsa(observations, actions, points_per_dim=i, ca=False))

- seems to actually make a lot of sense

## KSFC

In [None]:
full_results_path = "/home/hvater@uni-paderborn.de/projects/forks/DMPE/eval/scripts/pmsm/results/dmpe/NODE"
organized_experiment_ids = get_organized_experiment_ids(full_results_path)

params, observations, actions, _ = load_experiment_results(
    exp_id=organized_experiment_ids[False][5000][-1],
    results_path=full_results_path,
)

print(default_ksfc(observations, actions, ca=True))
print(default_ksfc(observations, actions, ca=False))
values = []

for i in range(10, 50):
    values.append(default_ksfc(observations, actions, points_per_dim=i, variance=0.1, eps=1e-12, ca=False))

plt.plot(values)

In [None]:
full_results_path = "/home/hvater@uni-paderborn.de/projects/forks/DMPE/eval/scripts/pmsm/results/dmpe/RLS"
organized_experiment_ids = get_organized_experiment_ids(full_results_path)

params, observations, actions, _ = load_experiment_results(
    exp_id=organized_experiment_ids[False][5000][-1],
    results_path=full_results_path,
)

print(default_ksfc(observations, actions, ca=True))
print(default_ksfc(observations, actions, ca=False))

values = []

for i in range(10, 50):
    values.append(default_ksfc(observations, actions, points_per_dim=i, variance=0.1, eps=1e-12, ca=False))

plt.plot(values)

In [None]:
from dmpe.utils.density_estimation import gaussian_kernel

In [None]:
h = jnp.array([0.1])
x = jnp.linspace(-1, 1, 100)[..., None]

y = jnp.exp(-x**2 / (2 * h**2))

y_gk = jnp.squeeze(jax.vmap(gaussian_kernel, in_axes=(0, None))(x, h))

plt.plot(x, y)
plt.plot(x, y_gk)

## JSD:

- see how it behaves: Where do the large numbers come from?

In [None]:
from dmpe_params import get_target_distribution

In [None]:
jsd_params = dict(
    points_per_dim=22,
    dim=4,
    grid_extend=1.05,
    bandwidth=0.08,
    penalty_function=penalty_function
)

target_distribution = get_target_distribution(
    jsd_params["points_per_dim"],
    jsd_params["bandwidth"],
    jsd_params["grid_extend"],
    True,
    jsd_params["penalty_function"]
)
metric = partial(
    default_jsd,
    points_per_dim=jsd_params["points_per_dim"],
    bandwidth=jsd_params["bandwidth"],
    target_distribution=target_distribution,
    ca=True,
)

In [None]:
full_results_path = "/home/hvater@uni-paderborn.de/projects/forks/DMPE/eval/pmsm/results/dmpe/NODE"
organized_experiment_ids = get_organized_experiment_ids(full_results_path)

params, observations, actions, _ = load_experiment_results(
    exp_id=organized_experiment_ids[True][5000][-1],
    results_path=full_results_path,
)

In [None]:
metric(observations, actions)