In [1]:
import os
import sys

module_path = os.path.abspath(os.path.join('..'))

if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
import jax
import jax.numpy as jnp
import jax.random as rnd
import jax.experimental.optimizers as opt

import tqdm

from rationality import dynamics as dyn,\
                        objectives as obj,\
                        distributions as dst,\
                        controllers as ctl,\
                        simulate as sim, \
                        inference as inf

from typing import Tuple

In [3]:
save_data = True

In [4]:
key = rnd.PRNGKey(0)
prior_samples = 10000
trials = 100
horizon = 6 # try 10 to 12

percentile = 95

is_samples = 100000

svgd_samples = 2
svgd_kernel = jax.jit(lambda a, b, c: inf.rbf_kernel(a, b, c, 1.0))
svgd_iters = 1000
svgd_opt = opt.adam(0.005)

noise_style = 'fixed'
noise_scales = [0.05, 0.05, 0.02, 0.1, 0.1, 0.1] #[0.0, 0.1, 0.075, 0.1, 0.1, 0.1]
noise_states = [0, 1, 2, 3, 4, 5] #[0, 1, 2, 3, 4, 5]

inv_temps = jnp.array([0.0, 0.1, 0.5, 1.0, 5.0, 10.0, 50.0, jnp.inf])
prior_ic_cov = jnp.diag(jnp.array([0.25, 0.25, 0.05, 0.1, 0.1, 1e-12]) ** 2)



In [5]:
dt = 0.3 # try ~0.2

Q = jnp.eye(6)
R = 0.1 * jnp.eye(2)
Qf = 100 * jnp.eye(6)

ic = jnp.array([1.0, -1.0, 0.0, 0.0, 0.0, 0.0])

In [6]:
objective = obj.quadratic(Q, R, Qf)
dynamics = dyn.crazyflie2d(dt)
linearized_dynamics = dyn.linear(*dyn.linearize(dynamics, jnp.zeros(6),
                                                jnp.array([dynamics.params.hover_force, 0.0]), 0))

prob = ctl.problem(linearized_dynamics, objective, horizon)

In [7]:
key, subkey = rnd.split(key)

n = prob.num_states
m = prob.num_inputs

mpc = ctl.lqr(prob)
prior_sim = sim.compile_simulation(prob, mpc)
prior_ics = jax.vmap(lambda k: rnd.multivariate_normal(k, ic, prior_ic_cov),
                     out_axes=-1)(rnd.split(subkey, trials))

prior_states, prior_inputs, prior_costs = jax.vmap(lambda x: sim.run(x, jnp.zeros((n, horizon)), prior_sim, prob,
                                                                     mpc), in_axes=1, out_axes=-1)(prior_ics)

In [8]:
prior_cov = jnp.diag(jnp.array([5e-2, 2e-5] * horizon) ** 2)

prior_params = [dst.GaussianParams(jnp.pad(prior_inputs.mean(axis=2)[:, t:].flatten(order='F'),
                                           (0, t * prob.num_inputs)), prior_cov) for t in range(horizon)]

isc = ctl.isc(prob, jnp.inf, is_samples, 0, dst.GaussianPrototype(prob.num_inputs), prior_params)

In [9]:
est_noise = jnp.zeros((n, horizon, trials))

if noise_style.lower() == 'max':
    for state, scale in zip(noise_states, noise_scales):
        key, subkey = rnd.split(key)
        stddev = (scale * jnp.max(jnp.abs(prior_states[state, :, :])))
        est_noise = est_noise.at[state, :, :].set(stddev * rnd.normal(subkey, (horizon, trials)))
        print(f'Estimation Noise for state {state} is N(0, {stddev ** 2:.3f}).')

elif noise_style.lower() == 'fixed':
    for state, scale in zip(noise_states, noise_scales):
        key, subkey = rnd.split(key)
        est_noise = est_noise.at[state, :, :].set(scale * rnd.normal(subkey, (horizon, trials)))
        print(f'Estimation Noise for state {state} is N(0, {scale ** 2:.2e}).')
else:
    raise ValueError(f"Noise style must be one of: 'max', 'varying', 'fixed'")

Estimation Noise for state 0 is N(0, 2.50e-03).
Estimation Noise for state 1 is N(0, 2.50e-03).
Estimation Noise for state 2 is N(0, 4.00e-04).
Estimation Noise for state 3 is N(0, 1.00e-02).
Estimation Noise for state 4 is N(0, 1.00e-02).
Estimation Noise for state 5 is N(0, 1.00e-02).


In [10]:
mpc_sim = sim.compile_simulation(prob, mpc)
isc_sim = sim.compile_simulation(prob, isc)

mpc_sim_with_noise = jax.jit(lambda noise: mpc_sim(ic, noise, prob.params, mpc.params))
isc_sim_with_noise = jax.jit(lambda inv_temp, key, noise:
                             isc_sim(ic, noise, prob.params, ctl.ISCParams(inv_temp, key)))

In [11]:
@jax.jit
def controller_stats(full_costs: jnp.ndarray, part_costs: jnp.ndarray) -> Tuple[float, float, float, float]:
    full_cumm_costs = full_costs.sum(axis=0)
    full_inner_percentile = jnp.percentile(full_cumm_costs, percentile)
    full_mean = jnp.take(full_cumm_costs, full_cumm_costs <= full_inner_percentile).mean()
    full_std = jnp.take(full_cumm_costs, full_cumm_costs <= full_inner_percentile).std()

    part_cumm_costs = part_costs.sum(axis=0)
    part_inner_percentile = jnp.percentile(part_cumm_costs, percentile)
    part_mean = jnp.take(part_cumm_costs, part_cumm_costs <= part_inner_percentile).mean()
    part_std = jnp.take(part_cumm_costs, part_cumm_costs <= part_inner_percentile).std()

    return full_mean, full_std, part_mean, part_std


In [12]:
is_full_states = []
is_full_inputs = []
is_full_costs = []


is_part_states = []
is_part_inputs = []
is_part_costs = []

print(f'        name\t\t                part-obs        \t\tfull-obs        ')

mp_full_states, mp_full_inputs, mp_full_costs = jax.vmap(mpc_sim_with_noise, in_axes=2, out_axes=-1)(jnp.zeros((n, horizon, trials)))
mp_part_states, mp_part_inputs, mp_part_costs = jax.vmap(mpc_sim_with_noise, in_axes=2, out_axes=-1)(est_noise)

full_mean, full_std, part_mean, part_std = controller_stats(mp_full_costs, mp_part_costs)

print(f'{"lqr":^23}\t\t'
          f'{part_mean:>9.3f} ± {part_std:<9.4f}\t\t'
          f'{full_mean:>9.3f} ± {full_std:<9.4f}')

for inv_temp in tqdm.tqdm(inv_temps):
    key, subkey = rnd.split(key)

    full_results = jax.vmap(isc_sim_with_noise, in_axes=(0, 0, 2), out_axes=-1)(inv_temp * jnp.ones(trials),
                                                                                rnd.split(subkey, trials),
                                                                                jnp.zeros((n, horizon, trials)))

    part_results = jax.vmap(isc_sim_with_noise, in_axes=(0, 0, 2), out_axes=-1)(inv_temp * jnp.ones(trials),
                                                                                rnd.split(subkey, trials), est_noise)

    is_full_states.append(full_results[0])
    is_full_inputs.append(full_results[1])
    is_full_costs.append(full_results[2])

    is_part_states.append(part_results[0])
    is_part_inputs.append(part_results[1])
    is_part_costs.append(part_results[2])

    full_mean, full_std, part_mean, part_std = controller_stats(is_full_costs[-1], is_part_costs[-1])
    name = f'isc-{inv_temp:.3}'

    print(f'{name:^23}\t\t'
          f'{part_mean:>9.3f} ± {part_std:<9.4f}\t\t'
          f'{full_mean:>9.3f} ± {full_std:<9.4f}')

        name		                part-obs        		full-obs        
          lqr          		   23.825 ± 2.7994   		    8.637 ± 0.0000   
        isc-0.0        		  303.048 ± 179.6336 		  303.048 ± 179.6335 
        isc-0.1        		   11.117 ± 2.4124   		   13.346 ± 1.3300   
        isc-0.5        		   16.605 ± 3.5161   		   10.979 ± 0.2451   
        isc-1.0        		   21.760 ± 0.3864   		   11.094 ± 0.3800   
        isc-5.0        		   17.654 ± 1.7343   		    9.435 ± 0.1308   
       isc-10.0        		   24.992 ± 3.5591   		    8.937 ± 0.0061   
       isc-50.0        		   20.793 ± 2.1092   		    9.444 ± 0.1337   
        isc-inf        		   21.654 ± 2.5415   		    9.610 ± 0.1201   


100%|██████████| 8/8 [13:55<00:00, 104.48s/it]


In [52]:
if save_data:
    import os
    import pandas as pd

    def make_df(states: jnp.ndarray, inputs: jnp.ndarray, costs: jnp.ndarray,
                name: str, beta: float, vis: str) -> pd.DataFrame:
        return pd.concat([pd.DataFrame({
            'Controller' : pd.Series([name] * (horizon + 1)),
            'Visibility' : pd.Series([vis] * (horizon + 1)),
            'beta' : pd.Series([beta] * (horizon + 1)),
            'Trial' : pd.Series([trial] * (horizon + 1)),
            't' : pd.Series(list(range(horizon + 1))),
            'x' : pd.Series(states[0, :, trial]),
            'y' : pd.Series(states[1, :, trial]),
            'theta' : pd.Series(states[2, :, trial]),
            'x_dot' : pd.Series(states[3, :, trial]),
            'y_dot' : pd.Series(states[4, :, trial]),
            'theta_dot' : pd.Series(states[5, :, trial]),
            'Thrust' : pd.Series(inputs[0, :, trial]),
            'Moment' : pd.Series(inputs[1, :, trial]),
            'Costs' : pd.Series(costs[:, trial])
        }) for trial in range(trials)])

    mp_full_df = make_df(mp_full_states, mp_full_inputs, mp_full_costs, 'LQR', jnp.nan, 'Full')
    mp_part_df = make_df(mp_part_states, mp_part_inputs, mp_part_costs, 'LQR', jnp.nan, 'Part')

    full_dfs = [mp_full_df] + [make_df(is_full_states[i], is_full_inputs[i], is_full_costs[i], 'ISC', inv_temp, 'Full')
                               for i, inv_temp in enumerate(inv_temps)]

    part_dfs = [mp_part_df] + [make_df(is_part_states[i], is_part_inputs[i], is_part_costs[i], 'ISC', inv_temp, 'Part')
                               for i, inv_temp in enumerate(inv_temps)]

    df = pd.concat(full_dfs + part_dfs)

    if os.path.exists('../data/linearized_quad2d/data.pkl'):
        os.remove('../data/linearized_quad2d/data.pkl')

    df.to_pickle('../data/linearized_quad2d/data.pkl')