In [1]:
import jax
import jax.numpy as jnp

In [2]:
a = jnp.array([ 1,2,3,4,5,6 ]).reshape(-1,1)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [4]:
a

Array([[1],
       [2],
       [3],
       [4],
       [5],
       [6]], dtype=int32)

In [6]:
b = a.reshape((3,2), order='F')

In [7]:
b

Array([[1, 4],
       [2, 5],
       [3, 6]], dtype=int32)

In [9]:
b.reshape(-1,1, order='F')

Array([[1],
       [2],
       [3],
       [4],
       [5],
       [6]], dtype=int32)

In [14]:
arr1 = jnp.array([[1.0]])
arr2 = jnp.array([[1.0,2.0,3.0,4.0,5.0]])
arr3 = jnp.array([[6.0,7.0,8.0,9.0,10.0]])
arr = jnp.concatenate((arr1,arr2,arr3),axis=1)
print(arr[0])
print(arr.shape)
print(arr)

[ 1.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10.]
(1, 11)
[[ 1.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10.]]


In [15]:
def get_mean( sigma_points, weights ):
    weighted_points = sigma_points * weights[0]
    mu = jnp.sum( weighted_points, 1 ).reshape(-1,1)
    return mu


In [16]:
def get_mean_cov(sigma_points, weights):
    
    # mean
    weighted_points = sigma_points * weights[0]
    mu = jnp.sum( weighted_points, 1 ).reshape(-1,1)
    
    # covariance
    centered_points = sigma_points - mu
    cov = jnp.diag(jnp.sum(centered_points**2 * weights[0], axis=1))
    return mu, cov


In [17]:
def get_ut_cov_root_diagonal(cov):
    offset = 0.000  # TODO: make sure not zero here
    root_term = jnp.diag( jnp.sqrt(jnp.diagonal(cov)+offset)  )
    return root_term


In [18]:
def get_ut_cov_root_diagonal(cov):
    offset = 0.000  # TODO: make sure not zero here
    root_term = jnp.diag( jnp.sqrt(jnp.diagonal(cov)+offset)  )
    return root_term
def get_mean_cov_skew_kurt( sigma_points, weights ):
    # mean
    weighted_points = sigma_points * weights[0]
    mu = jnp.sum( weighted_points, 1 ).reshape(-1,1)    
    centered_points = sigma_points - mu    
    cov = jnp.diag(jnp.sum(centered_points**2 * weights[0], axis=1))
    
    skewness = jnp.sum(centered_points**3 * weights[0], axis=1) #/ cov[0,0]**(3/2) # for scipy    
    kurt = jnp.sum(centered_points**4 * weights[0], axis=1)# / cov[0,0]**(4/2)  # -3 # -3 for scipy
    return mu, cov, skewness.reshape(-1,1), kurt.reshape(-1,1)

def get_mean_cov_skew_kurt_for_generation( sigma_points, weights ):
    # mean
    weighted_points = sigma_points * weights[0]
    mu = jnp.sum( weighted_points, 1 ).reshape(-1,1)    
    centered_points = sigma_points - mu    
    cov = jnp.diag(jnp.sum(centered_points**2 * weights[0], axis=1))

    skewness_temp = jnp.sum(centered_points**3 * weights[0], axis=1) #/ cov[0,0]**(3/2) # for scipy    
    skewness = skewness_temp / jnp.diag(cov)**(3/2)
    
    kurt_temp = jnp.sum(centered_points**4 * weights[0], axis=1)# / cov[0,0]**(4/2)  # -3 # -3 for scipy
    kurt = kurt_temp / jnp.diag(cov)**(4/2)

    return mu, cov, skewness.reshape(-1,1), kurt.reshape(-1,1)

def generate_sigma_points_gaussian( mu, cov_root, base_term, factor ):
    '''
    Standard UT: Generate sigma points based on a vector of (n,) mus, resulting in 2n+1 points for each dim
    input:
    mu: (n,1)
   
    output:
    new_weights: (1, 2n+1)
    new_points: (n, 2n+1)
    '''
    n = mu.shape[0]     
    N = 2*n + 1 # total points

    alpha = 1.0
    beta = 0.0#2.0#2.0 # optimal for gaussian
    k = 1.0
    Lambda = alpha**2 * ( n+k ) - n
    
    points0 = base_term + mu * factor
    points1 = base_term + (mu + jnp.sqrt(n+Lambda) * cov_root)*factor
    points2 = base_term + (mu - jnp.sqrt(n+Lambda) * cov_root)*factor
    
    weights0 = jnp.array([[ 1.0*Lambda/(n+Lambda) ]])
    weights1 = jnp.ones((1,n)) * 1.0/(n+Lambda)/2.0
    weights2 = jnp.ones((1,n)) * 1.0/(n+Lambda)/2.0

    new_points = jnp.concatenate((points0, points1, points2), axis=1)
    new_weights = jnp.concatenate((weights0, weights1, weights2), axis=1)
    
    return new_points, new_weights    

def generate_sigma_points_gaussian_GenUT( mu, cov_root, skewness, kurt, base_term, factor ):
    n = mu.shape[0]     
    N = 2*n + 1 # total points
    u = 0.5 * ( - skewness + jnp.sqrt( 4 * kurt - 3 * ( skewness )**2 ) )
    v = u + skewness

    w2 = (1.0 / v) / (u+v)
    w1 = (w2 * v) / u
    w0 = jnp.array([1 - jnp.sum(w1) - jnp.sum(w2)])
    
    U = jnp.diag(u[:,0])
    V = jnp.diag(v[:,0])
    points0 = base_term + mu * factor
    points1 = base_term + (mu - cov_root @ U) * factor
    points2 = base_term + (mu + cov_root @ V) * factor
    new_points = jnp.concatenate( (points0, points1, points2), axis=1 )
    new_weights = jnp.concatenate( (w0.reshape(-1,1), w1.reshape(1,-1), w2.reshape(1,-1)), axis=1 )

    return new_points, new_weights

def sigma_point_expand_with_mean_cov( mus, covs, weights ):
    n, N = mus.shape # n=6, N=13
    new_points = jnp.zeros((n*(2*n+1),N))
    new_weights = jnp.zeros((2*n+1,N))

    def body(i, inputs):
        new_points, new_weights = inputs        
        mu, cov = mus[:,[i]], jnp.diag(covs[:,i])

        # Albus: look these 2 lines ###########################
        root_term = get_ut_cov_root_diagonal(cov)           
        temp_points, temp_weights = generate_sigma_points_gaussian( mu, root_term, jnp.zeros((n,1)), 1.0 )
        ##############################################

        new_points = new_points.at[:,i].set( temp_points.reshape(-1,1, order='F')[:,0] )
        new_weights = new_weights.at[:,i].set( temp_weights.reshape(-1,1, order='F')[:,0] * weights[:,i] )   
        return new_points, new_weights
    new_points, new_weights = lax.fori_loop(0, N, body, (new_points, new_weights))
    return new_points.reshape((n, N*(2*n+1)), order='F'), new_weights.reshape((1,n*(2*n+1)), order='F')

def sigma_point_compress( sigma_points, weights ):
    mu, cov = get_mean_cov( sigma_points, weights )
    cov_root_term = get_ut_cov_root_diagonal( cov )  
    base_term = jnp.zeros((mu.shape))
    return generate_sigma_points_gaussian( mu, cov_root_term, base_term, jnp.array([1.0]) )

def sigma_point_compress_GenUT( sigma_points, weights ):
    mu, cov, skewness, kurt = get_mean_cov_skew_kurt_for_generation( sigma_points, weights )
    cov_root_term = get_ut_cov_root_diagonal( cov )  
    base_term = jnp.zeros((mu.shape))
    return generate_sigma_points_gaussian_GenUT( mu, cov_root_term, skewness, kurt, base_term, jnp.array([1.0]) )

In [26]:
key = jax.random.PRNGKey(0)

# Generate a 6x13 array
sigma_points = jax.random.uniform(key, (6, 13))

# Generate a new key for the next random array
key, subkey = jax.random.split(key)

# Generate a 1x13 array
weights = jax.random.uniform(subkey, (1, 13))
print("weights: ", weights)
print("weights[0]: ", weights[0])
mu = get_mean(sigma_points, weights)
print(mu)
key = jax.random.PRNGKey(0)

# Generate a 6x13 array
sigma_points = jax.random.uniform(key, (6, 13))

# Generate a new key for the next random array
key, subkey = jax.random.split(key)

# Generate a 1x13 array
weights = jax.random.uniform(subkey, (1, 13))
mu,cov = get_mean_cov(sigma_points, weights)
print(mu)
print(cov)


root_cov = get_ut_cov_root_diagonal(cov)
print(root_cov)


weights:  [[0.9158251  0.21620357 0.26835215 0.6011201  0.437374   0.8539797
  0.7195103  0.13791871 0.3185042  0.7317047  0.9511024  0.3303691
  0.24733603]]
weights[0]:  [0.9158251  0.21620357 0.26835215 0.6011201  0.437374   0.8539797
 0.7195103  0.13791871 0.3185042  0.7317047  0.9511024  0.3303691
 0.24733603]
[[3.8098347]
 [2.6275427]
 [4.102255 ]
 [3.8639886]
 [4.2684965]
 [2.830641 ]]
[[3.8098347]
 [2.6275427]
 [4.102255 ]
 [3.8639886]
 [4.2684965]
 [2.830641 ]]
[[71.174065  0.        0.        0.        0.        0.      ]
 [ 0.       33.973057  0.        0.        0.        0.      ]
 [ 0.        0.       82.579926  0.        0.        0.      ]
 [ 0.        0.        0.       73.28244   0.        0.      ]
 [ 0.        0.        0.        0.       89.51134   0.      ]
 [ 0.        0.        0.        0.        0.       39.613228]]
[[8.436472 0.       0.       0.       0.       0.      ]
 [0.       5.828641 0.       0.       0.       0.      ]
 [0.       0.       9.08735  0. 

In [22]:

mu, cov, skew, kurt = get_mean_cov_skew_kurt_for_generation(sigma_points, weights)
print(mu)
print(cov)
print(skew)
print(kurt)

[[3.8098347]
 [2.6275427]
 [4.102255 ]
 [3.8639886]
 [4.2684965]
 [2.830641 ]]
[[71.174065  0.        0.        0.        0.        0.      ]
 [ 0.       33.973057  0.        0.        0.        0.      ]
 [ 0.        0.       82.579926  0.        0.        0.      ]
 [ 0.        0.        0.       73.28244   0.        0.      ]
 [ 0.        0.        0.        0.       89.51134   0.      ]
 [ 0.        0.        0.        0.        0.       39.613228]]
[[-0.3885442 ]
 [-0.39047277]
 [-0.389005  ]
 [-0.3891043 ]
 [-0.3896621 ]
 [-0.39276117]]
[[0.15176381]
 [0.1537108 ]
 [0.15225746]
 [0.1523434 ]
 [0.1529323 ]
 [0.15589994]]


In [28]:
import numpy as np
def generate_state_vector(key, n):
    return jax.random.normal(key, (n, 1))

key = jax.random.PRNGKey(0)  # Initialize the random key
n = 6  # Size of the state vector
state_vector = generate_state_vector(key, n)
def initialize_sigma_points(X):
    '''
    Returns Equally weighted Sigma Particles
    '''
    # return 2N + 1 points
    n = X.shape[0]
    num_points = 2*n + 1
    sigma_points = np.repeat( X, num_points, axis=1 )
    weights = np.ones((1,num_points)) * 1.0/( num_points )
    return sigma_points, weights
initialize_sigma_points(state_vector)

(Array([[ 0.18784384,  0.18784384,  0.18784384,  0.18784384,  0.18784384,
          0.18784384,  0.18784384,  0.18784384,  0.18784384,  0.18784384,
          0.18784384,  0.18784384,  0.18784384],
        [-1.2833426 , -1.2833426 , -1.2833426 , -1.2833426 , -1.2833426 ,
         -1.2833426 , -1.2833426 , -1.2833426 , -1.2833426 , -1.2833426 ,
         -1.2833426 , -1.2833426 , -1.2833426 ],
        [ 0.6494181 ,  0.6494181 ,  0.6494181 ,  0.6494181 ,  0.6494181 ,
          0.6494181 ,  0.6494181 ,  0.6494181 ,  0.6494181 ,  0.6494181 ,
          0.6494181 ,  0.6494181 ,  0.6494181 ],
        [ 1.2490594 ,  1.2490594 ,  1.2490594 ,  1.2490594 ,  1.2490594 ,
          1.2490594 ,  1.2490594 ,  1.2490594 ,  1.2490594 ,  1.2490594 ,
          1.2490594 ,  1.2490594 ,  1.2490594 ],
        [ 0.24447003,  0.24447003,  0.24447003,  0.24447003,  0.24447003,
          0.24447003,  0.24447003,  0.24447003,  0.24447003,  0.24447003,
          0.24447003,  0.24447003,  0.24447003],
        [-0.117

In [29]:

cir_radius =      [0.3, 0.3, 0.4, 0.4, 0.4, 0.4]
cir_angular_vel = [1.5, 1.5, 2.0, 2.5, 3.0, 3.0]
###### in world frame NOT NED ######
cir_origin_x =    [0.0, 0.4, 0.4, 0.6, 0.8, 1.0]
cir_origin_y =    [0.4, 0.0, 0.0, 0.0, 0.0, 0.0]


figure8_radius =        [0.2, 0.2, 0.2, 0.4, 0.4, 0.4, 0.4, 0.4, 0.6]
figure8_angular_vel =   [1.5, 2.0, 2.5, 1.5, 1.5, 2.0, 2.0, 2.5, 1.5]
figure8_origin_x =      [0.8, 1.2, 1.2, 1.2, 1.0, 0.4, 1.0, 1.0, 1.0]
figure8_origin_y =      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
def circle_pos_vel_acc(deltaT, radius, angular_vel, origin_x, origin_y):
    '''
    Calculate reference pos, vel, and acc for the drone flying in a circular trajectory in NED frame.
    '''

    ######################################################
    ################## Reference Pos #####################
    ######################################################
    x = radius * jnp.cos(angular_vel * deltaT) + origin_x
    y = radius * jnp.sin(angular_vel * deltaT) + origin_y
    ref_pos = jnp.array([y,  x, -0.4])

    ######################################################
    ################## Reference Vel #####################
    ######################################################
    vx = -radius * angular_vel* jnp.sin(angular_vel * deltaT)
    vy = radius * angular_vel* jnp.cos(angular_vel * deltaT)
    ref_vel = jnp.array([vy, vx,0])

    ######################################################
    ################## Reference Acc #####################
    ######################################################
    ax = -radius * (angular_vel**2) * jnp.cos(angular_vel * deltaT)
    ay = -radius * (angular_vel**2) * jnp.sin(angular_vel * deltaT)
    ref_acc = jnp.array([ay, ax, 0])

    ######################################################
    ################## To Jnp Array ######################
    ######################################################
    # assert pos.shape == vel.shape == acc.shape, "output shapes are different"
    return ref_pos, ref_vel, ref_acc



def figure8_pos_vel_acc(deltaT, radius, angular_vel, origin_x, origin_y):
    '''
    Calculate reference pos, vel, and acc for the drone flying in figure8 trajectory in NED frame
    '''
    ######################################################
    ################## Reference Pos #####################
    ######################################################
    x = radius * jnp.sin(angular_vel * deltaT) + origin_x
    y = radius * jnp.sin(angular_vel * deltaT) * jnp.cos(angular_vel * deltaT) + origin_y
    ref_pos = [y,  x, -0.4]

    ######################################################
    ################## Reference Vel #####################
    ######################################################
    vx = radius * angular_vel * jnp.cos(angular_vel * deltaT)
    vy = radius * angular_vel * (jnp.cos(angular_vel * deltaT)**2-jnp.sin(angular_vel * deltaT)**2)
    ref_vel= [vy,vx,0]
    ######################################################
    ################## Reference Acc #####################
    ######################################################
    ax = -radius * (angular_vel**2) * jnp.sin(angular_vel * deltaT)
    ay = -radius * 4 * (angular_vel**2) * jnp.sin(angular_vel * deltaT) * jnp.cos(angular_vel * deltaT)
    ref_acc = [ay,ax,0]
    pos = jnp.array(ref_pos)
    vel = jnp.array(ref_vel)
    acc = jnp.array(ref_acc)
    assert pos.shape == vel.shape == acc.shape, "output shapes are different"
    return pos, vel, acc

def state_ref(t):
    pos, vel, acc = circle_pos_vel_acc( t, cir_radius[0], cir_angular_vel[0], cir_origin_x[0], cir_origin_y[0] )
    return pos.reshape(-1,1), vel.reshape(-1,1), acc.reshape(-1,1)
def policy( t, states, policy_params):
    '''
    Expect a multiple states as input. Each state is a column vector.
    Should then return multiple control inputs. Each input should be a column vector
    '''
    m = 0.641 #kg
    g = 9.8066
    # kx = 14
    # kv = 7.4

    kx = policy_params[0]
    kv = policy_params[1]

    pos_ref, vel_ref, acc_ref = state_ref(t)

    ex = states[0:3] - pos_ref
    ev = states[3:6] - vel_ref
    thrust = - kx * ex - kv * ev + m * acc_ref - m * g
    return thrust / m, pos_ref, vel_ref

def generate_time(start_time, end_time, step_size):
    '''
    Generate a time vector to simulate trajectories
    '''
    return jnp.arange(start_time, end_time + step_size, step_size)

time_array = generate_time(0, 5, 0.1)
