In [None]:
import os
import sys
import pathlib

module_path = os.path.abspath(os.path.join('../../'))

if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import jax
import jax.numpy as jnp
import jax.random as rnd
import tqdm

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import matplotlib.patches as patches

from rationality import dynamics as dyn, objectives as obj, distributions as dst,\
    controllers as ctl, simulate as sim, geometry as geom, util

from mpl_toolkits.axes_grid1 import make_axes_locatable

from typing import Optional

Set up plotting configuration.

In [None]:
pathlib.Path('figures/').mkdir(parents=True, exist_ok=True)

plt.style.reload_library()
plt.style.use(['notebook'])

#plt.rcParams['text.latex.preamble'] = r'\usepackage{lmodern}'


%config InlineBackend.figure_format = 'svg'

figure_formats = {'png', 'pdf'}

In [None]:
width = 1.0
height = 1.0

obs_width = 0.05

middle_height = 0.4
middle_offset = 0.17
edge_height = 0.1

ic = jnp.array([0.15, 0.4])

prior_params = dst.GaussianParams(jnp.zeros(2), 0.0008 * jnp.eye(2))
ol_dist = dst.gaussian(*prior_params)
trials = 1000
batch_size = 10000
horizon = 100

trajectories_to_draw = 5

grid_spacing = 0.05

In [None]:
obstacles = [geom.aabb(jnp.array([0.5, edge_height / 2]), jnp.array([obs_width, edge_height])),
             geom.aabb(jnp.array([0.5, 0.5 - middle_offset]), jnp.array([obs_width, middle_height])),
             geom.aabb(jnp.array([0.5, 1 - edge_height / 2]), jnp.array([obs_width, edge_height]))]

workspace = geom.workspace(width, height, obstacles)

gap1_height = 0.5 - edge_height - middle_height / 2 - middle_offset
gap2_height = 0.5 - edge_height - middle_height / 2 + middle_offset

gaps = [geom.aabb(jnp.array([0.5, edge_height + gap1_height / 2]), jnp.array([obs_width, gap1_height])),
        geom.aabb(jnp.array([0.5, 1 - edge_height - gap2_height / 2]),
                                    jnp.array([obs_width, gap2_height])),]

goal = geom.aabb(jnp.array([(0.5 + obs_width / 2) + 3 * (0.5 - obs_width / 2) / 4, 0.1]),
                 jnp.array([(0.5 - obs_width / 2) / 2, 0.2]))

In [None]:
def draw_workspace(workspace: geom.Workspace, ic: Optional[jnp.ndarray],
                   goal: Optional[geom.Polytope], ax: Optional[plt.Axes] = None) -> plt.Axes:
    width, height = workspace.boundary.dimensions
    obstacles = workspace.obstacles

    if ax is None:
        ax = plt.gca()

    ax.set_xlim([0, width])
    ax.set_ylim([0, height])

    if ic is not None:
        ax.scatter(ic[0], ic[1], marker='x', color='k', s=180)

    for i in range(obstacles.centroid.shape[0]):
        geom.draw(geom.aabb(obstacles.centroid[i, :], obstacles.dimensions[i, :]), ax)

    # for g in gaps:
    #     g.draw(ax, hatch='.')

    if goal is not None:
       geom.draw(goal, ax, hatch='/')

    return ax

In [None]:
plt.figure()
draw_workspace(workspace, ic, goal)

goal_patch = patches.Patch(fill=False, edgecolor='k', hatch=r'/', label='Goal')
obs_patch = patches.Patch(fill=True, color='k', label='Obstacle')
plt.gca().legend(handles=[goal_patch, obs_patch], loc=2, frameon=False)

plt.gca().set_xticks([])
plt.gca().set_yticks([])
plt.gca().set_aspect('equal')

util.savefig('figures/workspace', figure_formats)

In [None]:
State = tuple[jnp.ndarray, jnp.ndarray]

@jax.jit
def rollout(key: jnp.ndarray) -> tuple[State, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    inputs = ol_dist.sample(horizon, key)

    @jax.jit
    def scanner(state, input) -> tuple[State, tuple[jnp.ndarray, float, bool, bool]]:
        pos, vel = state

        new_pos = pos + vel
        new_vel = vel + input

        is_in_goal = goal.contains(new_pos)
        is_in_free = workspace.freespace_contains_segment(pos, new_pos)

        dist = jnp.linalg.norm(input, ord=2)
        
        new_state = (new_pos, new_vel)

        return new_state, (new_state, dist, is_in_goal, is_in_free)

    carry, temporal = jax.lax.scan(scanner, (ic, jnp.zeros(2)), inputs.T)
    states, dists, is_in_goal, is_in_free = temporal



    return (jnp.concatenate([ic.reshape((-1, 1)), states[0].T], axis=1), jnp.concatenate([jnp.zeros((2, 1)), states[1].T], axis=1)),\
           dists, is_in_goal, is_in_free

In [None]:
from functools import partial

@partial(jax.jit, static_argnums=3)
def process_traj(dists: jnp.ndarray, is_in_goal: jnp.ndarray, is_in_free: jnp.ndarray) -> tuple[int, float]:
    @jax.jit
    def scanner(carry: tuple[int, int, float], time_dist_goal_free: tuple[int, float, bool, bool]) -> tuple[float, bool]:
        reached_goal_idx, collision_idx, total_cost = carry
        time, dist, in_goal, in_free = time_dist_goal_free

        reached_goal_idx = jax.lax.cond(in_goal,
                                        lambda t: jnp.minimum(reached_goal_idx, t + 1),
                                        lambda _: reached_goal_idx,
                                        time)

        collision_idx = jax.lax.cond(~in_free,
                                     lambda t: jnp.minimum(collision_idx, t + 1),
                                     lambda _: collision_idx,
                                     time)

        step_cost = jax.lax.cond(in_free,
                                 lambda d: d,
                                 lambda _: jnp.inf,
                                 dist)

        total_cost += jax.lax.cond(reached_goal_idx <= horizon,
                                   lambda _: 0.0,
                                   lambda c: c,
                                   step_cost)
        
        return (reached_goal_idx, collision_idx, total_cost), step_cost

    carry, _ = jax.lax.scan(scanner, (horizon + 1, horizon + 1, 0.0), (jnp.arange(len(dists)) + 1, dists, is_in_goal, is_in_free))

    return carry



In [None]:
key = rnd.PRNGKey(0)
states, dists, in_goal, in_free = jax.vmap(jax.jit(lambda k: rollout(k)), out_axes=-1)(rnd.split(key, 10000))
reached_goal_idxs, collision_idxs, total_costs = jax.vmap(process_traj, in_axes=-1)(dists, in_goal, in_free)

In [None]:
total_costs.shape

In [None]:
key = rnd.PRNGKey(0)

found = 0
first_time = True

successful_trajectories = []
successful_distances = []
successful_stopping_time = []
        

with tqdm.tqdm(total=trials) as pbar:
    while True:
        key, subkey = rnd.split(key)
        states, dists, in_goal, in_free = jax.vmap(jax.jit(lambda k: rollout(k)), out_axes=-1)(rnd.split(subkey, 10000))
        reached_goal_idxs, collision_idxs, total_costs = jax.vmap(process_traj, in_axes=-1)(dists, in_goal, in_free)
        successful_idxs = ((reached_goal_idxs < horizon + 1) & (reached_goal_idxs < collision_idxs))

        successful_trajectories.append(states[0][:, :, successful_idxs])
        successful_distances.append(total_costs[successful_idxs])
        successful_stopping_time.append(reached_goal_idxs[successful_idxs])

        pbar.update(int(jnp.minimum(successful_idxs.sum(), trials - found)))
        found += successful_idxs.sum()

        if found > trials:
            pbar.update(trials)
            pbar.close()
            break

successful_trajectories = jnp.concatenate(successful_trajectories, axis=-1)[:, :, :trials]
successful_distances = jnp.concatenate(successful_distances, axis=0)[:trials]
successful_stopping_time = jnp.concatenate(successful_stopping_time, axis=0)[:trials]

In [None]:
successful_trajectories.shape

In [None]:
alphas.max()

In [None]:
best_n_trajectories = jnp.argsort(successful_distances)#[:100]
alphas = (1 / (0.001 * successful_distances))
alphas = (alphas / jnp.max(alphas))

plt.figure()
draw_workspace(workspace, ic, goal)

goal_patch = patches.Patch(fill=False, edgecolor='k', hatch=r'/', label='Goal')
obs_patch = patches.Patch(fill=True, color='k', label='Obstacle')
plt.gca().legend(handles=[goal_patch, obs_patch], loc=2, frameon=False)

for idx in best_n_trajectories:
    stop = successful_stopping_time[idx]
    plt.plot(successful_trajectories[0, :stop, idx], successful_trajectories[1, :stop, idx], c='dodgerblue', alpha=float(alphas[idx]), linewidth=1)

plt.gca().set_xticks([])
plt.gca().set_yticks([])
plt.gca().set_aspect('equal')

In [None]:
gap_numbers = []

for i in range(trials):
    starts = successful_trajectories[:, :first_in_goal[i], i]
    ends = successful_trajectories[:, 1:(first_in_goal[i] + 1), i]
    g1 = jax.vmap(lambda s, e: gaps[0].intersects(s, e))(starts.T, ends.T)
    g2 = jax.vmap(lambda s, e: gaps[1].intersects(s, e))(starts.T, ends.T)

    if jnp.any(g1):
        gap_numbers.append(True)
    else:
        gap_numbers.append(False)

gap_numbers = jnp.array(gap_numbers)

In [None]:
@jax.jit
def gap_prob(inv_temp: float) -> float:
    w = jnp.exp(-inv_temp * successful_distances)

    return  (w[gap_numbers] / w.sum()).sum()

def set_size(w=None, h=None, ax=None):
    """ w, h: width, height in inches """
    if not ax:
        ax = plt.gca()

    l = ax.figure.subplotpars.left
    r = ax.figure.subplotpars.right
    t = ax.figure.subplotpars.top
    b = ax.figure.subplotpars.bottom

    figw = float(w) / (r - l) if w is not None else (r - l)
    figh = float(h) / (t - b) if h is not None else (t - b)

    ax.figure.set_size_inches(figw, figh)

In [None]:
plt.style.use(['science', 'ieee', 'notebook'])
plt.style.use(['science', 'notebook'])

plt.rc('text', usetex=True)
plt.rc('font', family='serif', serif='times')
plt.rc('text.latex', preamble='\\usepackage{lmodern}\n\\renewcommand{\\rmdefault}{ptm}')

inv_temps = jnp.linspace(0.0, 10.0)
betas_to_check = jnp.array([1.0, 2.0, 4.0, 8.0])

plt.figure()

plt.plot(inv_temps, 1 - jax.vmap(gap_prob)(inv_temps), color='dodgerblue')

for i, beta in enumerate(betas_to_check):
    y = 1 - gap_prob(beta)
    plt.scatter(beta, y, c='k', zorder=3)
    plt.plot([beta, beta], [y, 0], 'k:')

label_font_size = 30
ticks_font_size = 26

plt.xlabel(f'$\\beta$', fontsize=label_font_size)
plt.xticks(fontsize=ticks_font_size)

plt.ylabel(f'Prob. Large Gap Traversed', fontsize=label_font_size)
plt.yticks(fontsize=ticks_font_size)

plt.title(f'Robust Path Planning', fontsize=label_font_size)


plt.ylim([0.0, 1.0])

ax = plt.gca()
ax.set_aspect(1.0 / ax.get_data_ratio(), adjustable='box')
plt.tight_layout()

set_size(5.0, 5.0, ax)

util.savefig(f'figures/gap-prob-plot', figure_formats)

In [None]:
grid_spacing = 0.01
xs = jnp.arange(grid_spacing, 1.0, grid_spacing)
ys = jnp.arange(grid_spacing, 1.0, grid_spacing)

X, Y = jnp.meshgrid(xs, ys)

C = []

for i, inv_temp in enumerate(betas_to_check):
    logits = -inv_temp * successful_distances

    starts = jnp.concatenate([successful_trajectories[:, :first_in_goal[i], i] for i in range(trials)], axis=1)
    ends = jnp.concatenate([successful_trajectories[:, 1:(first_in_goal[i] + 1), i] for i in range(trials)], axis=1)

    segment_logits = jnp.concatenate([logits[i] * jnp.ones(first_in_goal[i]) for i in range(trials)])

    @jax.jit
    def compute_color(x: float, y: float, starts: jnp.ndarray, ends: jnp.ndarray, segment_logits: jnp.ndarray) -> float:
        centroid = jnp.array([x + grid_spacing / 2.0, y + grid_spacing / 2.0])
        dimensions = jnp.array([grid_spacing, grid_spacing])
        box = geom.aabb(centroid, dimensions)

        return jax.vmap(lambda s, e, w: jax.lax.cond(box.intersects(s, e),
                                                     lambda _: w,
                                                     lambda _: 0.0,
                                                     None), in_axes=(1, 1, 0))(starts, ends, jnp.exp(segment_logits)).sum()


    C_for_inv_temp = jnp.stack([jax.vmap(lambda y: compute_color(x, y, starts, ends, segment_logits))(ys) for x in xs], axis=-1)
    C.append(C_for_inv_temp / C_for_inv_temp.sum())

vmax = jnp.min(jnp.array([jnp.unique(c)[-10] for i, c in enumerate(C)]))

In [None]:
plt.style.use(['science', 'ieee', 'notebook'])
plt.style.use(['science', 'notebook'])

plt.rc('text', usetex=True)
plt.rc('font', family='serif', serif='times')
plt.rc('text.latex', preamble='\\usepackage{lmodern}\n\\renewcommand{\\rmdefault}{ptm}')


cmap = mpl.colors.LinearSegmentedColormap.from_list('Dodger Blue', ['white', 'dodgerblue'])

def density_plot(ax: plt.Axes, subkey: jnp.ndarray, C: jnp.ndarray, inv_temp: float) -> plt.Axes:
    logits = -inv_temp * successful_distances
    sampled_idxs = jnp.unique(rnd.categorical(subkey, logits, shape=(trajectories_to_draw,)))

    ax.pcolormesh(X, Y, C, cmap=cmap,
                    vmin=0.0, vmax=vmax, shading='gouraud')

    for j, idx in enumerate(sampled_idxs):
        line_handle = ax.plot(successful_trajectories[0, :(first_in_goal[idx] + 1), idx],
                              successful_trajectories[1, :(first_in_goal[idx] + 1), idx],
                              c='k', linestyle=':', label='Sample Path')

    return ax, line_handle[0]



for i, inv_temp in enumerate(betas_to_check):
    plt.figure()
    ax = plt.gca()

    key, subkey = rnd.split(key)
    ax, line_handle = density_plot(ax, subkey, C[i], inv_temp)
    draw_workspace(workspace, ic, goal, ax)

    if inv_temp == inv_temps[-1]:
        goal_patch = patches.Patch(fill=False, edgecolor='k', hatch=r'/', label='Goal')
        obs_patch = patches.Patch(fill=True, color='k', label='Obstacle')
        ax.legend(handles=[goal_patch, obs_patch, line_handle], loc=2, frameon=False, fontsize=24)

    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect('equal')

    plt.xlabel(f'$\\beta = {inv_temp}$', fontsize=label_font_size, labelpad=ticks_font_size + 5)
    plt.tight_layout()

    set_size(5.0, 5.0, ax=ax)

    if inv_temp != betas_to_check[-1]:
        util.savefig(f'figures/density-beta-{inv_temp:.3f}', figure_formats)

set_size(w=5.0 / 0.93, h=5.0 / 0.93, ax=ax)

the_divider = make_axes_locatable(ax)
color_axis = the_divider.append_axes('right', size='5%', pad='2%')

norm = mpl.colors.Normalize(vmin=0.0, vmax=vmax)
#cbar = plt.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), cax=color_axis,
#                    format=mpl.ticker.FuncFormatter(lambda x, _: f'${x:.0e}$' if x > 0.0 else '0'))
cbar = plt.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), cax=color_axis)
cbar.formatter.set_powerlimits((0, 0))

cbar.set_ticks(cbar.ax.get_yticks()[0::2])

for t in cbar.ax.get_yticklabels():
    t.set_fontsize(ticks_font_size)

cbar.update_ticks()
util.savefig(f'figures/density-beta-{inv_temp:.3f}', figure_formats)