In [1]:
import os
import sys

module_path = os.path.abspath(os.path.join('../../'))

if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
import jax
import jax.numpy as jnp
import jax.random as rnd
import jax.experimental.optimizers as opt
import jax.scipy as jsp

import datetime
import shutil

import numpy as np

from tqdm import tqdm

from rationality import dynamics as dyn, objectives as obj, distributions as dst, \
    controllers as ctl, simulate as sim, util as utl, types as typ

In [None]:
key = rnd.PRNGKey(0)
trials = 1000
horizon = 12 # try 10 to 12

percentile = 100

svmpc_samples = 16
svmpc_bw = 'dynamic'
svmpc_iters = 2000
svmpc_opt = opt.adam(1e-1)
svmpc_clip = 1.0
svmpc_clip_ord = jnp.inf

mpc_opt = opt.adam(1e-1)
mpc_iters = 2000

noise_style = 'fixed'

relative_noise_stds = jnp.array([0.25, 0.25, 0.1, 0.25, 0.25, 0.1])
noise_scale_coeffs = jnp.array([3.0, 2.0, 1.0, 0.0])

inv_temps = jnp.concatenate([jnp.array([0.0]), 10 ** (jnp.linspace(3, 6, 30)), jnp.array([jnp.inf])])

In [None]:
dt = 0.3  # try ~0.2

Q = 1.0 * jnp.eye(6)
R = 0.1 * jnp.eye(2)
Qf = 100 * jnp.eye(6)

ic = jnp.array([1.0, -1.0, 0.0, 0.0, 0.0, 0.0])
ic_cov = jnp.diag(jnp.array([1e-1, 1e-1, 1e-3, 1e-2, 1e-2, 1e-4]) ** 2)

In [None]:
dynamics = dyn.quad2d(dt, 1.0, 9.82, 1.0)
hover_input = jnp.array([dynamics.params.hover_force, 0.0])

objective = obj.quadratic(Q, R, Qf, input_offset=hover_input)
prob = ctl.problem(dynamics, objective, horizon)

In [None]:
n = prob.num_states
m = prob.num_inputs

mpc_initial_inputs = jnp.zeros((m, horizon)) + hover_input.reshape((-1, 1))

mpc = ctl.mpc.create(prob, mpc_opt, mpc_iters, initial_inputs=mpc_initial_inputs)
prior_sim = sim.compile_simulation(prob, mpc)

In [None]:
linearized_dynamics = dyn.linear(*dyn.linearize(dynamics, jnp.zeros(6), hover_input, 0))

lin_prob = ctl.problem(linearized_dynamics, objective, horizon)

prior_params = ctl.lqr.input_stats(lin_prob, ic, ic_cov)

In [None]:
prior_covs = jnp.stack([1e-12 * jnp.eye(horizon * m) + jnp.diag(jnp.pad(jnp.concatenate([jnp.diag(prior_params.cov[s, :, :]) for s in range(t, horizon)]), ((0, m * t)))) for t in range(horizon)])
prior_covs = 1 * prior_covs

prior_params = [dst.GaussianParams(jnp.pad((prior_params.mean[t:, :] + hover_input.reshape((1, -1))).T.flatten(order='F'),
                                           (0, t * m)), prior_covs[t, :, :]) for t in range(horizon)]

In [None]:
svmpc = ctl.svmpc.create(prob, jnp.inf, svmpc_bw, svmpc_samples,
                         dst.GaussianPrototype(prob.num_inputs * horizon),
                         prior_params, svmpc_opt, svmpc_iters,
                         clip=svmpc_clip, clip_ord=svmpc_clip_ord)

In [None]:
relative_noise_cov = jnp.diag(relative_noise_stds ** 2)

In [None]:
key, ic_subkey, noise_subkey = rnd.split(key, 3)
ic_samples = jax.vmap(lambda k: rnd.multivariate_normal(k, ic, ic_cov), out_axes=-1)(rnd.split(ic_subkey, trials))
relative_noise_samples = jax.vmap(lambda nc: rnd.multivariate_normal(key, jnp.zeros(6), relative_noise_cov, shape=(horizon,)).T, out_axes=-1)(rnd.split(noise_subkey, trials))

In [None]:
mpc_sim = sim.compile_simulation(prob, mpc)
svmpc_sim = sim.compile_simulation(prob, svmpc)

mpc_sim_with_noise = jax.jit(lambda ic, noise: mpc_sim(ic, noise))
svmpc_sim_with_noise = jax.jit(lambda ic, inv_temp, key, noise: svmpc_sim.run_with_params(ic, noise, ctl.svmpc.SVMPCParams(inv_temp), key))

In [None]:
def trajectory_stats(traj: typ.Trajectory, percentile: int) -> tuple[float, float]:
    costs = traj.costs

    cumm_costs = costs.sum(axis=1)
    inner_percentile = np.percentile(cumm_costs, percentile)
    selected = cumm_costs[cumm_costs <= inner_percentile]
    mean = selected.mean()
    std = selected.std()

    return mean, std


In [None]:
def format_meanstd(mean: float, std: float) -> str:
    return f'{mean:>10.3f} ± {std:<10.3f}'

def format_row(name: str, part_obs: str, full_obs: str) -> str:
    return f'\t{name:^23} | {part_obs:^23} | {full_obs:^23} '

In [None]:
print(f'Inverse Temperatures: {jnp.array_str(inv_temps, precision=3)}\n\n')

for j, noise_scale_coeff in enumerate(noise_scale_coeffs):
    if j == 0:
        key, subkey = rnd.split(key)
        mpc_full = jax.vmap(mpc_sim_with_noise, in_axes=(1, 2))(ic_samples, jnp.zeros((n, horizon, trials)))
        svmpc_inf_full = jax.vmap(svmpc_sim_with_noise, in_axes=(1, 0, 0, 2))(ic_samples,
                                                                              jnp.inf * jnp.ones(trials),
                                                                              rnd.split(subkey, trials),
                                                                              jnp.zeros((n, horizon, trials)))

        print(f'Fully-Observable MPC Mean Terminal State:       {jnp.array_str(mpc_full.states.mean(axis=0)[:, -1])}')
        print(f'Fully-Observable SVMPC-INF Mean Terminal State: {jnp.array_str(svmpc_inf_full.states.mean(axis=0)[:, -1])}')

    print(f'Conducting Experiments for Noise Scale Coeff {noise_scale_coeff:.2f}\n\n\n')
    noise_samples = noise_scale_coeff * relative_noise_samples
    full_file_name = f'data/noise-scale-coeff-{noise_scale_coeff:.2f}.npz'
    tmp_file_name = f'data/noise-scale-coeff-{noise_scale_coeff:.2f}-in-progress.npz'

    svmpc_part = []

    mpc_part = jax.vmap(mpc_sim_with_noise, in_axes=(1, 2))(ic_samples, noise_samples)
    part_mean, part_std = trajectory_stats(mpc_part, percentile)


    header = format_row('β', 'part-obs', 'full-obs')
    print('\n' * 3)
    print(header)
    print('\t' + '-' * len(header))
    print(format_row('mpc', format_meanstd(part_mean, part_std), format_meanstd(0.0, 0.0)))

    partial_means = []



    for i, inv_temp in tqdm(enumerate(inv_temps), total=len(inv_temps), position=0, leave=True):
        key, subkey = rnd.split(key)

        # full_results = jax.vmap(svmpc_sim_with_noise, in_axes=(1, 0, 0, 2))(ic_samples,
        #                                                                     inv_temp * jnp.ones(trials),
        #                                                                     rnd.split(subkey, trials),
        #                                                                     jnp.zeros((n, horizon, trials)))

        part_results = jax.vmap(svmpc_sim_with_noise, in_axes=(1, 0, 0, 2))(ic_samples,
                                                                            inv_temp * jnp.ones(trials),
                                                                            rnd.split(subkey, trials),
                                                                            noise_samples)

        # svmpc_full.append(full_results.asnumpy().structured())
        svmpc_part.append(part_results.asnumpy().structured())

        part_mean, part_std = trajectory_stats(part_results, percentile)
        name = f'svmpc-{inv_temp:<18.3e}'.strip()

        tqdm.write(format_row(name, format_meanstd(part_mean, part_std), format_meanstd(0.0, 0.0)))
        partial_means.append(part_mean)



        np.savez(tmp_file_name,
                 mpc=mpc_part.asnumpy().structured(),
                 svmpc=svmpc_part,
                 noise_scale_coeffs=noise_scale_coeffs,
                 relative_noise_stds=relative_noise_stds,
                 inv_temps=inv_temps[:i])

    partial_means = np.array(partial_means)

    #%%

    if not utl.in_ipynb():
        if shutil.which('gnuplot') is not None:
            import termplotlib as tpl

            print(f'\n' * 3)

            fig = tpl.figure()
            fig.plot(jnp.log(inv_temps), jnp.log(partial_means), width=80, height=25, xlabel='Log(β)',
                     title='Log Mean Part-Obs Cost (SVMPC)')
            fig.show()
        else:
            print("The program `gnuplot' is not installed. Skipping terminal plot.")
    else:
        import matplotlib.pyplot as plt

        plt.figure()
        plt.plot(jnp.log(inv_temps), jnp.log(partial_means), 'ok')
        plt.title('Mean Part-Obs Cost (SVMPC)')
        plt.ylabel('Log Mean Cost')
        plt.xlabel('Log(β)')
        plt.show()

    now = datetime.datetime.now()
    nowstr = now.strftime('%Y.%m.%d.%H.%M.%S')

    np.savez(full_file_name,
             mpc=mpc_part.asnumpy().structured(),
             svmpc=svmpc_part,
             noise_scale_coeffs=noise_scale_coeffs,
             relative_noise_stds=relative_noise_stds,
             inv_temps=inv_temps)

    os.remove(tmp_file_name)