In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
import pathlib

from functools import partial
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pickle
import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import rc
# rc('font',**{'family':'serif','serif':['Helvetica']})
mpl.rcParams['text.usetex'] = True
mpl.rcParams.update({'font.size': 10 * 2.54})
mpl.rcParams['text.latex.preamble']=r"\usepackage{bm}\usepackage{amsmath}"

import jax
import jax.numpy as jnp
import jax_dataclasses as jdc
from jax.tree_util import tree_flatten, tree_unflatten

# jax.config.update('jax_platform_name', 'cpu')

# jax.config.update("jax_debug_nans", True)

gpus = jax.devices()
jax.config.update("jax_default_device", gpus[0])

import diffrax

In [None]:
from dmpe.evaluation.experiment_utils import extract_metrics_over_timesteps, extract_metrics_over_timesteps_via_interpolation
from dmpe.evaluation.plotting_utils import plot_metrics_by_sequence_length_for_all_algos
from dmpe.evaluation.experiment_utils import get_experiment_ids

from dmpe.utils.density_estimation import select_bandwidth
from dmpe.evaluation.experiment_utils import default_jsd, default_ae, default_mcudsa, default_ksfc

In [None]:
from math import e

In [None]:
def get_organized_experiment_ids(full_results_path):
    experiment_ids = get_experiment_ids(full_results_path)
    organized_experiment_ids = {}
    
    for experiment_id in experiment_ids:
    
        ca = experiment_id.split("ca_")[-1].split("_")[0] == "True"
    
        if ca not in organized_experiment_ids.keys():
            organized_experiment_ids[ca] = {}
            
        rpm = float(experiment_id.split("rpm_")[-1].split("_")[0])
        if rpm not in organized_experiment_ids[ca].keys():
            organized_experiment_ids[ca][rpm] = []
        organized_experiment_ids[ca][rpm].append(experiment_id)

    return organized_experiment_ids

full_results_path = "/home/hvater@uni-paderborn.de/projects/forks/DMPE/eval/scripts/pmsm/results/dmpe/NODE/before_target_fix"
organized_experiment_ids = get_organized_experiment_ids(full_results_path)
print(organized_experiment_ids[True][2000.0])

In [None]:
def extract_results(
    lengths,
    raw_results_path,
    algo_names,
    interpolate_to_lengths,
    metrics=None,
    metric_params=None,
    extra_folders=None,
):

    all_results_by_metric = {algo_name: {} for algo_name in algo_names}
    
    for (algo_name, use_interpolation) in zip(algo_names, interpolate_to_lengths):
        full_results_path = raw_results_path / pathlib.Path(algo_name)
        full_results_path = full_results_path / pathlib.Path(extra_folders) if extra_folders is not None else full_results_path

        print("Extract results for", algo_name, "\n at", full_results_path)

        organized_experiment_ids = get_organized_experiment_ids(full_results_path)

        for ca in organized_experiment_ids.keys():
            print("Momentary experiments consider action distribution:", ca)

            specific_metrics = {}
            for metric_name, metric_function in metrics.items():
                if metric_name == "jsd":
                    target_distribution = get_target_distribution(
                        metric_params[metric_name]["points_per_dim"],
                        metric_params[metric_name]["bandwidth"],
                        metric_params[metric_name]["grid_extend"],
                        ca,
                        metric_params[metric_name]["penalty_function"]
                    )
                    specific_metrics[metric_name] = partial(
                        metric_function,
                        points_per_dim=metric_params[metric_name]["points_per_dim"],
                        bandwidth=metric_params[metric_name]["bandwidth"],
                        target_distribution=target_distribution,
                        ca=ca,
                    )
                elif metric_name == "mcudsa" or metric_name == "kfsc":
                    specific_metrics[metric_name] = partial(
                        metric_function, ca=ca, **metric_params[metric_name]
                    )
                else:
                    specific_metrics[metric_name] = metric_function


            if ca not in all_results_by_metric[algo_name]:
                all_results_by_metric[algo_name][ca] = {}

            for rpm in organized_experiment_ids[ca].keys():
                print(f"Momentary experiments run at {int(rpm)} rpm.")
            
                if not use_interpolation:
                    all_results_by_metric[algo_name][ca][rpm] = extract_metrics_over_timesteps(
                        experiment_ids=organized_experiment_ids[ca][rpm],
                        results_path=full_results_path,
                        lengths=lengths,
                        metrics=specific_metrics,
                    )
                else:
                    all_results_by_metric[algo_name][ca][rpm] = extract_metrics_over_timesteps_via_interpolation(
                        experiment_ids=organized_experiment_ids[ca][rpm],
                        results_path=full_results_path,
                        target_lengths=lengths,
                        metrics=specific_metrics,
                    )
                print("\n")
    return all_results_by_metric

## PMSM:

In [None]:
full_column_width = 16.5719132056

def plot_metrics_by_sequence_length_for_all_algos(data_per_algo, lengths, algo_names, use_log=False):
    assert len(data_per_algo) == len(algo_names), "Mismatch in number of algo results and number of algo names"

    metric_keys = data_per_algo[0].keys()

    fig, axs = plt.subplots(len(metric_keys), figsize=(full_column_width, 18), sharex=True) # figsize=(19, 18)
    colors = plt.rcParams["axes.prop_cycle"]()

    for algo_name, data in zip(algo_names, data_per_algo):
        
        c = next(colors)["color"]
        if c == '#d62728':
            c = next(colors)["color"]
        for metric_idx, metric_key in enumerate(metric_keys):
            mean = jnp.nanmean(jnp.log(data[metric_key]), axis=0) if use_log else jnp.nanmean(data[metric_key], axis=0)
            std = jnp.nanstd(jnp.log(data[metric_key]), axis=0) if use_log else jnp.nanstd(data[metric_key], axis=0)

            axs[metric_idx].plot(
                lengths,
                mean,  # jnp.log(mean) if use_log else mean,
                label=algo_name,
                color=c,
                #linestyle='dashed' if algo_name=="$\mathrm{DMPE}$" else None,
            )
            axs[metric_idx].fill_between(
                lengths,
                mean - std,  # jnp.log(mean - std) if use_log else mean - std,
                mean + std,  # jnp.log(mean + std) if use_log else mean + std,
                color=c,
                alpha=0.1,
            )
        # axs[metric_idx].set_ylabel(("log " if use_log else "") + metric_key)

        axs[-1].set_yscale('log', base=10) if list(metric_keys)[-1] == "ksfc" else None

    for idx, metric_key in enumerate(metric_keys):
        axs[idx].set_ylabel(f"$\mathcal{{L}}_\mathrm{{{metric_key.upper()}}}$")
    
    axs[-1].set_xlabel("$k$")
    axs[-1].set_xlim(lengths[0], lengths[-1])
    [ax.grid(True, which="both") for ax in axs]
    axs[0].legend(prop={'size': 8 * 2.54})

    plt.subplots_adjust(hspace=0.02)
    
    plt.tight_layout(pad=0.05)

    [ax.tick_params(axis="y", direction='in') for ax in axs]
    [ax.tick_params(axis="x", direction='in') for ax in axs]
    # [ax.yaxis.set_major_locator(plt.MaxNLocator(3)) for ax in axs]

    fig.align_ylabels(axs)

    return fig

In [None]:
lengths = jnp.linspace(1000, 5000, 5, dtype=jnp.int32)
lengths

In [None]:
# from dmpe.utils.density_estimation import build_grid, DensityEstimate
# from dmpe.utils.env_utils.pmsm_utils import PMSM_penalty, ExcitingPMSM

In [None]:
from dmpe_params import get_target_distribution
from eval_dmpe import setup_env

In [None]:
env, penalty_function = setup_env(1000)

In [None]:
# points_per_dim = 22  # works for < 30?
# dim = 4
# grid_extend = 1.05
# bandwidth = 0.08

# target_distribution = get_target_distribution(
#     points_per_dim,
#     bandwidth,
#     grid_extend,
#     consider_action_distribution=True,
#     penalty_function=penalty_function
# )
# target_distribution.shape

In [None]:
system_name = "pmsm"

all_pmsm_results_by_metric = extract_results(
    lengths=lengths,
    raw_results_path=pathlib.Path("/home/hvater@uni-paderborn.de/projects/forks/DMPE/eval/scripts/pmsm/results/dmpe/"),
    algo_names=["NODE", "PM", "RLS"],
    interpolate_to_lengths=[False, False, False],
    extra_folders=None,
    metrics={
        "jsd": default_jsd,
        "ae": default_ae,
        "mcudsa": default_mcudsa, # partial(default_mcudsa, points_per_dim=30),
        "ksfc": default_ksfc, #partial(default_ksfc, points_per_dim=30, variance=0.1, eps=1e-6),
    },
    metric_params={
        "jsd": dict(
            points_per_dim=22,
            dim=4,
            grid_extend=1.05,
            bandwidth=0.08,
            penalty_function=penalty_function
        ),
        "ksfc": dict(points_per_dim=30, variance=0.1, eps=1e-6),
        "mcudsa": dict(points_per_dim=30),
    },
    penalty_function=penalty_function,
    
)
# with open("results/pmsm_results_jsd_only.pickle", "wb") as handle:
#     pickle.dump(all_pmsm_results_by_metric, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
for ca in all_pmsm_results_by_metric["PM"].keys():
    for rpm in all_pmsm_results_by_metric["PM"][ca].keys():
        print("Considers_actions:", ca)
        print("rpm:", rpm)
        NODE_results = all_pmsm_results_by_metric["NODE"][ca][rpm]
        RLS_results = all_pmsm_results_by_metric["RLS"][ca][rpm]
        PM_results = all_pmsm_results_by_metric["PM"][ca][rpm]
        
        # print(NODE_results.keys())
        
        fig = plot_metrics_by_sequence_length_for_all_algos(
            data_per_algo=[NODE_results, RLS_results, PM_results],
            lengths=lengths,
            algo_names=["$\mathrm{NODE-DMPE}$", "$\mathrm{RLS-DMPE}$", "$\mathrm{PM-DMPE}$"],
            use_log=False,
        );
        plt.show()
        
# plt.savefig(f"metrics_per_sequence_length_{system_name}.pdf")


In [None]:
dmpe_results_by_metric = all_pmsm_results_by_metric["dmpe"]
igoats_results_by_metric = all_pmsm_results_by_metric["igoats"]["interp"]
pm_dmpe_results_by_metric = all_pmsm_results_by_metric["perfect_model_dmpe"]
dmpe_results_by_metric.keys()

In [None]:
plot_metrics_by_sequence_length_for_all_algos(
    data_per_algo=[pm_dmpe_results_by_metric, dmpe_results_by_metric, igoats_results_by_metric],
    lengths=lengths,
    algo_names=["$\mathrm{PM-DMPE}$", "$\mathrm{DMPE}$", "$\mathrm{iGOATS}$"],
    use_log=False,
);
plt.savefig(f"metrics_per_sequence_length_{system_name}.pdf")