In [2]:
import jax
import jax.numpy as jnp
from flax.struct import dataclass

In [3]:
x = jnp.array([0, 0, 1, 0, 2])
y = jnp.array([0,0,0])
x[jnp.nonzero(x)]
jnp.concat((x, y[jnp.nonzero(y)]))

Array([0, 0, 1, 0, 2], dtype=int32)

In [4]:
# jnp.where(x, True, False)
arr = jnp.array([1,2,3])
# arr[jnp.array([True, False, False, False, True, False])]
jnp.unique(jnp.array([2,2,1,1,0,0]))

Array([0, 1, 2], dtype=int32)

In [5]:
jnp.nonzero(x, size=1)[0] + jnp.array([0,1,2,3])

Array([2, 3, 4, 5], dtype=int32)

In [6]:
jnp.array([0,1,2,3])[jnp.array([True, False, True, False])]

Array([0, 2], dtype=int32)

In [7]:
@dataclass
class Gaussian:
    info: jnp.ndarray
    precision: jnp.ndarray
    dims: jnp.ndarray 

    @property
    def shape(self):
        return {
            "info": self.info.shape,
            "precision": self.precision.shape,
            "dims": self.dims.shape
        }
 
    @property
    def mean(self) -> jnp.ndarray:
        return jnp.linalg.inv(self.precision) @ self.info
    
    @property
    def covariance(self) -> jnp.ndarray:
        return jnp.linalg.inv(self.precision)
    
    @staticmethod
    def identity(variable: int) -> jnp.ndarray:
        dims = jnp.array([variable, variable, variable, variable])
        return Gaussian(jnp.zeros(4), jnp.eye(4), dims)
    
    def __getitem__(self, index) -> "Gaussian":
        return Gaussian(self.info[index], self.precision[index], self.dims[index])

    @jax.jit
    def __mul__(self, other: 'Gaussian') -> 'Gaussian':
        if other is None:
            return self.copy()

        # Merge dims
        # dims = [i for i in self.dims]
        # for d in other.dims:
        #     if d not in dims:
        #         dims.append(d)
        """
        Iterate through unique values in other dims
        Check if it's not in current dims
        If so, then True indices, else False
        Extract the elements that are True, then concatenate with current dims
        """

        dims = self.dims.copy()
        other_dims = other.dims.copy()
        unique_dim_values = jnp.unique(self.dims, size=self.dims.shape[0]//4)
        unique_other_dim_values = jnp.unique(other_dims, size=other_dims.shape[0]//4)

        def check_missing_index(carry, x):
            mask = jnp.where(carry == x, 1.0, 0.0)
            result = jax.lax.select(mask.sum() == 0, jnp.full((4,), True), jnp.full((4,), False))
            return carry, result 

        def get_indices(carry, x):
            mask = jnp.where(carry == x, x, 0.0)
            return carry, jnp.array([0.,1.,2.,3.]) + jnp.nonzero(mask, size=1)[0]

        _, missing_axes_values = jax.lax.scan(check_missing_index, dims, unique_other_dim_values)
        missing_axes_values = other_dims[missing_axes_values.flatten()] 
        unique_missing_axes_values = jnp.unique(missing_axes_values, size=missing_axes_values.shape[0]//4)
        _, missing_axes = jax.lax.scan(get_indices, other_dims, unique_missing_axes_values)
        dims = jnp.concat((dims, other_dims[missing_axes.flatten().astype(int)])) # nonzero usage will break vmap
        jax.debug.print("dims: {}", dims)
        
        # Extend self matrix
        prec_self = jnp.zeros((len(dims), len(dims)))
        info_self = jnp.zeros((len(dims), 1))
        # idxs_self = jnp.array([dims.index(d) for d in self.dims]) # here, need to fix this # [0, 1, 2, 3]
        _, idxs_self = jax.lax.scan(get_indices, dims, unique_dim_values)
        idxs_self = idxs_self.flatten().astype(int)
        jax.debug.print("idxs_self: {}", idxs_self)
        prec_self = prec_self.at[jnp.ix_(idxs_self, idxs_self)].set(self.precision)
        info_self = info_self.at[jnp.ix_(idxs_self,jnp.array([0]))].set(self.info.reshape(-1,1))

        # Extend other matrix
        prec_other = jnp.zeros((len(dims), len(dims)))
        info_other = jnp.zeros((len(dims), 1))
        # idxs_other = jnp.array([dims.index(d) for d in other.dims]) # here, need to fix this
        _, idxs_other = jax.lax.scan(get_indices, dims, unique_other_dim_values)
        idxs_other = idxs_other.flatten().astype(int)
        jax.debug.print("idxs_other: {}", idxs_other)
        prec_other = prec_other.at[jnp.ix_(idxs_other, idxs_other)].set(other.precision)
        info_other = info_other.at[jnp.ix_(idxs_other, jnp.array([0]))].set(other.info.reshape(-1,1))

        # Add
        prec = prec_other + prec_self
        info = (info_other + info_self).squeeze(-1)
        return Gaussian(info, prec, dims)

    def __imul__(self, other: 'Gaussian') -> 'Gaussian':
        return self.__mul__(other)
    
    @jax.jit
    def marginalize(self, marginalized_dims: jnp.ndarray) -> "Gaussian":
        info, prec = self.info, self.precision
        info = info.reshape(-1,1)
        unique_dim_values = jnp.unique(marginalized_dims, size=marginalized_dims.shape[0]//4)
        # axis_a = jnp.array([idx for idx, d in enumerate(self.dims) if d not in dims])
        # axis_b = jnp.array([idx for idx, d in enumerate(self.dims) if d in dims])
        
        def find_axis_a_values(carry, x):
            mask = jnp.where(carry == x, False, True)
            return carry, mask

        def find_axis_b(carry, x):
            mask = jnp.where(carry == x, True, False)
            return carry, mask
        
        def get_indices(carry, x):
            mask = jnp.where(carry == x, 1.0, 0.0)
            return carry, jnp.array([0., 1., 2., 3.]) + jnp.nonzero(mask, size=1)[0]
        _, axis_a_values = jax.lax.scan(find_axis_a_values, self.dims, unique_dim_values)
        axis_a_values = self.dims[axis_a_values.flatten()]
        unique_axis_a_values = jnp.unique(axis_a_values, size=axis_a_values.shape[0]//4)
        _, axis_a = jax.lax.scan(get_indices, self.dims, unique_axis_a_values)
        axis_a = axis_a.flatten().astype(int)

        _, axis_b_values = jax.lax.scan(find_axis_b, self.dims, unique_dim_values)
        axis_b_values = self.dims[axis_b_values.flatten()]
        unique_axis_b_values = jnp.unique(axis_b_values, size=axis_b_values.shape[0]//4)
        _, axis_b = jax.lax.scan(get_indices, self.dims, unique_axis_b_values)
        axis_b = axis_b.flatten().astype(int)

        jax.debug.print("axis a: {}", axis_a)
        jax.debug.print("axis b: {}", axis_b)

        info_a = info[jnp.ix_(axis_a, jnp.array([0]))]
        prec_aa = prec[jnp.ix_(axis_a, axis_a)]
        info_b = info[jnp.ix_(axis_b, jnp.array([0]))]
        prec_ab = prec[jnp.ix_(axis_a, axis_b)]
        prec_ba = prec[jnp.ix_(axis_b, axis_a)]
        prec_bb = prec[jnp.ix_(axis_b, axis_b)]

        prec_bb_inv = jnp.linalg.inv(prec_bb)
        info_ = info_a - prec_ab @ prec_bb_inv @ info_b
        prec_ = prec_aa - prec_ab @ prec_bb_inv @ prec_ba

        # new_dims = tuple(i for i in self.dims if i not in marginalized_dims)
        # other_dim_unique_values = jnp.unique(marginalized_dims, size=marginalized_dims.shape[0]//4)
        # def find_new_dims(carry, x):
        #     mask = jnp.
        # _, new_dims_values = jax.lax.scan(find_axis_a_values, self.dims, other_dim_unique_values)
        # new_dims_values = new_dims_values.flatten()
        return Gaussian(info_.squeeze(-1), prec_, axis_a_values)

In [8]:
same_g0 = Gaussian(jnp.ones(4,) * 1, jnp.eye(4) * 1, jnp.ones(4,) * 1)
same_g1 = Gaussian(jnp.ones(4,) * 1, jnp.eye(4) * 1, jnp.ones(4,) * 1)
same_g0 * same_g1

NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[4])

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError

In [10]:
# same_g0 = Gaussian(jnp.ones(4,) * 1, jnp.eye(4) * 1, jnp.ones(4,) * 0)
# same_g1 = Gaussian(jnp.ones(4,) * 1, jnp.eye(4) * 1, jnp.ones(4,) * 1)
g0 = Gaussian(jnp.ones(4,) * 1, jnp.eye(4) * 1, jnp.ones((4,)) * 1)
g1 = Gaussian(jnp.ones(4,) * 2, jnp.eye(4) * 2, jnp.ones((4,)) * 2)
g0 * g1 

dims: [1. 1. 1. 1. 2. 2. 2. 2.]
idxs_self: [0 1 2 3]
idxs_other: [4 5 6 7]


Gaussian(info=Array([1., 1., 1., 1., 2., 2., 2., 2.], dtype=float32), precision=Array([[1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 2., 0., 0., 0.],
       [0., 0., 0., 0., 0., 2., 0., 0.],
       [0., 0., 0., 0., 0., 0., 2., 0.],
       [0., 0., 0., 0., 0., 0., 0., 2.]], dtype=float32), dims=Array([1., 1., 1., 1., 2., 2., 2., 2.], dtype=float32))

In [11]:
g0 = Gaussian(jnp.ones(4,) * 1, jnp.eye(4) * 1, jnp.ones(4,) * 1)
g1 = Gaussian(jnp.ones(4,) * 2, jnp.eye(4) * 2, jnp.ones(4,) * 2)
g2 = Gaussian(jnp.ones(4,) * 3, jnp.eye(4) * 3, jnp.ones(4,) * 3)
combined_gaussian = g0 * g1 * g2
combined_gaussian

dims: [1. 1. 1. 1. 2. 2. 2. 2.]
idxs_self: [0 1 2 3]
idxs_other: [4 5 6 7]
dims: [1. 1. 1. 1. 2. 2. 2. 2. 3. 3. 3. 3.]
idxs_self: [0 1 2 3 4 5 6 7]
idxs_other: [ 8  9 10 11]


Gaussian(info=Array([1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.], dtype=float32), precision=Array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 2., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 2., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 2., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 2., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 3., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 3., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 3., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 3.]], dtype=float32), dims=Array([1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.], dtype=float32))

In [12]:
combined_gaussian * g2

dims: [1. 1. 1. 1. 2. 2. 2. 2. 3. 3. 3. 3.]
idxs_self: [ 0  1  2  3  4  5  6  7  8  9 10 11]
idxs_other: [ 8  9 10 11]


Gaussian(info=Array([1., 1., 1., 1., 2., 2., 2., 2., 6., 6., 6., 6.], dtype=float32), precision=Array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 2., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 2., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 2., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 2., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 6., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 6., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 6., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 6.]], dtype=float32), dims=Array([1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.], dtype=float32))

In [13]:
combined_gaussian.marginalize(jnp.array([2.,2.,2.,2.]))

NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[12])

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError

In [None]:
combined_gaussian.marginalize(jnp.array([3.,3.,3.,3.]))

axis a: [0 1 2 3 4 5 6 7]
axis b: [ 8  9 10 11]


Gaussian(info=Array([1., 1., 1., 1., 2., 2., 2., 2.], dtype=float32), precision=Array([[1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 2., 0., 0., 0.],
       [0., 0., 0., 0., 0., 2., 0., 0.],
       [0., 0., 0., 0., 0., 0., 2., 0.],
       [0., 0., 0., 0., 0., 0., 0., 2.]], dtype=float32), dims=Array([1., 1., 1., 1., 2., 2., 2., 2.], dtype=float32))