In [13]:
import jax
import jax.numpy as jnp

In [14]:
arr1 = jnp.array([[1.0]])
arr2 = jnp.array([[1.0,2.0,3.0,4.0,5.0]])
arr3 = jnp.array([[6.0,7.0,8.0,9.0,10.0]])
arr = jnp.concatenate((arr1,arr2,arr3),axis=1)
print(arr[0])
print(arr.shape)
print(arr)

[ 1.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10.]
(1, 11)
[[ 1.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10.]]


In [15]:
def get_mean( sigma_points, weights ):
    weighted_points = sigma_points * weights[0]
    mu = jnp.sum( weighted_points, 1 ).reshape(-1,1)
    return mu


In [16]:
def get_mean_cov(sigma_points, weights):
    
    # mean
    weighted_points = sigma_points * weights[0]
    mu = jnp.sum( weighted_points, 1 ).reshape(-1,1)
    
    # covariance
    centered_points = sigma_points - mu
    cov = jnp.diag(jnp.sum(centered_points**2 * weights[0], axis=1))
    return mu, cov


In [17]:
def get_ut_cov_root_diagonal(cov):
    offset = 0.000  # TODO: make sure not zero here
    root_term = jnp.diag( jnp.sqrt(jnp.diagonal(cov)+offset)  )
    return root_term


In [18]:
def get_ut_cov_root_diagonal(cov):
    offset = 0.000  # TODO: make sure not zero here
    root_term = jnp.diag( jnp.sqrt(jnp.diagonal(cov)+offset)  )
    return root_term
def get_mean_cov_skew_kurt( sigma_points, weights ):
    # mean
    weighted_points = sigma_points * weights[0]
    mu = jnp.sum( weighted_points, 1 ).reshape(-1,1)    
    centered_points = sigma_points - mu    
    cov = jnp.diag(jnp.sum(centered_points**2 * weights[0], axis=1))
    
    skewness = jnp.sum(centered_points**3 * weights[0], axis=1) #/ cov[0,0]**(3/2) # for scipy    
    kurt = jnp.sum(centered_points**4 * weights[0], axis=1)# / cov[0,0]**(4/2)  # -3 # -3 for scipy
    return mu, cov, skewness.reshape(-1,1), kurt.reshape(-1,1)

def get_mean_cov_skew_kurt_for_generation( sigma_points, weights ):
    # mean
    weighted_points = sigma_points * weights[0]
    mu = jnp.sum( weighted_points, 1 ).reshape(-1,1)    
    centered_points = sigma_points - mu    
    cov = jnp.diag(jnp.sum(centered_points**2 * weights[0], axis=1))

    skewness_temp = jnp.sum(centered_points**3 * weights[0], axis=1) #/ cov[0,0]**(3/2) # for scipy    
    skewness = skewness_temp / jnp.diag(cov)**(3/2)
    
    kurt_temp = jnp.sum(centered_points**4 * weights[0], axis=1)# / cov[0,0]**(4/2)  # -3 # -3 for scipy
    kurt = kurt_temp / jnp.diag(cov)**(4/2)

    return mu, cov, skewness.reshape(-1,1), kurt.reshape(-1,1)

def generate_sigma_points_gaussian( mu, cov_root, base_term, factor ):
    '''
    Standard UT: Generate sigma points based on a vector of (n,) mus, resulting in 2n+1 points for each dim
    input:
    mu: (n,1)
   
    output:
    new_weights: (1, 2n+1)
    new_points: (n, 2n+1)
    '''
    n = mu.shape[0]     
    N = 2*n + 1 # total points

    alpha = 1.0
    beta = 0.0#2.0#2.0 # optimal for gaussian
    k = 1.0
    Lambda = alpha**2 * ( n+k ) - n
    
    points0 = base_term + mu * factor
    points1 = base_term + (mu + jnp.sqrt(n+Lambda) * cov_root)*factor
    points2 = base_term + (mu - jnp.sqrt(n+Lambda) * cov_root)*factor
    
    weights0 = jnp.array([[ 1.0*Lambda/(n+Lambda) ]])
    weights1 = jnp.ones((1,n)) * 1.0/(n+Lambda)/2.0
    weights2 = jnp.ones((1,n)) * 1.0/(n+Lambda)/2.0

    new_points = jnp.concatenate((points0, points1, points2), axis=1)
    new_weights = jnp.concatenate((weights0, weights1, weights2), axis=1)
    
    return new_points, new_weights    

def generate_sigma_points_gaussian_GenUT( mu, cov_root, skewness, kurt, base_term, factor ):
    n = mu.shape[0]     
    N = 2*n + 1 # total points
    u = 0.5 * ( - skewness + jnp.sqrt( 4 * kurt - 3 * ( skewness )**2 ) )
    v = u + skewness

    w2 = (1.0 / v) / (u+v)
    w1 = (w2 * v) / u
    w0 = jnp.array([1 - jnp.sum(w1) - jnp.sum(w2)])
    
    U = jnp.diag(u[:,0])
    V = jnp.diag(v[:,0])
    points0 = base_term + mu * factor
    points1 = base_term + (mu - cov_root @ U) * factor
    points2 = base_term + (mu + cov_root @ V) * factor
    new_points = jnp.concatenate( (points0, points1, points2), axis=1 )
    new_weights = jnp.concatenate( (w0.reshape(-1,1), w1.reshape(1,-1), w2.reshape(1,-1)), axis=1 )

    return new_points, new_weights

def sigma_point_expand_with_mean_cov( mus, covs, weights ):
    n, N = mus.shape # n=6, N=13
    new_points = jnp.zeros((n*(2*n+1),N))
    new_weights = jnp.zeros((2*n+1,N))

    def body(i, inputs):
        new_points, new_weights = inputs        
        mu, cov = mus[:,[i]], jnp.diag(covs[:,i])

        # Albus: look these 2 lines ###########################
        root_term = get_ut_cov_root_diagonal(cov)           
        temp_points, temp_weights = generate_sigma_points_gaussian( mu, root_term, jnp.zeros((n,1)), 1.0 )
        ##############################################

        new_points = new_points.at[:,i].set( temp_points.reshape(-1,1, order='F')[:,0] )
        new_weights = new_weights.at[:,i].set( temp_weights.reshape(-1,1, order='F')[:,0] * weights[:,i] )   
        return new_points, new_weights
    new_points, new_weights = lax.fori_loop(0, N, body, (new_points, new_weights))
    return new_points.reshape((n, N*(2*n+1)), order='F'), new_weights.reshape((1,n*(2*n+1)), order='F')

def sigma_point_compress( sigma_points, weights ):
    mu, cov = get_mean_cov( sigma_points, weights )
    cov_root_term = get_ut_cov_root_diagonal( cov )  
    base_term = jnp.zeros((mu.shape))
    return generate_sigma_points_gaussian( mu, cov_root_term, base_term, jnp.array([1.0]) )

def sigma_point_compress_GenUT( sigma_points, weights ):
    mu, cov, skewness, kurt = get_mean_cov_skew_kurt_for_generation( sigma_points, weights )
    cov_root_term = get_ut_cov_root_diagonal( cov )  
    base_term = jnp.zeros((mu.shape))
    return generate_sigma_points_gaussian_GenUT( mu, cov_root_term, skewness, kurt, base_term, jnp.array([1.0]) )

In [26]:
key = jax.random.PRNGKey(0)

# Generate a 6x13 array
sigma_points = jax.random.uniform(key, (6, 13))

# Generate a new key for the next random array
key, subkey = jax.random.split(key)

# Generate a 1x13 array
weights = jax.random.uniform(subkey, (1, 13))
print("weights: ", weights)
print("weights[0]: ", weights[0])
mu = get_mean(sigma_points, weights)
print(mu)
key = jax.random.PRNGKey(0)

# Generate a 6x13 array
sigma_points = jax.random.uniform(key, (6, 13))

# Generate a new key for the next random array
key, subkey = jax.random.split(key)

# Generate a 1x13 array
weights = jax.random.uniform(subkey, (1, 13))
mu,cov = get_mean_cov(sigma_points, weights)
print(mu)
print(cov)


root_cov = get_ut_cov_root_diagonal(cov)
print(root_cov)


weights:  [[0.9158251  0.21620357 0.26835215 0.6011201  0.437374   0.8539797
  0.7195103  0.13791871 0.3185042  0.7317047  0.9511024  0.3303691
  0.24733603]]
weights[0]:  [0.9158251  0.21620357 0.26835215 0.6011201  0.437374   0.8539797
 0.7195103  0.13791871 0.3185042  0.7317047  0.9511024  0.3303691
 0.24733603]
[[3.8098347]
 [2.6275427]
 [4.102255 ]
 [3.8639886]
 [4.2684965]
 [2.830641 ]]
[[3.8098347]
 [2.6275427]
 [4.102255 ]
 [3.8639886]
 [4.2684965]
 [2.830641 ]]
[[71.174065  0.        0.        0.        0.        0.      ]
 [ 0.       33.973057  0.        0.        0.        0.      ]
 [ 0.        0.       82.579926  0.        0.        0.      ]
 [ 0.        0.        0.       73.28244   0.        0.      ]
 [ 0.        0.        0.        0.       89.51134   0.      ]
 [ 0.        0.        0.        0.        0.       39.613228]]
[[8.436472 0.       0.       0.       0.       0.      ]
 [0.       5.828641 0.       0.       0.       0.      ]
 [0.       0.       9.08735  0. 

In [22]:

mu, cov, skew, kurt = get_mean_cov_skew_kurt_for_generation(sigma_points, weights)
print(mu)
print(cov)
print(skew)
print(kurt)

[[3.8098347]
 [2.6275427]
 [4.102255 ]
 [3.8639886]
 [4.2684965]
 [2.830641 ]]
[[71.174065  0.        0.        0.        0.        0.      ]
 [ 0.       33.973057  0.        0.        0.        0.      ]
 [ 0.        0.       82.579926  0.        0.        0.      ]
 [ 0.        0.        0.       73.28244   0.        0.      ]
 [ 0.        0.        0.        0.       89.51134   0.      ]
 [ 0.        0.        0.        0.        0.       39.613228]]
[[-0.3885442 ]
 [-0.39047277]
 [-0.389005  ]
 [-0.3891043 ]
 [-0.3896621 ]
 [-0.39276117]]
[[0.15176381]
 [0.1537108 ]
 [0.15225746]
 [0.1523434 ]
 [0.1529323 ]
 [0.15589994]]


In [None]:
generate_sigma_points_gaussian