In [None]:
%load_ext autoreload
%autoreload 2

import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML

import functools

from stljax.formula import Predicate, Eventually, Always, DifferentiableAlways
from stljax.viz import make_stl_graph
from stljax.utils import anneal

from matplotlib import rc
rc('font',**{'family':'serif','serif':['Palatino']})
rc('text', usetex=True)

jax.config.update("jax_enable_x64", True)


Some helper functions

In [None]:
# define single integrator dynamics
@jax.jit
def dynamics_discrete_step(state, control, dt=0.1):
    '''Single integrator dynamics'''
    return state + control * dt

# roll out states for a sequence of controls
@jax.jit
def simulate_dynamics(controls, state0, dt):
    '''Function to roll out initial state and controls'''
    def scan_fn(state, control):
        new_state = dynamics_discrete_step(state, control, dt)
        return new_state, new_state[0]
    return jnp.concatenate([state0, jax.lax.scan(scan_fn, state0, controls)[1]], axis=0)

# predicate function - distance from a point
@jax.jit
def compute_distance_to_point(states, point):
    return jnp.linalg.norm(states[...,:2] - point, axis=-1, keepdims=True)

# predicate function - distance from origin
@jax.jit
def compute_distance_to_origin(states):
    return compute_distance_to_point(states, jnp.zeros(2))


Defining STL formula

In [None]:
# environment parameters
# target region can be treated as an obstacle or a region to visit
target_center = jnp.array([[0, 2]]) # target location
reach_radius = 0.1 # radius considered to have reached goal 
target_radius = 0.5 # radius of the target + agent circular footprint


distance_to_origin = Predicate("distance to origin", compute_distance_to_origin)
distance_to_target = Predicate("distance to target", functools.partial(compute_distance_to_point, point=target_center))

reach = Eventually(distance_to_origin < reach_radius) # eventually reach the goal
avoid = Always(distance_to_target > target_radius) # always avoid target
# stay = Eventually(Always(distance_to_target < 0.5, interval=[0, 7]), interval=[0,20])   # if you don't want to have differentiable time intervals
stay = DifferentiableAlways(distance_to_target < 0.5) # stay inside target region

formula = reach & stay
# formula = Until(distance_to_target > 0.5, Always(distance_to_origin < 0.5), interval=[40,45])

make_stl_graph(formula)


Defining cost function

In [None]:
def exponenial_penalty(x):
    return jnp.exp(x)

@functools.partial(jax.jit, static_argnames=("approx_method"))
def loss(controls, t_start, t_end, scale, state0, umax, dt, coeffs=jnp.array([1., 0.1, 5., 0.]), approx_method="true", temperature=None):
    # see paper for more details on loss function
    # generate trajectory from control sequence and reverse along time dimension
    traj = simulate_dynamics(controls, state0, dt)
    # loss functions
    
    # penalize for negative robustness, no penality for positive robustness
    loss_robustness = jax.nn.relu(-formula.robustness(traj, t_start=t_start, t_end=t_end, scale=scale, approx_method=approx_method, temperature=temperature))
    
    # loss to encourage controls to be smooth. Total variation + mean squared
    total_variation_weight = 0.1
    loss_control_smoothness = total_variation_weight * jnp.abs(jnp.diff(controls, axis=1)).sum(-1).mean()  + (controls**2).sum(-1).mean() 
    
    # penalize for control limit violation
    loss_control_limits = jax.nn.relu(jnp.linalg.norm(controls, axis=-1) - umax).mean() 
    
    # encourage time duration in target region to be as long as possible, encouraging at least min_interval.
    min_interval = 0.2
    interval_difference = min_interval - (t_end - t_start)  # negative is good
    
    # cost vector
    cost_array = jnp.array([
        loss_robustness,
        loss_control_smoothness,
        loss_control_limits,
        exponenial_penalty(2 * interval_difference)
    ])
    
    # multiply with weighting coefficients
    return jnp.dot(coeffs, cost_array)
    
grad_jit = jax.jit(jax.grad(loss, [0,1,2]), static_argnames="approx_method") # get gradient of loss with respect to controls, t_start, t_end


# compute true robustness of a trajectory
@jax.jit
def true_robustness(controls, t_start, t_end, scale, state0, dt):
    traj = simulate_dynamics(controls, state0, dt)
    return formula.robustness(traj, t_start=t_start, t_end=t_end, scale=scale).mean()

@jax.jit
def schedule(i, i_max, start, end):
    j = (i / i_max)
    temp = anneal(j)
    return temp * (end - start) + start


Setting up parameters to begin the gradient descent routine

In [None]:
np.random.seed(123)
T = 51  # time horizon
dt = 0.1 # time step size
ts = jnp.array([t * dt for t in range(T)]) # time step array
umax = 2.  # max control limit

controls = jnp.array(np.random.randn(T,2))  # initial random control sequence
state0 = jnp.ones(2).reshape([1,2]) * 3. # initial state
states_ = [simulate_dynamics(controls, state0, dt)]  # list to collect all the state trajectories at each gradient descent step

# initial values for time interval (before passing through softmax)
t_start = -1.8
t_end = 1.5

lr = 1E-2 # learning rate
approx_method = "logsumexp"
n_steps = 10000   # number of gradient steps

# start and end values for annealing temperature and scale
start_temp = 1
end_temp = 100

start_scale = 10
end_scale = 100

coeffs = jnp.array([1.1, 0.5, 2., 0.05]) # coefficients for loss function

Run the functions to test them out


In [None]:
scale = 0.1
loss(controls, t_start, t_end, scale, state0, umax, dt)
loss(controls, t_start, t_end, scale, state0, umax, dt, approx_method="softmax", temperature=5)
true_robustness(controls, t_start, t_end, scale, state0, dt)
grad_jit(controls, t_start, t_end, scale, state0, umax, dt, coeffs, approx_method, 0.2);



Run optimization loop!

In [None]:

for i in range(n_steps):
    
    # get annealed temperature and scale
    temperature = schedule(i, n_steps, start_temp, end_temp)
    scale = schedule(i, n_steps, start_scale, end_scale)
    
    # get t_start and t_end in [0, 1] range
    t_start_ = jax.nn.sigmoid(t_start)
    t_end_ = jax.nn.sigmoid(t_end)
    
    # compute gradient of loss with respect to controls, t_start, t_end
    g1, g2, g3 = grad_jit(controls, t_start_, t_end_, scale, state0, umax, dt, coeffs, approx_method, temperature)  # take gradient
    
    # stop loop if gradient is small or NaN
    if ((jnp.linalg.norm(g1)/ T / 2) < 5E-6) or (jnp.isnan(g1).sum() > 0):
        break

    # update controls and time interval    
    controls -= g1 * lr
    # account for sigmoid transformation of t_start and t_end
    t_start -= g2 * lr * t_start_ * (1 - t_start_) 
    t_end -= g3 * lr * t_end_ * (1 - t_end_)
    
    # collect state trajectory for the current controls
    states_.append(simulate_dynamics(controls, state0, dt))
    
    # print out every 50th step
    if (i % 50) == 0:
        t_start_ = jax.nn.sigmoid(t_start)
        t_end_ = jax.nn.sigmoid(t_end)
        print("%3i -- true robustness: %.2f   smoothness: %.2f    control limits: %.2f    interval: %.2f t_start: %.2f    t_end: %.2f"%(i, true_robustness(controls, t_start_, t_end_, 1000., state0, dt), loss(controls, t_start_, t_end_, 1000., state0, umax, dt, coeffs=jnp.array([0., 1., 0., 0.])), loss(controls, t_start_, t_end_, 1000., state0, umax, dt, coeffs=jnp.array([0., 0., 1., 0.])), loss(controls, t_start_, t_end_, 1000., state0, umax, dt, coeffs=jnp.array([0., 0., 0., 1.])), t_start_, t_end_))


## visualize final solution

In [None]:
fig, axs = plt.subplots(1,3, figsize=(15,4)) 

ax = axs[0]
circle1 = plt.Circle((0, 0), 0.2, color='C2', alpha=0.4)
circle2 = plt.Circle(target_center[0], 0.5, color='C1', alpha=0.4)

ax.add_patch(circle1)
ax.add_patch(circle2)

N = 250
[ax.plot(*s.T, color="k", alpha=0.2) for s in states_[::N]]
[ax.plot(*s.T, color="blue", label="Initial traj") for s in states_[:1]]
[ax.plot(*s.T, '.-', color="r", markersize=10, label="Final traj") for s in states_[-1:]]

ax.scatter(states_[-1][0,:1], states_[-1][0,1:], marker="^", c='yellow', edgecolor="k", s=100, label="start", zorder=4)
ax.scatter(states_[-1][-1,:1], states_[-1][-1,1:], marker="*", c='yellow', edgecolor="k", s=200, label="end", zorder=4)

ax.set_xlabel("x position")
ax.set_ylabel("y position")
ax.grid()
ax.legend()
ax.axis("equal")
ax.set_title("Trajectory")

# plot x, y
ax = axs[1]
# ax.plot(ts, states_[-1][:-1,:1], label="x")
# ax.plot(ts, states_[-1][:-1,1:], label="y")
ax.plot(ts, distance_to_origin.predicate_function(states_[-1][1:]).squeeze(), label="distance to origin")
ax.hlines(reach_radius, ts[0], ts[-1], color="C0", linestyle="--", label="reach radius")
ax.plot(ts, distance_to_target.predicate_function(states_[-1][1:]).squeeze(), label="distance to target")
ax.hlines(target_radius, ts[0], ts[-1], color="C1", linestyle="--", label="target radius")
ax.grid()
ax.axis("equal")
ax.legend()
ax.set_xlabel("Time (s)")
ax.set_ylabel("Distance")
ax.set_title("Distance to regions over time")

# plot control signal
ax = axs[2]
ax.plot(ts, controls[:,:1], label="x control")
ax.plot(ts, controls[:,1:], label="y control")
ax.plot(ts, jnp.linalg.norm(controls, axis=-1).squeeze(), label="control norm")
ax.grid()
ax.axis("equal")
ax.legend(ncols=3)
ax.set_xlabel("Time (s)")
ax.set_ylabel("Controls")
ax.set_title("Control sequence")

plt.tight_layout()


In [None]:
fontsize = 14
fig, ax = plt.subplots(figsize=(5, 5))

def visualize_solution(ax, i):

    circle1 = plt.Circle((0, 0), 0.2, color='C2', alpha=0.7)
    circle2 = plt.Circle(target_center[0], 0.5, color='C1', alpha=0.7)


    # [ax.plot(*s.T, color="blue", label="Initial traj", alpha=0.6) for s in states_[:1]]
    # [ax.plot(*s.T, '.-', color="r", label="Final traj", zorder=10, linewidth=2, markersize=8) for s in states_[-1:]]
    ax.scatter(states_[i][0,:1], states_[i][0,1:], zorder=10, label="start", color="yellow", edgecolor="black", marker="^", s=100)
    ax.scatter(states_[i][-1,:1], states_[i][-1,1:], zorder=10, label="end", color="yellow", edgecolor="black", marker="*", s=200)
    ax.plot(*states_[i].T, 'o-', color="k", alpha=0.4, zorder=9, linewidth=2)

    ax.add_patch(circle1)
    ax.add_patch(circle2)

    ax.annotate("Goal", (-0.4, -0.4), fontsize=fontsize-2)
    ax.annotate("Target", (-0.4, 2.7), fontsize=fontsize-2)

    ax.set_xlabel("$x$ position [m]", fontsize=fontsize, labelpad=-2)
    ax.set_ylabel("$y$ position [m]", fontsize=fontsize)
    ax.set_title("Robustness $\\rho$ = %.2f"%formula.robustness(states_[i], t_start=jax.nn.sigmoid(t_start), t_end=jax.nn.sigmoid(t_end), scale=1000.), fontsize=fontsize)
    ax.grid(zorder=-6, alpha=0.5)
    ax.legend(["Start", "End", "Solution $i=%i$"%i], ncol=1, fontsize=fontsize-3)
    ax.axis("equal")
    ax.set_xlim(-1.0, 3.5)
    ax.set_ylim(-1.0, 3.5)
    plt.tight_layout()

visualize_solution(ax, -1)


In [None]:
fig, ax = plt.subplots(figsize=(5, 5))
N = 250  # show intermediate solutions at every N iterations

def animate(idx):
    ax.clear()
    visualize_solution(ax, idx)

frames = range(0, len(states_)-1, N)
ani = animation.FuncAnimation(fig, animate, frames=frames, interval=100, repeat=False)

plt.close(fig)  # Prevent duplicate static plot in notebook output
HTML(ani.to_jshtml()) # Display the animation in the notebook

In [None]:
# generate a static plot of the final solution, from paper

fontsize = 14
fig, ax = plt.subplots(figsize=(5, 5))

circle1 = plt.Circle((0, 0), 0.2, color='C2', alpha=0.7)
circle2 = plt.Circle(target_center[0], 0.5, color='C1', alpha=0.7)

N = 250 # show intermediate solutions at everge N iterations

[ax.plot(*s.T, color="blue", label="Initial traj", alpha=0.6) for s in states_[:1]]
[ax.plot(*s.T, '.-', color="r", label="Final traj", zorder=10, linewidth=2, markersize=8) for s in states_[-1:]]
ax.scatter(states_[-1][0,:1], states_[-1][0,1:], zorder=10, label="start", color="yellow", edgecolor="black", marker="^", s=100)
ax.scatter(states_[-1][-1,:1], states_[-1][-1,1:], zorder=10, label="end", color="yellow", edgecolor="black", marker="*", s=200)
[ax.plot(*s.T, color="k", alpha=0.2, label="Iterations", zorder=0) for s in states_[::N]]

ax.add_patch(circle1)
ax.add_patch(circle2)

ax.annotate("Goal", (-0.4, -0.4), fontsize=fontsize-2)
ax.annotate("Target", (-0.4, 2.7), fontsize=fontsize-2)

ax.set_xlabel("$x$ position [m]", fontsize=fontsize, labelpad=-2)
ax.set_ylabel("$y$ position [m]", fontsize=fontsize)
ax.set_title("Robustness $\\rho$ = %.2f"%formula.robustness(states_[-1], t_start=jax.nn.sigmoid(t_start), t_end=jax.nn.sigmoid(t_end), scale=1000.), fontsize=fontsize)
ax.grid(zorder=-6, alpha=0.5)
ax.legend(["Initial guess", "Final trajectory", "Start", "End"], ncol=1, fontsize=fontsize-3)
ax.axis("equal")
plt.tight_layout()
