In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
import pathlib

import time
from functools import partial
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

from tqdm.notebook import tqdm
import pickle
import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import rc

import jax
import jax.numpy as jnp
import jax_dataclasses as jdc
from jax.tree_util import tree_flatten, tree_unflatten

gpus = jax.devices()
jax.config.update("jax_default_device", gpus[0])

import diffrax

In [None]:
import dmpe

from dmpe.evaluation.experiment_utils import (
    load_experiment_results, get_experiment_ids, get_organized_experiment_ids
)
from dmpe.evaluation.metrics_utils import default_ae, default_jsd, default_mcudsa, default_ksfc

from dmpe.evaluation.plotting_utils import plot_sequence

from dmpe.utils.env_utils.pmsm_utils import plot_constraints_induced_voltage
from dmpe.utils.density_estimation import build_grid, DensityEstimate

from dmpe.algorithms.algorithm_utils import interact_and_observe

In [None]:
import eval_dmpe
from eval_dmpe import setup_env

In [None]:
def _filter_valid_points(data_points, penalty_function):

    valid_points_bool = jax.vmap(penalty_function, in_axes=0)(data_points) == 0
    return data_points[jnp.where(valid_points_bool == True)]

filter_valid_points = lambda observations: _filter_valid_points(observations, penalty_function=lambda x: penalty_function(x, None))

## random-walk:

In [None]:
seed = 0
rpm = 3000
n_time_steps = 15_000
n_tries = 4000

env, penalty_function = setup_env(rpm)

# ---- # 

obs, state = env.reset(env.env_properties)
dim_obs_space = obs.shape[0]
dim_action_space = env.action_dim

observations = jnp.zeros((n_time_steps, dim_obs_space))
observations = observations.at[0].set(obs)
actions = jnp.zeros((n_time_steps - 1, dim_action_space))


key = jax.random.key(seed)
key, action_key = jax.random.split(key)

action = jax.random.normal(action_key, shape=(env.action_dim,))

In [None]:
# @partial(jax.jit, static_argnums=(0,))
# def choose_action(env, proposed_actions, state, choice_key):
#     test_obs, test_state = jax.vmap(env.step, in_axes=(None, 0, None))(state, proposed_actions, env.env_properties)    
#     valid_points_bool = jax.vmap(penalty_function, in_axes=(0, 0))(test_obs, proposed_actions) == 0
#     prob_points = valid_points_bool.astype(jnp.float32) / jnp.sum(valid_points_bool)
#     return jax.random.choice(choice_key, proposed_actions, p=prob_points, axis=0)

In [None]:
@partial(jax.jit, static_argnums=(0, 1))
def choose_action(env, penalty_function, proposed_actions, state, choice_key):
    """Choose randomly among the proposed actions that keep the system within bounds for the next step.
    If none of the inputs keep the systems in bounds, apply the one that causes the least penalty.
    
    This is a heursitic implmentation that uses an oracle to ensure compliance with the bounds, but chooses
    mostly randomly among the actions.
    """
    test_obs, test_state = jax.vmap(env.step, in_axes=(None, 0, None))(state, proposed_actions, env.env_properties)

    penalty_values = jax.vmap(penalty_function, in_axes=(0, 0))(test_obs, proposed_actions)

    def true_fun(key, data_array, penalty_values):
        """There are not options that keep the system within bounds. Apply the one with the least penalty."""
        idx_min_penalty = jnp.argmin(penalty_values)
        return data_array[idx_min_penalty]

    def false_fun(key, data_array, penalty_values):
        """There are actions that keep the system within bounds. Choose one randomly."""
        valid_points_bool = penalty_values == 0
        prob_points = valid_points_bool.astype(jnp.float32) / jnp.sum(valid_points_bool)
        return jax.random.choice(choice_key, proposed_actions, p=prob_points, axis=0)
       
    return jax.lax.cond(jnp.all(penalty_values != 0), true_fun, false_fun, *(choice_key, proposed_actions, penalty_values))

In [None]:
for k in tqdm(range(n_time_steps)):

    key, action_key, choice_key = jax.random.split(key, 3)
    proposed_actions = action + jax.random.normal(action_key, shape=(n_tries, env.action_dim,))

    action = choose_action(env, penalty_function, proposed_actions, state, choice_key)
    
    next_obs, next_state, actions, observations = interact_and_observe(
        env=env, k=jnp.array([k]), action=action, state=state, actions=actions, observations=observations
    )

    state = next_state
    obs = next_obs

    if k % 5000 == 0 and k > 0:
        fig = plot_sequence(observations[:k], actions[:k], env.tau, env.obs_description[:2], env.action_description)
        plt.show()

In [None]:
plot_sequence(observations, actions, env.tau, env.obs_description[:2], env.action_description)

In [None]:
import json
import argparse
import datetime
import os

In [None]:
results = {"observations": observations.tolist(), "actions": actions.tolist()}
results

def safe_json_dump(obj, fp):
    default = lambda o: f"<<non-serializable: {type(o).__qualname__}>>"
    return json.dump(obj, fp, default=default)

file_name = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
with open(f"./results/heuristics/random_walk/random_walk_rpm_{rpm}_{file_name}.json", "w") as fp:
    safe_json_dump(results, fp)

## FOC-PID:

In [None]:
from dmpe.utils.env_utils.foc_pid import ClassicController

In [None]:
@partial(jax.jit, static_argnums=(0, 1))
def run_pid_experiment(env, pid, references_norm, init_obs, init_state, init_pid_state):
    def body_fun(carry, reference_norm):
        obs, state, pid_state = carry
    
        currents_norm = obs[None]
        eps = state.physical_state.epsilon * jnp.ones((1,1))
        reference_norm = reference_norm[None, ...]
        pid_obs = jnp.concatenate([currents_norm, eps, reference_norm], axis=-1)
        
        action, next_pid_state = pid(pid_obs, pid_state)
    
        next_obs, next_state = env.step(state, jnp.squeeze(action), env.env_properties)
        return (next_obs, next_state, next_pid_state), jnp.array([jnp.squeeze(obs), jnp.squeeze(action)])

    (last_obs, last_state, last_pid_state), data = jax.lax.scan(body_fun, (init_obs, init_state, init_pid_state), references_norm)
    observations = data[:, 0, :]
    actions = data[:, 1, :]

    return observations, actions

In [None]:
# references = build_grid(2, low=-0.9, high=0.9, points_per_dim=20)
# references = jnp.flip(references, axis=1)

_, penalty_function = setup_env(0)

points_per_dim = 100


references = build_grid(2, low=-0.95, high=0.95, points_per_dim=points_per_dim)
references = jnp.flip(references, axis=1)
references = jnp.flip(references, axis=0)

references = references.reshape(points_per_dim, points_per_dim, 2)
references = np.array(references)
for k in range(points_per_dim):
    if k % 2 == 0:
        references[k] = np.flip(references[k], axis=0)
references = references.reshape(points_per_dim**2, 2)
references = references[int(points_per_dim/2):]

references = jnp.array(references)

references_norm = references[:, None, :].repeat(5, axis=1).reshape(-1, 2)
# references_norm = references[:, None, :].repeat(500, axis=1).reshape(-1, 2)

references_norm = filter_valid_points(references_norm)
references_norm = jnp.concatenate([references_norm[0, :][None].repeat(1000, axis=0), references_norm], axis=0)

print("number_steps:", references_norm.shape)

plt.plot(references_norm[:, 0], label="$i_d$")
plt.legend()
plt.show()

plt.plot(references_norm[:, 1], label="$i_q$")
plt.legend()
plt.show()

plt.plot(references_norm[:, 0], references_norm[:, 1])
plt.xlabel("$i_d$")
plt.ylabel("$i_q$")
plt.show()

In [None]:
# references = build_grid(2, low=-0.9, high=0.9, points_per_dim=20)
# references_norm = references[:, None, :].repeat(100, axis=1).reshape(-1, 2)

# references_norm = filter_valid_points(references_norm)
# references_norm = jnp.concatenate([references_norm[0, :][None].repeat(1000, axis=0), references_norm], axis=0)

# print("number_steps:", references_norm.shape)

# plt.plot(references_norm[:, 0])
# plt.plot(references_norm[:, 1])

# plt.show()

In [None]:
def induced_voltage_constr(x_g, env, w):
    p = env.env_properties.static_params.p
    r_s = env.env_properties.static_params.r_s
    psi_p = env.env_properties.static_params.psi_p

    i_d_normalizer = env.env_properties.physical_normalizations.i_d
    i_q_normalizer = env.env_properties.physical_normalizations.i_q
    physical_i_d = i_d_normalizer.denormalize(x_g[0])
    physical_i_q = i_q_normalizer.denormalize(x_g[1])
    
    psid = env.LUT_interpolators['Psi_d'](jnp.array([physical_i_d, physical_i_q]))
    psiq = env.LUT_interpolators['Psi_q'](jnp.array([physical_i_d, physical_i_q]))

    ud = r_s * physical_i_d - w * psiq
    uq = r_s * physical_i_q + w * psid

    u = jnp.sqrt(ud**2 + uq**2)

    return jnp.squeeze(jax.nn.relu(u - 400 / jnp.sqrt(3)))


def filter_voltage_constraints(env, rpm, references):
    
    penalty_values = jax.vmap(induced_voltage_constr, in_axes=(0, None, None))(
        references, env, rpm * env.env_properties.static_params.p * 2 * jnp.pi / 60
    )

    valid_points = penalty_values == 0
    filtered_references = references[jnp.where(valid_points == True)]
    
    return filtered_references

In [None]:
results = {}

for rpm in [0, 3000, 5000, 7000, 9000]:
    print("RPM:", rpm)
    
    env, penalty_function = setup_env(rpm)
    
    pid = ClassicController(
        motor=env,
        rpm=rpm,
        saturated=env.env_properties.saturated,
        a=4,
        decoupling=True,
        tau=env.tau
    )
    
    references_filtered_for_voltage_constraints = filter_voltage_constraints(env, rpm, references_norm)
    
    init_obs, init_state = env.reset(env.env_properties)
    init_pid_state = pid.reset(1)

    start = time.time()
    observations, actions = run_pid_experiment(env, pid, references_filtered_for_voltage_constraints, init_obs, init_state, init_pid_state)
    end = time.time()
    print("computation_time:", round(end - start, 4), "s")
    
    fig, axs = plot_sequence(observations, actions, env.tau, env.obs_description[:2], env.action_description)

    axs[1].scatter(references_filtered_for_voltage_constraints[:, 0], references_filtered_for_voltage_constraints[:, 1], s=1)
    
    plt.show()


    results[rpm] = {
        "obs": np.array(observations).tolist(),
        "act": np.array(actions).tolist(),
    }
    
    print("------------------------------------------")
    print("\n")


In [None]:
results

## load and evaluate:

In [None]:
import json
import argparse
import datetime
import os

In [None]:
def safe_json_dump(obj, fp):
    default = lambda o: f"<<non-serializable: {type(o).__qualname__}>>"
    return json.dump(obj, fp, default=default)

file_name = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
with open(f"./results/heuristics/current_plane_sweep/current_plane_sweep_{file_name}.json", "w") as fp:
    safe_json_dump(results, fp)

In [None]:
from dmpe.evaluation.experiment_utils import get_experiment_ids

In [None]:
get_experiment_ids("./results/heuristics/current_plane_sweep")

In [None]:
results_path = "./results/heuristics/current_plane_sweep"
exp_id = 'current_plane_sweep_2025-03-11_14:59:16'

with open(results_path / pathlib.Path(f"{exp_id}.json"), "rb") as fp:
    data = json.load(fp)

array_data = {}
for rpm in data.keys():
    array_data[rpm] = {key: jnp.array(data[rpm][key]) for key in data[rpm].keys()}

In [None]:
for rpm in array_data.keys():
    print("rpm: ", rpm)
    fig, axs = plot_sequence(
        array_data[rpm]["obs"][500:], array_data[rpm]["act"][500:], 1e-4, [r"$i_\mathrm{d}$", r"$i_\mathrm{q}$"], [r"$u_\mathrm{d}$", r"$u_\mathrm{q}$"]
    )
    plt.show()