In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
import pathlib

import time
from functools import partial
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

from tqdm.notebook import tqdm
import pickle
import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import rc

import jax
import jax.numpy as jnp
import jax_dataclasses as jdc
from jax.tree_util import tree_flatten, tree_unflatten

gpus = jax.devices()
jax.config.update("jax_default_device", gpus[0])

import diffrax

In [None]:
import dmpe

from dmpe.evaluation.experiment_utils import (
    load_experiment_results, get_experiment_ids, get_organized_experiment_ids
)
from dmpe.evaluation.metrics_utils import default_ae, default_jsd, default_mcudsa, default_ksfc

from dmpe.evaluation.plotting_utils import plot_sequence

from dmpe.utils.env_utils.pmsm_utils import plot_constraints_induced_voltage
from dmpe.utils.density_estimation import build_grid, DensityEstimate

from dmpe.algorithms.algorithm_utils import interact_and_observe

In [None]:
import eval_dmpe
from eval_dmpe import setup_env

## random-walk:

In [None]:
seed = 0
rpm = 5000
n_time_steps = 15_000

env, penalty_function = setup_env(5000)

# ---- # 

obs, state = env.reset(env.env_properties)
dim_obs_space = obs.shape[0]
dim_action_space = env.action_dim

observations = jnp.zeros((n_time_steps, dim_obs_space))
observations = observations.at[0].set(obs)
actions = jnp.zeros((n_time_steps - 1, dim_action_space))


key = jax.random.key(seed)
key, action_key = jax.random.split(key)

action = jax.random.normal(action_key, shape=(env.action_dim,))

In [None]:
for k in range(time_steps):
    next_obs, next_state, actions, observations = interact_and_observe(
        env=env, k=jnp.array([k]), action=action, state=state, actions=actions, observations=observations
    )

    key, action_key = jax.random.split(key)
    action = action + jax.random.normal(action_key, shape=(env.action_dim,))

    state = next_state
    obs = next_obs

In [None]:
plot_sequence(observations, actions, env.tau, env.obs_description[:2], env.action_description)

## FOC-PID:

In [None]:
from dmpe.utils.env_utils.foc_pid import ClassicController

In [None]:
def _filter_valid_points(data_points, penalty_function):
    
    valid_points_bool = jax.vmap(penalty_function, in_axes=0)(data_points) == 0
    return data_points[jnp.where(valid_points_bool == True)]

filter_valid_points = lambda observations: _filter_valid_points(observations, penalty_function=lambda x: penalty_function(x, None))

In [None]:
@partial(jax.jit, static_argnums=(0, 1))
def run_pid_experiment(env, pid, references_norm, init_obs, init_state, init_pid_state):
    def body_fun(carry, reference_norm):
        obs, state, pid_state = carry
    
        currents_norm = obs[None]
        eps = state.physical_state.epsilon * jnp.ones((1,1))
        reference_norm = reference_norm[None, ...]
        pid_obs = jnp.concatenate([currents_norm, eps, reference_norm], axis=-1)
        
        action, next_pid_state = pid(pid_obs, pid_state)
    
        next_obs, next_state = env.step(state, jnp.squeeze(action), env.env_properties)
        return (next_obs, next_state, next_pid_state), jnp.array([jnp.squeeze(obs), jnp.squeeze(action)])

    (last_obs, last_state, last_pid_state), data = jax.lax.scan(body_fun, (init_obs, init_state, init_pid_state), references_norm)
    observations = data[:, 0, :]
    actions = data[:, 1, :]

    return observations, actions

In [None]:
results = {}

for rpm in [0, 3000, 5000, 7000, 9000]:
    print("RPM:", rpm)
    
    env, penalty_function = setup_env(rpm)
    
    pid = ClassicController(
        motor=env,
        saturated=env.env_properties.saturated,
        a=4,
        decoupling=True,
        tau=env.tau
    )
    
    references = build_grid(2, low=-0.9, high=0.9, points_per_dim=10)
    references_norm = references[:, None, :].repeat(200, axis=1).reshape(-1, 2)

    references_norm = filter_valid_points(references_norm)
    references_norm = jnp.concatenate([references_norm[0, :][None].repeat(1000, axis=0), references_norm], axis=0)
    
    print("number_steps:", references_norm.shape)
    
    plt.plot(references_norm[:, 0])
    plt.plot(references_norm[:, 1])
    
    plt.show()
    
    init_obs, init_state = env.reset(env.env_properties)
    init_pid_state = pid.reset(1)

    start = time.time()
    observations, actions = run_pid_experiment(env, pid, references_norm, init_obs, init_state, init_pid_state)
    end = time.time()
    print("computation_time:", round(end - start, 4), "s")
    
    fig = plot_sequence(observations, actions, env.tau, env.obs_description[:2], env.action_description)
    plt.show()


    results[rpm] = {
        "obs": np.array(observations).tolist(),
        "act": np.array(actions).tolist(),
    }
    
    print("------------------------------------------")
    print("\n")


In [None]:
results

In [None]:
import json
import argparse
import datetime
import os

In [None]:
# def safe_json_dump(obj, fp):
#     default = lambda o: f"<<non-serializable: {type(o).__qualname__}>>"
#     return json.dump(obj, fp, default=default)

# file_name = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
# with open(f"./results/heuristics/current_plane_sweep_{file_name}.json", "w") as fp:
#     safe_json_dump(results, fp)