In [4]:
import os
import sys

# Remove any conflicting env vars
os.environ.pop("JAX_METAL_ENABLE", None)  # Remove Metal flag if present
os.environ.pop("ENABLE_PJRT_COMPATIBILITY", None)

# Set CPU as ONLY platform - more aggressive approach
os.environ["JAX_PLATFORMS"] = "cpu"
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

# Most important: Clean module cache to force reinitialization
if "jax" in sys.modules:
    del sys.modules["jax"]
if "jaxlib" in sys.modules:
    del sys.modules["jaxlib"]

# Now import JAX - it will use these settings
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)

# Verify CPU is being used (should show TFRT_CPU devices)
devices = jax.devices()
print(f"JAX {jax.__version__} using: {devices}")
if any("metal" in str(d).lower() for d in devices):
    raise RuntimeError("Still using Metal! Need to restart kernel completely.")

# Import remaining packages
import numpy as np
import hj_reachability as hj

JAX 0.4.26 using: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]


In [5]:
import jax
import jax.numpy as jnp
import numpy as np
import hj_reachability as hj

First, we define our 2D drone system.

In [6]:
from hj_reachability import dynamics
from hj_reachability import sets

class Drone2D(dynamics.ControlAndDisturbanceAffineDynamics):

  def __init__(self,
               k,
               g=9.8,
               u_bar=1.):
    self.k = k
    self.g = g
    control_mode = 'max'
    disturbance_mode = 'min'
    control_space = sets.Box(jnp.array([-u_bar]), jnp.array([u_bar]))
    disturbance_space = sets.Box(jnp.array([0.]), jnp.array([0.]))
    super().__init__(control_mode, disturbance_mode, control_space, disturbance_space)

  def open_loop_dynamics(self, state, time):
    _, v = state
    return jnp.array([v, -self.g])

  def control_jacobian(self, state, time):
    return jnp.array([[0.], [self.k]])

  def disturbance_jacobian(self, state, time):
    return jnp.array([[0.], [0.]])

Here, we set the k values, state grid, failure function, and times that we want to compute the BRTs for.

We set solver parameters as well.

In [7]:
# Different values of constant K that we will compute the value function for
ks = np.linspace(6, 12, 11, endpoint=True)

# Define the computation grid for numerically solving the PDE
grid = hj.Grid.from_lattice_parameters_and_boundary_conditions(
    hj.sets.Box(np.array([-2., -4.]),
                np.array([+2., +4.])),
    (51, 51))

# Define the implicit function l(x) for the failure set
failure_values = 1.5 - jnp.abs(grid.states[..., 0])

# Solver settings
times = np.linspace(0, -5, 11, endpoint=True)
solver_settings = hj.SolverSettings.with_accuracy('very_high',
                                                  hamiltonian_postprocessor=hj.solver.backwards_reachable_tube)

Next, we compute the corresponding value functions.

In [8]:
# Compute the value function by solving the PDE for each K
values = np.full((len(ks), len(times), 51, 51), fill_value=np.nan)
from tqdm import tqdm
for i, k in tqdm(enumerate(ks)):
  dynamics = Drone2D(k)
  values[i] = hj.solve(solver_settings, dynamics, grid, times, failure_values)

100%|##########|  5.0000/5.0 [00:00<00:00, 45.49sim_s/s]
100%|##########|  5.0000/5.0 [00:00<00:00, 43.33sim_s/s]
100%|##########|  5.0000/5.0 [00:00<00:00, 43.35sim_s/s]
100%|##########|  5.0000/5.0 [00:00<00:00, 41.74sim_s/s]
100%|##########|  5.0000/5.0 [00:00<00:00, 41.24sim_s/s]
100%|##########|  5.0000/5.0 [00:00<00:00, 40.50sim_s/s]
100%|##########|  5.0000/5.0 [00:00<00:00, 39.85sim_s/s]
100%|##########|  5.0000/5.0 [00:00<00:00, 39.01sim_s/s]
100%|##########|  5.0000/5.0 [00:00<00:00, 37.22sim_s/s]
100%|##########|  5.0000/5.0 [00:00<00:00, 37.23sim_s/s]
100%|##########|  5.0000/5.0 [00:00<00:00, 35.81sim_s/s]
11it [00:05,  2.00it/s]


Finally, we visualize the computed value functions and corresponding BRTs.

In [9]:
import matplotlib.pyplot as plt
from ipywidgets import interact, FloatSlider

vbar = 1.5
def plot_value_function(k, t):
  ki = np.argwhere(np.isclose(ks, k)).item()
  ti = np.argwhere(np.isclose(times, t)).item()

  plt.figure()
  plt.title(f'$V(x, {t})$ for k={k}')
  plt.xlabel('$v_z$ (m/s)')
  plt.ylabel('$z$ (m)')

  plt.pcolormesh(
      grid.coordinate_vectors[1],
      grid.coordinate_vectors[0],
      values[ki, ti],
      cmap='RdBu',
      vmin=-vbar, vmax=vbar
  )
  plt.colorbar()
  plt.contour(
      grid.coordinate_vectors[1],
      grid.coordinate_vectors[0],
      values[ki, ti],
      levels=0,
      colors='k'
  )
  plt.contour(
      grid.coordinate_vectors[1],
      grid.coordinate_vectors[0],
      failure_values,
      levels=0,
      colors='r'
  )
  plt.show()

interact(
    plot_value_function,
    k=FloatSlider(value=12., min=6., max=12., step=0.6),
    t=FloatSlider(value=0., min=-5., max=0., step=0.5)
)

interactive(children=(FloatSlider(value=12.0, description='k', max=12.0, min=6.0, step=0.6), FloatSlider(value…

<function __main__.plot_value_function(k, t)>

Let's simulate the (time-invariant) optimal safety controller.

\begin{align}
u^*(x)&=\arg\max_{u\in[-1, 1]}\nabla V(x)^Tf(x,u)\\
&=\arg\max_{u\in[-1,1]}\{\beta_1(x)v_z + \beta_2(x)(ku-g)\}\\\
&=\begin{cases}
-1 & \text{if }\beta_2(x)k<0\\
1 & \text{otherwise},
\end{cases}
\end{align}
where $\nabla V(x)^T=[\beta_1(x), \beta_2(x)]$.

We take the case when $k=12$.

In [10]:
values_k12 = values[-1, -1]
grads_k12 = grid.grad_values(values_k12, solver_settings.upwind_scheme)
beta2s_k12 = grads_k12[:, :, 1]

from scipy.interpolate import interpn
def optimal_safety_controller(x):
  beta2_k12 = interpn(
      ([np.array(v) for v in grid.coordinate_vectors]),
      np.array(beta2s_k12),
      x,
      method='linear',
      bounds_error=False,
      fill_value=None
  )
  return np.sign(beta2_k12).item()

Next, we simulate the system when $k=12$ using the discrete Euler approximation
for 5 seconds for some starting state, under the optimal safety controller.

In [11]:
T = 5
dt = 0.01
def simulate(x0):
  nt = int(T / dt)
  xs = np.full((nt, 2), fill_value=np.nan)
  xs[0] = x0
  for i in range(1, nt):
    x = xs[i-1]
    u = optimal_safety_controller(x)
    xs[i] = x + dt*np.array([x[1], 12*u-9.8])
  return xs

Finally, we visualize the state trajectory.

In [None]:
def plot_trajectory(z, vz):
  xs = simulate(np.array([z, vz]))
  plt.figure()
  plt.plot(xs[:, 1], xs[:, 0], linewidth=4, color='purple')
  plt.title(f'Optimally Safe Trajectory from $z={z}$, $v_z={vz}$')
  plt.xlabel('$v_z$ (m/s)')
  plt.ylabel('$z$ (m)')
  plt.pcolormesh(
      grid.coordinate_vectors[1],
      grid.coordinate_vectors[0],
      values_k12,
      cmap='RdBu',
      vmin=-vbar, vmax=vbar
  )
  plt.colorbar()
  plt.contour(
      grid.coordinate_vectors[1],
      grid.coordinate_vectors[0],
      values_k12,
      levels=0,
      colors='k'
  )
  plt.contour(
      grid.coordinate_vectors[1],
      grid.coordinate_vectors[0],
      failure_values,
      levels=0,
      colors='r'
  )
  plt.show()

interact(
    plot_trajectory,
    z=FloatSlider(value=0.9, min=-2., max=2., step=0.1),
    vz=FloatSlider(value=0., min=-4., max=4., step=0.1)
)

interactive(children=(FloatSlider(value=0.9, description='z', max=2.0, min=-2.0), FloatSlider(value=0.0, descr…

<function __main__.plot_trajectory(z, vz)>