In [24]:
import jax
import jax.numpy as jnp
from flax.struct import dataclass

In [25]:
@dataclass
class Gaussian:
    info: jnp.array
    precision: jnp.array

    def __add__(self, other_gaussian: "Gaussian"):
        return Gaussian(self.info + other_gaussian.info, self.precision + other_gaussian.precision)
    
    @property
    def mean(self):
        return jnp.linalg.inv(self.precision) @ self.info
    
    @property
    def covariance(self):
        return jnp.linalg.inv(self.precision)
    
    @staticmethod
    def identity():
        return Gaussian(jnp.zeros(4), jnp.eye(4))

def add_one_to_info(g: Gaussian) -> Gaussian:
    other_gaussian = Gaussian(jnp.ones((4,)), jnp.zeros((4,4)))
    return g + other_gaussian

def add_one_to_precision(g: Gaussian) -> Gaussian:
    other_gaussian = Gaussian(jnp.zeros((4,)), jnp.ones((4,4)))
    return g + other_gaussian

batch_add_info = jax.vmap(add_one_to_info)
batch_add_precision = jax.vmap(add_one_to_precision)

In [3]:
g = Gaussian(jnp.ones((2,4)), jnp.stack((jnp.eye(4), jnp.eye(4))))
jax.debug.print("{}", batch_add_info(g))
jax.debug.print("{}", batch_add_precision(g))

Gaussian(info=Array([[2., 2., 2., 2.],
       [2., 2., 2., 2.]], dtype=float32), precision=Array([[[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]],

       [[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]]], dtype=float32))
Gaussian(info=Array([[1., 1., 1., 1.],
       [1., 1., 1., 1.]], dtype=float32), precision=Array([[[2., 1., 1., 1.],
        [1., 2., 1., 1.],
        [1., 1., 2., 1.],
        [1., 1., 1., 2.]],

       [[2., 1., 1., 1.],
        [1., 2., 1., 1.],
        [1., 1., 2., 1.],
        [1., 1., 1., 2.]]], dtype=float32))


In [6]:
N_STATES = 4
POSE_NOISE = 1e-15
DYNAMICS_NOISE = 0.005
OBSTACLE_NOISE = 0.005

In [26]:
from abc import abstractmethod

class Factor(Gaussian):
    def __init__(self, state: jnp.array, precision: jnp.array, linear:bool=True):
        self.linear = linear 
        super().__init__(self._calc_info(state, precision), self._calc_precision(state, precision))
    
    @abstractmethod
    def _calc_measurement(self, state):
        pass

    def _calc_info(self, state, precision):
        X = state
        if self.linear:
            eta = precision @ (jnp.zeros(N_STATES) - self._calc_measurement(state))
        else:
            J = jax.jacfwd(self._calc_measurement)(state)
            eta = (J.T @ precision) @ (J @ X + jnp.zeros((X.shape[0], 1)) - self.calc_measurement(state))
        return eta
    
    def _calc_precision(self, state, precision):
        if self.linear:
            return precision
        else:
            J = jax.jacfwd(self._calc_measurement)(state)
            return J.T @ precision @ J

class PoseFactor(Factor):
    def __init__(self, state: jnp.array):
        precision = jnp.pow(POSE_NOISE, -2) * jnp.eye(N_STATES)
        super(PoseFactor, self).__init__(state, precision)
    
    def _calc_measurement(self, state):
        return state

In [27]:
pose_factor = PoseFactor(jnp.array([-5, 0, -0.5, 0]))
pose_factor.mean

Array([5. , 0. , 0.5, 0. ], dtype=float32)

In [None]:
class DynamicsFactor(Factor):
    def __init__(
        self,
        delta_t: float
    ) -> None:
        self.delta_t = delta_t
        process_covariance = DYNAMICS_NOISE * jnp.eye(N_STATES//2)
        top_half = jnp.hstack(
            (
                self.delta_t**3 * process_covariance / 3,
                self.delta_t**2 * process_covariance / 2,
            )
        )
        bottom_half = jnp.hstack(
            (
                self.delta_t**2 * process_covariance / 2,
                self.delta_t * process_covariance,
            )
        )
        precision = jnp.vstack((top_half, bottom_half))
        precision = jnp.linalg.inv(precision)
        super(DynamicsFactor, self).__init__(precision)

        self.state_transition = jnp.eye(4)
        self.state_transition = self.state_transition.at[0:2,2:] * self.delta_t

    def calc_measurement(self, state: jnp.array):
        prev_state = state[0:4]
        current_state = state[4:]
        return self.state_transition @ prev_state - current_state