In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pathlib
from functools import partial

import time
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['text.usetex'] = True
mpl.rcParams.update({'font.size': 10 * 2.54})
mpl.rcParams['text.latex.preamble']=r"\usepackage{bm}"
import plotly.express as px
import plotly.graph_objects as go

In [None]:
import jax
import jax.numpy as jnp
# jax.config.update("jax_enable_x64", True)
gpus = jax.devices()
jax.config.update("jax_default_device", gpus[0])
# jax.config.update('jax_platform_name', 'cpu')
import chex

import diffrax
import equinox as eqx
import optax

from haiku import PRNGSequence

In [None]:
import exciting_environments as excenvs

import dmpe
from dmpe.models import NeuralEulerODEPendulum, NeuralODEPendulum, NeuralEulerODE, NeuralEulerODECartpole
from dmpe.models.model_utils import simulate_ahead_with_env
from dmpe.models.model_training import ModelTrainer
from dmpe.excitation import loss_function, Exciter

from dmpe.utils.density_estimation import (
    update_density_estimate_single_observation, update_density_estimate_multiple_observations, DensityEstimate, select_bandwidth, build_grid
)
from dmpe.utils.signals import aprbs
from dmpe.evaluation.plotting_utils import (
    plot_sequence, append_predictions_to_sequence_plot, plot_sequence_and_prediction, plot_model_performance,
    plot_2d_kde_as_contourf, plot_2d_kde_as_surface, plot_feature_combinations
)
from dmpe.evaluation.experiment_utils import (
    get_experiment_ids, load_experiment_results, quick_eval, evaluate_experiment_metrics, evaluate_algorithm_metrics, evaluate_metrics
)

---

In [None]:
# setup PRNG
key = jax.random.PRNGKey(seed=222)

data_key, model_key, loader_key, expl_key, key = jax.random.split(key, 5)
data_rng = PRNGSequence(data_key)

In [None]:
import pmsm_utils
from pmsm_utils import ExcitingPMSM, plot_current_constraints

In [None]:
batch_size=1

env = ExcitingPMSM(
    batch_size=batch_size,
    saturated=True,
    LUT_motor_name="BRUSA",
    static_params = {
        "p": 3,
        "r_s": 15e-3,
        "l_d": jnp.nan,
        "l_q": jnp.nan,
        "psi_p": jnp.nan,
        "deadtime": 0,
    },
    solver=diffrax.Euler()
)

In [None]:
obs, state = env.vmap_reset()

n_steps = 99
actions = jnp.concatenate([aprbs(n_steps, batch_size, 1, 10, next(data_rng)), aprbs(n_steps, batch_size, 1, 10, next(data_rng))], axis=-1)

observations = [obs[..., 0:2]]

for i in range(actions.shape[1]):
   
    obs, state = env.vmap_step(state, actions[:, i,:])
    observations.append(obs[...,0:2])

In [None]:
plot_sequence(
    np.vstack(observations),
    np.vstack(actions),
    env.tau,
    env.obs_description[:2],
    env.action_description
)

In [None]:
class RLS(eqx.Module):
    """RLS estimator based on the description given in [Brosch2021] and [Jakobeit2025]."""

    num_coefficients: int
    lambda_: float
    
    class State(eqx.Module):
        w: jax.Array
        P: jax.Array
    
    def __init__(self, num_coefficients, lambda_):
        self.num_coefficients = num_coefficients
        self.lambda_ = lambda_

    def reset(self):
        """
        Reset the filter to its initial state.
        """
        return self.State(
            w=jnp.zeros((self.num_coefficients)),
            P=jnp.eye(self.num_coefficients)
        )

    @eqx.filter_jit
    def __call__(self, rls_state, x):
        """
        Predict the output for a given input using the RLS model.
        """
        y_pred = x @ rls_state.w
        return y_pred

    @eqx.filter_jit
    def update(self, rls_state, x, d):
        """
        Update function.
        """
        P = rls_state.P
        w = rls_state.w
    
        c = (P @ x) / (self.lambda_ + jnp.squeeze(x @ P @ x))
        w_new = w + c * jnp.squeeze(d - x @ w)
        P_new = (jnp.eye(self.num_coefficients) - c @ x) @ P / self.lambda_
    
        return self.State(
            w=jnp.zeros((self.num_coefficients,1)),
            P=jnp.eye(self.num_coefficients)
        )

In [None]:
rls = RLS(5, 0.99)

In [None]:
rls_state = eqx.filter_vmap(rls.reset, axis_size=2)()
rls_state

In [None]:
rls_state = eqx.filter_vmap(rls.update, in_axes=(RLS.State(w=0, P=0), None, 0))(rls_state, jnp.zeros((5)), jnp.zeros((2, 1)))

In [None]:
rls_state

---
---
---

In [None]:
from rls import RLS, PMSM_RLS, SimulationPMSM_RLS

In [None]:
rls = RLS(5, 2, 0.99)

In [None]:
rls = RLS.update(
    rls=rls,
    x=jnp.zeros((5, 1)),
    d=jnp.zeros((2, 1))
)

In [None]:
rls.w.shape

In [None]:
RLS.predict(rls, jnp.zeros((5, 1)))

In [None]:
observations = jnp.array(np.load("results/pmsm_dmpe_observations.npy"))
actions = jnp.array(np.load("results/pmsm_dmpe_actions.npy"))

In [None]:
plot_sequence(
    observations[:500],
    actions[:500],
    env.tau,
    env.obs_description[:2],
    env.action_description
)

In [None]:
# predict what you just have learned

lambda_ = 0.9
print("lambda:", lambda_)
pmsm_rls = PMSM_RLS(lambda_=lambda_)

hindsight_errors = []
pred_errors = []
predictions = []

for i in range(1, 500):
    # learn from last
    rls_in = jnp.concatenate([observations[i-1], actions[i-1, :], jnp.ones(1)])[..., None]
    pmsm_rls = PMSM_RLS.update(pmsm_rls, x=rls_in, d=observations[i][..., None])

    # predict current in hindsight
    current_pred = pmsm_rls(jnp.concatenate([observations[i-1], actions[i-1, :], jnp.ones(1)])[..., None])
    hindsight_errors.append(current_pred - observations[i][..., None])

    
    # predict next
    next_pred = pmsm_rls(jnp.concatenate([observations[i], actions[i, :], jnp.ones(1)])[..., None])
    pred_errors.append(next_pred - observations[i+1][..., None])

    predictions.append(next_pred)


predictions = jnp.hstack(predictions).T

hindsight_errors = jnp.squeeze(jnp.stack(hindsight_errors))
plt.plot(jnp.linalg.norm(hindsight_errors, axis=-1))
plt.ylim(0, 1)
plt.show()

pred_errors = jnp.squeeze(jnp.stack(pred_errors))
plt.plot(jnp.linalg.norm(pred_errors, axis=-1))
plt.ylim(0, 1)
plt.show()

In [None]:
predictions = jnp.concatenate([jnp.zeros((1,2)), predictions])

In [None]:
plt.plot(observations[:500, 0], label=r"$i_d$")
plt.plot(predictions[:500, 0], label=r"$\hat{i}_d$")
plt.grid()
plt.legend()
plt.show()

plt.plot(observations[:500, 1], label=r"$i_q$")
plt.plot(predictions[:500, 1], label=r"$\hat{i}_q$")
plt.grid()
plt.legend()
plt.show()

# Based on running data:

In [None]:
# setup PRNG
key = jax.random.PRNGKey(seed=5555)

data_key, model_key, loader_key, expl_key, key = jax.random.split(key, 5)
data_rng = PRNGSequence(data_key)

obs, state = env.vmap_reset()

n_steps = 1000
# actions = jnp.concatenate([aprbs(n_steps, batch_size, 1, 10, next(data_rng)), aprbs(n_steps, batch_size, 1, 10, next(data_rng))], axis=-1)

actions = jax.random.uniform(next(data_rng), shape=(1, n_steps, 2), minval=-1, maxval=1)

observations = [obs[0, 0:2]]
predictions = [jnp.zeros(2)]

pmsm_rls = PMSM_RLS(lambda_=0.9)

for i in range(actions.shape[1]):

    rls_in = jnp.concatenate([obs[0,...], actions[0, i, :], jnp.ones(1)])[..., None]
    prediction = pmsm_rls(rls_in)
    
    next_obs, state = env.vmap_step(state, actions[:, i,:])
    observations.append(next_obs[0,0:2])
    predictions.append(prediction[..., 0])
    
    pmsm_rls = PMSM_RLS.update(pmsm_rls, x=rls_in, d=next_obs.T)
    obs = next_obs

observations = jnp.stack(observations)
predictions = jnp.stack(predictions)

In [None]:
t = jnp.linspace(0, observations.shape[0] - 1, observations.shape[0]) * env.tau

plt.plot(t, observations[..., 0], label=r"$i_d$")
plt.plot(t, predictions[..., 0], label=r"$\hat{i}_d$")
plt.grid()
plt.legend()
plt.show()

plt.plot(t, observations[..., 1], label=r"$i_q$")
plt.plot(t, predictions[..., 1], label=r"$\hat{i}_q$")
plt.grid()
plt.legend()
plt.ylim(-1, 1)
plt.show()

In [None]:
rls = SimulationPMSM_RLS(lambda_=0.9)

obs, state = env.reset(env.env_properties)

k = 200

for i in range(k):
    rls_in = jnp.concatenate([obs, actions[0, i, :], jnp.ones(1)])[..., None]

    next_obs, state = env.step(state, actions[0, i,:], env.env_properties)

    rls = SimulationPMSM_RLS.update(rls, x=rls_in, d=next_obs[..., None])
    obs = next_obs

In [None]:
rls

In [None]:
sequence_length = 20
predictions = rls(obs, actions[0, k:k+sequence_length])

In [None]:
t = jnp.linspace(0, sequence_length, sequence_length+1) * env.tau

plt.plot(t, observations[k:k+sequence_length+1, 0], label=r"$i_d$")
plt.plot(t, predictions[..., 0], label=r"$\hat{i}_d$")
plt.grid()
plt.legend()
#plt.ylim(-1, 1)
plt.show()

plt.plot(t, observations[k:k+sequence_length+1, 1], label=r"$i_q$")
plt.plot(t, predictions[..., 1], label=r"$\hat{i}_q$")
plt.grid()
plt.legend()
plt.ylim(-1, 1)
plt.show()