In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from signed_distance import *
from nlp_builder import NLPBuilder
from nlp_solver import SQPSolver
from functools import partial

In [2]:
# problem
num_steps = 10
center_obs = jnp.array([0, 0.1])
r_obs = 0.2

dim_robot = 2
r_robot = 0.1
safe_dist = 0.1

q_init = jnp.array([-1, -1])
q_goal = jnp.array([1, 1])
qu, ql = jnp.full(2, 1.1), jnp.full(2, -1.1)

In [3]:
#util functions
def to_mat(x):
    return x.reshape(-1, dim_robot)
def to_vec(mat):
    return mat.flatten()
def at_timestep(i, x):
    return to_mat(x)[i]
def to_vel(x):
    return (to_mat(x)[1:] - to_mat(x)[:-1]).flatten()

In [4]:
# define robot
x_robot = r_robot*jnp.cos(jnp.linspace(0, jnp.pi*2, 10))
y_robot = r_robot*jnp.sin(jnp.linspace(0, jnp.pi*2, 10))
p_robot = jnp.vstack([x_robot, y_robot]).T

In [5]:
# functions
obs = Circle(center_obs, r_obs)
env = EnvSDF((obs,), safe_dist) # exclude r_robot in safe_dist
assign_points = lambda q: p_robot + q
assign_points_path = lambda x: jnp.vstack(jax.vmap(assign_points)(to_mat(x)))

def min_dist_cost(x):
    v = to_vel(x)
    return v @ v
state_init = partial(at_timestep, 0)
state_goal = partial(at_timestep,-1)
penetration = lambda x: env.penetrations(assign_points_path(x))

xl = jnp.tile(ql, num_steps)
xu = jnp.tile(qu, num_steps)
x0 = jnp.zeros_like(xu)
#jnp.linspace(q_init, q_goal, num_steps).flatten()

In [6]:
#build NLP problem 
dim = num_steps * dim_robot
nlp = NLPBuilder(dim=dim)
nlp.set_f(min_dist_cost)
nlp.add_eq_const(state_init, q_init)
nlp.add_eq_const(state_goal, q_goal)
nlp.add_eq_const(penetration, 0.)
nlp.set_state_bound(xl, xu)

In [7]:
solver = SQPSolver.from_builder(nlp)
solver.prebuild()

In [8]:
xsol = solver.solve(x0, tol=0.04, verbose=True, save_history=True)

0: grad:0.0611 | viol:0.0302 | alpha:1.0000
1: grad:0.1809 | viol:0.0034 | alpha:1.0000
2: grad:0.1539 | viol:0.0012 | alpha:0.8000
3: grad:0.1629 | viol:0.0028 | alpha:1.0000
4: grad:0.1597 | viol:0.0028 | alpha:0.1678
5: grad:0.0398 | viol:0.0009 | alpha:1.0000
SQP solved !


In [9]:
xs = [s.x for s in solver.history]
len(xs)

6

In [10]:
i= 0

In [17]:
import matplotlib.pyplot as plt
qs = to_mat(xs[i])

fig, ax = plt.subplots()
robots = []
if i>=1:
    qs_prev = to_mat(xs[i-1])
    ax.scatter(*assign_points_path(qs_prev).T, color='gray')
ax.scatter(*assign_points_path(qs).T, color='blue')

circle = plt.Circle(center_obs, r_obs, color='r')
xlim, ylim = jnp.vstack([ql, qu]).T
ax.add_patch(circle)
#ax = plt.scatter(*qs.T)
ax.set_xlim(xlim)
ax.set_ylim(ylim)
i += 1

IndexError: list index out of range