In [1]:
import jax
import jax.numpy as jnp

import pandas as pd

import os

In [2]:
def make_df(states: jnp.ndarray, inputs: jnp.ndarray, costs: jnp.ndarray,
            name: str, beta: float, vis: str, percentile: float) -> pd.DataFrame:
    horizon = inputs.shape[1]
    trials = states.shape[2]

    cumm_costs = costs.sum(axis=0)

    return pd.concat([pd.DataFrame({
        'Controller': pd.Series([name] * (horizon + 1)),
        'Visibility': pd.Series([vis] * (horizon + 1)),
        'Discarded': pd.Series([costs[:, trial].sum() > jnp.percentile(cumm_costs, percentile)] * (horizon + 1)),
        'beta': pd.Series([beta] * (horizon + 1)),
        'Trial': pd.Series([trial] * (horizon + 1)),
        't': pd.Series(list(range(horizon + 1))),
        'x': pd.Series(states[0, :, trial]),
        'y': pd.Series(states[1, :, trial]),
        'theta': pd.Series(states[2, :, trial]),
        'x_dot': pd.Series(states[3, :, trial]),
        'y_dot': pd.Series(states[4, :, trial]),
        'theta_dot': pd.Series(states[5, :, trial]),
        'Thrust': pd.Series(inputs[0, :, trial]),
        'Moment': pd.Series(inputs[1, :, trial]),
        'Costs': pd.Series(costs[:, trial]),
        'Cumm. Costs': pd.Series(costs[:, trial].cumsum())
    }) for trial in range(trials)])

In [3]:
percentile = 95.0

In [4]:
npzfile = jnp.load('data.npz')
locals().update(npzfile)

In [5]:
mp_full_df = make_df(mp_full_states, mp_full_inputs, mp_full_costs, 'LQR', jnp.nan, 'Full', percentile)
mp_part_df = make_df(mp_part_states, mp_part_inputs, mp_part_costs, 'LQR', jnp.nan, 'Part', percentile)

In [6]:
full_dfs = [mp_full_df] + [make_df(is_full_states[:, :, :, i], is_full_inputs[:, :, :, i], is_full_costs[:, :, i],
                                   'ISC', inv_temp, 'Full', percentile) for i, inv_temp in enumerate(inv_temps)] \
           + [make_df(svgdc_full_states[:, :, :, i], svgdc_full_inputs[:, :, :, i], svgdc_full_costs[:, :, i],
                      'ISC', inv_temp, 'Part', percentile) for i, inv_temp in enumerate(inv_temps)]

part_dfs = [mp_part_df] + [make_df(is_part_states[:, :, :, i], is_part_inputs[:, :, :, i], is_part_costs[:, :, i],
                                   'ISC', inv_temp, 'Part', percentile) for i, inv_temp in enumerate(inv_temps)] \
           + [make_df(svgdc_part_states[:, :, :, i], svgdc_part_inputs[:, :, :, i], svgdc_part_costs[:, :, i],
                      'ISC', inv_temp, 'Part', percentile) for i, inv_temp in enumerate(inv_temps)]

df = pd.concat(full_dfs + part_dfs)

if os.path.exists('dataframe.pkl'):
    os.remove('dataframe.pkl')

df.to_pickle('dataframe.pkl')