In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pathlib
from functools import partial

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['text.usetex'] = True
mpl.rcParams.update({'font.size': 10 * 2.54})
mpl.rcParams['text.latex.preamble']=r"\usepackage{bm}"
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp
# jax.config.update("jax_enable_x64", True)
gpus = jax.devices()
jax.config.update("jax_default_device", gpus[0])
# jax.config.update('jax_platform_name', 'cpu')
import chex

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
import exciting_environments as excenvs

import dmpe
from dmpe.models import NeuralEulerODEPendulum, NeuralODEPendulum, NeuralEulerODE, NeuralEulerODECartpole
from dmpe.models.models import NeuralEulerODEPMSM
from dmpe.models.model_utils import simulate_ahead_with_env
from dmpe.models.model_training import ModelTrainer
from dmpe.excitation import loss_function, Exciter

from dmpe.utils.density_estimation import (
    update_density_estimate_single_observation, update_density_estimate_multiple_observations, DensityEstimate, select_bandwidth, build_grid
)
from dmpe.utils.signals import aprbs
from dmpe.evaluation.plotting_utils import (
    plot_sequence, append_predictions_to_sequence_plot, plot_sequence_and_prediction, plot_model_performance,
    plot_2d_kde_as_contourf, plot_2d_kde_as_surface, plot_feature_combinations
)
from dmpe.evaluation.experiment_utils import (
    get_experiment_ids, load_experiment_results, quick_eval, evaluate_experiment_metrics, evaluate_algorithm_metrics, evaluate_metrics
)
from dmpe.algorithms import excite_with_dmpe, default_dmpe, default_dmpe_parameterization

In [None]:
import pmsm_utils
from pmsm_utils import ExcitingPMSM, plot_current_constraints

---

In [None]:
batch_size=1

env = ExcitingPMSM(
    initial_rpm=6000,
    batch_size=batch_size,
    saturated=True,
    LUT_motor_name="BRUSA",
    static_params = {
        "p": 3,
        "r_s": 17.932e-3,
        "l_d": 0.37e-3,
        "l_q": 1.2e-3,
        "psi_p": 65.65e-3,
        "deadtime": 0,
    },
    solver=diffrax.Tsit5()
)

PMSM_penalty = lambda observations, actions: pmsm_utils.PMSM_penalty(env, observations, actions)

In [None]:
env.pmsm_lut.keys()

In [None]:
i_max=250

psi_d = env.pmsm_lut['Psi_d']
psi_d_mirrored = jnp.flip((psi_d - psi_d[:, -1][:, None]), axis=1)
psi_d_mirrored = - psi_d_mirrored + psi_d[:, -1][:, None]
psi_d_mirrored = psi_d_mirrored[:, 1:]

psi_d = psi_d[:, :-1]

psi_d = jnp.concatenate([psi_d, psi_d_mirrored], axis=1)
n_grid_points_y, n_grid_points_x = psi_d.shape

x, y = np.linspace(-i_max, i_max, n_grid_points_x), np.linspace(-i_max, i_max, n_grid_points_y)
psi_d_interp = jax.scipy.interpolate.RegularGridInterpolator(
    (x, y), psi_d.T, method="linear", bounds_error=False, fill_value=None
)

n_points = n_grid_points_y * 1

test_x, test_y = np.linspace(-i_max, i_max, n_points), np.linspace(-i_max, i_max, n_points)
xx, yy = jnp.meshgrid(test_x, test_y)
zz = jnp.concatenate([xx[..., None], yy[..., None]], axis=-1)

res = psi_d_interp(zz.reshape(-1, 2))
fig = plt.figure(figsize=(6, 6))
axs = fig.add_subplot(111, projection="3d")

_ = axs.plot_surface(
    zz[..., 0],
    zz[..., 1],
    res.reshape(n_points, n_points),
    antialiased=False,
    alpha=1,
    cmap=plt.cm.coolwarm,
)

axs.view_init(30,225)

axs.set_xlabel(env.obs_description[0])
axs.set_ylabel(env.obs_description[1])

In [None]:
psi_q = env.pmsm_lut["Psi_q"]
psi_q_mirrored = jnp.flip((psi_q - psi_q[:, -1][:, None]), axis=1)
psi_q_mirrored = psi_q_mirrored + psi_q[:, -1][:, None]
psi_q_mirrored = psi_q_mirrored[:, 1:]

psi_q = psi_q[:, :-1]

psi_q = jnp.concatenate([psi_q, psi_q_mirrored], axis=1)
n_grid_points_y, n_grid_points_x = psi_q.shape

x, y = np.linspace(-i_max, i_max, n_grid_points_x), np.linspace(-i_max, i_max, n_grid_points_y)
psi_q_interp = jax.scipy.interpolate.RegularGridInterpolator(
    (x, y), psi_q.T, method="linear", bounds_error=False, fill_value=None
)

n_points = n_grid_points_y * 1

test_x, test_y = np.linspace(-i_max, i_max, n_points), np.linspace(-i_max, i_max, n_points)
xx, yy = jnp.meshgrid(test_x, test_y)
zz = jnp.concatenate([xx[..., None], yy[..., None]], axis=-1)

res = psi_q_interp(zz.reshape(-1, 2))
fig = plt.figure(figsize=(6, 6))
axs = fig.add_subplot(111, projection="3d")

_ = axs.plot_surface(
    zz[..., 0],
    zz[..., 1],
    res.reshape(n_points, n_points),
    antialiased=False,
    alpha=1,
    cmap=plt.cm.coolwarm,
)

axs.view_init(30,225)

axs.set_xlabel(env.obs_description[0])
axs.set_ylabel(env.obs_description[1])