In [1]:
import os
import sys

module_path = os.path.abspath(os.path.join('../../'))

if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
import jax
import jax.numpy as jnp
import jax.random as rnd
import jax.experimental.optimizers as opt
import jax.scipy as jsp

from tqdm import tqdm

import matplotlib.pyplot as plt
import matplotlib.font_manager as fm

from rationality import dynamics as dyn, objectives as obj, distributions as dst, \
    controllers as ctl, simulate as sim, util as utils

In [3]:
save_data = False

In [4]:
key = rnd.PRNGKey(0)
prior_samples = 10000
trials = 100
horizon = 12  # try 10 to 12

batch_size = 1000

noise_style = 'fixed'
noise_scales = [0.1, 0.1, 0.01, 0.01, 0.01, 0.01]
noise_states = [0, 1, 2, 3, 4, 5]

inv_temps = jnp.concatenate([jnp.exp(jnp.linspace(-7, 7, 100))])
prior_ic_cov = jnp.diag(jnp.array([1e-1, 1e-1, 1e-3, 1e-2, 1e-2, 1e-4]) ** 2)

In [5]:
dt = 0.3  # try ~0.2

Q = jnp.eye(6)
R = 0.1 * jnp.eye(2)
Qf = 100 * jnp.eye(6)

ic = jnp.array([1.0, -1.0, 0.0, 0.0, 0.0, 0.0])
ic_cov = prior_ic_cov

In [6]:
objective = obj.quadratic(Q, R, Qf)
dynamics = dyn.crazyflie2d(dt)
linearized_dynamics = dyn.linear(*dyn.linearize(dynamics, jnp.zeros(6),
                                                jnp.array([dynamics.params.hover_force, 0.0]), 0))

prob = ctl.problem(linearized_dynamics, objective, horizon)

In [7]:
key, subkey = rnd.split(key)

n = prob.num_states
m = prob.num_inputs

lqr = ctl.lqr.create(prob)
prior_sim = sim.compile_simulation(prob, lqr)
prior_ics = jax.vmap(lambda k: rnd.multivariate_normal(k, ic, prior_ic_cov),
                     out_axes=-1)(rnd.split(subkey, trials))

prior_states, prior_inputs, prior_costs = jax.vmap(lambda x: sim.run(x, jnp.zeros((n, horizon)), prior_sim, prob,
                                                                     lqr), in_axes=1, out_axes=-1)(prior_ics)

In [8]:
key, subkey = rnd.split(key)

# prior_cov = jnp.diag(jnp.array([1e-2, 1e-5] * horizon) ** 2)
#
# prior_params = [dst.GaussianParams(jnp.pad(prior_inputs.mean(axis=2)[:, t:].flatten(order='F'),
#                                            (0, t * prob.num_inputs)), prior_cov) for t in range(horizon)]

prior_covs = jnp.stack([jnp.cov(prior_inputs[:, t, :]) + 1e-11 * jnp.eye(m) for t in range(horizon)], axis=-1)

prior_params = [dst.GaussianParams(prior_inputs.mean(axis=2)[:, t].flatten(order='F'), prior_covs[:, :, t])
                for t in range(horizon)]

In [9]:
est_noise = jnp.zeros((n, horizon, trials))

if noise_style.lower() == 'max':
    for state, scale in zip(noise_states, noise_scales):
        key, subkey = rnd.split(key)
        stddev = (scale * jnp.max(jnp.abs(prior_states[state, :, :])))
        est_noise = est_noise.at[state, :, :].set(stddev * rnd.normal(subkey, (horizon, trials)))
        print(f'Estimation Noise for state {state} is N(0, {stddev ** 2:.3f}).')

elif noise_style.lower() == 'fixed':
    for state, scale in zip(noise_states, noise_scales):
        key, subkey = rnd.split(key)
        est_noise = est_noise.at[state, :, :].set(scale * rnd.normal(subkey, (horizon, trials)))
        print(f'Estimation Noise for state {state} is N(0, {scale ** 2:.2e}).')
else:
    raise ValueError(f"Noise style must be one of: 'max', 'varying', 'fixed'")

Estimation Noise for state 0 is N(0, 1.00e-02).
Estimation Noise for state 1 is N(0, 1.00e-02).
Estimation Noise for state 2 is N(0, 1.00e-04).
Estimation Noise for state 3 is N(0, 1.00e-04).
Estimation Noise for state 4 is N(0, 1.00e-04).
Estimation Noise for state 5 is N(0, 1.00e-04).


In [10]:
key, subkey = rnd.split(key)
lqbr = ctl.lqbr.create(prob, prior_params, 1.0, subkey)
lqbr_sim = sim.compile_simulation(prob, lqbr)

In [11]:
ctg_full_obs = jax.jit(lambda beta: ctl.lqbr.cost_to_go(prob, ctl.lqbr.LQBRParams(beta, subkey, lqbr.params.prior_params), ic, ic_cov))

noise_cov = jnp.transpose(jnp.tile(jnp.diag(jnp.array(noise_scales)), (horizon, 1, 1)), (1, 2, 0))
ctg_part_obs = jax.jit(lambda beta: ctl.lqbr.cost_to_go(prob, ctl.lqbr.LQBRParams(beta, subkey, lqbr.params.prior_params), ic, ic_cov, noise_cov))

In [12]:
@jax.jit
def vector_metric(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
    x = x.reshape((-1, 1))
    y = y.reshape((-1, 1))

    return 0.5 * jnp.linalg.norm(x @ x.T - y @ y.T, ord='fro') + jnp.linalg.norm(x - y, ord=2)

@jax.jit
def traj_metric(x: jnp.ndarray, y: jnp.ndarray, weights: jnp.ndarray) -> jnp.ndarray:
    return jax.vmap(vector_metric, in_axes=(-1, -1))(x, y) @ weights

In [13]:
full_obs = jax.vmap(ctg_full_obs)(inv_temps)
part_obs = jax.vmap(ctg_part_obs)(inv_temps)
best_inv_temp = inv_temps[jnp.argmin(part_obs)]

In [16]:
key, ic_key, ctl_key = rnd.split(key, 3)
ctl_keys = rnd.split(ctl_key, trials)
ics = jax.vmap(lambda k: rnd.multivariate_normal(k, ic, ic_cov), out_axes=-1)(rnd.split(ic_key, trials))

states, inputs, costs = jax.vmap(lambda ic, subkey: lqbr_sim(ic, jnp.zeros((6, horizon)), prob.params,
                                                             ctl.lqbr.LQBRParams(0.0, subkey, lqbr.params.prior_params)),
         in_axes=(0, 0), out_axes=-1)(ics.T, ctl_keys)

In [None]:
fm._rebuild()
plt.style.reload_library()
plt.style.use(['science', 'notebook', 'ieee'])
plt.rc('text', usetex=True)
plt.rc('font', family='serif')

plt.figure()
plt.plot(jnp.log(inv_temps), jax.vmap(ctg_full_obs)(inv_temps))
plt.plot(jnp.log(inv_temps), jax.vmap(ctg_part_obs)(inv_temps))

plt.vlines(jnp.log(best_inv_temp), 0.0, part_obs.max(), colors='gray', linestyles='dotted')

plt.xlabel('$\\log(\\beta)$')
plt.ylabel('$\mathbb{E}[C]$')