In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext line_profiler

In [None]:
import pathlib

from functools import partial
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

from copy import deepcopy
import pickle
import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import rc
# rc('font',**{'family':'serif','serif':['Helvetica']})
mpl.rcParams['text.usetex'] = True
mpl.rcParams.update({'font.size': 10 * 2.54})
mpl.rcParams['text.latex.preamble']=r"\usepackage{bm}\usepackage{amsmath}"

import jax
import jax.numpy as jnp
import jax_dataclasses as jdc
from jax.tree_util import tree_flatten, tree_unflatten

# jax.config.update('jax_platform_name', 'cpu')

# jax.config.update("jax_debug_nans", True)

gpus = jax.devices()
jax.config.update("jax_default_device", gpus[0])

import diffrax

In [None]:
from dmpe.evaluation.experiment_utils import extract_metrics_over_timesteps, extract_metrics_over_timesteps_via_interpolation
from dmpe.evaluation.plotting_utils import plot_metrics_by_sequence_length_for_all_algos
from dmpe.evaluation.experiment_utils import get_experiment_ids

from dmpe.utils.density_estimation import select_bandwidth
from dmpe.evaluation.experiment_utils import default_jsd, default_ae, default_mcudsa, default_ksfc

In [None]:
from math import e

In [None]:
from dmpe.evaluation.experiment_utils import get_organized_experiment_ids
from dmpe_params import get_target_distribution
from eval_dmpe import setup_env

full_results_path = "/home/hvater@uni-paderborn.de/projects/forks/DMPE/eval/pmsm/results/heuristics/current_plane_sweep"
organized_experiment_ids = get_organized_experiment_ids(full_results_path, force_consider_actions=True)

In [None]:
for rpm in organized_experiment_ids[True]:
    print(len(organized_experiment_ids[True][rpm]))

In [None]:
def extract_results(
    lengths,
    raw_results_path,
    algo_names,
    interpolate_to_lengths,
    metrics=None,
    metric_params=None,
    extra_folders=None,
    force_consider_actions=False,
):

    all_results_by_metric = {algo_name: {} for algo_name in algo_names}
    
    for (algo_name, use_interpolation) in zip(algo_names, interpolate_to_lengths):
        full_results_path = raw_results_path / pathlib.Path(algo_name)
        full_results_path = full_results_path / pathlib.Path(extra_folders) if extra_folders is not None else full_results_path

        print("Extract results for", algo_name, "\n at", full_results_path)

        organized_experiment_ids = get_organized_experiment_ids(full_results_path, force_consider_actions=force_consider_actions)

        for ca in organized_experiment_ids.keys():


            print("Momentary experiments consider action distribution:", ca)

            specific_metrics = {}
            for metric_name, metric_function in metrics.items():
                if metric_name == "jsd":
                    target_distribution = get_target_distribution(
                        metric_params[metric_name]["points_per_dim"],
                        metric_params[metric_name]["bandwidth"],
                        metric_params[metric_name]["grid_extend"],
                        ca,
                        metric_params[metric_name]["penalty_function"]
                    )
                    specific_metrics[metric_name] = partial(
                        metric_function,
                        points_per_dim=metric_params[metric_name]["points_per_dim"],
                        bandwidth=metric_params[metric_name]["bandwidth"],
                        target_distribution=target_distribution,
                        ca=ca,
                    )
                elif metric_name == "mcudsa" or metric_name == "kfsc":
                    specific_metrics[metric_name] = partial(
                        metric_function, ca=ca, **metric_params[metric_name]
                    )
                else:
                    specific_metrics[metric_name] = metric_function


            if ca not in all_results_by_metric[algo_name]:
                all_results_by_metric[algo_name][ca] = {}

            for rpm in organized_experiment_ids[ca].keys():
                print(f"Momentary experiments run at {int(rpm)} rpm.")
            
                if not use_interpolation:
                    all_results_by_metric[algo_name][ca][rpm] = extract_metrics_over_timesteps(
                        experiment_ids=organized_experiment_ids[ca][rpm],
                        results_path=full_results_path,
                        lengths=lengths,
                        metrics=specific_metrics,
                    )
                else:
                    all_results_by_metric[algo_name][ca][rpm] = extract_metrics_over_timesteps_via_interpolation(
                        experiment_ids=organized_experiment_ids[ca][rpm],
                        results_path=full_results_path,
                        target_lengths=lengths,
                        metrics=specific_metrics,
                    )
                print("\n")
    return all_results_by_metric

# PMSM:

In [None]:
full_column_width = 18.2
half_column_width = 8.89

def plot_metrics_by_sequence_length_for_all_algos(data_per_algo, lengths, algo_names, use_log=False, plot_log=False):
    assert len(data_per_algo) == len(algo_names), "Mismatch in number of algo results and number of algo names"

    metric_keys = list(data_per_algo[0].keys())

    metric_keys.remove("ae")



    fig, axs = plt.subplots(len(metric_keys), figsize=(full_column_width/3, 8), sharex=True) # figsize=(19, 18)
    colors = plt.rcParams["axes.prop_cycle"]()

    for algo_name, data in zip(algo_names, data_per_algo):
        
        c = next(colors)["color"]
        if c == '#d62728':
            c = next(colors)["color"]
        for metric_idx, metric_key in enumerate(metric_keys):
            
            mean = jnp.nanmean(jnp.log(data[metric_key]), axis=0) if use_log else jnp.nanmean(data[metric_key], axis=0)
            std = jnp.nanstd(jnp.log(data[metric_key]), axis=0) if use_log else jnp.nanstd(data[metric_key], axis=0)

            axs[metric_idx].plot(
                lengths,
                mean,  # jnp.log(mean) if use_log else mean,
                label=algo_name,
                color=c,
                #linestyle='dashed' if algo_name=="$\mathrm{DMPE}$" else None,
            )
            axs[metric_idx].fill_between(
                lengths,
                mean - std,  # jnp.log(mean - std) if use_log else mean - std,
                mean + std,  # jnp.log(mean + std) if use_log else mean + std,
                color=c,
                alpha=0.1,
            )
        # axs[metric_idx].set_ylabel(("log " if use_log else "") + metric_key)

    if plot_log:
        for ax in axs:
            ax.set_yscale('log', base=10)
    
        #axs[-1].set_yscale('log', base=10) if list(metric_keys)[-1] == "ksfc" else None

    for idx, metric_key in enumerate(metric_keys):
        axs[idx].set_ylabel(f"$\mathcal{{L}}_\mathrm{{{metric_key.upper()}}}$")
    
    axs[-1].set_xlabel("$k$")
    axs[-1].set_xlim(lengths[0], lengths[-1])
    [ax.grid(True, which="both") for ax in axs]
    axs[0].legend(prop={'size': 8 * 2.54}, framealpha=0.5)

    plt.subplots_adjust(hspace=0.02)
    
    plt.tight_layout(pad=0.05)

    [ax.tick_params(axis="y", direction='in') for ax in axs]
    [ax.tick_params(axis="x", direction='in') for ax in axs]
    # [ax.yaxis.set_major_locator(plt.MaxNLocator(3)) for ax in axs]

    fig.align_ylabels(axs)

    return fig

In [None]:
lengths = jnp.linspace(1000, 15000, 15, dtype=jnp.int32)
_, penalty_function = setup_env(1000)

lengths

# Extract:

## DMPE results:

In [None]:
system_name = "pmsm"

dmpe_pmsm_results_by_metric = extract_results(
    lengths=lengths,
    raw_results_path=pathlib.Path("/home/hvater@uni-paderborn.de/projects/forks/DMPE/eval/pmsm/results/dmpe/"),
    algo_names=["NODE", "PM", "RLS"],
    interpolate_to_lengths=[False, False, False],
    extra_folders=None,
    metrics={
        "jsd": default_jsd,
        "ae": default_ae,
        "mcudsa": default_mcudsa, # partial(default_mcudsa, points_per_dim=30),
        "ksfc": default_ksfc, #partial(default_ksfc, points_per_dim=30, variance=0.1, eps=1e-6),
    },
    metric_params={
        "jsd": dict(
            points_per_dim=22,
            dim=4,
            grid_extend=1.05,
            bandwidth=0.08,
            penalty_function=penalty_function
        ),
        "ksfc": dict(points_per_dim=30, variance=0.1, eps=1e-6),
        "mcudsa": dict(points_per_dim=30),
    },   
)


In [None]:
with open("results/quantitative_metrics_per_time/dmpe_pmsm_results.pickle", "wb") as handle:
    pickle.dump(dmpe_pmsm_results_by_metric, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
for ca in dmpe_pmsm_results_by_metric["RLS"].keys():
    for rpm in dmpe_pmsm_results_by_metric["RLS"][ca].keys():
        print("Considers_actions:", ca)
        print("rpm:", rpm)
        NODE_results = dmpe_pmsm_results_by_metric["NODE"][ca][rpm]
        RLS_results = dmpe_pmsm_results_by_metric["RLS"][ca][rpm]
        
        fig = plot_metrics_by_sequence_length_for_all_algos(
            data_per_algo=[NODE_results, RLS_results],
            lengths=lengths,
            algo_names=["$\mathrm{NODE-DMPE}$", "$\mathrm{RLS-DMPE}$"],
            use_log=False,
        );
        plt.show()

## iGOATS:

In [None]:
system_name = "pmsm"

igoats_pmsm_results_by_metric = extract_results(
    lengths=lengths,
    raw_results_path=pathlib.Path("/home/hvater@uni-paderborn.de/projects/forks/DMPE/eval/pmsm/results"),
    algo_names=["igoats"],
    interpolate_to_lengths=[False],
    extra_folders=None,
    metrics={
        "jsd": default_jsd,
         "ae": default_ae,
        "mcudsa": default_mcudsa, # partial(default_mcudsa, points_per_dim=30),
        "ksfc": default_ksfc, #partial(default_ksfc, points_per_dim=30, variance=0.1, eps=1e-6),
    },
    metric_params={
        "jsd": dict(
            points_per_dim=22,
            dim=4,
            grid_extend=1.05,
            bandwidth=0.08,
            penalty_function=penalty_function
        ),
        "ksfc": dict(points_per_dim=30, variance=0.1, eps=1e-6),
        "mcudsa": dict(points_per_dim=30),
    },
)


In [None]:
with open("results/quantitative_metrics_per_time/igoats_pmsm_results.pickle", "wb") as handle:
    pickle.dump(igoats_pmsm_results_by_metric, handle, protocol=pickle.HIGHEST_PROTOCOL)

## Heuristics:

In [None]:
system_name = "pmsm"

heuristics_pmsm_results_by_metric = extract_results(
    lengths=lengths,
    raw_results_path=pathlib.Path("/home/hvater@uni-paderborn.de/projects/forks/DMPE/eval/pmsm/results/heuristics"),
    algo_names=["random_walk", "current_plane_sweep"],
    interpolate_to_lengths=[False, False],
    extra_folders=None,
    metrics={
        "jsd": default_jsd,
         "ae": default_ae,
        "mcudsa": default_mcudsa, # partial(default_mcudsa, points_per_dim=30),
        "ksfc": default_ksfc, #partial(default_ksfc, points_per_dim=30, variance=0.1, eps=1e-6),
    },
    metric_params={
        "jsd": dict(
            points_per_dim=22,
            dim=4,
            grid_extend=1.05,
            bandwidth=0.08,
            penalty_function=penalty_function
        ),
        "ksfc": dict(points_per_dim=30, variance=0.1, eps=1e-6),
        "mcudsa": dict(points_per_dim=30),
    },
    force_consider_actions=True,
)


In [None]:
with open("results/quantitative_metrics_per_time/heuristics_pmsm_results.pickle", "wb") as handle:
    pickle.dump(heuristics_pmsm_results_by_metric, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
for ca in heuristics_pmsm_results_by_metric["random_walk"].keys():
    for rpm in heuristics_pmsm_results_by_metric["random_walk"][ca].keys():
        print("Considers_actions:", ca)
        print("rpm:", rpm)
        current_plane_sweep_results = heuristics_pmsm_results_by_metric["current_plane_sweep"][ca][rpm]
        random_walk_results = heuristics_pmsm_results_by_metric["random_walk"][ca][rpm]
        
        fig = plot_metrics_by_sequence_length_for_all_algos(
            data_per_algo=[current_plane_sweep_results, random_walk_results],
            lengths=lengths,
            algo_names=["$\mathrm{PIsweep}$", "$\mathrm{random walk}$"],
            use_log=False,
        );
        plt.show()

# Plot:

In [None]:
# load:

with open("results/quantitative_metrics_per_time/dmpe_pmsm_results.pickle", 'rb') as handle:
    dmpe_pmsm_results_by_metric = pickle.load(handle)

with open("results/quantitative_metrics_per_time/heuristics_pmsm_results.pickle", 'rb') as handle:
    heuristics_pmsm_results_by_metric = pickle.load(handle)

with open("results/quantitative_metrics_per_time/igoats_pmsm_results.pickle", 'rb') as handle:
    igoats_pmsm_results_by_metric = pickle.load(handle)

## all together:

In [None]:
for rpm in heuristics_pmsm_results_by_metric["random_walk"][True].keys():
    print("Considers_actions:", True)
    print("rpm:", rpm)
    
    current_plane_sweep_results = heuristics_pmsm_results_by_metric["current_plane_sweep"][True][rpm]
    random_walk_results = heuristics_pmsm_results_by_metric["random_walk"][True][rpm]

    NODE_results = dmpe_pmsm_results_by_metric["NODE"][True][rpm]
    RLS_results = dmpe_pmsm_results_by_metric["RLS"][True][rpm]

    fig = plot_metrics_by_sequence_length_for_all_algos(
        data_per_algo=[current_plane_sweep_results, random_walk_results, NODE_results, RLS_results],
        lengths=lengths,
        algo_names=["$\mathrm{PI-sweep}$", "$\mathrm{random-walk}$", "$\mathrm{NODE-DMPE}$", "$\mathrm{RLS-DMPE}$"],
        use_log=False,
        plot_log=True,
    );
    plt.savefig(f"results/quantitative_metrics_per_time/{rpm}_all_algos_all_metrics.pdf")
    plt.savefig(f"results/quantitative_metrics_per_time/{rpm}_all_algos_all_metrics.png", dpi=200)
    plt.show()

In [None]:
for rpm in heuristics_pmsm_results_by_metric["random_walk"][True].keys():
    print("Considers_actions:", True)
    print("rpm:", rpm)
    
    current_plane_sweep_results = heuristics_pmsm_results_by_metric["current_plane_sweep"][True][rpm]
    random_walk_results = heuristics_pmsm_results_by_metric["random_walk"][True][rpm]

    NODE_results = deepcopy(dmpe_pmsm_results_by_metric["NODE"][True][rpm])

    # exclude crashed runs:
    if rpm == 7000:
        for key in NODE_results.keys():
            NODE_results[key] = jnp.concatenate([NODE_results[key][:5],  NODE_results[key][7:]], axis=0)
    elif rpm == 9000:
        for key in NODE_results.keys():
            NODE_results[key] =  jnp.concatenate([NODE_results[key][:2], NODE_results[key][3:4], NODE_results[key][7:]], axis=0)

    print(len(NODE_results["jsd"]))
    
    RLS_results = dmpe_pmsm_results_by_metric["RLS"][True][rpm]

    fig = plot_metrics_by_sequence_length_for_all_algos(
        data_per_algo=[current_plane_sweep_results, random_walk_results, NODE_results, RLS_results],
        lengths=lengths,
        algo_names=["$\mathrm{PI-sweep}$", "$\mathrm{random-walk}$", "$\mathrm{NODE-DMPE}$", "$\mathrm{RLS-DMPE}$"],
        use_log=False,
        plot_log=True,
    );
    plt.savefig(f"results/quantitative_metrics_per_time/{rpm}_all_algos_all_metrics_Excluding_crashes.pdf")
    plt.savefig(f"results/quantitative_metrics_per_time/{rpm}_all_algos_all_metrics_Excluding_crashes.png", dpi=200)
    plt.show()

## same figure:

In [None]:
import matplotlib.ticker as ticker

def custom_formatter(val, pos):
    if val < 1.0:
        return rf"${val:.2f}$"  
    else:
        return rf"${val:.1f}$"  
        

def plot_metrics_by_sequence_length_for_all_algos_for_all_rpm(data_per_algo_with_rpm, lengths, algo_names, use_log=False, plot_log=False):
    assert len(data_per_algo_with_rpm) == len(algo_names), "Mismatch in number of algo results and number of algo names"

    rpms = list(data_per_algo_with_rpm[0].keys())
    metric_keys = list(data_per_algo_with_rpm[0][rpms[0]].keys())
    #metric_keys.remove("ae")
    
    fig, axs = plt.subplots(len(metric_keys), len(rpms), figsize=(full_column_width, 14), sharex=True)#, sharey="row") # figsize=(19, 18)

    for rpm_idx, rpm in enumerate(rpms):
        data_per_algo = [element[rpm] for element in data_per_algo_with_rpm]
        colors = plt.rcParams["axes.prop_cycle"]()

        for algo_name, data in zip(algo_names, data_per_algo):
            
            c = next(colors)["color"]
            if c == '#d62728':
                c = next(colors)["color"]
            for metric_idx, metric_key in enumerate(metric_keys):
                
                mean = jnp.nanmean(jnp.log(data[metric_key]), axis=0) if use_log else jnp.nanmean(data[metric_key], axis=0)
                std = jnp.nanstd(jnp.log(data[metric_key]), axis=0) if use_log else jnp.nanstd(data[metric_key], axis=0)


                if algo_name=="$\mathrm{RLS-DMPE}$":
                    style = "dashed"
                elif algo_name=="$\mathrm{iGOATS}$":
                    style = "dashdot"
                elif algo_name=="$\mathrm{PM-DMPE}$":
                    style = "dotted"
                else:
                    style=None
                
                axs[metric_idx, rpm_idx].plot(
                    lengths * 1e-4,
                    mean,  # jnp.log(mean) if use_log else mean,
                    label=algo_name if metric_idx == 0 and rpm_idx == 0 else None,
                    color=c,
                    linewidth=2.5,
                    linestyle=style,
                )
                axs[metric_idx, rpm_idx].fill_between(
                    lengths * 1e-4,
                    mean - std,  # jnp.log(mean - std) if use_log else mean - std,
                    mean + std,  # jnp.log(mean + std) if use_log else mean + std,
                    color=c,
                    alpha=0.1,
                )
            # axs[metric_idx].set_ylabel(("log " if use_log else "") + metric_key)
    # for ax in axs[-1]:
    #     ax.set_yscale('log', base=10) if list(metric_keys)[-1] == "ksfc" else None

    for idx, metric_key in enumerate(metric_keys):
        axs[idx, 0].set_ylabel(f"$\mathcal{{L}}_\mathrm{{{metric_key.upper()}}}$")

    for ax in axs[-1]:
        ax.set_xlabel("$t$ $\mathrm{in}$ $s$")
        ax.set_xlim(lengths[0] * 1e-4 - 0.02, lengths[-1] * 1e-4 + 0.02)
        ax.set_xticks((lengths[0] * 1e-4, 0.5, 1.0, 1.5))

        xtick_labels = ax.get_xticklabels()
        xtick_labels[0].set_ha('left')
        xtick_labels[-1].set_ha('right')

    for ax_ in axs:
        for ax in ax_:
            ax.grid(True, which="both", alpha=0.3)
            ax.tick_params(which='both', axis="y", direction='in')
            ax.tick_params(which='both', axis="x", direction='in') 
    
    legend = fig.legend(prop={'size': 8 * 2.54}, framealpha=0.5, loc="lower center", bbox_to_anchor=(0.5, -0.03), fancybox=True, shadow=False,  ncol=len(algo_names))
    # legend = axs[-1, 2].legend(prop={'size': 7 * 2.54}, framealpha=0.5, loc="upper center", bbox_to_anchor=(0.5, -0.35), fancybox=True, shadow=False,  ncol=len(algo_names))
    # legend.set_in_layout(False)

    for ax, col in zip(axs[0], ["$0$ $\mathrm{min}^{-1}$", "$3000$ $\mathrm{min}^{-1}$", "$5000$ $\mathrm{min}^{-1}$", "$7000$ $\mathrm{min}^{-1}$", "$9000$ $\mathrm{min}^{-1}$",]):
        ax.set_title(col)


    if plot_log:
        for ax_ in axs:
            for ax in ax_:
                ax.set_yscale('log', base=10)

        for idx_y, ax_ in enumerate(axs[:-1]):
            for idx_x, ax in enumerate(ax_):
                ax.yaxis.set_major_formatter(ticker.FuncFormatter(custom_formatter))
                ax.yaxis.set_minor_formatter(ticker.FuncFormatter(custom_formatter))
                ax.yaxis.set_major_locator(ticker.LogLocator(numticks=2))

                if idx_y == 0 and idx_x == 4:
                    pass
                elif idx_y > 0:
                    ax.yaxis.set_minor_locator(ticker.LogLocator(subs=(0.5,), numticks=4))
                else:
                    ax.yaxis.set_minor_locator(ticker.LogLocator(subs=(0.3, 0.6), numticks=4))
                #ax.yaxis.set_minor_locator(ticker.LogLocator(subs="auto"))
    
    # axs[2, -1].set_yticks((0.4, 1.2))

    #axs[0, -1].set_yticks((0.35, 0.65))

    # plt.subplots_adjust(hspace=0.1, wspace=0.0)

    # fig.set_tight_layout(True)
    fig.tight_layout(h_pad=-0.1, w_pad=0.35)
    fig.align_ylabels(axs)

    return fig, legend

In [None]:
current_plane_sweep_results = heuristics_pmsm_results_by_metric["current_plane_sweep"][True]
random_walk_results = heuristics_pmsm_results_by_metric["random_walk"][True]

NODE_results = deepcopy(dmpe_pmsm_results_by_metric["NODE"][True])
RLS_results = dmpe_pmsm_results_by_metric["RLS"][True]
PM_results = dmpe_pmsm_results_by_metric["PM"][True]


iGOATS_results = igoats_pmsm_results_by_metric["igoats"][True]


for rpm in NODE_results.keys():
    # exclude crashed runs:    
    if rpm == 7000:
        for key in NODE_results[rpm].keys():
            NODE_results[rpm][key] = jnp.concatenate([NODE_results[rpm][key][:5],  NODE_results[rpm][key][7:]], axis=0)
    elif rpm == 9000:
        for key in NODE_results[rpm].keys():
            NODE_results[rpm][key] =  jnp.concatenate([NODE_results[rpm][key][:2], NODE_results[rpm][key][3:5], NODE_results[rpm][key][7:]], axis=0)

    print(len(NODE_results[rpm]["jsd"]))

fig, legend = plot_metrics_by_sequence_length_for_all_algos_for_all_rpm(        
    data_per_algo_with_rpm=[current_plane_sweep_results, random_walk_results, NODE_results, RLS_results, iGOATS_results, PM_results],
    lengths=lengths,
    algo_names=["$\mathrm{PI-sweep}$", "$\mathrm{random-walk}$", "$\mathrm{NODE-DMPE}$", "$\mathrm{RLS-DMPE}$", "$\mathrm{iGOATS}$",  "$\mathrm{PM-DMPE}$"],
    use_log=False,
    plot_log=True,
);

# plt.show()
plt.savefig("results/quantitative_metrics_per_time/all_metrics_all_rpm_all_algos.pdf", bbox_inches='tight')

## stuff

In [None]:
for ca in all_pmsm_results_by_metric["PM"].keys():
    for rpm in all_pmsm_results_by_metric["PM"][ca].keys():
        print("Considers_actions:", ca)
        print("rpm:", rpm)
        NODE_results = all_pmsm_results_by_metric["NODE"][ca][rpm]
        RLS_results = all_pmsm_results_by_metric["RLS"][ca][rpm]
        PM_results = all_pmsm_results_by_metric["PM"][ca][rpm]
        
        # print(NODE_results.keys())
        
        fig = plot_metrics_by_sequence_length_for_all_algos(
            data_per_algo=[NODE_results, RLS_results, PM_results],
            lengths=lengths,
            algo_names=["$\mathrm{NODE-DMPE}$", "$\mathrm{RLS-DMPE}$", "$\mathrm{PM-DMPE}$"],
            use_log=False,
        );
        plt.show()
        
# plt.savefig(f"metrics_per_sequence_length_{system_name}.pdf")


In [None]:
dmpe_results_by_metric = all_pmsm_results_by_metric["dmpe"]
igoats_results_by_metric = all_pmsm_results_by_metric["igoats"]["interp"]
pm_dmpe_results_by_metric = all_pmsm_results_by_metric["perfect_model_dmpe"]
dmpe_results_by_metric.keys()

In [None]:
plot_metrics_by_sequence_length_for_all_algos(
    data_per_algo=[pm_dmpe_results_by_metric, dmpe_results_by_metric, igoats_results_by_metric],
    lengths=lengths,
    algo_names=["$\mathrm{PM-DMPE}$", "$\mathrm{DMPE}$", "$\mathrm{iGOATS}$"],
    use_log=False,
);
plt.savefig(f"metrics_per_sequence_length_{system_name}.pdf")