In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pathlib
from functools import partial

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['text.usetex'] = True
mpl.rcParams.update({'font.size': 10 * 2.54})
mpl.rcParams['text.latex.preamble']=r"\usepackage{bm}"
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp
# jax.config.update("jax_enable_x64", True)
gpus = jax.devices()
jax.config.update("jax_default_device", gpus[0])
# jax.config.update('jax_platform_name', 'cpu')
import chex

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
import exciting_environments as excenvs

import dmpe
from dmpe.models import NeuralEulerODEPendulum, NeuralODEPendulum, NeuralEulerODE, NeuralEulerODECartpole
from dmpe.models.models import NeuralEulerODEPMSM
from dmpe.models.model_utils import simulate_ahead_with_env
from dmpe.models.model_training import ModelTrainer
from dmpe.excitation import loss_function, Exciter

from dmpe.utils.density_estimation import (
    update_density_estimate_single_observation, update_density_estimate_multiple_observations, DensityEstimate, select_bandwidth, build_grid
)
from dmpe.utils.signals import aprbs
from dmpe.evaluation.plotting_utils import (
    plot_sequence, append_predictions_to_sequence_plot, plot_sequence_and_prediction, plot_model_performance,
    plot_2d_kde_as_contourf, plot_2d_kde_as_surface, plot_feature_combinations
)
from dmpe.evaluation.experiment_utils import (
    get_experiment_ids, load_experiment_results, quick_eval, evaluate_experiment_metrics, evaluate_algorithm_metrics, evaluate_metrics
)
from dmpe.algorithms import excite_with_dmpe, default_dmpe, default_dmpe_parameterization

In [None]:
import pmsm_utils
from pmsm_utils import ExcitingPMSM, plot_current_constraints, plot_sequence_with_constraints

---

In [None]:
batch_size=1

env = ExcitingPMSM(
    initial_rpm=2_000,
    batch_size=batch_size,
    saturated=True,
    LUT_motor_name="BRUSA",
    static_params = {
        "p": 3,
        "r_s": 17.932e-3,
        "l_d": 0.37e-3,
        "l_q": 1.2e-3,
        "psi_p": 65.65e-3,
        "deadtime": 0,
    },
    solver=diffrax.Tsit5()
)

PMSM_penalty = lambda observations, actions: pmsm_utils.PMSM_penalty(env, observations, actions)

In [None]:
seed = 41241
alg_params = dict(
    bandwidth=jnp.nan,
    n_prediction_steps=5,
    points_per_dim=21,
    grid_extend=1.05,
    # excitation_optimizer=optax.lbfgs(),
    excitation_optimizer=optax.adabelief(1e-2),
    n_opt_steps=100,
    # n_opt_steps=25,
    start_optimizing=5,
    consider_action_distribution=True,
    penalty_function=PMSM_penalty,
    target_distribution=None,
    clip_action=False,
    n_starts=3,
    reuse_proposed_actions=True,
)

dim = 4 if alg_params["consider_action_distribution"] else 2
# points_per_dim = alg_params["points_per_dim"]
# target_distribution = (np.ones(shape=[points_per_dim for _ in range(dim)]) ** dim)[..., None]
# xx, yy = np.meshgrid(np.linspace(-1, 0, points_per_dim), np.linspace(-1, 1, points_per_dim))
# target_distribution[xx**2 + yy**2 > 1] = 0
# target_distribution = target_distribution / jnp.sum(target_distribution)
# alg_params["target_distribution"] = jnp.array(target_distribution.reshape(-1, 1))

alg_params["bandwidth"] = float(
    select_bandwidth(
        delta_x=2,
        dim=dim,
        n_g=alg_params["points_per_dim"],
        percentage=0.3,
    )
)

alg_params["bandwidth"] = 0.08
print("bw", alg_params["bandwidth"])

####################################################################################################


# model_trainer_params = dict(
#     start_learning=alg_params["n_prediction_steps"],
#     training_batch_size=64,
#     n_train_steps=1,
#     sequence_length=alg_params["n_prediction_steps"],
#     featurize=lambda x: x,
#     model_lr=1e-4,
# )
# model_params = dict(obs_dim=2, action_dim=env.action_dim, width_size=32, depth=4, key=None)

model_trainer_params = dict(
    start_learning=alg_params["n_prediction_steps"],
    training_batch_size=128,
    n_train_steps=5,
    sequence_length=alg_params["n_prediction_steps"],
    featurize=lambda x: x,
    model_lr=1e-4,
)
model_params = dict(obs_dim=2, action_dim=env.action_dim, width_size=128, depth=4, key=None)

####################################################################################################

exp_params = dict(
    seed=seed,
    n_time_steps=5_000,
    model_class=NeuralEulerODEPMSM,
    env_params=None,
    alg_params=alg_params,
    model_trainer_params=model_trainer_params,
    model_params=model_params,
)

# redesign target distribution
x_g = build_grid(
    dim,
    low=-exp_params["alg_params"]["grid_extend"],
    high=exp_params["alg_params"]["grid_extend"],
    points_per_dim=exp_params["alg_params"]["points_per_dim"]
)
constr_func = lambda x_g: PMSM_penalty(x_g[..., None, :2], x_g[..., None, 2:])
valid_grid_point = jax.vmap(constr_func, in_axes=0)(x_g) == 0
constrained_data_points = x_g[jnp.where(valid_grid_point == True)]
constrained_data_points.shape

target_distribution = DensityEstimate.from_dataset(
    constrained_data_points[None],
    points_per_dim=alg_params["points_per_dim"],
    bandwidth=alg_params["bandwidth"],
)

exp_params["alg_params"]["target_distribution"] = target_distribution.p[0]

# setup prng
key = jax.random.PRNGKey(seed=exp_params["seed"])
data_key, model_key, loader_key, expl_key, key = jax.random.split(key, 5)

data_rng = PRNGSequence(data_key)
exp_params["model_params"]["key"] = model_key

# initial guess
proposed_actions = jnp.hstack(
    [
        aprbs(alg_params["n_prediction_steps"], env.batch_size, 1, 10, next(data_rng))[0]
        for _ in range(env.action_dim)
    ]
)

observations, actions, model, density_estimate, losses, proposed_actions = excite_with_dmpe(
    env, exp_params, proposed_actions, loader_key, expl_key, plot_every=1000
)

In [None]:
fig, _ = plot_sequence_with_constraints(env, observations, actions)

In [None]:
def get_constraint_violations(observations, actions, length):
    single_results = []
    
    for i in range(len(lengths) - 1):
        single_results.append(
            PMSM_penalty(
                observations[lengths[i] : lengths[i + 1]], actions[lengths[i] : lengths[i + 1]]
            ) / (lengths[i + 1] - lengths[i])
        )

    return single_results

lengths = jnp.linspace(0, 5000, 151, dtype=jnp.int32)
plt.plot(lengths[:-1], np.log(get_constraint_violations(observations, actions, lengths)))

lengths = jnp.linspace(0, 5000, 25, dtype=jnp.int32)
plt.plot(lengths[:-1], np.log(get_constraint_violations(observations, actions, lengths)))

plt.grid()
plt.show()

## 
         
lengths = jnp.linspace(1000, 5000, 151, dtype=jnp.int32)
plt.plot(lengths[:-1], get_constraint_violations(observations, actions, lengths))

lengths = jnp.linspace(1000, 5000, 25, dtype=jnp.int32)
plt.plot(lengths[:-1], get_constraint_violations(observations, actions, lengths))
         
plt.grid()
plt.ylim(-0.1, None)

In [None]:
from pmsm_utils import plot_constraints_induced_voltage

In [None]:
_, state = env.reset(env.env_properties)
i_d_normalizer = env.env_properties.physical_normalizations.i_d
i_q_normalizer = env.env_properties.physical_normalizations.i_q

physical_i_d = i_d_normalizer.denormalize(observations[..., 0])
physical_i_q = i_q_normalizer.denormalize(observations[..., 1])


plot_constraints_induced_voltage(env, physical_i_d, physical_i_q, w_el=state.physical_state.omega_el, saturated=True)

In [None]:
fig = plot_feature_combinations(
    jnp.concatenate([observations[0:-1], actions], axis=-1),
    labels=["$\\tilde{i}_d$", "$\\tilde{i}_q$", "$\\tilde{u}_d$", "$\\tilde{u}_q$"],
    mode="contourf",
    bandwidth=0.05
);
plt.show()
# plt.savefig("results/plots/focus_both.pdf") if alg_params["consider_action_distribution"] else plt.savefig("results/plots/focus_obs.pdf")

In [None]:
fig = plot_feature_combinations(
    jnp.concatenate([observations[0:-1], actions], axis=-1),
    labels=["$\\tilde{i}_d$", "$\\tilde{i}_q$", "$\\tilde{u}_d$", "$\\tilde{u}_q$"]
);
# plt.savefig("results/plots/focus_both.pdf") if alg_params["consider_action_distribution"] else plt.savefig("results/plots/focus_obs.pdf")

In [None]:
# np.save("results/pmsm_dmpe_observations.npy", observations)
# np.save("results/pmsm_dmpe_actions.npy", actions)

In [None]:
raise

In [None]:
print(np.where(physical_i_d > 0))


plot_range=(2400, 2430)


fig, axs = plot_sequence(
    observations=jnp.vstack([physical_i_d[plot_range[0]:plot_range[1]], physical_i_q[plot_range[0]:plot_range[1]]]).T,
    actions=jnp.vstack([
        env.env_properties.action_normalizations.u_d.denormalize(actions[plot_range[0]:plot_range[1], 0]),
        env.env_properties.action_normalizations.u_q.denormalize(actions[plot_range[0]:plot_range[1], 1])
    ]).T,
    tau=env.tau,
    obs_labels=["i_d", "i_q"],
    action_labels=["u_d", "u_q"]
)

plot_current_constraints(fig, axs[1], i_d_normalizer, i_q_normalizer)

plt.plot(physical_i_d[plot_range[0]:plot_range[1]], physical_i_q[plot_range[0]:plot_range[1]])

In [None]:
model

In [None]:
from dmpe.evaluation.plotting_utils import plot_model_performance

In [None]:
from dmpe.models.model_training import precompute_starting_points, load_single_batch, fit, ModelTrainer

In [None]:
model_trainer = ModelTrainer(
    start_learning=exp_params["model_trainer_params"]["start_learning"],
    training_batch_size=exp_params["model_trainer_params"]["training_batch_size"],
    n_train_steps=exp_params["model_trainer_params"]["n_train_steps"],
    sequence_length=100,
    featurize=exp_params["model_trainer_params"]["featurize"],
    model_optimizer=optax.adabelief(exp_params["model_trainer_params"]["model_lr"]),
    tau=env.tau,
)

loader_key = jax.random.key(0)

model = exp_params["model_class"](**exp_params["model_params"])
opt_state_model = model_trainer.model_optimizer.init(eqx.filter(model, eqx.is_inexact_array))

for i in tqdm(range(5000)):
    model, opt_state_model, loader_key = model_trainer.fit(
        model=model,
        k=jnp.array([observations.shape[0]]),
        observations=observations,
        actions=actions,
        opt_state=opt_state_model,
        loader_key=loader_key,
    )

In [None]:
batch_size = 10
sequence_length = 20

starting_points, loader_key = precompute_starting_points(
    n_train_steps=1, k=observations.shape[0], sequence_length=sequence_length, training_batch_size=batch_size, loader_key=jax.random.key(1)
)

batched_obs, batched_act = load_single_batch(observations, actions, starting_points[0, ...], sequence_length)

for i in range(batched_obs.shape[0]):
    plot_model_performance(model, batched_obs[i, ...], batched_act[i, ...], env.tau, ["i_d", "i_q"], ["u_d", "u_q"])
    plt.show()