In [None]:
import timeit
from jax import jit
from jax import lax
import jax.numpy as jnp 
from numba import njit

vel = 1.0
accl = 4.0
v_switch = 1.0
a_max = 5.0
v_min = 0.5
v_max = 20.0

@jit
def accl_constraints(vel, accl, v_switch, a_max, v_min, v_max):
    """Acceleration constraints, adjusts the acceleration based on constraints.

        Args:
            vel (float): current velocity of the vehicle
            accl (float): unconstraint desired acceleration
            v_switch (float): switching velocity (velocity at which the acceleration is no longer able to
                create wheel spin)
            a_max (float): maximum allowed acceleration
            v_min (float): minimum allowed velocity
            v_max (float): maximum allowed velocity

        Returns:
            accl (float): adjusted acceleration
        """
    pos_limit = lax.cond(vel > v_switch, lambda _: a_max * v_switch / vel, lambda _: a_max, vel)
    accl_limit = lax.cond(
        lax.bitwise_and(vel <= v_min, vel <= 0) | lax.bitwise_and(vel >= v_max, vel >= 0),
        lambda _: 0.,
        lambda _: lax.cond(
            vel <= v_min,
            lambda _: lax.cond(accl <= -a_max, lambda _: -a_max, lambda _: pos_limit, accl),
            lambda _: lax.cond(accl >= pos_limit, lambda _: pos_limit, lambda _: accl, accl),
            vel
        ),
        vel
    )
    return accl_limit


def multiple_calls_jax(vel, accl, v_switch, a_max, v_min, v_max):  
    for i in range(1, 1000):
        vel = i/100
        accl_limit = accl_constraints(vel, accl, v_switch, a_max, v_min, v_max)
    return accl_limit



callable_object = lambda: multiple_calls_jax(vel, accl, v_switch, a_max, v_min, v_max)
time = timeit.timeit(callable_object, number=100)
print(f"Time of function with floats: {time}")
time = timeit.timeit(callable_object, number=100)
print(f"Time of function with floats: {time}")
time = timeit.timeit(callable_object, number=100)
print(f"Time of function with floats: {time}")

vel = jnp.array([1.0])
accl = jnp.array([4.0])
v_switch = jnp.array([1.0])
a_max = jnp.array([5.0])
v_min = jnp.array([0.5])
v_max = jnp.array([20.0])


def array_accl_constraints(vel, accl, v_switch, a_max, v_min, v_max):
    """Acceleration constraints, adjusts the acceleration based on constraints.

        Args:
            vel (float): current velocity of the vehicle
            accl (float): unconstraint desired acceleration
            v_switch (float): switching velocity (velocity at which the acceleration is no longer able to
                create wheel spin)
            a_max (float): maximum allowed acceleration
            v_min (float): minimum allowed velocity
            v_max (float): maximum allowed velocity

        Returns:
            accl (float): adjusted acceleration
        """
    import pdb
    pdb.set_trace()
    pos_limit = lax.cond(vel > v_switch, lambda _: a_max * v_switch / vel, lambda _: a_max, vel)
    accl_limit = lax.cond(
        lax.bitwise_and(vel <= v_min, vel <= 0) | lax.bitwise_and(vel >= v_max, vel >= 0),
        lambda _: 0.,
        lambda _: lax.cond(
            vel <= v_min,
            lambda _: lax.cond(accl <= -a_max, lambda _: -a_max, lambda _: pos_limit, accl),
            lambda _: lax.cond(accl >= pos_limit, lambda _: pos_limit, lambda _: accl, accl),
            vel
        ),
        vel
    )
    return accl_limit

def multiple_calls_array(vel, accl, v_switch, a_max, v_min, v_max):  
    for i in range(1, 1000):
        vel = i/100
        accl_limit = array_accl_constraints(vel, accl, v_switch, a_max, v_min, v_max)
    return accl_limit

callable_object = lambda: multiple_calls_array(vel, accl, v_switch, a_max, v_min, v_max)
time = timeit.timeit(callable_object, number=100)
print(f"Time of function with arrays: {time}")
time = timeit.timeit(callable_object, number=100)
print(f"Time of function with arrays: {time}")
time = timeit.timeit(callable_object, number=100)
print(f"Time of function with arrays: {time}")

In [4]:
import jax.numpy as jnp
import numpy as np
from jax import jit 
from jax import lax 


@jit
def accl_constraints(vel, accl, v_switch, a_max, v_min, v_max):
    """Acceleration constraints, adjusts the acceleration based on constraints.

        Args:
            vel (float): current velocity of the vehicle
            accl (float): unconstraint desired acceleration
            v_switch (float): switching velocity (velocity at which the acceleration is no longer able to
                create wheel spin)
            a_max (float): maximum allowed acceleration
            v_min (float): minimum allowed velocity
            v_max (float): maximum allowed velocity

        Returns:
            accl (float): adjusted acceleration
        """
    pos_limit = lax.cond(vel > v_switch, lambda _: a_max * v_switch / vel, lambda _: a_max, vel)
    accl_limit = lax.cond(
        lax.bitwise_and(vel <= v_min, vel <= 0) | lax.bitwise_and(vel >= v_max, vel >= 0),
        lambda _: 0.,
        lambda _: lax.cond(
            vel <= v_min,
            lambda _: lax.cond(accl <= -a_max, lambda _: -a_max, lambda _: pos_limit, accl),
            lambda _: lax.cond(accl >= pos_limit, lambda _: pos_limit, lambda _: accl, accl),
            vel
        ),
        vel
    )
    return accl_limit


@jit
def steering_constraint(steering_angle, steering_velocity, s_min, s_max, sv_min, sv_max):
    """Steering constraints, adjusts the steering velocity based on constraints.

    Args:
        steering_angle (float): current steering_angle of the vehicle
        steering_velocity (float): unconstraint desired steering_velocity
        s_min (float): minimum steering angle
        s_max (float): maximum steering angle
        sv_min (float): minimum steering velocity
        sv_max (float): maximum steering velocity

    Returns:
        steering_velocity (float): adjusted steering velocity
    """

    # constraint steering velocity

    steering_velocity = lax.cond(
        (steering_angle <= s_min) | (steering_angle >= s_max),
        lambda _: 0.,
        lambda _: lax.cond(
            steering_velocity <= sv_min,
            lambda _: lax.cond(steering_velocity <= sv_min, lambda _: sv_min, lambda _: steering_velocity, steering_velocity),
            lambda _: lax.cond(steering_velocity >= sv_max, lambda _: sv_max, lambda _: steering_velocity, steering_velocity),
            steering_velocity
        ),
        steering_angle
    )
    return steering_velocity

@jit
def vehicle_dynamics_ks(x, u_init, mu, C_Sf, C_Sr, lf, lr, h, m, I, s_min, s_max, sv_min, sv_max, v_switch, a_max,
                        v_min, v_max):
    """Single Track Kinematic Vehicle Dynamics.

    Args:
        x (numpy.ndarray (3, )): vehicle state vector (x1, x2, x3, x4, x5)
            x1: x position in global coordinates
            x2: y position in global coordinates
            x3: steering angle of front wheels
            x4: velocity in x direction
            x5: yaw angle
        u (numpy.ndarray (2, )): control input vector (u1, u2)
            u1: steering angle velocity of front wheels
            u2: longitudinal acceleration

    Returns:
        f (numpy.ndarray): right hand side of differential equations
    """
    # wheelbase
    lwb = lf + lr

    # constraints
    u = jnp.array([steering_constraint(x[2], u_init[0], s_min, s_max, sv_min, sv_max),
                  accl_constraints(x[3], u_init[1], v_switch, a_max, v_min, v_max)])
    # this is legacy code for now, i don't know if these arrays are mutated in place somewhere, so am using regular numpy arrays for now

    # system dynamics
    f = jnp.array([x[3] * jnp.cos(x[4]),
                  x[3] * jnp.sin(x[4]),
                  u[0],
                  u[1],
                  x[3] / lwb * jnp.tan(x[2])])

    return f

@jit
def vehicle_dynamics_st(x, u_init, mu, C_Sf, C_Sr, lf, lr, h, m, I, s_min, s_max, sv_min, sv_max, v_switch, a_max,
                        v_min, v_max):
    """Single Track Dynamic Vehicle Dynamics.

    Args:
        x (numpy.ndarray (3, )): vehicle state vector (x1, x2, x3, x4, x5, x6, x7)
            x1: x position in global coordinates
            x2: y position in global coordinates
            x3: steering angle of front wheels
            x4: velocity in x direction
            x5: yaw angle
            x6: yaw rate
            x7: slip angle at vehicle center
        u (numpy.ndarray (2, )): control input vector (u1, u2)
            u1: steering angle velocity of front wheels
            u2: longitudinal acceleration

    Returns:
        f (numpy.ndarray): right hand side of differential equations
    """
    # Note that the indexing of the x vector in the above legacy documentation is a bit misleading because
    # indexing of vectors of course starts at 0

    # gravity
    g = 9.81

    # constraints
    u = jnp.array(
        [
            steering_constraint(x[2], u_init[0], s_min, s_max, sv_min, sv_max),
            accl_constraints(x[3], u_init[1], v_switch, a_max, v_min, v_max)
        ]
    )
    # u_init only used here
    '''
    A vector ƒrom all the functions arguments is created here to enable passing within lax.cond()
    vec[0] : x1
    vec[1] : x2
    vec[2] : x3
    vec[3] : x4
    vec[4] : x5
    vec[5] : x6
    vec[6] : x7
    vec[7] : u1
    vec[8] : u2
    vec[9] : mu
    vec[10] : C_Sf
    vec[11] : C_Sr
    vec[12] : lf
    vec[13] : lr
    vec[14] : h 
    vec[15] : m
    vec[16] : I
    vec[17] : s_min
    vec[18] : s_max
    vec[19] : sv_min 
    vec[20] : sv_max
    vec[21] : v_switch
    vec[22] : a_max
    vec[23] : v_min
    vec[24] : v_max
    '''
    vec = jnp.array([
        x[0], x[1], x[2], x[3], x[4], x[5], x[6], u[0], u[1], mu, C_Sf, C_Sr, lf, lr, h, m, I, s_min, s_max,
         sv_min, sv_max, v_switch, a_max,
         v_min, v_max
         ])

    def true_fun(vec):
        x_ks = vec[0:5]
        u = vec[7:9]
        mu = vec[9]
        C_Sf = vec[10]
        C_Sr = vec[11]
        lf = vec[12]
        lr = vec[13]
        h = vec[14]
        m = vec[15]
        I = vec[16]
        s_min = vec[17]
        s_max = vec[18]
        sv_min = vec[19]
        sv_max = vec[20]
        v_switch = vec[21]
        a_max = vec[22]
        v_min = vec[23]
        v_max = vec[24]
        # wheelbase
        lwb = lf + lr

        # vehicle dynamics
        f_ks = vehicle_dynamics_ks(x_ks, u, mu, C_Sf, C_Sr, lf, lr, h, m, I, s_min, s_max, sv_min, sv_max, v_switch,
                                   a_max, v_min, v_max)

        f = jnp.hstack((f_ks, jnp.array([u[1] / lwb * jnp.tan(x[2]) + x[3] / (lwb * jnp.cos(x[2]) ** 2) * u[0],
                                       0])))

        return f

    def false_fun(vec):
        u = vec[7:9]
        mu = vec[9]
        C_Sf = vec[10]
        C_Sr = vec[11]
        lf = vec[12]
        lr = vec[13]
        h = vec[14]
        m = vec[15]
        I = vec[16]

        # system dynamics
        f = jnp.array([
            x[3] * jnp.cos(x[6] + x[4]),
            x[3] * jnp.sin(x[6] + x[4]),
            u[0],
            u[1],
            x[5],
            -mu * m / (x[3] * I * (lr + lf)) * (
                    lf ** 2 * C_Sf * (g * lr - u[1] * h) + lr ** 2 * C_Sr * (g * lf + u[1] * h)) * x[5] \
            + mu * m / (I * (lr + lf)) * (lr * C_Sr * (g * lf + u[1] * h) - lf * C_Sf * (g * lr - u[1] * h)) * x[6] \
            + mu * m / (I * (lr + lf)) * lf * C_Sf * (g * lr - u[1] * h) * x[2],
            (mu / (x[3] ** 2 * (lr + lf) + 1e-5) * (
                    C_Sr * (g * lf + u[1] * h) * lr - C_Sf * (g * lr - u[1] * h) * lf) - 1) *
            x[5] \
            - mu / (x[3] * (lr + lf)) * (C_Sr * (g * lf + u[1] * h) + C_Sf * (g * lr - u[1] * h)) * x[6] \
            + mu / (x[3] * (lr + lf)) * (C_Sf * (g * lr - u[1] * h)) * x[2]])

        return f

    # switch to kinematic model for small velocities
    f = lax.cond(vec[3] < 1, true_fun, false_fun, vec)
    return f

vel = 1.0
accl = 4.0
v_switch = 1.0
a_max = 5.0
v_min = 0.5
v_max = 20.0
u_init = [2.0, 2.0]
mu = 3.0
C_Sf = 1.0
C_Sr = 1.0
lf = 1.0
lr = 1.0
h = 2.0
m = 10
I = 3.5
s_min = 2.0
s_max = 10.0
sv_min = 1.0
sv_max = 10.0
x = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]


call_vehicle_dynamics = vehicle_dynamics_st(x, u_init, mu, C_Sf, C_Sr, lf, lr, h, m, I, s_min, s_max, sv_min, sv_max, v_switch, a_max,
                        v_min, v_max)
print(f"{call_vehicle_dynamics}")

[ -0.41614684   0.9092974    0.           2.           1.
 -24.900002    -9.715061  ]


In [16]:
import jax.numpy as jnp 

state = jnp.array([0, 0, 0, 0, 0, 0, 0])
steer_buffer = jnp.array([0, 0])
steer_buffer_size = jnp.array([2])
pose_of_other_vehicles = jnp.zeros((6, 3))
a_x = jnp.array([0])
a_y = jnp.array([0])

agent_array = jnp.concatenate(state, steer_buffer, steer_buffer_size, a_x, a_y)
agent_array2 = jnp.concatenate((state, steer_buffer, steer_buffer_size, a_x, a_y))

agents_array = jnp.array([agent_array, agent_array2])

print(f"agent_array: {agent_array}")
print(f"agent_array shape: {agent_array.shape}")
print(f"agents_array: {agents_array}")
print(f"agents_array shape: {agents_array.shape}")

TypeError: concatenate() takes from 1 to 3 positional arguments but 5 were given

In [7]:
import jax.numpy as jnp 

arr1 = jnp.array([1])
arr2 = jnp.array([3,4])

print(f"arr1: {arr1}")
print(f"arr2: {arr2}")

arr1: [1]
arr2: [3 4]


In [33]:
import jax.numpy as jnp

poses = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

n = poses.shape[0]
opp_poses = jnp.zeros((n, n-1, 3))

for i in range(n): 
    _ = jnp.delete(poses, i, 0)
    opp_poses = opp_poses.at[i].set(_) 
    
print(f"opp_poses: {opp_poses}")



opp_poses: [[[4. 5. 6.]
  [7. 8. 9.]]

 [[1. 2. 3.]
  [7. 8. 9.]]

 [[1. 2. 3.]
  [4. 5. 6.]]]


In [34]:
import jax.numpy as jnp

# Original array
original_array = jnp.array([1, 2, 3, 4, 5])

# Arrays representing changes
change1 = jnp.array([0, 0, 0, 0, 0])
change2 = jnp.array([0, 0, 10, 0, 0])
change3 = jnp.array([0, 0, 0, 0, 5])

# Combine changes with the original array
combined_array = original_array + change1 + change2 + change3

print(combined_array)


[ 1  2 13  4 10]


In [78]:
import jax.numpy as jnp 
from jax.lax import scan 

arr1 = jnp.array([1, 1, 1, 4, 0])
arr2 = jnp.array([1, 2, 1, 4, 4])
arr3 = jnp.array([1, 1, 3, 4, 4])
arr4 = jnp.array([1, 1, 1, 4, 5])


def combine(original, input):
    mask = original != input
    output = original
    for row_mask, row_input in zip(mask,input): 
        output = jnp.where(row_mask, row_input, output)

    return output
    

res = combine(arr1, jnp.array([arr2, arr3, arr4]))
print(res)

[1 2 3 4 5]
