In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
import pathlib

from functools import partial
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pickle
import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import rc

import jax
import jax.numpy as jnp
import jax_dataclasses as jdc
from jax.tree_util import tree_flatten, tree_unflatten

gpus = jax.devices()
jax.config.update("jax_default_device", gpus[0])
# jax.config.update("jax_enable_x64", True)
import diffrax

In [None]:
import dmpe
from dmpe.evaluation.experiment_utils import load_experiment_results
from dmpe.utils.env_utils.pmsm_utils import plot_constraints_induced_voltage
from dmpe.utils.density_estimation import build_grid, DensityEstimate

import eval_dmpe
import eval_goats
from eval_dmpe import setup_env
from dmpe_params import get_RLS_params, get_NODE_params, get_PM_params, get_target_distribution

from dmpe.evaluation.experiment_utils import get_experiment_ids, get_organized_experiment_ids

# Overview:

- plots the $i_\mathrm{dq}$ data for all generated results

In [None]:
def plot_results(full_results_path):

    num_plots = 0
    try:
        organized_experiment_ids = get_organized_experiment_ids(full_results_path)[True]
    except KeyError:
        print("No fitting results found at the specified location.")
        return
    
    for rpm in organized_experiment_ids.keys():
        if rpm not in [0, 3000, 5000, 7000, 9000]:
            continue
        
        for file_name in organized_experiment_ids[rpm]:
            params, observations, actions, _ = load_experiment_results(
                exp_id=file_name,
                results_path=full_results_path,
            )
    
            env, penalty_function = setup_env(params["rpm"])
    
            _, state = env.reset(env.env_properties)
            i_d_normalizer = env.env_properties.physical_normalizations.i_d
            i_q_normalizer = env.env_properties.physical_normalizations.i_q
            
            physical_i_d = i_d_normalizer.denormalize(observations[..., 0])
            physical_i_q = i_q_normalizer.denormalize(observations[..., 1])
            
            print(state.physical_state.omega_el)
            
            plot_constraints_induced_voltage(
                env,
                physical_i_d,
                physical_i_q,
                w_el=state.physical_state.omega_el,
                saturated=True,
                show_torque=False
            )
            plt.show()
            num_plots += 1

    print("number of results:", num_plots)

## NODE-DMPE

In [None]:
full_results_path = "/home/hvater@uni-paderborn.de/projects/forks/DMPE/eval/pmsm/results/dmpe/NODE"
plot_results(full_results_path)

## RLS-DMPE

In [None]:
full_results_path = "/home/hvater@uni-paderborn.de/projects/forks/DMPE/eval/pmsm/results/dmpe/RLS"
plot_results(full_results_path)

## PM-DMPE:

In [None]:
full_results_path = "/home/hvater@uni-paderborn.de/projects/forks/DMPE/eval/pmsm/results/dmpe/PM"
plot_results(full_results_path)

## iGOATS

In [None]:
full_results_path = "/home/hvater@uni-paderborn.de/projects/forks/DMPE/eval/pmsm/results/igoats"
plot_results(full_results_path)

## heuristics:

### current_plane_sweep:

In [None]:
results_path = "/home/hvater@uni-paderborn.de/projects/forks/DMPE/eval/pmsm/results/heuristics/current_plane_sweep"
experiment_ids = get_experiment_ids(results_path)
experiment_ids = ["data_" + exp_id for exp_id in experiment_ids]

for exp_id in experiment_ids:
    with open(results_path / pathlib.Path(f"{exp_id}.json"), "rb") as fp:
        data = json.load(fp)

    observations = jnp.array(data["observations"])
    actions = jnp.array(data["actions"])
    
    rpm = int(exp_id.split("_")[2])

    env, penalty_function = setup_env(rpm)
    
    _, state = env.reset(env.env_properties)
    i_d_normalizer = env.env_properties.physical_normalizations.i_d
    i_q_normalizer = env.env_properties.physical_normalizations.i_q
    
    physical_i_d = i_d_normalizer.denormalize(observations[..., 0])
    physical_i_q = i_q_normalizer.denormalize(observations[..., 1])
    
    print(state.physical_state.omega_el)

    print(observations.shape[0])
    
    plot_constraints_induced_voltage(
        env,
        physical_i_d,
        physical_i_q,
        w_el=state.physical_state.omega_el,
        saturated=True,
        show_torque=False
    )
    plt.show()
    

### random_walk:

In [None]:
results_path = "/home/hvater@uni-paderborn.de/projects/forks/DMPE/eval/pmsm/results/heuristics/random_walk"
experiment_ids = get_experiment_ids(results_path)
experiment_ids = ["data_" + exp_id for exp_id in experiment_ids]

for exp_id in experiment_ids:
    with open(results_path / pathlib.Path(f"{exp_id}.json"), "rb") as fp:
        data = json.load(fp)

    observations = jnp.array(data["observations"])
    actions = jnp.array(data["actions"])
    
    rpm = int(exp_id.split("_")[2])

    env, penalty_function = setup_env(rpm)
    
    _, state = env.reset(env.env_properties)
    i_d_normalizer = env.env_properties.physical_normalizations.i_d
    i_q_normalizer = env.env_properties.physical_normalizations.i_q
    
    physical_i_d = i_d_normalizer.denormalize(observations[..., 0])
    physical_i_q = i_q_normalizer.denormalize(observations[..., 1])
    
    print(state.physical_state.omega_el)
    
    plot_constraints_induced_voltage(
        env,
        physical_i_d,
        physical_i_q,
        w_el=state.physical_state.omega_el,
        saturated=True,
        show_torque=False
    )
    plt.show()
    

# Detailed:

- consideration of feature space distribution:
- Gather/select data for qualitative plots