In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pathlib
from functools import partial

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['text.usetex'] = True
mpl.rcParams.update({'font.size': 10 * 2.54})
mpl.rcParams['text.latex.preamble']=r"\usepackage{bm}"
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp
# jax.config.update("jax_enable_x64", True)
gpus = jax.devices()
jax.config.update("jax_default_device", gpus[0])
# jax.config.update('jax_platform_name', 'cpu')
import chex

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
import exciting_environments as excenvs

import dmpe
from dmpe.models import NeuralEulerODEPendulum, NeuralODEPendulum, NeuralEulerODE, NeuralEulerODECartpole
from dmpe.models.model_utils import simulate_ahead_with_env
from dmpe.models.model_training import ModelTrainer
from dmpe.excitation import loss_function, Exciter

from dmpe.utils.density_estimation import (
    update_density_estimate_single_observation, update_density_estimate_multiple_observations, DensityEstimate, select_bandwidth, build_grid
)
from dmpe.utils.signals import aprbs
from dmpe.evaluation.plotting_utils import (
    plot_sequence, append_predictions_to_sequence_plot, plot_sequence_and_prediction, plot_model_performance,
    plot_2d_kde_as_contourf, plot_2d_kde_as_surface, plot_feature_combinations
)
from dmpe.evaluation.experiment_utils import (
    get_experiment_ids, load_experiment_results, quick_eval, evaluate_experiment_metrics, evaluate_algorithm_metrics, evaluate_metrics
)

In [None]:
i_d = np.linspace(-1, 0, 1000)
i_q = np.sqrt(1**2 - i_d**2)

i_d = (i_d + 0.5) * 2

plt.plot(i_d, i_q, "k")
plt.plot(i_d, -i_q, "k")
plt.plot(np.ones(2), np.array([-1, 1]), "k")
plt.show()

---

In [None]:
# setup PRNG
key = jax.random.PRNGKey(seed=222)

data_key, model_key, loader_key, expl_key, key = jax.random.split(key, 5)
data_rng = PRNGSequence(data_key)

In [None]:
from exciting_environments.pmsm.pmsm_env import PMSM

In [None]:
class ExcitingPMSM(PMSM):

    def generate_observation(self, system_state, env_properties):
        physical_constraints = env_properties.physical_constraints

        eps = system_state.physical_state.epsilon
        cos_eps = jnp.cos(eps)
        sin_eps = jnp.sin(eps)
        
        obs = jnp.hstack(
            (
                (system_state.physical_state.i_d + (physical_constraints.i_d * 0.5)) / (physical_constraints.i_d * 0.5),
                system_state.physical_state.i_q / physical_constraints.i_q,
                # cos_eps,
                # sin_eps,
            )
        )
        return obs

    def init_state(self, env_properties, rng=None, vmap_helper=None):
        """Returns default initial state for all batches."""
        phys = self.PhysicalState(
            u_d_buffer=0.0,
            u_q_buffer=0.0,
            epsilon=0.0,
            i_d=-env_properties.physical_constraints.i_d / 2,
            i_q=0.0,
            torque=0.0,
            omega_el=2 * jnp.pi * 3 * 1000 / 60,
        )
        subkey = jnp.nan
        additions = None  # self.Optional(something=jnp.zeros(self.batch_size))
        ref = self.PhysicalState(
            u_d_buffer=jnp.nan,
            u_q_buffer=jnp.nan,
            epsilon=jnp.nan,
            i_d=jnp.nan,
            i_q=jnp.nan,
            torque=jnp.nan,
            omega_el=jnp.nan,
        )
        return self.State(physical_state=phys, PRNGKey=subkey, additions=additions, reference=ref)

In [None]:
batch_size=1

env = ExcitingPMSM(
    batch_size=batch_size,
    saturated=True,
    LUT_motor_name="BRUSA",
    static_params = {
        "p": 3,
        "r_s": 15e-3,
        "l_d": jnp.nan,
        "l_q": jnp.nan,
        "psi_p": jnp.nan,
        "deadtime": 0,
    },
    solver=diffrax.Euler()
)

In [None]:
obs, state = env.vmap_reset()

n_steps = 99
actions = jnp.concatenate([aprbs(n_steps, batch_size, 1, 10, next(data_rng)), aprbs(n_steps, batch_size, 1, 10, next(data_rng))], axis=-1)

observations = [obs[..., 0:2]]

for i in range(actions.shape[1]):
   
    obs, state = env.vmap_step(state, actions[:, i,:])
    observations.append(obs[...,0:2])

In [None]:
plot_sequence(np.concatenate(observations), np.concatenate(actions), env.tau, obs_labels=env.obs_description[:2], action_labels=['u_d', 'u_q'])

In [None]:
from dmpe.algorithms import excite_with_dmpe, default_dmpe, default_dmpe_parameterization
from dmpe.models.models import NeuralEulerODEPMSM

In [None]:
def test(a, b, penalty_order=1):
    a = (a - 1)/ 2
    return jax.nn.relu(a**2 + b**2 - 0.8**2)**penalty_order

In [None]:
extend = 1.5

xx, yy = np.meshgrid(np.linspace(-extend, extend, 100), np.linspace(-extend, extend, 100))
zz = test(xx, yy)

In [None]:
fig = plt.figure(figsize=(6, 6))
axs = fig.add_subplot(111, projection="3d")

_ = axs.plot_surface(
    xx,
    yy,
    zz,
    antialiased=False,
    alpha=0.8,
    cmap=plt.cm.coolwarm,
)

In [None]:
from dmpe.excitation.excitation_utils import soft_penalty

In [None]:
def PMSM_penalty(observations, actions, penalty_order=2):

    action_penalty = soft_penalty(actions, a_max=1, penalty_order=1)

    physical_i_d = observations[..., 0] * (env.env_properties.physical_constraints.i_d * 0.5) - (env.env_properties.physical_constraints.i_d * 0.5)
    physical_i_q = observations[..., 1] * env.env_properties.physical_constraints.i_q

    a = physical_i_d / 250
    b = physical_i_q / 250
    
    obs_penalty = jax.nn.relu(a**2 + b**2 - 0.9)
    obs_penalty = jnp.sum(obs_penalty) 
    i_d_penalty = jnp.sum(jax.nn.relu(a))
    
    return (obs_penalty + i_d_penalty + action_penalty) * 1e3

In [None]:
# exp_params, init_actions, loader_key, expl_key = default_dmpe_parameterization(env, seed=0, featurize=None, model_class=NeuralEulerODEPMSM)

In [None]:
seed = 63463473

alg_params = dict(
    bandwidth=jnp.nan,
    n_prediction_steps=5,
    points_per_dim=21,
    action_lr=1e-2,
    n_opt_steps=100,
    consider_action_distribution=True,
    penalty_function=PMSM_penalty,
    target_distribution=None,
    clip_action=False,
    n_starts=3,
    reuse_proposed_actions=True,
)


# alg_params["bandwidth"] = float(
#     select_bandwidth(
#         delta_x=2,
#         dim=env.physical_state_dim + env.action_dim,
#         n_g=alg_params["points_per_dim"],
#         percentage=0.3,
#     )
# )

dim = 4 if alg_params["consider_action_distribution"] else 2
points_per_dim = alg_params["points_per_dim"]
target_distribution = (np.ones(shape=[points_per_dim for _ in range(dim)]) ** dim)[..., None]
xx, yy = np.meshgrid(np.linspace(-1, 0, points_per_dim), np.linspace(-1, 1, points_per_dim))
target_distribution[xx**2 + yy**2 > 1] = 0
target_distribution = target_distribution / jnp.sum(target_distribution)
alg_params["target_distribution"] = jnp.array(target_distribution.reshape(-1, 1))


alg_params["bandwidth"] = float(
    select_bandwidth(
        delta_x=2,
        dim=dim,
        n_g=alg_params["points_per_dim"],
        percentage=0.3,
    )
)
print("bw", alg_params["bandwidth"])


model_trainer_params = dict(
    start_learning=alg_params["n_prediction_steps"],
    training_batch_size=128,
    n_train_steps=1,
    sequence_length=alg_params["n_prediction_steps"],
    featurize=lambda x: x,
    model_lr=1e-4,
)
model_params = dict(obs_dim=2, action_dim=env.action_dim, width_size=128, depth=3, key=None)

exp_params = dict(
    seed=seed,
    n_time_steps=1_000,
    model_class=NeuralEulerODEPMSM,
    env_params=None,
    alg_params=alg_params,
    model_trainer_params=model_trainer_params,
    model_params=model_params,
)

In [None]:
alg_params["bandwidth"] = 0.07

In [None]:
x_g = build_grid(dim, low=-1, high=1, points_per_dim=exp_params["alg_params"]["points_per_dim"])
constr_func = lambda x_g: PMSM_penalty(x_g[..., None, :2], x_g[..., None, 2:])
valid_grid_point = jax.vmap(constr_func, in_axes=0)(x_g) == 0
constrained_data_points = x_g[jnp.where(valid_grid_point == True)]
constrained_data_points.shape

In [None]:
fig = plot_feature_combinations(
    constrained_data_points,
    labels=["$\\tilde{i}_d$", "$\\tilde{i}_q$", "$\\tilde{u}_d$", "$\\tilde{u}_q$"],
    mode="contourf",
    points_per_dim=alg_params["points_per_dim"],
    bandwidth=alg_params["bandwidth"]
);

In [None]:
target_distribution = DensityEstimate.from_dataset(
    constrained_data_points[None],
    points_per_dim=alg_params["points_per_dim"],
    bandwidth=alg_params["bandwidth"],
)

exp_params["alg_params"]["target_distribution"] = target_distribution.p[0]

In [None]:
# helper_density_estimate = DensityEstimate.from_dataset(
#     jnp.concatenate([observations[0:-1, :2], actions], axis=-1)[None],
#     points_per_dim=alg_params["points_per_dim"],
#     bandwidth=alg_params["bandwidth"],
# )
# exp_params["alg_params"]["target_distribution"] = (
#     exp_params["alg_params"]["target_distribution"] 
#     + (exp_params["alg_params"]["target_distribution"] - helper_density_estimate.p[0])
# )
# exp_params["alg_params"]["target_distribution"] = jax.nn.relu(exp_params["alg_params"]["target_distribution"])

In [None]:
key = jax.random.PRNGKey(seed=exp_params["seed"])
data_key, model_key, loader_key, expl_key, key = jax.random.split(key, 5)

data_rng = PRNGSequence(data_key)
exp_params["model_params"]["key"] = model_key

# initial guess
proposed_actions = jnp.hstack(
    [
        aprbs(alg_params["n_prediction_steps"], env.batch_size, 1, 10, next(data_rng))[0]
        for _ in range(env.action_dim)
    ]
)

In [None]:
observations, actions, model, density_estimate, losses, proposed_actions = excite_with_dmpe(
    env, exp_params, proposed_actions, loader_key, expl_key, plot_every=1000
)

In [None]:
physical_i_d = observations[..., 0] * (env.env_properties.physical_constraints.i_d * 0.5) - (env.env_properties.physical_constraints.i_d * 0.5)
physical_i_q = observations[..., 1] * env.env_properties.physical_constraints.i_q

In [None]:
fig, axs = plot_sequence(
    observations=jnp.vstack([physical_i_d, physical_i_q]).T,
    actions=(actions * jnp.hstack([env.env_properties.action_constraints.u_d, env.env_properties.action_constraints.u_q])),
    tau=env.tau,
    obs_labels=["i_d", "i_q"],
    action_labels=["u_d", "u_q"]
)
# t = jnp.linspace(0, observations.shape[0] - 1, observations.shape[0]) * env.tau
# axs[0].plot(t, np.ones(observations.shape[0]) * env.env_properties.physical_constraints.i_d)
# axs[0].plot(t, -np.ones(observations.shape[0]) * env.env_properties.physical_constraints.i_d)

axs[1].set_xlim(-270,0)
axs[1].set_ylim(-270, 270)
# t = t[:-1]
# axs[2].plot(t, np.ones(actions.shape[0]) * env.env_properties.action_constraints.u_d)
# axs[2].plot(t, -np.ones(actions.shape[0]) * env.env_properties.action_constraints.u_d)

plt.show()

In [None]:
fig = plot_feature_combinations(
    jnp.concatenate([observations[0:-1], actions], axis=-1),
    labels=["$\\tilde{i}_d$", "$\\tilde{i}_q$", "$\\tilde{u}_d$", "$\\tilde{u}_q$"],
    mode="contourf",
    bandwidth=0.05
);
plt.show()
# plt.savefig("results/plots/focus_both.pdf") if alg_params["consider_action_distribution"] else plt.savefig("results/plots/focus_obs.pdf")

In [None]:
fig = plot_feature_combinations(
    jnp.concatenate([observations[0:-1], actions], axis=-1),
    labels=["$\\tilde{i}_d$", "$\\tilde{i}_q$", "$\\tilde{u}_d$", "$\\tilde{u}_q$"]
);
# plt.savefig("results/plots/focus_both.pdf") if alg_params["consider_action_distribution"] else plt.savefig("results/plots/focus_obs.pdf")

In [None]:
np.save("results/pmsm_dmpe_observations.npy", observations)
np.save("results/pmsm_dmpe_actions.npy", actions)

In [None]:
raise

### PM-DMPE:

In [None]:
from dmpe.models.model_utils import ModelEnvWrapperPMSM

In [None]:
seed = 41152145125

alg_params = dict(
    bandwidth=jnp.nan,
    n_prediction_steps=3,
    points_per_dim=21,
    action_lr=1e-2,
    n_opt_steps=50,
    consider_action_distribution=True,
    penalty_function=PMSM_penalty,
    target_distribution=None,
    clip_action=False,
    n_starts=5,
    reuse_proposed_actions=True,
)


# alg_params["bandwidth"] = float(
#     select_bandwidth(
#         delta_x=2,
#         dim=env.physical_state_dim + env.action_dim,
#         n_g=alg_params["points_per_dim"],
#         percentage=0.3,
#     )
# )


dim = 4 if alg_params["consider_action_distribution"] else 2
points_per_dim = alg_params["points_per_dim"]
target_distribution = (np.ones(shape=[points_per_dim for _ in range(dim)]) ** dim)[..., None]
xx, yy = np.meshgrid(np.linspace(-1, 0, points_per_dim), np.linspace(-1, 1, points_per_dim))
# target_distribution[yy > 0.3] = (yy[yy > 0.3] * 5 + 1).flatten()[..., None, None, None]
# target_distribution[yy > 0] = (yy[yy > 0] * 5 + 1).flatten()[..., None]
target_distribution[xx**2 + yy**2 > 0.9] = 0
# target_distribution[yy < 0] = 0
target_distribution = target_distribution / jnp.sum(target_distribution)
alg_params["target_distribution"] = jnp.array(target_distribution.reshape(-1, 1))


alg_params["bandwidth"] = float(
    select_bandwidth(
        delta_x=2,
        dim=dim,
        n_g=alg_params["points_per_dim"],
        percentage=0.05,
    )
)

exp_params = dict(
    seed=int(1),
    n_time_steps=5_000,
    model_class=None,
    env_params=None,
    alg_params=alg_params,
    model_trainer_params=None,
    model_params=None,
    model_env_wrapper=ModelEnvWrapperPMSM,
)

In [None]:
fig = plt.figure(figsize=(6, 6))
ax = fig.add_subplot(111)
ax.imshow(target_distribution[:, :, 0, 0, 0])
ax.invert_yaxis()

In [None]:
alg_params["bandwidth"]

In [None]:
env

In [None]:
key = jax.random.PRNGKey(seed=exp_params["seed"])
data_key, _, _, expl_key, key = jax.random.split(key, 5)
data_rng = PRNGSequence(data_key)

proposed_actions = jnp.hstack(
    [
        aprbs(alg_params["n_prediction_steps"], env.batch_size, 1, 10, next(data_rng))[0]
        for _ in range(env.action_dim)
    ]
)
# run excitation algorithm
observations, actions, model, density_estimate, losses, proposed_actions = excite_with_dmpe(
    env,
    exp_params,
    proposed_actions,
    None,
    expl_key,
    plot_every=10,
)

In [None]:
%debug

In [None]:
fig = plot_feature_combinations(jnp.concatenate([observations[0:-1], actions], axis=-1), labels=["i_d", "i_q", "u_d", "u_q"], mode="contourf");
plt.savefig("focus_both_contourf.png") if alg_params["consider_action_distribution"] else plt.savefig("focus_obs_contourf.png")

In [None]:
fig = plot_feature_combinations(jnp.concatenate([observations[0:-1], actions], axis=-1), labels=["i_d", "i_q", "u_d", "u_q"]);
plt.savefig("focus_both.png") if alg_params["consider_action_distribution"] else plt.savefig("focus_obs.png")

In [None]:
raise

In [None]:
physical_i_d = observations[..., 0] * (env.env_properties.physical_constraints.i_d * 0.5) - (env.env_properties.physical_constraints.i_d * 0.5)
physical_i_q = observations[..., 1] * env.env_properties.physical_constraints.i_q

In [None]:
fig, axs = plot_sequence(
    observations=jnp.vstack([physical_i_d, physical_i_q]).T,
    actions=actions * jnp.hstack([env.env_properties.action_constraints.u_d, env.env_properties.action_constraints.u_q]),
    tau=env.tau,
    obs_labels=["i_d", "i_q"],
    action_labels=["u_d", "u_q"]
)
t = jnp.linspace(0, observations.shape[0] - 1, observations.shape[0]) * env.tau
axs[0].plot(t, np.ones(observations.shape[0]) * env.env_properties.physical_constraints.i_d)
axs[0].plot(t, -np.ones(observations.shape[0]) * env.env_properties.physical_constraints.i_d)


t = t[:-1]
axs[2].plot(t, np.ones(actions.shape[0]) * env.env_properties.action_constraints.u_d)
axs[2].plot(t, -np.ones(actions.shape[0]) * env.env_properties.action_constraints.u_d)

plt.savefig("test_200vdc.png")

In [None]:
from dmpe.evaluation.plotting_utils import plot_2d_kde_as_contourf, plot_2d_kde_as_surface

In [None]:
obs_density = DensityEstimate.from_dataset(observations, actions=actions, use_actions=False, points_per_dim=100, bandwidth=exp_params["alg_params"]["bandwidth"])
obs_labels = ["i_d", "i_q"]
plot_2d_kde_as_contourf(obs_density.p, obs_density.x_g, obs_labels)

plt.savefig("-".join(obs_labels) + ".png")

In [None]:
act_density = DensityEstimate.from_dataset(actions, actions=observations, use_actions=False, points_per_dim=100, bandwidth=exp_params["alg_params"]["bandwidth"])

obs_labels = ["u_d", "u_q"]
plot_2d_kde_as_contourf(act_density.p, act_density.x_g, obs_labels)

plt.savefig("-".join(obs_labels) + ".png")

In [None]:
density_estimate_test =  DensityEstimate.from_dataset(jnp.concatenate([observations[:-1, 0][..., None], actions[..., 0][..., None]], axis=-1), actions=actions, use_actions=False, points_per_dim=100, bandwidth=exp_params["alg_params"]["bandwidth"])

obs_labels = ["i_d", "u_d"]
plot_2d_kde_as_contourf(density_estimate_test.p, density_estimate_test.x_g, obs_labels)

plt.savefig("-".join(obs_labels) + ".png")

In [None]:
density_estimate_test =  DensityEstimate.from_dataset(jnp.concatenate([observations[:-1, 1][..., None], actions[..., 1][..., None]], axis=-1), actions=actions, use_actions=False, points_per_dim=100, bandwidth=exp_params["alg_params"]["bandwidth"])

obs_labels = ["i_q", "u_q"]
plot_2d_kde_as_contourf(density_estimate_test.p, density_estimate_test.x_g, obs_labels)

plt.savefig("-".join(obs_labels) + ".png")

In [None]:
density_estimate_test =  DensityEstimate.from_dataset(jnp.concatenate([observations[:-1, 0][..., None], actions[..., 1][..., None]], axis=-1), actions=actions, use_actions=False, points_per_dim=100, bandwidth=exp_params["alg_params"]["bandwidth"])

obs_labels = ["i_d", "u_q"]
plot_2d_kde_as_contourf(density_estimate_test.p, density_estimate_test.x_g, obs_labels)

plt.savefig("-".join(obs_labels) + ".png")

In [None]:
density_estimate_test =  DensityEstimate.from_dataset(jnp.concatenate([observations[:-1, 1][..., None], actions[..., 0][..., None]], axis=-1), actions=actions, use_actions=False, points_per_dim=100, bandwidth=exp_params["alg_params"]["bandwidth"])

obs_labels = ["i_q", "u_d"]
plot_2d_kde_as_contourf(density_estimate_test.p, density_estimate_test.x_g, obs_labels)

plt.savefig("-".join(obs_labels) + ".png")

In [None]:
%debug

---

In [None]:
raise

In [None]:
from copy import deepcopy

In [None]:
from dmpe.models.model_utils import simulate_ahead_with_env

In [None]:
L = 20

obs, state = env.reset()

model_predictions = []
env_predictions = []
sim_ahead_predictions = []

for i in range(100):

    model_pred_observations = model(obs[0, :2], actions[i:i+L, :], env.tau)
    model_predictions.append(model_pred_observations)

    sim_ahead_predictions.append(jax.vmap(simulate_ahead_with_env, in_axes=(None, None, 0, None))(env, obs[0, :2], state, actions[i:i+L, :])[0][0])

    _state = deepcopy(state)
    env_obs = [obs[..., 0:2]]
    for j in range(L):
        obs, _state = env.vmap_step(_state, actions[i+j, :][None])
        env_obs.append(obs[..., 0:2])

    obs, state = env.vmap_step(state, actions[i, :][None])
    
    
    env_predictions.append(jnp.vstack(env_obs))


env_predictions = jnp.stack(env_predictions)
model_predictions = jnp.stack(model_predictions)
sim_ahead_predictions = jnp.stack(sim_ahead_predictions)

In [None]:
env.reset()

In [None]:
1476.5486

In [None]:
2 * jnp.pi * 3 * 4700 / 60,

In [None]:
for i in range(L):
    plt.plot(env_predictions[i, ...], label=["env_i_d", "env_i_q"])
    plt.plot(model_predictions[i, ...], label=["model_i_d", "model_i_q"])
    plt.plot(sim_ahead_predictions[i, ...], label=["sim_ahead_i_d", "sim_ahead_i_q"])
    plt.legend()
    plt.show()

In [None]:
model_predictions[0, ...]

In [None]:
jnp.stack(env_predictions)
jnp.stack(model_predictions)

In [None]:
state.physical_state.omega_el

In [None]:
2 * jnp.pi * 3 * 4700 / 60,

In [None]:
1476.5486 * 60 / (2* jnp.pi * 3)

In [None]:
env.LUT_interpolators["L_dd"](jnp.array([200,100]))

In [None]:
env.action_description

In [None]:
saturated_quants = [
    "L_dd",
    "L_dq",
    "L_qd",
    "L_qq",
    "Psi_d",
    "Psi_q",
]

i_max = 250

n_grid_points_y, n_grid_points_x = env.pmsm_lut[saturated_quants[0]].shape

x, y = np.linspace(-i_max, 0, n_grid_points_x), np.linspace(-i_max, i_max, n_grid_points_y)

In [None]:
LUT_interpolators = {
    q: jax.scipy.interpolate.RegularGridInterpolator(
        (x, y), env.pmsm_lut[q][:, :].T, method="linear", bounds_error=False, fill_value=None
    )
    for q in saturated_quants
}

In [None]:
for q in saturated_quants:
    print(q, env.pmsm_lut[q][0,0])

In [None]:
env.pmsm_lut["L_dd"][26, 0]

In [None]:
env.pmsm_lut["L_dd"][26, -1]

In [None]:
env.pmsm_lut["L_qq"][26, -1]

In [None]:
env.pmsm_lut["L_qq"][26, 0]

In [None]:
n_points = n_grid_points_y * 1

test_x, test_y = np.linspace(-i_max, 0, n_points), np.linspace(-i_max, i_max, n_points)
xx, yy = jnp.meshgrid(test_x, test_y )
zz = jnp.concatenate([xx[..., None], yy[..., None]], axis=-1)


In [None]:
for q in saturated_quants[:4]:
    res = LUT_interpolators[q](zz.reshape(-1, 2))
    
    fig, axs = plt.subplots(figsize=(6.75, 6))
    
    cax = axs.contourf(
        zz[..., 0],
        zz[..., 1],
        res.reshape(n_points, n_points),
        antialiased=False,
        levels=100,
        alpha=0.9,
        cmap=plt.cm.coolwarm,
    )
    
    axs.set_xlabel(env.obs_description[0])
    axs.set_ylabel(env.obs_description[1])

In [None]:
for q in saturated_quants[:4]:
    res = LUT_interpolators[q](zz.reshape(-1, 2))
    fig = plt.figure(figsize=(6, 6))
    axs = fig.add_subplot(111, projection="3d")
    
    _ = axs.plot_surface(
        zz[..., 0],
        zz[..., 1],
        res.reshape(n_points, n_points),
        antialiased=False,
        alpha=1,
        cmap=plt.cm.coolwarm,
    )
    
    axs.view_init(30,225)
    
    axs.set_xlabel(env.obs_description[0])
    axs.set_ylabel(env.obs_description[1])

In [None]:
for q in saturated_quants:
    print(env.pmsm_lut[q].shape)