In [None]:
from pncbf.dyn.doubleint_wall import DoubleIntWall
from pncbf.dyn.sim_cts import SimCtsReal
import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp

In [None]:
import jax
print('JAX version:', jax.__version__); print('Available devices:', jax.devices())

In [None]:
system = DoubleIntWall()

print(f"state dim is: {system.nx}")
print(f"control dim is: {system.nu}")
print(f"state label is: {system.x_labels}")
print(f"control label is: {system.u_labels}")
print(f"constraint label is: {system.h_labels}")

In [None]:
fig, ax = plt.subplots(figsize = (8,6))

system_setup = system.phase2d_setups()[0]
system_setup.plot(ax)

plt.title("Double Integrator Phase Space")
plt.tight_layout()
plt.show()


In [None]:
def nominal_policy(state):
    k = np.array([[0.5, 0.8]])
    return -k @ state 

# test agaist an arbitrary state
test_state = np.array([-0.5, 0.1]) #[pos, vel]
control = nominal_policy(test_state)
print(f"state: {test_state}, control: {control}")

In [None]:
simulation_time = 5.0
dt = system.dt
sim = SimCtsReal(system, nominal_policy, simulation_time, dt)

# inital state
x0 = np.array([-0.5, 1.0]) #[pos, vel]

# simulation
T_states, T_times, _ = sim.rollout_plot(x0)

# plot trajectory
fig, ax = plt.subplots(figsize = (8,6))
system_setup.plot(ax)
ax.plot(T_states[:, 0], T_states[:, 1], 'r', linewidth = 2, label = 'Trajectory')
ax.plot(x0[0], x0[1], 'o', markersize = 8, label = 'Initial State')
ax.legend()
plt.title("Double Integrator Trajectory")
plt.tight_layout()
plt.show()

# plot safety values along the trajectory
T_h = jnp.array([system.h(state) for state in T_states])
plt.figure(figsize = (8,4))
plt.plot(T_times, T_h)
plt.axhline(y = 0, color = 'r', linestyle = '-', alpha = 0.5)
plt.xlabel('Time')
plt.ylabel('Safety Value')
plt.title('Safety Value Along Trajectory (h < 0 is safe)')
plt.grid(True)
plt.tight_layout()
plt.show()