In [None]:
import os
import sys
import pathlib

module_path = os.path.abspath(os.path.join('../../'))

if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import jax
import jax.numpy as jnp
import jax.random as rnd
import jax.experimental.optimizers as opt
import tqdm

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import matplotlib.patches as patches

from rationality import dynamics as dyn, objectives as obj, distributions as dst,\
    controllers as ctl, simulate as sim, geometry as geom, util, inference as inf

from mpl_toolkits.axes_grid1 import make_axes_locatable

from typing import Optional, Callable
from functools import partial

In [None]:
step_size = jnp.array(0.1)
disturbance_iters = 100

btl_scale = jnp.array(0.9)
btl_iters = 1000

r = 10.0

betas = jnp.array([1.0])#jnp.exp(jnp.linspace(-1.0, 3.0, 20))
prior_std = 1.0

number_of_prior_samples = 100
trials = 100

In [None]:
@jax.jit
def metric(x: float, x_hat: float) -> float:
    return 0.5 * jnp.abs(x ** 2 - x_hat ** 2) + jnp.abs(x - x_hat)

@partial(jax.jit, static_argnums=0)
def btl(pred: Callable[[jnp.ndarray], bool], feasible_point: jnp.ndarray, new_point: jnp.ndarray, scaling: float):
    @jax.jit
    def btl_scanner(carry: float, temporal: None) -> tuple[float, tuple[float, jnp.ndarray]]:
        direction = new_point - feasible_point
        test_point = feasible_point + carry * direction

        return (carry * scaling), (pred(test_point), test_point)

    feas, points = jax.lax.scan(btl_scanner, 1.0, None, length=btl_iters)[1]

    return jnp.append(feas, pred(feasible_point)), jnp.append(points, feasible_point)


@jax.jit
def find_disturbance(x_hat: float, u: float, r: float, prior_samples: jnp.ndarray, beta: float) -> float:
    @jax.jit
    def predicate(test_point: jnp.ndarray) -> bool:
        return metric(x_hat, x_hat - test_point) <= r

    @jax.jit
    def objective(d: jnp.ndarray) -> float:
        prior_hamiltonian_values = jax.vmap(lambda prior: hamiltonian(x_hat - d, prior))(prior_samples)
        logits = jax.vmap(lambda prior: log_prob(x_hat - d, prior, beta))(prior_samples)
        unnormalized = jax.scipy.special.logsumexp(logits, b=prior_hamiltonian_values)

        return unnormalized - jax.scipy.special.logsumexp(logits)

    @jax.jit
    def opt_scanner(current_disturbance: jnp.ndarray, _: None) -> tuple[jnp.ndarray, tuple[float, jnp.ndarray]]:
        test_disturbance = current_disturbance + step_size * grad(current_disturbance)
        feasible, disturbances = btl(predicate, current_disturbance, test_disturbance, btl_scale)
        values = jax.vmap(objective, in_axes=-1)(disturbances)

        best_idx = jnp.argmax(jnp.where(feasible, values, -1.0))


        return disturbances[best_idx], (values[best_idx], disturbances[disturbance_iters])

    grad = jax.jit(jax.grad(objective))

    best_disturbance, opt_traj = jax.lax.scan(opt_scanner, jnp.array(0.0), None, length=disturbance_iters)

    return best_disturbance


@jax.jit
def hamiltonian(x: float, u: float) -> float:
    return jnp.array(0.5 * (x - u) ** 2, float)

@jax.jit
def lipschitz(u: float) -> float:
    return jnp.maximum(jnp.abs(u), 0.5)

In [None]:
@jax.jit
def log_prob(x: float, u: float, beta: float) -> float:
    return jax.scipy.stats.norm.logpdf(u, scale=prior_std) - beta * hamiltonian(x, u)

In [None]:


@jax.jit
def conduct_trial(beta: float, key: jnp.ndarray, prior_samples: jnp.ndarray) -> tuple[float, float, float]:
    key, sk1, sk2 = rnd.split(key, 3)

    x_hat = rnd.normal(sk1)
    u = inf.sir(lambda u: log_prob(x_hat, u, beta), prior_samples, sk2)
    d = find_disturbance(x_hat, u, r, prior_samples, beta)
    x = x_hat - d

    return hamiltonian(x_hat, u), hamiltonian(x, u), metric(x, x_hat)

key = rnd.PRNGKey(0)
key, subkey = rnd.split(key)
prior_samples = prior_std * rnd.normal(subkey, shape=(number_of_prior_samples,))

In [None]:
def helper(beta: float) -> tuple[float, float]:
    fake, real, _ = jax.vmap(lambda subkey: conduct_trial(beta, subkey, prior_samples))(rnd.split(key, trials))

    return fake.mean(), real.mean()

In [None]:
fakes, reals = jax.vmap(helper)(betas)

In [None]:
plt.figure()
plt.plot(betas, fakes, label='Fully Observable')
plt.plot(betas, reals, label='Worst Case')