In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from inv_pendulum import inv_pend
import hj_reachability as hj
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")
from matplotlib import MatplotlibDeprecationWarning
warnings.filterwarnings("ignore", category=MatplotlibDeprecationWarning)
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax
import numpy as np

import matplotlib
import pickle as pkl
import utils
from scipy.interpolate import interp1d
from matplotlib.animation import FuncAnimation
from mpl_toolkits.axes_grid1 import make_axes_locatable
from IPython.display import HTML
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize

# What domain does the value function live on? What discretization do I want?
grid = hj.Grid.from_lattice_parameters_and_boundary_conditions(
    domain = hj.sets.Box(lo=np.array([-jnp.pi, -1, 1]), hi=np.array([jnp.pi,1, 2])),
    shape = (111, 111, 5),
    # Boundary conditions matter for accuracy on the edge of the domain grid.
    # E.g. if we have an angle theta \in [0, 2pi] then we want a periodic boundary in that dim.
    periodic_dims=0,
    #boundary_conditions=(hj.boundary_conditions.extrapolate_away_from_zero,
    #                    hj.boundary_conditions.extrapolate_away_from_zero)
)

# Start State
x_init = None
# Target Region (here 2D circular defined by a center point and a Radius)
#x_target = [jnp.pi, 0]
#r_target = 1
direction = 'backward'

target = jnp.array([[-0.1, 0.1], [-0.05, 0.05], [0.5,2.5]])
sdf = utils.build_target_sdf(target, 0.1)

sdf_values = hj.utils.multivmap(sdf, jnp.arange(grid.ndim))(grid.states)

#cf = plt.contourf(grid.coordinate_vectors[0], grid.coordinate_vectors[1], sdf_values.T, levels=100)
#plt.contour(grid.coordinate_vectors[0], grid.coordinate_vectors[1], sdf_values.T, levels=[0.0], colors='k', linewidths=3.0)
# add colorbar
#plt.colorbar(cf)

#Setting the input
uMax = 10
dMin = 1
dMax = 1
params = {'v': [1.,2.]}
times = np.linspace(0, -4, 120)

initial_values = sdf_values
dynamics =  inv_pend(params=params, dMin = dMin, dMax = dMax, uMax = uMax,
                   control_mode="min", disturbance_mode="max")

#Defining as tube
hamiltonian_postprocessor = lambda x : jnp.minimum(x,0) #BRT
obstacles = None

accuracy = "high"


artificial_dissipation_scheme = hj.artificial_dissipation.global_lax_friedrichs

x_init_for_solving = None

identity = lambda *x: x[-1]
if 'hamiltonian_postprocessor' not in locals():
    hamiltonian_postprocessor=identity

if 'value_postprocessor' not in locals():
    value_postprocessor=identity

solver_settings = hj.SolverSettings.with_accuracy(
    accuracy=accuracy,
    x_init=x_init_for_solving,
    artificial_dissipation_scheme=artificial_dissipation_scheme,
    hamiltonian_postprocessor=hamiltonian_postprocessor,
    value_postprocessor=value_postprocessor,
)

#Solve
all_values = utils.param_solve(
   solver_settings, dynamics, grid, times, initial_values)
np.save('valuefunc.npy', all_values)

sum = (all_values > 0.0).sum(axis=2).sum(axis=1)
print(sum)


# Set up the figure and axis
fig = plt.figure()
ax = fig.add_subplot(111)
div = make_axes_locatable(ax)

vmax = np.abs(all_values[0]).max()
cf = ax.contourf(grid.coordinate_vectors[0], grid.coordinate_vectors[1],
                 all_values[0].T)
cont = ax.contour(grid.coordinate_vectors[0], grid.coordinate_vectors[1],
           all_values[0].T, levels=[0], colors='green', linewidths=2)
#cont2 = ax.contour(grid.coordinate_vectors[0], grid.coordinate_vectors[1],
#              all_values_dist[0].T, levels=[0], colors='blue', linewidths=0.1)
zerolvl = ax.contour(grid.coordinate_vectors[0], grid.coordinate_vectors[1],
           sdf_values.T, levels=[0], colors='black', linewidths=2)
cbar = fig.colorbar(cf)
cbar.ax.set_ylabel('Value')
tx = ax.set_title(f'HJR time $t=0$')
ax.set_xlabel('$y$ (Horizontal)')
ax.set_ylabel('$z$ (Vertical)')
tx = ax.set_title(f'$v_y=0, v_z=0$, HJR time $t=0$')

# Update function to draw contours for a given idi value
def update(idi):
    global cont, cont2, cbar
    arr = all_values[idi].T
    #arr2 = all_values_dist[idi].T
    vmax = np.abs(arr).max()
    cf = ax.contourf(grid.coordinate_vectors[0], grid.coordinate_vectors[1], arr) #, vmax=vmax, vmin=-vmax)
    cont.collections[0].remove()
    cont2.collections[0].remove()
    cbar.remove()
    cont = ax.contour(grid.coordinate_vectors[0], grid.coordinate_vectors[1],
               arr, levels=[0], colors='green')
    #cont2 = ax.contour(grid.coordinate_vectors[0], grid.coordinate_vectors[1],
    #           arr2, levels=[0], colors='blue')
    cbar = fig.colorbar(cf)
    tx.set_text('HJR time t={:.2f}'.format(np.abs(times[idi].item())))


# Animate with idi values from 0 to 11
ani = FuncAnimation(fig, update, frames=range(len(times)))
ani.save('ani.mp4')