In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
import pathlib

from functools import partial
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pickle
import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import rc

import jax
import jax.numpy as jnp
import jax_dataclasses as jdc
from jax.tree_util import tree_flatten, tree_unflatten

gpus = jax.devices()
jax.config.update("jax_default_device", gpus[0])

import diffrax

In [None]:
import dmpe
from dmpe.evaluation.experiment_utils import load_experiment_results
from dmpe.utils.env_utils.pmsm_utils import plot_constraints_induced_voltage
from dmpe.utils.density_estimation import build_grid, DensityEstimate

from eval_dmpe import run_experiment, setup_env
from dmpe_params import get_RLS_params, get_NODE_params, get_PM_params, get_target_distribution

In [None]:
# params:
rpm = 2_000
model_name = "NODE"
consider_actions = True

In [None]:
env, penalty_function = setup_env(rpm)

alg_params, model_params, model_class, model_trainer_params, model_env_wrapper = get_NODE_params(
    consider_action_distribution=consider_actions, penalty_function=penalty_function
)

In [None]:
# inspect the target_distribution:

target_distribution = alg_params["target_distribution"]
target_distribution = target_distribution.reshape(4 * (alg_params["points_per_dim"],))
target_distribution.shape

In [None]:
print("currents")
print(jnp.sum(target_distribution[:, :, 0, 0]))
print(jnp.sum(target_distribution[:, :, 10, 10]))
plt.imshow(target_distribution[:, :, 0, 0])
plt.show()

print("voltages")
plt.imshow(target_distribution[0, 0, :, :])
plt.show()

In [None]:
plt.imshow(target_distribution[:, 10, :, 10])
plt.show()

In [None]:
exp_params = dict(
    seed=None,
    rpm=float(rpm),
    n_time_steps=5_000,
    alg_params=alg_params,
    model_params=model_params,
    model_class=model_class,
    model_trainer_params=model_trainer_params,
    model_env_wrapper=model_env_wrapper,
)

In [None]:
seed = 222
exp_params["seed"] = int(seed)
run_experiment(model_name, 0, env, exp_params)

### inspect experiment result:

In [None]:
from dmpe.evaluation.experiment_utils import get_experiment_ids

def get_organized_experiment_ids(full_results_path):
    experiment_ids = get_experiment_ids(full_results_path)
    organized_experiment_ids = {}
    
    for experiment_id in experiment_ids:
    
        ca = experiment_id.split("ca_")[-1].split("_")[0] == "True"
    
        if ca not in organized_experiment_ids.keys():
            organized_experiment_ids[ca] = {}
            
        rpm = float(experiment_id.split("rpm_")[-1].split("_")[0])
        if rpm not in organized_experiment_ids[ca].keys():
            organized_experiment_ids[ca][rpm] = []
        organized_experiment_ids[ca][rpm].append(experiment_id)

    return organized_experiment_ids

full_results_path = "/home/hvater@uni-paderborn.de/projects/forks/DMPE/eval/scripts/pmsm/results/dmpe/RLS"
organized_experiment_ids = get_organized_experiment_ids(full_results_path)
organized_experiment_ids

In [None]:
params, observations, actions, _ = load_experiment_results(
    exp_id=organized_experiment_ids[False][7000][-1],
    results_path="results/dmpe/RLS",
)

env, penalty_function = setup_env(params["rpm"])

_, state = env.reset(env.env_properties)
i_d_normalizer = env.env_properties.physical_normalizations.i_d
i_q_normalizer = env.env_properties.physical_normalizations.i_q

physical_i_d = i_d_normalizer.denormalize(observations[..., 0])
physical_i_q = i_q_normalizer.denormalize(observations[..., 1])

print(state.physical_state.omega_el)

plot_constraints_induced_voltage(
    env,
    physical_i_d,
    physical_i_q,
    w_el=state.physical_state.omega_el,
    saturated=True,
    show_torque=False
)

In [None]:
path = "results/dmpe/NODE"

organized_experiment_ids = get_organized_experiment_ids(path)[False]

for rpm in organized_experiment_ids.keys():
    for experiment_id in organized_experiment_ids[rpm]:
        print("rpm:", rpm)
        print(experiment_id)
        params, observations, actions, _ = load_experiment_results(
            exp_id=experiment_id,
            results_path=path,
        )
        
        env, penalty_function = setup_env(params["rpm"])
        
        _, state = env.reset(env.env_properties)
        i_d_normalizer = env.env_properties.physical_normalizations.i_d
        i_q_normalizer = env.env_properties.physical_normalizations.i_q
        
        physical_i_d = i_d_normalizer.denormalize(observations[..., 0])
        physical_i_q = i_q_normalizer.denormalize(observations[..., 1])
        
        print(state.physical_state.omega_el)
        
        plot_constraints_induced_voltage(
            env,
            physical_i_d,
            physical_i_q,
            w_el=state.physical_state.omega_el,
            saturated=True,
            show_torque=False
        )
        plt.show()