In [1]:
import os
import sys

module_path = os.path.abspath(os.path.join('..'))

if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
import jax
import jax.numpy as jnp
import jax.random as rnd
import jax.experimental.optimizers as opt
import jax.scipy as jsp

from tqdm import tqdm

from rationality import dynamics as dyn,\
                        objectives as obj,\
                        distributions as dst,\
                        controllers as ctl,\
                        simulate as sim, \
                        inference as inf

from typing import Tuple

In [3]:
save_data = False

In [4]:
key = rnd.PRNGKey(0)
prior_samples = 10000
trials = 100
horizon = 12 # try 10 to 12

percentile = 95

is_samples = 10000

svgd_samples = 16
svgd_bw = 'dynamic'
svgd_iters = 10000
svgd_opt = opt.adam(1e-4)


noise_style = 'fixed'
noise_scales = [0.2, 0.2, 0.02, 0.02, 0.02, 0.02]
noise_states = [0, 1, 2, 3, 4, 5]

inv_temps = jnp.array([0.0, 1.0, 10.0, 25.0, 50.0, 75.0, 100.0, 125.0, 150.0, 175.0, 200.0, 300.0, 400.0, 500.0, jnp.inf]) / 100
prior_ic_cov = jnp.diag(jnp.array([1e-1, 1e-1, 1e-3, 1e-2, 1e-2, 1e-4]) ** 2)

In [5]:
dt = 0.3 # try ~0.2

Q = jnp.eye(6)
R = 0.1 * jnp.eye(2)
Qf = 100 * jnp.eye(6)

ic = jnp.array([1.0, -1.0, 0.0, 0.0, 0.0, 0.0])
ic_cov = prior_ic_cov

In [6]:
objective = obj.quadratic(Q, R, Qf)
dynamics = dyn.crazyflie2d(dt)
linearized_dynamics = dyn.linear(*dyn.linearize(dynamics, jnp.zeros(6),
                                                jnp.array([dynamics.params.hover_force, 0.0]), 0))

prob = ctl.problem(linearized_dynamics, objective, horizon)

In [7]:
key, subkey = rnd.split(key)

n = prob.num_states
m = prob.num_inputs

mpc = ctl.lqr.create(prob)
prior_sim = sim.compile_simulation(prob, mpc)
prior_ics = jax.vmap(lambda k: rnd.multivariate_normal(k, ic, prior_ic_cov),
                     out_axes=-1)(rnd.split(subkey, trials))

prior_states, prior_inputs, prior_costs = jax.vmap(lambda x: sim.run(x, jnp.zeros((n, horizon)), prior_sim, prob,
                                                                     mpc), in_axes=1, out_axes=-1)(prior_ics)

In [8]:
key, subkey = rnd.split(key)

# prior_cov = jnp.diag(jnp.array([1e-2, 1e-5] * horizon) ** 2)
#
# prior_params = [dst.GaussianParams(jnp.pad(prior_inputs.mean(axis=2)[:, t:].flatten(order='F'),
#                                            (0, t * prob.num_inputs)), prior_cov) for t in range(horizon)]

prior_covs = jnp.stack([jsp.linalg.block_diag(jnp.cov(prior_inputs[:, t:, :].reshape((m * (horizon - t), trials), order='F')),
                                 0 * jnp.eye(t * m)) + 1e-11 * jnp.eye(horizon * m)
           for t in range(horizon)], axis=-1)

prior_params = [dst.GaussianParams(jnp.pad(prior_inputs.mean(axis=2)[:, t:].flatten(order='F'),
                                           (0, t * prob.num_inputs)), prior_covs[:, :, t])
                for t in range(horizon)]

isc = ctl.isc.create(prob, jnp.inf, is_samples, subkey, dst.GaussianPrototype(prob.num_inputs * horizon), prior_params)
svgdc = ctl.svgdc.create(prob, jnp.inf, subkey, svgd_bw, svgd_samples, dst.GaussianPrototype(prob.num_inputs * horizon),
                         prior_params, svgd_opt, svgd_iters)

In [9]:
est_noise = jnp.zeros((n, horizon, trials))

if noise_style.lower() == 'max':
    for state, scale in zip(noise_states, noise_scales):
        key, subkey = rnd.split(key)
        stddev = (scale * jnp.max(jnp.abs(prior_states[state, :, :])))
        est_noise = est_noise.at[state, :, :].set(stddev * rnd.normal(subkey, (horizon, trials)))
        print(f'Estimation Noise for state {state} is N(0, {stddev ** 2:.3f}).')

elif noise_style.lower() == 'fixed':
    for state, scale in zip(noise_states, noise_scales):
        key, subkey = rnd.split(key)
        est_noise = est_noise.at[state, :, :].set(scale * rnd.normal(subkey, (horizon, trials)))
        print(f'Estimation Noise for state {state} is N(0, {scale ** 2:.2e}).')
else:
    raise ValueError(f"Noise style must be one of: 'max', 'varying', 'fixed'")

Estimation Noise for state 0 is N(0, 4.00e-02).
Estimation Noise for state 1 is N(0, 4.00e-02).
Estimation Noise for state 2 is N(0, 4.00e-04).
Estimation Noise for state 3 is N(0, 4.00e-04).
Estimation Noise for state 4 is N(0, 4.00e-04).
Estimation Noise for state 5 is N(0, 4.00e-04).


In [10]:
key, subkey = rnd.split(key)

n = prob.num_states
m = prob.num_inputs

ic_samples = jax.vmap(lambda k: rnd.multivariate_normal(k, ic, ic_cov), out_axes=-1)(rnd.split(subkey, trials))

In [11]:
mpc_sim = sim.compile_simulation(prob, mpc)
isc_sim = sim.compile_simulation(prob, isc)
svgd_sim = sim.compile_simulation(prob, svgdc)

mpc_sim_with_noise = jax.jit(lambda ic_s, noise: mpc_sim(ic, noise, prob.params, mpc.params))

isc_sim_with_noise = jax.jit(lambda ic_s, inv_temp, key, noise:
                             isc_sim(ic, noise, prob.params, ctl.isc.ISCParams(inv_temp, key)))

svgdc_sim_with_noise = jax.jit(lambda ic_s, inv_temp, key, noise:
                                svgd_sim(ic, noise, prob.params,
                                         ctl.svgdc.SVGDCParams(inv_temp, key,
                                                               jnp.nan if svgd_bw == 'dynamic' else svgd_bw)))

In [12]:
def controller_stats(full_costs: jnp.ndarray, part_costs: jnp.ndarray,
                     percentile: float) -> Tuple[float, float, float, float]:
    full_cumm_costs = full_costs.sum(axis=0)
    full_inner_percentile = jnp.percentile(full_cumm_costs, percentile)
    full_selected = full_cumm_costs[full_cumm_costs <= full_inner_percentile]
    full_mean = full_selected.mean()
    full_std = full_selected.std()

    part_cumm_costs = part_costs.sum(axis=0)
    part_inner_percentile = jnp.percentile(part_cumm_costs, percentile)
    part_selected = part_cumm_costs[part_cumm_costs <= part_inner_percentile]
    part_mean = part_selected.mean()
    part_std = part_selected.std()

    return full_mean, full_std, part_mean, part_std


In [None]:
from IPython.display import clear_output

is_full_states = []
is_full_inputs = []
is_full_costs = []

is_part_states = []
is_part_inputs = []
is_part_costs = []

svgdc_full_states = []
svgdc_full_inputs = []
svgdc_full_costs = []

svgdc_part_states = []
svgdc_part_inputs = []
svgdc_part_costs = []

print(f'        name\t\t                part-obs        \t\tfull-obs        ')

mp_full_states, mp_full_inputs, mp_full_costs = jax.vmap(mpc_sim_with_noise, in_axes=(1, 2), out_axes=-1)(ic_samples, jnp.zeros((n, horizon, trials)))
mp_part_states, mp_part_inputs, mp_part_costs = jax.vmap(mpc_sim_with_noise, in_axes=(1, 2), out_axes=-1)(ic_samples, est_noise)

full_mean, full_std, part_mean, part_std = controller_stats(mp_full_costs, mp_part_costs, percentile)

print(f'{"lqr":^23}\t\t'
          f'{part_mean:>9.3f} ± {part_std:<9.4f}\t\t'
          f'{full_mean:>9.3f} ± {full_std:<9.4f}')


for inv_temp in tqdm(inv_temps):
    key, subkey = rnd.split(key)

    full_results = jax.vmap(isc_sim_with_noise, in_axes=(1, 0, 0, 2), out_axes=-1)(ic_samples,
                                                                                   inv_temp * jnp.ones(trials),
                                                                                   rnd.split(subkey, trials),
                                                                                   jnp.zeros((n, horizon, trials)))

    part_results = jax.vmap(isc_sim_with_noise, in_axes=(1, 0, 0, 2), out_axes=-1)(ic_samples,
                                                                                   inv_temp * jnp.ones(trials),
                                                                                   rnd.split(subkey, trials),
                                                                                   est_noise)

    is_full_states.append(full_results[0])
    is_full_inputs.append(full_results[1])
    is_full_costs.append(full_results[2])

    is_part_states.append(part_results[0])
    is_part_inputs.append(part_results[1])
    is_part_costs.append(part_results[2])

    full_mean, full_std, part_mean, part_std = controller_stats(is_full_costs[-1], is_part_costs[-1], percentile)
    name = f'isc-{inv_temp:.3}'

#     tqdm.write(f'{name:^23}\t\t'
#           f'{part_mean:>9.3f} ± {part_std:<9.4f}\t\t'
#           f'{full_mean:>9.3f} ± {full_std:<9.4f}')

    full_results = jax.vmap(svgdc_sim_with_noise, in_axes=(1, 0, 0, 2), out_axes=-1)(ic_samples,
                                                                                   inv_temp * jnp.ones(trials),
                                                                                   rnd.split(subkey, trials),
                                                                                   jnp.zeros((n, horizon, trials)))

    part_results = jax.vmap(svgdc_sim_with_noise, in_axes=(1, 0, 0, 2), out_axes=-1)(ic_samples,
                                                                                   inv_temp * jnp.ones(trials),
                                                                                   rnd.split(subkey, trials),
                                                                                   est_noise)

    svgdc_full_states.append(full_results[0])
    svgdc_full_inputs.append(full_results[1])
    svgdc_full_costs.append(full_results[2])

    svgdc_part_states.append(part_results[0])
    svgdc_part_inputs.append(part_results[1])
    svgdc_part_costs.append(part_results[2])

    full_mean, full_std, part_mean, part_std = controller_stats(svgdc_full_costs[-1], svgdc_part_costs[-1], percentile)
    name = f'svc-{inv_temp:.3}'

    tqdm.write(f'{name:^23}\t\t'
               f'{part_mean:>9.3f} ± {part_std:<9.4f}\t\t'
               f'{full_mean:>9.3f} ± {full_std:<9.4f}')

        name		                part-obs        		full-obs        
          lqr          		   13.370 ± 2.9278   		    6.245 ± 0.0000   


  0%|          | 0/12 [00:00<?, ?it/s]

In [None]:
# if save_data:
#     import os
#     import pandas as pd
#
#     def make_df(states: jnp.ndarray, inputs: jnp.ndarray, costs: jnp.ndarray,
#                 name: str, beta: float, vis: str) -> pd.DataFrame:
#         cumm_costs = costs.sum(axis=0)
#
#         return pd.concat([pd.DataFrame({
#             'Controller' : pd.Series([name] * (horizon + 1)),
#             'Visibility' : pd.Series([vis] * (horizon + 1)),
#             'Discarded' : pd.Series([costs[:, trial].sum() > jnp.percentile(cumm_costs, percentile)] * (horizon + 1)),
#             'beta' : pd.Series([beta] * (horizon + 1)),
#             'Trial' : pd.Series([trial] * (horizon + 1)),
#             't' : pd.Series(list(range(horizon + 1))),
#             'x' : pd.Series(states[0, :, trial]),
#             'y' : pd.Series(states[1, :, trial]),
#             'theta' : pd.Series(states[2, :, trial]),
#             'x_dot' : pd.Series(states[3, :, trial]),
#             'y_dot' : pd.Series(states[4, :, trial]),
#             'theta_dot' : pd.Series(states[5, :, trial]),
#             'Thrust' : pd.Series(inputs[0, :, trial]),
#             'Moment' : pd.Series(inputs[1, :, trial]),
#             'Costs' : pd.Series(costs[:, trial]),
#             'Cumm. Costs' : pd.Series(costs[:, trial].cumsum())
#         }) for trial in range(trials)])
#
#     mp_full_df = make_df(mp_full_states, mp_full_inputs, mp_full_costs, 'LQR', jnp.nan, 'Full')
#     mp_part_df = make_df(mp_part_states, mp_part_inputs, mp_part_costs, 'LQR', jnp.nan, 'Part')
#
#     full_dfs = [mp_full_df] + [make_df(is_full_states[i], is_full_inputs[i], is_full_costs[i], 'ISC', inv_temp, 'Full')
#                                for i, inv_temp in enumerate(inv_temps)]
#
#     part_dfs = [mp_part_df] + [make_df(is_part_states[i], is_part_inputs[i], is_part_costs[i], 'ISC', inv_temp, 'Part')
#                                for i, inv_temp in enumerate(inv_temps)]
#
#     df = pd.concat(full_dfs + part_dfs)
#
#     if os.path.exists('../data/linearized_quad2d/data.pkl'):
#         os.remove('../data/linearized_quad2d/data.pkl')
#
#     df.to_pickle('../data/linearized_quad2d/data.pkl')

In [None]:
print('hi')