In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pathlib
from functools import partial

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['text.usetex'] = True
mpl.rcParams.update({'font.size': 10 * 2.54})
mpl.rcParams['text.latex.preamble']=r"\usepackage{bm}"
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp
# jax.config.update("jax_enable_x64", True)
gpus = jax.devices()
jax.config.update("jax_default_device", gpus[0])
# jax.config.update('jax_platform_name', 'cpu')
import chex

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
import exciting_environments as excenvs

import dmpe
from dmpe.models import NeuralEulerODEPendulum, NeuralODEPendulum, NeuralEulerODE, NeuralEulerODECartpole
from dmpe.models.model_utils import simulate_ahead_with_env
from dmpe.models.model_training import ModelTrainer
from dmpe.excitation import loss_function, Exciter
from dmpe.algorithms.algorithm_utils import interact_and_observe
from dmpe.utils.metrics import JSDLoss

from dmpe.utils.density_estimation import (
    update_density_estimate_single_observation, update_density_estimate_multiple_observations, DensityEstimate, select_bandwidth, build_grid
)
from dmpe.utils.signals import aprbs
from dmpe.evaluation.plotting_utils import (
    plot_sequence, append_predictions_to_sequence_plot, plot_sequence_and_prediction, plot_model_performance,
    plot_2d_kde_as_contourf, plot_2d_kde_as_surface, plot_feature_combinations
)
from dmpe.evaluation.experiment_utils import (
    get_experiment_ids, load_experiment_results, quick_eval, evaluate_experiment_metrics, evaluate_algorithm_metrics, evaluate_metrics
)

In [None]:
import pmsm_utils
from pmsm_utils import ExcitingPMSM, plot_current_constraints, PMSM_penalty, plot_sequence_with_constraints
from rls import SimulationPMSM_RLS

---

In [None]:
batch_size=1

env = ExcitingPMSM(
    initial_rpm=2_000,
    batch_size=batch_size,
    saturated=True,
    LUT_motor_name="BRUSA",
    static_params = {
        "p": 3,
        "r_s": 17.932e-3,
        "l_d": 0.37e-3,
        "l_q": 1.2e-3,
        "psi_p": 65.65e-3,
        "deadtime": 0,
    },
    solver=diffrax.Tsit5()
)

In [None]:
seed = 5232367
alg_params = dict(
    bandwidth=jnp.nan,
    n_prediction_steps=5,
    points_per_dim=21,
    grid_extend=1.05,
    #excitation_optimizer=optax.lbfgs(),
    excitation_optimizer=optax.adabelief(1e-2),
    n_opt_steps=100,
    #n_opt_steps=25,
    start_optimizing=5,
    consider_action_distribution=True,
    penalty_function=lambda observations, actions: PMSM_penalty(env, observations, actions),
    target_distribution=None,
    clip_action=False,
    n_starts=3,
    reuse_proposed_actions=True,
)


dim = 4 if alg_params["consider_action_distribution"] else 2
alg_params["bandwidth"] = 0.08

model_params = dict(lambda_=0.9)

exp_params = dict(
    seed=seed,
    n_time_steps=5_000,
    alg_params=alg_params,
    model_trainer_params=None,
    model_params=model_params,
    model_class=SimulationPMSM_RLS,
)

# redesign target distribution
x_g = build_grid(
    dim,
    low=-exp_params["alg_params"]["grid_extend"],
    high=exp_params["alg_params"]["grid_extend"],
    points_per_dim=exp_params["alg_params"]["points_per_dim"]
)
constr_func = lambda x_g: PMSM_penalty(env, x_g[..., None, :2], x_g[..., None, 2:])
valid_grid_point = jax.vmap(constr_func, in_axes=0)(x_g) == 0
constrained_data_points = x_g[jnp.where(valid_grid_point == True)]
constrained_data_points.shape

target_distribution = DensityEstimate.from_dataset(
    constrained_data_points[None],
    points_per_dim=alg_params["points_per_dim"],
    bandwidth=alg_params["bandwidth"],
)

exp_params["alg_params"]["target_distribution"] = target_distribution.p[0]

# setup prng
key = jax.random.PRNGKey(seed=exp_params["seed"])
data_key, model_key, loader_key, expl_key, key = jax.random.split(key, 5)

data_rng = PRNGSequence(data_key)


# initial guess
proposed_actions = jnp.hstack(
    [
        aprbs(alg_params["n_prediction_steps"], env.batch_size, 1, 10, next(data_rng))[0]
        for _ in range(env.action_dim)
    ]
)

In [None]:
observations, actions, model, density_estimate, losses, proposed_actions = dmpe.algorithms.excite_with_dmpe(
    env, exp_params, proposed_actions, loader_key, expl_key, plot_every=1000
)

In [None]:
fig, _ = plot_sequence_with_constraints(env, observations, actions)

In [None]:
def get_constraint_violations(observations, actions, length):
    single_results = []
    
    for i in range(len(lengths) - 1):
        single_results.append(
            PMSM_penalty(
                env, 
                observations[lengths[i] : lengths[i + 1]], actions[lengths[i] : lengths[i + 1]]
            ) / (lengths[i + 1] - lengths[i])
        )

    return single_results

lengths = jnp.linspace(0, 5000, 151, dtype=jnp.int32)
plt.plot(lengths[:-1], np.log(get_constraint_violations(observations, actions, lengths)))

lengths = jnp.linspace(0, 5000, 25, dtype=jnp.int32)
plt.plot(lengths[:-1], np.log(get_constraint_violations(observations, actions, lengths)))

plt.grid()
plt.show()

## 
         
lengths = jnp.linspace(1000, 5000, 151, dtype=jnp.int32)
plt.plot(lengths[:-1], get_constraint_violations(observations, actions, lengths))

lengths = jnp.linspace(1000, 5000, 25, dtype=jnp.int32)
plt.plot(lengths[:-1], get_constraint_violations(observations, actions, lengths))
         
plt.grid()
plt.ylim(-0.1, None)

In [None]:
from pmsm_utils import plot_constraints_induced_voltage

In [None]:
_, state = env.reset(env.env_properties)
i_d_normalizer = env.env_properties.physical_normalizations.i_d
i_q_normalizer = env.env_properties.physical_normalizations.i_q

physical_i_d = i_d_normalizer.denormalize(observations[..., 0])
physical_i_q = i_q_normalizer.denormalize(observations[..., 1])

plot_constraints_induced_voltage(env, physical_i_d, physical_i_q, w_el=state.physical_state.omega_el, saturated=True)

In [None]:
fig = plot_feature_combinations(
    jnp.concatenate([observations[0:-1], actions], axis=-1),
    labels=["$\\tilde{i}_d$", "$\\tilde{i}_q$", "$\\tilde{u}_d$", "$\\tilde{u}_q$"],
    mode="contourf",
    bandwidth=0.05
);
plt.show()
# plt.savefig("results/plots/focus_both.pdf") if alg_params["consider_action_distribution"] else plt.savefig("results/plots/focus_obs.pdf")

In [None]:
fig = plot_feature_combinations(
    jnp.concatenate([observations[0:-1], actions], axis=-1),
    labels=["$\\tilde{i}_d$", "$\\tilde{i}_q$", "$\\tilde{u}_d$", "$\\tilde{u}_q$"]
);
# plt.savefig("results/plots/focus_both.pdf") if alg_params["consider_action_distribution"] else plt.savefig("results/plots/focus_obs.pdf")

In [None]:
raise

In [None]:
obs, state = env.reset(env.env_properties)
dim_obs_space = obs.shape[0]
dim_action_space = env.action_dim

n_grid_points = exp_params["alg_params"]["points_per_dim"] ** dim
observations = jnp.zeros((exp_params["n_time_steps"], dim_obs_space))
observations = observations.at[0].set(obs)
actions = jnp.zeros((exp_params["n_time_steps"] - 1, dim_action_space))

exciter = Exciter(
    loss_function=loss_function,
    grad_loss_function=jax.value_and_grad(loss_function, argnums=(3)),
    excitation_optimizer=optax.adabelief(exp_params["alg_params"]["action_lr"]),
    tau=env.tau,
    n_opt_steps=exp_params["alg_params"]["n_opt_steps"],
    consider_action_distribution=exp_params["alg_params"]["consider_action_distribution"],
    target_distribution=exp_params["alg_params"]["target_distribution"],
    penalty_function=exp_params["alg_params"]["penalty_function"],
    clip_action=exp_params["alg_params"]["clip_action"],
    n_starts=exp_params["alg_params"]["n_starts"],
    reuse_proposed_actions=exp_params["alg_params"]["reuse_proposed_actions"],
)


model = SimulationPMSM_RLS(lambda_=model_params["lambda_"])

density_estimate = DensityEstimate(
    p=jnp.zeros([n_grid_points, 1]),
    x_g=build_grid(
        dim,
        low=-exp_params["alg_params"]["grid_extend"],
        high=exp_params["alg_params"]["grid_extend"],
        points_per_dim=exp_params["alg_params"]["points_per_dim"],
    ),
    bandwidth=jnp.array([exp_params["alg_params"]["bandwidth"]]),
    n_observations=jnp.array([0]),
)

In [None]:
prediction_losses = []
data_losses = []

for k in tqdm(range(exp_params["n_time_steps"])):

    if k > 1:
        action, proposed_actions, density_estimate, prediction_loss, expl_key = exciter.choose_action(
            obs=obs,
            state=state,
            model=model,
            density_estimate=density_estimate,
            proposed_actions=proposed_actions,
            expl_key=expl_key,
        )
        
    else:
        action = proposed_actions[0]
        prediction_loss=0.0
    prediction_losses.append(prediction_loss)
    # raise
    
    next_obs, state, actions, observations = interact_and_observe(
        env=env, k=jnp.array([k]), action=action, state=state, actions=actions, observations=observations
    )

    ## update RLS model
    rls_in = jnp.concatenate([jnp.squeeze(obs), action, jnp.ones(1)])[..., None]
    model = SimulationPMSM_RLS.update(model, x=rls_in, d=next_obs[..., None])

    obs = next_obs

    data_loss = JSDLoss(
        density_estimate.p / jnp.sum(density_estimate.p),
        exciter.target_distribution / jnp.sum(exciter.target_distribution),
    )
    data_losses.append(data_loss)

    if k % 1_000 == 0 and k > 0:
        print("last input opt loss:", prediction_losses[-1])
        print("current data loss:", data_loss)
        fig, axs = plot_sequence_and_prediction(
            observations=observations[: k + 2, :],
            actions=actions[: k + 1, :],
            tau=exciter.tau,
            obs_labels=env.obs_description,
            actions_labels=env.action_description,
            model=model,
            init_obs=obs,
            init_state=state,
            proposed_actions=proposed_actions,
        )
        plt.show()

        plt.plot(np.log(data_losses))
        plt.grid(True)
        plt.show()