In [24]:
import jax
import jax.numpy as jnp
from flax.struct import dataclass

In [25]:
@dataclass
class Gaussian:
    info: jnp.array
    precision: jnp.array

    def __add__(self, other_gaussian: "Gaussian"):
        return Gaussian(self.info + other_gaussian.info, self.precision + other_gaussian.precision)
    
    @property
    def mean(self):
        return jnp.linalg.inv(self.precision) @ self.info
    
    @property
    def covariance(self):
        return jnp.linalg.inv(self.precision)
    
    @staticmethod
    def identity():
        return Gaussian(jnp.zeros(4), jnp.eye(4))

def add_one_to_info(g: Gaussian) -> Gaussian:
    other_gaussian = Gaussian(jnp.ones((4,)), jnp.zeros((4,4)))
    return g + other_gaussian

def add_one_to_precision(g: Gaussian) -> Gaussian:
    other_gaussian = Gaussian(jnp.zeros((4,)), jnp.ones((4,4)))
    return g + other_gaussian

batch_add_info = jax.vmap(add_one_to_info)
batch_add_precision = jax.vmap(add_one_to_precision)

In [3]:
g = Gaussian(jnp.ones((2,4)), jnp.stack((jnp.eye(4), jnp.eye(4))))
jax.debug.print("{}", batch_add_info(g))
jax.debug.print("{}", batch_add_precision(g))

Gaussian(info=Array([[2., 2., 2., 2.],
       [2., 2., 2., 2.]], dtype=float32), precision=Array([[[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]],

       [[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]]], dtype=float32))
Gaussian(info=Array([[1., 1., 1., 1.],
       [1., 1., 1., 1.]], dtype=float32), precision=Array([[[2., 1., 1., 1.],
        [1., 2., 1., 1.],
        [1., 1., 2., 1.],
        [1., 1., 1., 2.]],

       [[2., 1., 1., 1.],
        [1., 2., 1., 1.],
        [1., 1., 2., 1.],
        [1., 1., 1., 2.]]], dtype=float32))


In [64]:
outer_idx = jnp.array([0, -1])
dummy_state = jnp.zeros((2, 10, 4))
pose_msgs = jax.vmap(jax.vmap(lambda _: Gaussian.identity()))(dummy_state[:,outer_idx])
jax.vmap(lambda x: jax.debug.print("{}, {}", x.info.shape, x.precision.shape))(pose_msgs)

(Array(2, dtype=int32, weak_type=True), Array(4, dtype=int32, weak_type=True)), (Array(2, dtype=int32, weak_type=True), Array(4, dtype=int32, weak_type=True), Array(4, dtype=int32, weak_type=True))


In [6]:
N_STATES = 4
POSE_NOISE = 1e-15
DYNAMICS_NOISE = 0.005
OBSTACLE_NOISE = 0.005

In [138]:
from abc import abstractmethod

class Factor:
    def __init__(
        self, state: jnp.array, state_precision: jnp.array, linear: bool = True
    ) -> None:
        self._state = state
        self._state_precision = state_precision
        self._linear = linear

    def calculate_likelihood(self) -> Gaussian:
        return Gaussian(
            self._calc_info(self._state, self._state_precision),
            self._calc_precision(self._state, self._state_precision),
        )

    @abstractmethod
    def _calc_measurement(self, state: jnp.array) -> jnp.array:
        pass

    def _calc_info(self, state: jnp.array, precision: jnp.array) -> jnp.array:
        X = state
        if self._linear:
            eta = precision @ (jnp.zeros(N_STATES) - self._calc_measurement(state))
        else:
            J = jax.jacfwd(self._calc_measurement)(state)
            eta = (J.T @ precision) @ (
                J @ X + jnp.zeros((X.shape[0], 1)) - self.calc_measurement(state)
            )
        return eta

    def _calc_precision(self, state: jnp.array, precision: jnp.array) -> jnp.array:
        if self._linear:
            return precision
        else:
            J = jax.jacfwd(self._calc_measurement)(state)
            return J.T @ precision @ J


class PoseFactor(Factor):
    def __init__(self, state: jnp.array) -> None:
        precision = jnp.pow(POSE_NOISE, -2) * jnp.eye(N_STATES)
        super(PoseFactor, self).__init__(state, precision)

    def _calc_measurement(self, state: jnp.array) -> jnp.array:
        return state


class DynamicsFactor(Factor):
    def __init__(self, state: jnp.array, delta_t: float) -> None:
        self.delta_t = delta_t
        process_covariance = DYNAMICS_NOISE * jnp.eye(N_STATES // 2)
        top_half = jnp.hstack(
            (
                self.delta_t**3 * process_covariance / 3,
                self.delta_t**2 * process_covariance / 2,
            )
        )
        bottom_half = jnp.hstack(
            (
                self.delta_t**2 * process_covariance / 2,
                self.delta_t * process_covariance,
            )
        )
        precision = jnp.vstack((top_half, bottom_half))
        precision = jnp.linalg.inv(precision)

        self.state_transition = jnp.eye(4)
        self.state_transition = self.state_transition.at[0:2, 2:].set(
            jnp.eye(2) * self.delta_t
        )

        super(DynamicsFactor, self).__init__(state, precision)

    def _calc_measurement(self, state: jnp.array) -> jnp.array:
        prev_state = state[0:4]
        current_state = state[4:]
        return self.state_transition @ prev_state - current_state

In [139]:
# I think the sign got switched for pose b/c \eta = \lambda * (0 - h(x))
# see the - before h(x)
pose_factor = PoseFactor(jnp.array([-5, 0, -0.5, 0])).calculate_likelihood()
pose_factor.mean

Array([5. , 0. , 0.5, 0. ], dtype=float32)

In [141]:
jnp.linalg.inv(pose_factor.covariance) @ pose_factor.mean, pose_factor.info

(Array([5.e+30, 0.e+00, 5.e+29, 0.e+00], dtype=float32),
 Array([5.e+30, 0.e+00, 5.e+29, 0.e+00], dtype=float32))

In [142]:
jax.vmap(lambda x: PoseFactor(x).calculate_likelihood())(jnp.array([[-5, 0, -0.5, 0], [-8, 0, 0, 0]]))

Gaussian(info=Array([[5.e+30, 0.e+00, 5.e+29, 0.e+00],
       [8.e+30, 0.e+00, 0.e+00, 0.e+00]], dtype=float32), precision=Array([[[1.e+30, 0.e+00, 0.e+00, 0.e+00],
        [0.e+00, 1.e+30, 0.e+00, 0.e+00],
        [0.e+00, 0.e+00, 1.e+30, 0.e+00],
        [0.e+00, 0.e+00, 0.e+00, 1.e+30]],

       [[1.e+30, 0.e+00, 0.e+00, 0.e+00],
        [0.e+00, 1.e+30, 0.e+00, 0.e+00],
        [0.e+00, 0.e+00, 1.e+30, 0.e+00],
        [0.e+00, 0.e+00, 0.e+00, 1.e+30]]], dtype=float32))

In [143]:
dyn_factor = DynamicsFactor(jnp.array([5.1, 0, 0.5, 0, 5.2, 0, 0.5, 0]), 0.2).calculate_likelihood()
dyn_factor.mean

Array([0., 0., 0., 0.], dtype=float32)

In [144]:
dyn_factor

Gaussian(info=Array([0., 0., 0., 0.], dtype=float32), precision=Array([[299999.84  ,      0.    , -29999.98  ,      0.    ],
       [     0.    , 299999.84  ,      0.    , -29999.98  ],
       [-29999.982 ,     -0.    ,   3999.9978,     -0.    ],
       [     0.    , -29999.982 ,      0.    ,   3999.9978]],      dtype=float32))

In [44]:
dyn_factor = DynamicsFactor(jnp.array([5.1, 0, 0.5, 0, 5.5, 0, 0.5, 0]), 0.2)
dyn_factor.mean, dyn_factor

(Array([3.0000007e-01, 0.0000000e+00, 5.2619725e-07, 0.0000000e+00],      dtype=float32),
 DynamicsFactor(info=Array([90000.01,     0.  , -9000.  ,     0.  ], dtype=float32), precision=Array([[299999.84  ,      0.    , -29999.98  ,      0.    ],
        [     0.    , 299999.84  ,      0.    , -29999.98  ],
        [-29999.982 ,     -0.    ,   3999.9978,     -0.    ],
        [     0.    , -29999.982 ,      0.    ,   3999.9978]],      dtype=float32)))

# Message Passing

In [96]:
@dataclass
class Var2FacMessages:
    poses: jnp.array
    dynamics: jnp.array


@dataclass
class Fac2VarMessages:
    poses: jnp.array
    dynamics: jnp.array

@dataclass
class Factors:
    poses: PoseFactor
    dynamics: DynamicsFactor

In [97]:
end_pos = jnp.array([[8.,0.,0.,0.], [-8.,0.,0.,0.]])

In [90]:
state_transition = jnp.eye(4)
delta_t = 0.2
state_transition = state_transition.at[:2,2:].set(jnp.eye(2) * delta_t)

current_state1 = jnp.array([5, 0, 0.5, 0]).astype(float)
current_state2 = jnp.array([-5, 0, -0.5, 0]).astype(float)
state = jnp.stack((current_state1, current_state2))

time_horizon = 4
@jax.jit
def update_init_state(carry: jnp.array, _: int=None):
    carry = state_transition @ carry
    return carry, carry.T

_, states = jax.lax.scan(update_init_state, state.T, length=time_horizon)
states = jnp.swapaxes(states, 0, 1)

In [124]:
def init_var2fac_msgs():
    n_agents = states.shape[0]
    time_horizon = states.shape[1]
    pose_msgs = jax.vmap(jax.vmap(lambda _: Gaussian.identity()))(
        jnp.zeros((n_agents, 2))
    )
    dynamics_msgs = jax.vmap(jax.vmap(lambda _: Gaussian.identity()))(
        jnp.zeros((n_agents, (time_horizon - 1) * 2))
    )
    return Var2FacMessages(poses=pose_msgs, dynamics=dynamics_msgs)

var2fac_msgs = init_var2fac_msgs()

In [145]:
delta_t = 0.2
def update_factor_likelihoods(states):
    def batch_update_factor_likelihoods(agent_states, end_pos):
        poses = jax.vmap(lambda x: PoseFactor(x).calculate_likelihood())(jnp.stack((agent_states[0], end_pos)))
        agent_state_combos = jnp.hstack((agent_states[0:-1], agent_states[1:]))
        dynamics = jax.vmap(lambda x: DynamicsFactor(x, delta_t).calculate_likelihood())(agent_state_combos)
        return Factors(poses, dynamics)
    return jax.vmap(batch_update_factor_likelihoods)(states, end_pos)

factors = update_factor_likelihoods(states)

In [137]:
factors.dynamics.info.shape

(2, 3, 4)

In [92]:
def update_marginals(fac2var_msgs):
    print("{}, {}", fac2var_msgs.pose.info.shape, fac2var_msgs.pose.precision.shape)
    print("{}, {}", fac2var_msgs.dynamics.info.shape, fac2var_msgs.dynamics.precision.shape)
    pose = fac2var_msgs.pose
    dynamics = fac2var_msgs.dynamics

    outer_info = pose.info[outer_idx] + dynamics.info[outer_idx]
    outer_precision = pose.precision[outer_idx] + dynamics.precision[outer_idx]

    inner_info = dynamics.info[1:-1:2] + dynamics.info[2:-1:2]
    inner_precision = dynamics.precision[1:-1:2] + dynamics.precision[2:-1:2]

    return Gaussian(
        info=jnp.concat((outer_info[0:1], inner_info, outer_info[-1:])),
        precision=jnp.concat((outer_precision[0:1], inner_precision, outer_precision[-1:])),
    )

marginals = jax.vmap(update_marginals)(var2fac_msgs)
marginals.info.shape, marginals.precision.shape

{}, {} (2, 4) (2, 4, 4)
{}, {} (6, 4) (6, 4, 4)


((2, 4, 4), (2, 4, 4, 4))