In [None]:
!pip install --upgrade hj-reachability

In [None]:
import jax
import jax.numpy as jnp
import numpy as np

from IPython.display import HTML
import matplotlib.animation as anim
import matplotlib.pyplot as plt
import plotly.graph_objects as go

import hj_reachability as hj

# Double Integrator

## Define dynamics class

In [None]:
 class DoubleIntegrator(hj.ControlAndDisturbanceAffineDynamics):
    def __init__(self,
                 max_acceleration=1.,
                 max_position_disturbance=0,
                 control_mode="min",
                 disturbance_mode="max",
                 control_space=None,
                 disturbance_space=None):
        if control_space is None:
            control_space = hj.sets.Box(lo=jnp.array([-max_acceleration]), hi=jnp.array([max_acceleration]))
        if disturbance_space is None:
            disturbance_space = hj.sets.Ball(jnp.zeros(1), max_position_disturbance)
        super().__init__(control_mode, disturbance_mode, control_space, disturbance_space)

    def open_loop_dynamics(self, state, time):
        _, v = state
        return jnp.array([v, 0.])

    def control_jacobian(self, state, time):
        return jnp.array([
            [0.],
            [1.]
        ])

    def disturbance_jacobian(self, state, time):
        return jnp.array([
            [1.],
            [0.]
        ])

## Initialize solver

In [None]:
dynamics = DoubleIntegrator()
limits = np.array([5., 3., 1.])
grid_domain = hj.sets.Box(lo=jnp.array([-5., -3.]), hi=jnp.array([5., 3.]))
grid_shape = (100, 100)

grid = hj.Grid.from_lattice_parameters_and_boundary_conditions(domain=grid_domain, shape=grid_shape)
values = jnp.linalg.norm(grid.states[..., :2], axis=-1) - 1 # TODO: Understand what this is?
values = (grid.states[..., 0] - 5)**2 + (grid.states[..., 1] - 3)**2
solver_settings = hj.SolverSettings.with_accuracy("very_high",
                                                  hamiltonian_postprocessor=hj.solver.backwards_reachable_tube)

## Solve

In [None]:
time = 0.
target_time = -2.0
target_values = hj.step(solver_settings, dynamics, grid, time, values, target_time)

## Plot Results

In [None]:
plt.contourf(grid.coordinate_vectors[0], grid.coordinate_vectors[1], target_values)

# plt.jet()# plt.figure(figsize=(13, 8))# plt.contourf(grid.coordinate_vectors[0], grid.coordinate_vectors[1], target_values)# plt.colorbar()# plt.contour(grid.coordinate_vectors[0],#             grid.coordinate_vectors[1],#             target_values,#             levels=0,#             colors="black",#             linewidths=3)

## TODO: Animation

In [None]:
times = np.linspace(0, -2.8, 57)
initial_values = values
all_values = hj.solve(solver_settings, dynamics, grid, times, initial_values)

In [None]:
vmin, vmax = all_values.min(), all_values.max()
levels = np.linspace(round(vmin), round(vmax), round(vmax) - round(vmin) + 1)
fig = plt.figure(figsize=(13, 8))


def render_frame(i, colorbar=False):
    plt.contourf(grid.coordinate_vectors[0],
                 grid.coordinate_vectors[1],
                 all_values[i, :, :, 30].T,
                 vmin=vmin,
                 vmax=vmax,
                 levels=levels)
    if colorbar:
        plt.colorbar()
    plt.contour(grid.coordinate_vectors[0],
                grid.coordinate_vectors[1],
                target_values[:, :, 30].T,
                levels=0,
                colors="black",
                linewidths=3)


render_frame(0, True)
animation = HTML(anim.FuncAnimation(fig, render_frame, all_values.shape[0], interval=50).to_html5_video())
plt.close(); animation