In [1]:
import os
import sys

module_path = os.path.abspath(os.path.join('../../'))

if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
import jax
import jax.numpy as jnp
import jax.random as rnd
import jax.experimental.optimizers as opt
import jax.scipy as jsp
import datetime

import numpy as np

from tqdm import tqdm

from rationality import dynamics as dyn, objectives as obj, distributions as dst, \
    controllers as ctl, simulate as sim, util as utils

from typing import Tuple

In [3]:
save_data = False

In [4]:
key = rnd.PRNGKey(0)
prior_samples = 1000
trials = 100
horizon = 12  # try 10 to 12

percentile = 95

isc_samples = 10000

svmpc_samples = 16
svmpc_bw = 'dynamic'
svmpc_iters = 10000
svmpc_opt = opt.adam(1e-1)

mpc_opt = opt.adam(1e-0)
mpc_iters = 1000

noise_style = 'fixed'

relative_noise_stds = jnp.array([0.1, 0.1, 0.01, 0.01, 0.01, 0.01])
noise_states = [0, 1, 2, 3, 4, 5]
noise_scale_coeff = 1.5

inv_temps = jnp.concatenate([jnp.array([jnp.inf]), jnp.exp(jnp.linspace(0, 12.5, 20))])
prior_ic_cov = jnp.diag(jnp.array([1e-1, 1e-1, 1e-3, 1e-2, 1e-2, 1e-4]) ** 2)

In [5]:
dt = 0.3  # try ~0.2

Q = 1 * jnp.eye(6)
R = 0.1 * jnp.eye(2)
Qf = 100 * jnp.eye(6)

ic = jnp.array([1.0, -1.0, 0.0, 0.0, 0.0, 0.0])
ic_cov = prior_ic_cov

In [6]:
objective = obj.quadratic(Q, R, Qf, input_offset=jnp.array([9.82, 0.0]))
dynamics = dyn.quad2d(dt, 1.0, 9.82, 1.0)
prob = ctl.problem(dynamics, objective, horizon)

In [7]:
key, subkey = rnd.split(key)

n = prob.num_states
m = prob.num_inputs

mpc = ctl.mpc.create(prob, mpc_opt, mpc_iters)
prior_sim = sim.compile_simulation(prob, mpc)

In [8]:
prior_ics = jax.vmap(lambda k: rnd.multivariate_normal(k, ic, prior_ic_cov),
                     out_axes=-1)(rnd.split(subkey, trials))

prior_states, prior_inputs, prior_costs = jax.vmap(lambda x: sim.run(x, jnp.zeros((n, horizon)), prior_sim, prob,
                                                                     mpc), in_axes=1, out_axes=-1)(prior_ics)

In [9]:
key, subkey = rnd.split(key)

prior_cov = jnp.diag(jnp.array([1e-2, 1e-5] * horizon) ** 2)

prior_params = [dst.GaussianParams(jnp.pad(prior_inputs.mean(axis=2)[:, t:].flatten(order='F'),
                                           (0, t * prob.num_inputs)), prior_cov) for t in range(horizon)]

prior_covs = jnp.stack(
    [jsp.linalg.block_diag(jnp.cov(prior_inputs[:, t:, :].reshape((m * (horizon - t), trials), order='F')),
                           0 * jnp.eye(t * m)) + 1e-8 * jnp.eye(horizon * m)
     for t in range(horizon)], axis=-1)

prior_params = [dst.GaussianParams(jnp.pad(prior_inputs.mean(axis=2)[:, t:].flatten(order='F'),
                                           (0, t * prob.num_inputs)), prior_covs[:, :, t])
                for t in range(horizon)]

isc = ctl.isc.create(prob, jnp.inf, isc_samples, subkey, dst.GaussianPrototype(prob.num_inputs * horizon), prior_params)
svmpc = ctl.svmpc.create(prob, jnp.inf, subkey, svmpc_bw, svmpc_samples,
                         dst.GaussianPrototype(prob.num_inputs * horizon),
                         prior_params, svmpc_opt, svmpc_iters)

In [10]:
est_noise = jnp.zeros((n, horizon, trials))

if noise_style.lower() == 'max':
    for state, scale in zip(noise_states, relative_noise_stds):
        key, subkey = rnd.split(key)
        stddev = (scale * jnp.max(jnp.abs(prior_states[state, :, :])))
        est_noise = est_noise.at[state, :, :].set(stddev * rnd.normal(subkey, (horizon, trials)))
        print(f'Estimation Noise for state {state} is N(0, {stddev ** 2:.3f}).')

elif noise_style.lower() == 'fixed':
    for state, rel_scale in zip(noise_states, relative_noise_stds):
        key, subkey = rnd.split(key)
        stddev = noise_scale_coeff * rel_scale
        est_noise = est_noise.at[state, :, :].set(stddev * rnd.normal(subkey, (horizon, trials), ))
        print(f'Estimation Noise for state {state} is N(0, {stddev:.2e} ** 2).')
else:
    raise ValueError(f"Noise style must be one of: 'max', 'varying', 'fixed'")

In [11]:
key, subkey = rnd.split(key)

n = prob.num_states
m = prob.num_inputs

ic_samples = jax.vmap(lambda k: rnd.multivariate_normal(k, ic, ic_cov), out_axes=-1)(rnd.split(subkey, trials))

In [12]:
mpc_sim = sim.compile_simulation(prob, mpc)
isc_sim = sim.compile_simulation(prob, isc)
svmpc_sim = sim.compile_simulation(prob, svmpc)

mpc_sim_with_noise = jax.jit(lambda ic_samples, noise: mpc_sim(ic_samples, noise, prob.params, mpc.params))

isc_sim_with_noise = jax.jit(lambda ic_samples, inv_temp, key, noise:
                             isc_sim(ic_samples, noise, prob.params, ctl.isc.ISCParams(inv_temp, key)))

svmpc_sim_with_noise = jax.jit(lambda ic_samples, inv_temp, key, noise:
                               svmpc_sim(ic_samples, noise, prob.params,
                                         ctl.svmpc.SVMPCParams(inv_temp, key,
                                                               jnp.nan if svmpc_bw == 'dynamic' else svmpc_bw)))

In [13]:
def controller_stats(full_costs: np.ndarray, part_costs: np.ndarray,
                     percentile: float) -> tuple[float, float, float, float]:
    full_cumm_costs = full_costs.sum(axis=0)
    full_inner_percentile = np.percentile(full_cumm_costs, percentile)
    full_selected = full_cumm_costs[full_cumm_costs <= full_inner_percentile]
    full_mean = full_selected.mean()
    full_std = full_selected.std()

    part_cumm_costs = part_costs.sum(axis=0)
    part_inner_percentile = np.percentile(part_cumm_costs, percentile)
    part_selected = part_cumm_costs[part_cumm_costs <= part_inner_percentile]
    part_mean = part_selected.mean()
    part_std = part_selected.std()

    return full_mean, full_std, part_mean, part_std


In [14]:
import numpy as np

def format_meanstd(mean: float, std: float) -> str:
    return f'{mean:>10.3f} ± {std:<10.3f}'

def format_row(name: str, part_obs: str, full_obs: str) -> str:
    return f'\t{name:^23} | {part_obs:^23} | {full_obs:^23} '

header = format_row('β', 'part-obs', 'full-obs')
print()
print()
print()
print(header)
print('\t' + '-' * len(header))


In [15]:
isc_full_states = []
isc_full_inputs = []
isc_full_costs = []

isc_part_states = []
isc_part_inputs = []
isc_part_costs = []

svmpc_full_states = []
svmpc_full_inputs = []
svmpc_full_costs = []

svmpc_part_states = []
svmpc_part_inputs = []
svmpc_part_costs = []

mpc_full_states, \
mpc_full_inputs, \
mpc_full_costs = jax.vmap(mpc_sim_with_noise, in_axes=(1, 2), out_axes=-1)(ic_samples, jnp.zeros((n, horizon, trials)))

mpc_part_states, \
mpc_part_inputs, \
mpc_part_costs = jax.vmap(mpc_sim_with_noise, in_axes=(1, 2), out_axes=-1)(ic_samples, est_noise)

full_mean, full_std, part_mean, part_std = controller_stats(mpc_full_costs, mpc_part_costs, percentile)

# print(f'{"mpc":^23}\t\t'
#       f'{part_mean:>9.3f} ± {part_std:<9.4f}\t\t'
#       f'{full_mean:>9.3f} ± {full_std:<9.4f}')

print(format_row('mpc', format_meanstd(part_mean, part_std), format_meanstd(full_mean, full_std)))

partial_means = []

for i, inv_temp in tqdm(enumerate(inv_temps), total=len(inv_temps), position=0, leave=True):
    key, subkey = rnd.split(key)

    full_results = jax.vmap(isc_sim_with_noise, in_axes=(1, 0, 0, 2), out_axes=-1)(ic_samples,
                                                                                   inv_temp * jnp.ones(trials),
                                                                                   rnd.split(subkey, trials),
                                                                                   jnp.zeros((n, horizon, trials)))

    part_results = jax.vmap(isc_sim_with_noise, in_axes=(1, 0, 0, 2), out_axes=-1)(ic_samples,
                                                                                   inv_temp * jnp.ones(trials),
                                                                                   rnd.split(subkey, trials),
                                                                                   est_noise)

    isc_full_states.append(full_results[0])
    isc_full_inputs.append(full_results[1])
    isc_full_costs.append(full_results[2])

    isc_part_states.append(part_results[0])
    isc_part_inputs.append(part_results[1])
    isc_part_costs.append(part_results[2])

    # full_mean, full_std, part_mean, part_std = controller_stats(isc_full_costs[-1], isc_part_costs[-1], percentile)
    # name = f'isc-{inv_temp:.3e}'
    #
    # #     tqdm.write(f'{name:^23}\t\t'
    # #           f'{part_mean:>9.3f} ± {part_std:<9.4f}\t\t'
    # #           f'{full_mean:>9.3f} ± {full_std:<9.4f}')

    full_results = jax.vmap(svmpc_sim_with_noise, in_axes=(1, 0, 0, 2), out_axes=-1)(ic_samples,
                                                                                     inv_temp * jnp.ones(trials),
                                                                                     rnd.split(subkey, trials),
                                                                                     jnp.zeros((n, horizon, trials)))

    part_results = jax.vmap(svmpc_sim_with_noise, in_axes=(1, 0, 0, 2), out_axes=-1)(ic_samples,
                                                                                     inv_temp * jnp.ones(trials),
                                                                                     rnd.split(subkey, trials),
                                                                                     est_noise)

    svmpc_full_states.append(np.asarray(full_results[0]))
    svmpc_full_inputs.append(np.asarray(full_results[1]))
    svmpc_full_costs.append(np.asarray(full_results[2]))

    svmpc_part_states.append(np.asarray(part_results[0]))
    svmpc_part_inputs.append(np.asarray(part_results[1]))
    svmpc_part_costs.append(np.asarray(part_results[2]))

    full_mean, full_std, part_mean, part_std = controller_stats(svmpc_full_costs[-1], svmpc_part_costs[-1], percentile)
    name = f'svmpc-{inv_temp:<18.3e}'.strip()

    tqdm.write(format_row(name, format_meanstd(part_mean, part_std), format_meanstd(full_mean, full_std)))
    partial_means.append(part_mean)

    np.savez(f'data/trajectories-in-progress.npz',
             mpc_full_states=mpc_full_states,
             mpc_full_inputs=mpc_full_inputs,
             mpc_full_costs=mpc_full_costs,

             mpc_part_states=mpc_part_states,
             mpc_part_inputs=mpc_part_inputs,
             mpc_part_costs=mpc_part_costs,

             svmpc_full_states=np.stack(svmpc_full_states, axis=-1),
             svmpc_full_inputs=np.stack(svmpc_full_inputs, axis=-1),
             svmpc_full_costs=np.stack(svmpc_full_costs, axis=-1),

             svmpc_part_states=np.stack(svmpc_part_states, axis=-1),
             svmpc_part_inputs=np.stack(svmpc_part_inputs, axis=-1),
             svmpc_part_costs=np.stack(svmpc_part_costs, axis=-1),

             isc_full_states=np.stack(isc_full_states, axis=-1),
             isc_full_inputs=np.stack(isc_full_inputs, axis=-1),
             isc_full_costs=np.stack(isc_full_costs, axis=-1),

             isc_part_states=np.stack(isc_part_states, axis=-1),
             isc_part_inputs=np.stack(isc_part_inputs, axis=-1),
             isc_part_costs=np.stack(isc_part_costs, axis=-1),

             inv_temps=inv_temps[:i + 1])


partial_means = np.array(partial_means)

In [16]:
import shutil

if not utils.in_ipynb():
    if shutil.which('gnuplot') is not None:
        import termplotlib as tpl

        fig = tpl.figure()
        fig.plot(jnp.log(inv_temps), jnp.log(partial_means), width=80, height=25, xlabel='Log(β)',
                 title='Log Mean Part-Obs Cost (SVMPC)')
        fig.show()
    else:
        print("The program `gnuplot' is not installed. Skipping terminal plot.")
else:
    import matplotlib.pyplot as plt

    plt.figure()
    plt.plot(jnp.log(inv_temps), jnp.log(partial_means), 'ok')
    plt.title('Mean Part-Obs Cost (SVMPC)')
    plt.ylabel('Log Mean Cost')
    plt.xlabel('Log(β)')
    plt.show()

In [17]:
now = datetime.datetime.now()
nowstr = now.strftime('%Y.%m.%d.%H.%M.%S')

jnp.savez(f'data/trajectories-{nowstr}.npz',
              mpc_full_states=mpc_full_states,
              mpc_full_inputs=mpc_full_inputs,
              mpc_full_costs=mpc_full_costs,

              mpc_part_states=mpc_part_states,
              mpc_part_inputs=mpc_part_inputs,
              mpc_part_costs=mpc_part_costs,

              svmpc_full_states=np.stack(svmpc_full_states, axis=-1),
              svmpc_full_inputs=np.stack(svmpc_full_inputs, axis=-1),
              svmpc_full_costs=np.stack(svmpc_full_costs, axis=-1),

              svmpc_part_states=np.stack(svmpc_part_states, axis=-1),
              svmpc_part_inputs=np.stack(svmpc_part_inputs, axis=-1),
              svmpc_part_costs=np.stack(svmpc_part_costs, axis=-1),

              isc_full_states=np.stack(isc_full_states, axis=-1),
              isc_full_inputs=np.stack(isc_full_inputs, axis=-1),
              isc_full_costs=np.stack(isc_full_costs, axis=-1),

              isc_part_states=np.stack(isc_part_states, axis=-1),
              isc_part_inputs=np.stack(isc_part_inputs, axis=-1),
              isc_part_costs=np.stack(isc_part_costs, axis=-1),

              inv_temps=inv_temps)

os.remove('data/trajectories-in-progress.npz')
