# Masked Likelihoods

In [None]:
import sys, os
from pyprojroot import here

# spyder up to find the root
root = here(project_files=[".home"])

# append to path
sys.path.append(str(root))

%load_ext autoreload
%autoreload 2

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import chex
from einops import repeat, rearrange
from tensorflow_probability.substrates import jax as tfp

tfd = tfp.distributions

from filterjax._src.ops.masks import (
    mask_observation_noise,
    mask_observation_operator,
    mvn_logpdf,
)

### Data

### Observation Mask


$$
\mathbf{H}_t = 
\begin{bmatrix}
\mathbf{H}_t^{\text{obs}} \\
\mathbf{H}_t^{\text{missing}}
\end{bmatrix} =
\begin{bmatrix}
\mathbf{H}_t^{\text{obs}} \\
\mathbf{0}
\end{bmatrix}
$$

In [None]:
mask_observation_operator??

In [None]:
state_dim = 5
obs_dim = 2

# operator = np.ones((obs_dim, state_dim))
operator = np.random.randn(obs_dim, state_dim)
# no mask
mask = np.zeros(obs_dim)

print(mask)

In [None]:
operator_masked = mask_observation_operator(operator, mask)

chex.assert_equal_shape([operator_masked, operator])
chex.assert_trees_all_close(operator_masked, operator)

print(operator_masked)

Nothing was masked! because we see all of the values :).

In [None]:
# no mask
mask = np.array([1, 0])

print(mask)

In [None]:
operator_masked = mask_observation_operator(operator, mask)

chex.assert_equal_shape([operator_masked, operator])

print(operator_masked)

As we can see, now we have the operator masked for one of the dimensions.

### Noise Covariance Matrix

mask_observation_operator??$$
\mathbf{R}_t = 
\begin{bmatrix}
\mathbf{R}_{11t}^{\text{obs}} & \mathbf{R}_{12t}^{\text{cross}}\\
\mathbf{R}_{21t}^{\text{cross}} & \mathbf{R}_{22t}^{\text{missing}}
\end{bmatrix} =
\begin{bmatrix}
\mathbf{R}_{11t}^{\text{obs}} & \mathbf{0}\\
\mathbf{0} & \mathbf{I}
\end{bmatrix}
$$

In [None]:
mask_observation_noise??

In [None]:
state_dim = 10
obs_dim = 4
rng = np.random.RandomState(3)

# noise_mat = np.ones((obs_dim, obs_dim))
noise_mat = np.random.randn(obs_dim, obs_dim)
print(noise_mat)

In [None]:
mask = np.zeros(obs_dim)

print(mask)

In [None]:
noise_mat_masked = mask_observation_noise(noise_mat, mask)

chex.assert_equal_shape([noise_mat_masked, noise_mat])
chex.assert_trees_all_close(noise_mat_masked, noise_mat)

print(noise_mat_masked)

Again, we don't see any changes because the mask was 1s everywhere.

In [None]:
mask = np.array([0.0, 0.0, 1.0, 1.0])

noise_mat_masked = mask_observation_noise(noise_mat, mask)

chex.assert_equal_shape([noise_mat_masked, noise_mat])

print(noise_mat_masked)

here, we see the structure we showed above.

In [None]:
mask = np.array([1.0, 0.0, 0.0, 1.0])

noise_mat_masked = mask_observation_noise(noise_mat, mask)

chex.assert_equal_shape([noise_mat_masked, noise_mat_masked])

print(noise_mat_masked)

### Likelihood Term

In [None]:
obs_dim = 5

rng = np.random.RandomState(123)

x = rng.randn(obs_dim)
mean = rng.randn(obs_dim)
cov = rng.randn(obs_dim, obs_dim)
cov = cov @ cov.T

In [None]:
log_prob = mvn_logpdf(x, mean, cov, None)
assert log_prob.shape == ()
assert np.isnan(log_prob).sum() == 0
log_prob

In [None]:
log_prob = mvn_logpdf(x, mean, cov, jnp.zeros_like(x))
assert log_prob.shape == ()
assert np.isnan(log_prob).sum() == 0
log_prob

In [None]:
log_prob = mvn_logpdf(x, mean, cov, jnp.ones_like(x))
assert log_prob.shape == ()
assert np.isnan(log_prob).sum() == 0
log_prob

In [None]:
mask = rng.randint(
    0,
    2,
    size=(obs_dim,),
).astype(np.float32)

In [None]:
mask

In [None]:
log_prob = mvn_logpdf(x, mean, cov, mask)
assert log_prob.shape == ()
assert np.isnan(log_prob).sum() == 0
log_prob

##### Batches

In [None]:
obs_dim = 2
n_batch = 10
x = rng.randn(n_batch, obs_dim)
mean = rng.randn(n_batch, obs_dim)
cov = rng.randn(n_batch, obs_dim, obs_dim)
cov = jnp.matmul(cov, cov.transpose((0, 2, 1)))

In [None]:
fn = jax.vmap(mvn_logpdf, in_axes=(0, 0, 0, None))
# fn_ = jax.vmap(mvn_logpdf_custom, in_axes=(0, 0, 0, None))

log_probs = fn(x, mean, cov, None)
# log_probs_ = fn_(x, mean, cov, None)
# chex.assert_trees_all_close(log_probs, log_probs_)
assert log_probs.shape == (n_batch,)
assert np.isnan(log_probs).sum() == 0

In [None]:
log_probs