In [1]:
import jax
import jax.numpy as jnp
from flax.struct import dataclass

In [2]:
x = jnp.array([0, 0, 1, 0, 2])
y = jnp.array([0,0,0])
x[jnp.nonzero(x)]
jnp.concat((x, y[jnp.nonzero(y)]))

Array([0, 0, 1, 0, 2], dtype=int32)

In [54]:
jnp.where(x, True, False)

Array([False, False,  True, False,  True], dtype=bool)

In [28]:
jnp.nonzero(x, size=1)[0] + jnp.array([0,1,2,3])

Array([2, 3, 4, 5], dtype=int32)

In [53]:
jnp.array([0,1,2,3])[jnp.array([True, False, True, False])]

Array([0, 2], dtype=int32)

In [67]:
@dataclass
class Gaussian:
    info: jnp.ndarray
    precision: jnp.ndarray
    dims: jnp.ndarray 

    @property
    def shape(self):
        return {
            "info": self.info.shape,
            "precision": self.precision.shape,
            "dims": self.dims.shape
        }
 
    @property
    def mean(self) -> jnp.ndarray:
        return jnp.linalg.inv(self.precision) @ self.info
    
    @property
    def covariance(self) -> jnp.ndarray:
        return jnp.linalg.inv(self.precision)
    
    @staticmethod
    def identity(variable: int) -> jnp.ndarray:
        dims = jnp.array([variable, variable, variable, variable])
        return Gaussian(jnp.zeros(4), jnp.eye(4), dims)
    
    def __getitem__(self, index) -> "Gaussian":
        return Gaussian(self.info[index], self.precision[index], self.dims[index])

    def __mul__(self, other: 'Gaussian') -> 'Gaussian':
        if other is None:
            return self.copy()

        # Merge dims
        # dims = [i for i in self.dims]
        # for d in other.dims:
        #     if d not in dims:
        #         dims.append(d)

        # my version
        dims = self.dims.copy()
        other_dims = other.dims.copy()
        unique_dim_values = jnp.unique(self.dims, size=self.dims.shape[0]//4)
        unique_other_dim_values = jnp.unique(other_dims, size=other_dims.shape[0]//4)

        def check_missing_index(carry, x):
            mask = jnp.where(carry == x, x, 0.)
            result = jax.lax.select(mask.sum() == 0, jnp.array([x]), jnp.array([0.]))
            return carry, result

        def get_indices(carry, x):
            mask = jnp.where(carry == x, x, 0.0)
            return carry, jnp.array([0.,1.,2.,3.]) + jnp.nonzero(mask, size=1)[0]

        _, missing_axes = jax.lax.scan(check_missing_index, dims, other_dims)
        missing_axes = missing_axes.flatten()
        dims = jnp.concat((dims, missing_axes[jnp.nonzero(missing_axes)])) # nonzero usage will break vmap
        jax.debug.print("dims: {}", dims)
        # my version end
        
        # Extend self matrix
        prec_self = jnp.zeros((len(dims), len(dims)))
        info_self = jnp.zeros((len(dims), 1))
        # idxs_self = jnp.array([dims.index(d) for d in self.dims]) # here, need to fix this # [0, 1, 2, 3]
        _, idxs_self = jax.lax.scan(get_indices, dims, unique_dim_values)
        idxs_self = idxs_self.flatten().astype(int)
        jax.debug.print("idxs_self: {}", idxs_self)
        prec_self = prec_self.at[jnp.ix_(idxs_self, idxs_self)].set(self.precision)
        info_self = info_self.at[jnp.ix_(idxs_self,jnp.array([0]))].set(self.info.reshape(-1,1))

        # Extend other matrix
        prec_other = jnp.zeros((len(dims), len(dims)))
        info_other = jnp.zeros((len(dims), 1))
        # idxs_other = jnp.array([dims.index(d) for d in other.dims]) # here, need to fix this
        _, idxs_other = jax.lax.scan(get_indices, dims, unique_other_dim_values)
        idxs_other = idxs_other.flatten().astype(int)
        jax.debug.print("idxs_other: {}", idxs_other)
        prec_other = prec_other.at[jnp.ix_(idxs_other, idxs_other)].set(other.precision)
        info_other = info_other.at[jnp.ix_(idxs_other, jnp.array([0]))].set(other.info.reshape(-1,1))

        # Add
        prec = prec_other + prec_self
        info = (info_other + info_self).squeeze(-1)
        return Gaussian(info, prec, dims)

    def __imul__(self, other: 'Gaussian') -> 'Gaussian':
        return self.__mul__(other)
    
    def marginalize(self, dims: jnp.ndarray) -> "Gaussian":
        info, prec = self.info, self.precision
        info = info.reshape(-1,1)
        unique_dim_values = jnp.unique(self.dims, size=self.dims.shape[0]//4)
        # axis_a = jnp.array([idx for idx, d in enumerate(self.dims) if d not in dims])
        # axis_b = jnp.array([idx for idx, d in enumerate(self.dims) if d in dims])
        
        def find_axis_a(carry, x):
            mask = jnp.where(carry == x, False, True)
            return carry, mask
        _, axis_a = jax.lax.scan(find_axis_a, self.dims, unique_dim_values)
        axis_a = self.dims[axis_a.flatten()]

        def find_axis_b(carry, x):
            mask = jnp.where(carry == x, True, False)
            return carry, mask
        _, axis_b = jax.lax.scan(find_axis_b, self.dims, unique_dim_values)
        axis_b = self.dims[axis_b.flatten()]

        # def axis_a_fn(kp, v):
        #     if v not in dims:
        #         return kp[0].idx
        #     else:
        #         return -1
            
        # def axis_b_fn(kp, v):
        #     if v in dims:
        #         return kp[0].idx
        #     else:
        #         return -1
            
        # axis_a = jnp.array(jax.tree_util.tree_map_with_path(axis_a_fn, self.dims))
        # axis_b = jnp.array(jax.tree_util.tree_map_with_path(axis_b_fn, self.dims))
        # axis_a = axis_a[jnp.where(axis_a != -1)]
        # axis_b = axis_b[jnp.where(axis_b != -1)]

        info_a = info[jnp.ix_(axis_a, jnp.array([0]))]
        prec_aa = prec[jnp.ix_(axis_a, axis_a)]
        info_b = info[jnp.ix_(axis_b, jnp.array([0]))]
        prec_ab = prec[jnp.ix_(axis_a, axis_b)]
        prec_ba = prec[jnp.ix_(axis_b, axis_a)]
        prec_bb = prec[jnp.ix_(axis_b, axis_b)]

        prec_bb_inv = jnp.linalg.inv(prec_bb)
        info_ = info_a - prec_ab @ prec_bb_inv @ info_b
        prec_ = prec_aa - prec_ab @ prec_bb_inv @ prec_ba

        new_dims = tuple(i for i in self.dims if i not in dims)
        # other_dim_unique_values = jnp.unique(dims, size=dims.shape[0]//4)
        # def find_new_dims(carry, x):
        #     mask = jnp.
        # _, new_dims = jax.lax.scan(find_new_dims, self.dims, other_dim_unique_values)
        return Gaussian(info_.squeeze(-1), prec_, new_dims)

In [68]:
same_g0 = Gaussian(jnp.ones(4,) * 1, jnp.eye(4) * 1, jnp.ones(4,) * 1)
same_g1 = Gaussian(jnp.ones(4,) * 1, jnp.eye(4) * 1, jnp.ones(4,) * 1)
same_g0 * same_g1

dims: [1. 1. 1. 1.]
idxs_self: [0 1 2 3]
idxs_other: [0 1 2 3]


Gaussian(info=Array([2., 2., 2., 2.], dtype=float32), precision=Array([[2., 0., 0., 0.],
       [0., 2., 0., 0.],
       [0., 0., 2., 0.],
       [0., 0., 0., 2.]], dtype=float32), dims=Array([1., 1., 1., 1.], dtype=float32))

In [69]:
# same_g0 = Gaussian(jnp.ones(4,) * 1, jnp.eye(4) * 1, jnp.ones(4,) * 0)
# same_g1 = Gaussian(jnp.ones(4,) * 1, jnp.eye(4) * 1, jnp.ones(4,) * 1)
g0 = Gaussian(jnp.ones(4,) * 1, jnp.eye(4) * 1, jnp.ones((4,)) * 1)
g1 = Gaussian(jnp.ones(4,) * 2, jnp.eye(4) * 2, jnp.ones((4,)) * 2)
g0 * g1 

dims: [1. 1. 1. 1. 2. 2. 2. 2.]
idxs_self: [0 1 2 3]
idxs_other: [4 5 6 7]


Gaussian(info=Array([1., 1., 1., 1., 2., 2., 2., 2.], dtype=float32), precision=Array([[1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 2., 0., 0., 0.],
       [0., 0., 0., 0., 0., 2., 0., 0.],
       [0., 0., 0., 0., 0., 0., 2., 0.],
       [0., 0., 0., 0., 0., 0., 0., 2.]], dtype=float32), dims=Array([1., 1., 1., 1., 2., 2., 2., 2.], dtype=float32))

In [70]:
g0 = Gaussian(jnp.ones(4,) * 1, jnp.eye(4) * 1, jnp.ones(4,) * 1)
g1 = Gaussian(jnp.ones(4,) * 2, jnp.eye(4) * 2, jnp.ones(4,) * 2)
g2 = Gaussian(jnp.ones(4,) * 3, jnp.eye(4) * 3, jnp.ones(4,) * 3)
combined_gaussian = g0 * g1 * g2
combined_gaussian

dims: [1. 1. 1. 1. 2. 2. 2. 2.]
idxs_self: [0 1 2 3]
idxs_other: [4 5 6 7]
dims: [1. 1. 1. 1. 2. 2. 2. 2. 3. 3. 3. 3.]
idxs_self: [0 1 2 3 4 5 6 7]
idxs_other: [ 8  9 10 11]


Gaussian(info=Array([1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.], dtype=float32), precision=Array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 2., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 2., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 2., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 2., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 3., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 3., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 3., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 3.]], dtype=float32), dims=Array([1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.], dtype=float32))

In [None]:
combined_gaussian = combined_gaussian * g2
combined_gaussian

Gaussian(info=Array([0., 0., 0., 0., 2., 3.], dtype=float32), precision=Array([[0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 2., 0.],
       [0., 0., 0., 0., 0., 3.]], dtype=float32), dims=[Array(0., dtype=float32), Array(0., dtype=float32), Array(0., dtype=float32), Array(0., dtype=float32), Array(1., dtype=float32), Array(2., dtype=float32)])

In [None]:
combined_gaussian = combined_gaussian.marginalize(jnp.array([2.,2.,2.,2.]))