In [3]:
import jax.numpy as jnp

def conditional_distribution(joint_mean, joint_covariance, x, dimensions):
    y_dimensions = [i for i in range(len(joint_mean)) if i not in dimensions]
    
    dimensions = jnp.array(dimensions)
    y_dimensions = jnp.array(y_dimensions)

    mean_x = jnp.take(joint_mean, dimensions)
    mean_y = jnp.take(joint_mean, y_dimensions)
    
    # Create a grid of indices from A and B using meshgrid
    cov_XX = joint_covariance[jnp.ix_(dimensions, dimensions)]
    cov_YY = joint_covariance[jnp.ix_(y_dimensions, y_dimensions)]
    cov_YX = joint_covariance[jnp.ix_(y_dimensions, dimensions)]
    cov_XY = joint_covariance[jnp.ix_(dimensions, y_dimensions)]
    
    mean_y_given_x = mean_y + cov_YX @ jnp.linalg.inv(cov_XX) @ (x - mean_x)

    cov_y_given_x = cov_YY - cov_YX @ jnp.linalg.inv(cov_XX) @ cov_XY

    return mean_y_given_x, cov_y_given_x

XY_mean = jnp.array([1000., 0.1, 5.2, 400., 0.7,
                      0.3, 3.0, 0.25, -0.1, 0.5,
                      1500, 0.08, 6.1, 0.8, 0.3,
                      3.0, 0.2, -0.1, 0.5])
XY_sigma = jnp.diag(jnp.array([1.0 ** 2, 0.02 ** 2, 1.0 ** 2, 200 ** 2, 0.1 ** 2,
                                0.1 ** 2, 0.5 ** 2, 0.1 ** 2, 0.02 ** 2, 0.2 ** 2,
                                1.0 ** 2, 0.02 ** 2, 1.0 ** 2, 0.1 ** 2, 0.05 ** 2,
                                1.0 ** 2, 0.05 ** 2, 0.02 ** 2, 0.2 ** 2]))
XY_sigma = XY_sigma.at[4, 6].set(0.6 * jnp.sqrt(XY_sigma[4, 4]) * jnp.sqrt(XY_sigma[6, 6]))
XY_sigma = XY_sigma.at[6, 4].set(0.6 * jnp.sqrt(XY_sigma[4, 4]) * jnp.sqrt(XY_sigma[6, 6]))
XY_sigma = XY_sigma.at[4, 13].set(0.6 * jnp.sqrt(XY_sigma[4, 4]) * jnp.sqrt(XY_sigma[13, 13]))
XY_sigma = XY_sigma.at[13, 4].set(0.6 * jnp.sqrt(XY_sigma[4, 4]) * jnp.sqrt(XY_sigma[13, 13]))
XY_sigma = XY_sigma.at[4, 15].set(0.6 * jnp.sqrt(XY_sigma[4, 4]) * jnp.sqrt(XY_sigma[15, 15]))
XY_sigma = XY_sigma.at[15, 4].set(0.6 * jnp.sqrt(XY_sigma[4, 4]) * jnp.sqrt(XY_sigma[15, 15]))
XY_sigma = XY_sigma.at[6, 13].set(0.6 * jnp.sqrt(XY_sigma[6, 6]) * jnp.sqrt(XY_sigma[13, 13]))
XY_sigma = XY_sigma.at[13, 6].set(0.6 * jnp.sqrt(XY_sigma[6, 6]) * jnp.sqrt(XY_sigma[13, 13]))
XY_sigma = XY_sigma.at[6, 15].set(0.6 * jnp.sqrt(XY_sigma[6, 6]) * jnp.sqrt(XY_sigma[15, 15]))
XY_sigma = XY_sigma.at[15, 6].set(0.6 * jnp.sqrt(XY_sigma[6, 6]) * jnp.sqrt(XY_sigma[15, 15]))
XY_sigma = XY_sigma.at[13, 15].set(0.6 * jnp.sqrt(XY_sigma[13, 13]) * jnp.sqrt(XY_sigma[15, 15]))
XY_sigma = XY_sigma.at[15, 13].set(0.6 * jnp.sqrt(XY_sigma[13, 13]) * jnp.sqrt(XY_sigma[15, 15]))

x = jnp.array([0.5, 0.5])  # Example value of X
dimensions = [4, 13]  # Indices for dimensions 5 and 14 (zero-indexed)

mean_y_given_x, cov_y_given_x = conditional_distribution(XY_mean, XY_sigma, x, dimensions)


In [4]:
joint_covariance = XY_sigma
joint_mean = XY_mean

y_dimensions = [i for i in range(len(joint_mean)) if i not in dimensions]

dimensions = jnp.array(dimensions)
y_dimensions = jnp.array(y_dimensions)

mean_x = jnp.take(joint_mean, dimensions)
mean_y = jnp.take(joint_mean, y_dimensions)

# Create a grid of indices from A and B using meshgrid
cov_XX = joint_covariance[jnp.ix_(dimensions, dimensions)]
cov_YY = joint_covariance[jnp.ix_(y_dimensions, y_dimensions)]
cov_YX = joint_covariance[jnp.ix_(y_dimensions, dimensions)]
cov_XY = joint_covariance[jnp.ix_(dimensions, y_dimensions)]

mean_y_given_x = mean_y + cov_YX @ jnp.linalg.inv(cov_XX) @ (x - mean_x)
cov_y_given_x = cov_YY - cov_YX @ jnp.linalg.inv(cov_XX) @ cov_XY


In [8]:
import jax
seed = 10
rng_key = jax.random.PRNGKey(seed)

A = jnp.arange(10)
B = jax.random.randint(rng_key, shape=(10,), minval=0, maxval=10)

In [9]:
A

Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

In [10]:
B

Array([4, 6, 1, 6, 2, 9, 6, 6, 8, 6], dtype=int32)

In [11]:
jnp.maximum(A, B)

Array([4, 6, 2, 6, 4, 9, 6, 7, 8, 9], dtype=int32)