In [96]:
rng_key = jax.random.PRNGKey(0)
A = jax.random.normal(rng_key, shape=(3, 2))

B = jnp.einsum('ij,jk->ijk', A, jnp.eye(2))

print(A)

[[ 0.18784384 -1.2833426 ]
 [ 0.6494181   1.2490594 ]
 [ 0.24447003 -0.11744965]]


In [97]:
print(B)

[[[ 0.18784384  0.        ]
  [-0.         -1.2833426 ]]

 [[ 0.6494181   0.        ]
  [ 0.          1.2490594 ]]

 [[ 0.24447003  0.        ]
  [-0.         -0.11744965]]]


In [99]:
B[0, :, :]

Array([[ 0.18784384,  0.        ],
       [-0.        , -1.2833426 ]], dtype=float32)

In [83]:
rng_key = jax.random.PRNGKey(0)
A = jax.random.normal(rng_key, shape=(3, 3, 2, 2))
print(A)

[[[[-0.19596423  0.254582  ]
   [ 0.573146    0.44064867]]

  [[-0.847186    0.31778416]
   [ 0.64643025  0.03368271]]

  [[-0.88885146 -0.26579142]
   [-1.5609733  -0.63806945]]]


 [[[-0.44171792  0.9098043 ]
   [-0.01651609  0.7582043 ]]

  [[ 1.0892068  -0.8457174 ]
   [ 1.490981    0.07877276]]

  [[-1.222362    0.951683  ]
   [ 0.21038245  1.3863313 ]]]


 [[[-0.33807516  2.9521072 ]
   [-0.99476653 -0.515637  ]]

  [[ 0.2918976  -0.14347766]
   [ 1.614347    1.6433928 ]]

  [[ 0.11422886  0.25447085]
   [-1.3060236  -2.473075  ]]]]


In [84]:
A[:, :, 0, 0]

Array([[-0.19596423, -0.847186  , -0.88885146],
       [-0.44171792,  1.0892068 , -1.222362  ],
       [-0.33807516,  0.2918976 ,  0.11422886]], dtype=float32)

In [85]:
jnp.diagonal(A, axis1=0, axis2=1).shape

(2, 2, 3)

In [86]:
jnp.diagonal(A, axis1=0, axis2=1).transpose((2, 0, 1))

Array([[[-0.19596423,  0.254582  ],
        [ 0.573146  ,  0.44064867]],

       [[ 1.0892068 , -0.8457174 ],
        [ 1.490981  ,  0.07877276]],

       [[ 0.11422886,  0.25447085],
        [-1.3060236 , -2.473075  ]]], dtype=float32)

In [87]:
jnp.diagonal(A, axis1=0, axis2=1).transpose((2, 0, 1))[:, 0, 0]

Array([-0.19596423,  1.0892068 ,  0.11422886], dtype=float32)

(2, 4)

In [24]:
A + B

Array([[[2., 2., 2., 2.],
        [2., 2., 2., 2.],
        [2., 2., 2., 2.],
        [2., 2., 2., 2.]],

       [[2., 2., 2., 2.],
        [2., 2., 2., 2.],
        [2., 2., 2., 2.],
        [2., 2., 2., 2.]],

       [[2., 2., 2., 2.],
        [2., 2., 2., 2.],
        [2., 2., 2., 2.],
        [2., 2., 2., 2.]]], dtype=float32)

In [28]:
jnp.linalg.det(B).shape

(3,)

In [120]:
import jax.numpy as jnp

# Example usage:
N1 = 3
N2 = 4
N3 = 2
D = 6

rng_key = jax.random.PRNGKey(0)
x = jax.random.normal(rng_key, shape=(N1, N2, D))

rng_key, _ = jax.random.split(rng_key)
mean = jax.random.normal(rng_key, shape=(N3, D))

rng_key, _ = jax.random.split(rng_key)
covariance = jax.random.uniform(rng_key, shape=(N3, D))
covariance = jnp.stack([jnp.diag(c) for c in covariance.reshape(-1, D)], axis=0).reshape(N3, D, D)


x_expanded = jnp.expand_dims(x, 2)
mean_expanded = jnp.expand_dims(mean, (0, 1))
# covariance_expanded = jnp.expand_dims(covariance, 0)

diff = x_expanded - mean_expanded
precision_matrix = jnp.linalg.inv(covariance)
exponent = -0.5 * jnp.einsum('nijk,jkl,nijl->nij', diff, precision_matrix, diff)
normalization = -0.5 * (D * jnp.log(2 * jnp.pi) - 0.5 * jnp.log(jnp.linalg.det(covariance)))
# normalization = 0

log_likelihood = exponent + normalization

print(log_likelihood, log_likelihood.shape)

[[[ -44.08798    -29.82048  ]
  [ -16.644535   -11.3195305]
  [ -47.749825   -22.784111 ]
  [ -17.861132   -19.58884  ]]

 [[ -54.307503   -15.372691 ]
  [-222.3221     -15.377424 ]
  [ -37.44995    -13.018433 ]
  [ -20.321373   -38.71425  ]]

 [[ -61.85912    -16.834034 ]
  [ -34.70679    -15.578803 ]
  [ -48.57657    -33.38154  ]
  [ -78.95834    -14.548405 ]]] (3, 4, 2)


In [121]:
log_likelihood = jnp.zeros([N1, N2, N3])

for n1 in range(N1):
    for n2 in range(N2):
        for n3 in range(N3):
            diff = x[n1, n2, :] - mean[n3, :]
            precision_matrix = jnp.linalg.inv(covariance[n3, :, :])
            exponent = -0.5 * jnp.dot(diff.T, jnp.dot(precision_matrix, diff))
            normalization = -0.5 * (D * jnp.log(2 * jnp.pi) + 0.5 * jnp.log(jnp.linalg.det(precision_matrix)))
            log_likelihood = log_likelihood.at[n1, n2, n3].set(exponent + normalization)
            
print(log_likelihood)


[[[ -44.087975  -29.820477]
  [ -16.644535  -11.319531]
  [ -47.74982   -22.784111]
  [ -17.861134  -19.58884 ]]

 [[ -54.307503  -15.372691]
  [-222.3221    -15.377424]
  [ -37.44995   -13.018433]
  [ -20.321373  -38.71425 ]]

 [[ -61.859123  -16.834034]
  [ -34.70679   -15.578803]
  [ -48.57657   -33.38154 ]
  [ -78.95834   -14.548405]]]


In [113]:
log_likelihood.shape

(3, 4, 2)

In [57]:
normalization.shape

(4, 2)

In [54]:
covariance.shape

(4, 2, 3, 3)