In [37]:
from typing import Tuple

import jax
import jax.numpy as jnp
from flax.struct import dataclass

In [2]:
@dataclass
class Var2FacMessages:
    poses: jnp.array
    dynamics: jnp.array


@dataclass
class Fac2VarMessages:
    poses: jnp.array
    dynamics: jnp.array

target_states = jnp.array([[8.,0.,0.,0.], [-8.,0.,0.,0.]])

state_transition = jnp.eye(4)
delta_t = 0.2
state_transition = state_transition.at[:2,2:].set(jnp.eye(2) * delta_t)

current_state1 = jnp.array([5, 0, 0.5, 0]).astype(float)
current_state2 = jnp.array([-5, 0, -0.5, 0]).astype(float)
state = jnp.stack((current_state1, current_state2))

time_horizon = 4
@jax.jit
def update_init_state(carry: jnp.array, _: int=None):
    carry = state_transition @ carry
    return carry, carry.T

_, states = jax.lax.scan(update_init_state, state.T, length=time_horizon)
states = jnp.swapaxes(states, 0, 1)

In [3]:
@dataclass
class Gaussian:
    info: jnp.ndarray
    precision: jnp.ndarray
    dims: jnp.ndarray 

    @property
    def shape(self):
        return {
            "info": self.info.shape,
            "precision": self.precision.shape,
            "dims": self.dims.shape
        }
 
    @property
    def mean(self) -> jnp.ndarray:
        return jnp.linalg.inv(self.precision) @ self.info
    
    @property
    def covariance(self) -> jnp.ndarray:
        return jnp.linalg.inv(self.precision)
    
    @staticmethod
    def identity(variable: int) -> jnp.ndarray:
        dims = jnp.array([variable, variable, variable, variable])
        return Gaussian(jnp.zeros(4), jnp.eye(4), dims)
    
    def concatenate(self, other_gaussian: "Gaussian") -> "Gaussian":
        return Gaussian(
            jnp.concatenate(self.info, other_gaussian.info),
            jnp.concatenate(self.precision, other_gaussian.precision),
            jnp.concatenate(self.dims, other_gaussian.dims)
        )
    
    def __getitem__(self, index) -> "Gaussian":
        return Gaussian(self.info[index], self.precision[index], self.dims[index])

    def __mul__(self, other: 'Gaussian') -> 'Gaussian':
        if other is None:
            return self.copy()

        # Merge dims
        dims = list(self.dims)
        for d in other.dims:
            if d not in dims:
                dims.append(d)
        
        # Extend self matrix
        prec_self = jnp.zeros((len(dims), len(dims)))
        info_self = jnp.zeros((len(dims), 1))
        idxs_self = jnp.array([dims.index(d) for d in self._dims]) # here, need to fix this
        prec_self = prec_self.at[jnp.ix_(idxs_self, idxs_self)].set(self.precision)
        info_self = info_self.at[jnp.ix_(idxs_self,jnp.array([0]))].set(self.info.reshape(-1,1))

        # Extend other matrix
        prec_other = jnp.zeros((len(dims), len(dims)))
        info_other = jnp.zeros((len(dims), 1))
        idxs_other = jnp.array([dims.index(d) for d in other._dims]) # here, need to fix this
        prec_other = prec_other.at[jnp.ix_(idxs_other, idxs_other)].set(other.precision)
        info_other = info_other.at[jnp.ix_(idxs_other, jnp.array([0]))].set(other.info.reshape(-1,1))
        # Add
        prec = prec_other + prec_self
        info = (info_other + info_self).squeeze(-1)
        return Gaussian(info, prec, dims)

    def __imul__(self, other: 'Gaussian') -> 'Gaussian':
        return self.__mul__(other)
    
    def marginalize(self, dims: Tuple) -> "Gaussian":
        info, prec = self.info, self.precision
        info = info.reshape(-1,1)
        axis_a = [idx for idx, d in enumerate(self.dims) if d not in dims]
        axis_b = [idx for idx, d in enumerate(self.dims) if d in dims]

        def axis_a_fn(kp, v):
            if v not in dims:
                return kp[0].idx
            else:
                return -1
            
        def axis_b_fn(kp, v):
            if v in dims:
                return kp[0].idx
            else:
                return -1
            
        axis_a = jnp.array(jax.tree_util.tree_map_with_path(axis_a_fn, self.dims))
        axis_b = jnp.array(jax.tree_util.tree_map_with_path(axis_b_fn, self.dims))
        axis_a = axis_a[jnp.where(axis_a != -1)]
        axis_b = axis_b[jnp.where(axis_b != -1)]

        info_a = info[jnp.ix_(axis_a, jnp.array([0]))]
        prec_aa = prec[jnp.ix_(axis_a, axis_a)]
        info_b = info[jnp.ix_(axis_b, jnp.array([0]))]
        prec_ab = prec[jnp.ix_(axis_a, axis_b)]
        prec_ba = prec[jnp.ix_(axis_b, axis_a)]
        prec_bb = prec[jnp.ix_(axis_b, axis_b)]

        prec_bb_inv = jnp.linalg.inv(prec_bb)
        info_ = info_a - prec_ab @ prec_bb_inv @ info_b
        prec_ = prec_aa - prec_ab @ prec_bb_inv @ prec_ba

        dims = tuple(i for i in self.dims if i not in dims)
        return Gaussian(info_.squeeze(-1), prec_, dims)

In [4]:
def init_var2fac_msgs():
    n_agents = states.shape[0]
    time_horizon = states.shape[1]
    pose_msgs = jax.vmap(jax.vmap(lambda _, var: Gaussian.identity(var)))(
        jnp.zeros((n_agents, 2)), jnp.repeat(jnp.array([[0, time_horizon - 1]]), n_agents, axis=0)
    )
    def create_dynamics_axes(carry: jnp.ndarray, _: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
        return carry + 1, carry
    _, dynamics_axes = jax.lax.scan(create_dynamics_axes, jnp.array([1,1]), length=time_horizon-2) 
    dynamics_axes = jnp.concat((jnp.array([[0]]), dynamics_axes.reshape((1, -1)),jnp.array([[time_horizon - 1]])), axis=1)
    dynamics_msgs = jax.vmap(jax.vmap(lambda _, var: Gaussian.identity(var)))(
        jnp.zeros((n_agents, (time_horizon - 1) * 2)), jnp.repeat(dynamics_axes, n_agents, axis=0)
    )
    return Var2FacMessages(poses=pose_msgs, dynamics=dynamics_msgs)

var2fac_msgs = init_var2fac_msgs()
assert var2fac_msgs.poses.dims.shape == (2,2,4)
assert var2fac_msgs.dynamics.dims.shape == (2,6,4)

In [5]:
var2fac_msgs.dynamics.dims[0]

Array([[0, 0, 0, 0],
       [1, 1, 1, 1],
       [1, 1, 1, 1],
       [2, 2, 2, 2],
       [2, 2, 2, 2],
       [3, 3, 3, 3]], dtype=int32)

In [6]:
N_STATES = 4
POSE_NOISE = 1e-15
DYNAMICS_NOISE = 0.005
OBSTACLE_NOISE = 0.005

In [7]:
from abc import abstractmethod

class Factor:
    def __init__(
        self, state: jnp.array, state_precision: jnp.array, dims: jnp.array, linear: bool = True
    ) -> None:
        self._state = state
        self._state_precision = state_precision
        self._linear = linear
        self._dims = dims

    def calculate_likelihood(self) -> Gaussian:
        return Gaussian(
            self._calc_info(self._state, self._state_precision),
            self._calc_precision(self._state, self._state_precision),
            self._dims
        )

    @abstractmethod
    def _calc_measurement(self, state: jnp.array) -> jnp.array:
        pass

    def _calc_info(self, state: jnp.array, precision: jnp.array) -> jnp.array:
        X = state
        if self._linear:
            eta = precision @ (jnp.zeros(N_STATES) - self._calc_measurement(state))
        else:
            J = jax.jacfwd(self._calc_measurement)(state)
            eta = (J.T @ precision) @ (
                J @ X + jnp.zeros((X.shape[0], 1)) - self.calc_measurement(state)
            )
        return eta

    def _calc_precision(self, state: jnp.array, precision: jnp.array) -> jnp.array:
        if self._linear:
            return precision
        else:
            J = jax.jacfwd(self._calc_measurement)(state)
            return J.T @ precision @ J


class PoseFactor(Factor):
    def __init__(self, state: jnp.ndarray, dims: jnp.ndarray) -> None:
        precision = jnp.pow(POSE_NOISE, -2) * jnp.eye(N_STATES)
        super(PoseFactor, self).__init__(state, precision, dims)

    def _calc_measurement(self, state: jnp.ndarray) -> jnp.ndarray:
        return state


class DynamicsFactor(Factor):
    def __init__(self, state: jnp.array, delta_t: float, dims: jnp.array) -> None:
        self.delta_t = delta_t
        process_covariance = DYNAMICS_NOISE * jnp.eye(N_STATES // 2)
        top_half = jnp.hstack(
            (
                self.delta_t**3 * process_covariance / 3,
                self.delta_t**2 * process_covariance / 2,
            )
        )
        bottom_half = jnp.hstack(
            (
                self.delta_t**2 * process_covariance / 2,
                self.delta_t * process_covariance,
            )
        )
        precision = jnp.vstack((top_half, bottom_half))
        precision = jnp.linalg.inv(precision)

        self.state_transition = jnp.eye(4)
        self.state_transition = self.state_transition.at[0:2, 2:].set(
            jnp.eye(2) * self.delta_t
        )

        super(DynamicsFactor, self).__init__(state, precision, dims)

    def _calc_measurement(self, state: jnp.array) -> jnp.array:
        prev_state = state[0:4]
        current_state = state[4:]
        return self.state_transition @ prev_state - current_state

In [8]:
@dataclass
class Factors:
    poses: PoseFactor
    dynamics: DynamicsFactor

In [9]:
# I think the sign got switched for pose b/c \eta = \lambda * (0 - h(x))
# see the - before h(x)
dims = jnp.array([0.0, 0.0, 0.0, 0.0])
pose_factor = PoseFactor(jnp.array([-5, 0, -0.5, 0]), dims).calculate_likelihood()
pose_factor.dims

Array([0., 0., 0., 0.], dtype=float32)

In [10]:
jnp.linalg.inv(pose_factor.covariance) @ pose_factor.mean, pose_factor.info

(Array([5.e+30, 0.e+00, 5.e+29, 0.e+00], dtype=float32),
 Array([5.e+30, 0.e+00, 5.e+29, 0.e+00], dtype=float32))

In [11]:
dyn_factor = DynamicsFactor(jnp.array([5.1, 0, 0.5, 0, 5.2, 0, 0.5, 0]), 0.2, jnp.array([0.0, 0.0, 0.0, 0.0])).calculate_likelihood()
dyn_factor.mean

Array([0., 0., 0., 0.], dtype=float32)

In [12]:
dyn_factor

Gaussian(info=Array([0., 0., 0., 0.], dtype=float32), precision=Array([[299999.84  ,      0.    , -29999.98  ,      0.    ],
       [     0.    , 299999.84  ,      0.    , -29999.98  ],
       [-29999.982 ,     -0.    ,   3999.9978,     -0.    ],
       [     0.    , -29999.982 ,      0.    ,   3999.9978]],      dtype=float32), dims=Array([0., 0., 0., 0.], dtype=float32))

In [13]:
dyn_factor = DynamicsFactor(jnp.array([5.1, 0, 0.5, 0, 5.5, 0, 0.5, 0]), 0.2, jnp.zeros(4,)).calculate_likelihood()
dyn_factor.mean, dyn_factor

(Array([3.0000007e-01, 0.0000000e+00, 5.2619725e-07, 0.0000000e+00],      dtype=float32),
 Gaussian(info=Array([90000.01,     0.  , -9000.  ,     0.  ], dtype=float32), precision=Array([[299999.84  ,      0.    , -29999.98  ,      0.    ],
        [     0.    , 299999.84  ,      0.    , -29999.98  ],
        [-29999.982 ,     -0.    ,   3999.9978,     -0.    ],
        [     0.    , -29999.982 ,      0.    ,   3999.9978]],      dtype=float32), dims=Array([0., 0., 0., 0.], dtype=float32)))

# Message Passing

In [14]:
var2fac_msgs.poses.dims

Array([[[0, 0, 0, 0],
        [3, 3, 3, 3]],

       [[0, 0, 0, 0],
        [3, 3, 3, 3]]], dtype=int32)

In [15]:
delta_t = 0.2
def update_factor_likelihoods(states: jnp.array) -> Factors:
    # shapes check out when run
    time_horizon = states.shape[1]
    def batch_update_factor_likelihoods(agent_states, end_pos):
        pose_combos = jnp.stack((agent_states[0], end_pos))  # [2,4]
        pose_dims = jnp.stack([jnp.zeros(4,), jnp.ones(4,) * (time_horizon - 1)])
        poses = jax.vmap(lambda x, y: PoseFactor(x, y).calculate_likelihood())(
            pose_combos, pose_dims
        )  #
        dynamic_dims = jnp.arange(0, time_horizon - 1).reshape(-1,1) * jnp.ones((time_horizon - 1, 4))
        dynamic_dims = jnp.hstack((dynamic_dims, dynamic_dims + 1))
        dynamic_combos = jnp.hstack(
            (agent_states[0:-1], agent_states[1:])
        )  # [time_horizon - 1, 8]
        dynamics = jax.vmap(
            lambda x, y: DynamicsFactor(x, delta_t, y).calculate_likelihood()
        )(dynamic_combos, dynamic_dims)
        return Factors(poses, dynamics)

    return jax.vmap(batch_update_factor_likelihoods)(states, target_states)

factors = update_factor_likelihoods(states)

In [16]:
factors.dynamics.info.shape, factors.dynamics.dims

((2, 3, 4),
 Array([[[0., 0., 0., 0., 1., 1., 1., 1.],
         [1., 1., 1., 1., 2., 2., 2., 2.],
         [2., 2., 2., 2., 3., 3., 3., 3.]],
 
        [[0., 0., 0., 0., 1., 1., 1., 1.],
         [1., 1., 1., 1., 2., 2., 2., 2.],
         [2., 2., 2., 2., 3., 3., 3., 3.]]], dtype=float32))

In [17]:
outer_idx = jnp.array([0, -1])
def update_marginals(fac2var_msgs):
    # print("{}, {}", fac2var_msgs.pose.info.shape, fac2var_msgs.pose.precision.shape)
    # print("{}, {}", fac2var_msgs.dynamics.info.shape, fac2var_msgs.dynamics.precision.shape)
    pose = fac2var_msgs.poses
    dynamics = fac2var_msgs.dynamics

    outer_info = pose.info[outer_idx] + dynamics.info[outer_idx]
    outer_precision = pose.precision[outer_idx] + dynamics.precision[outer_idx]

    inner_info = dynamics.info[1:-1:2] + dynamics.info[2:-1:2]
    inner_precision = dynamics.precision[1:-1:2] + dynamics.precision[2:-1:2]

    return Gaussian(
        info=jnp.concat((outer_info[0:1], inner_info, outer_info[-1:])),
        precision=jnp.concat((outer_precision[0:1], inner_precision, outer_precision[-1:])),
        dims=None
    )

marginals = jax.vmap(update_marginals)(var2fac_msgs)
marginals.info.shape, marginals.precision.shape, marginals

((2, 4, 4),
 (2, 4, 4, 4),
 Gaussian(info=Array([[[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],
 
        [[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]]], dtype=float32), precision=Array([[[[2., 0., 0., 0.],
          [0., 2., 0., 0.],
          [0., 0., 2., 0.],
          [0., 0., 0., 2.]],
 
         [[2., 0., 0., 0.],
          [0., 2., 0., 0.],
          [0., 0., 2., 0.],
          [0., 0., 0., 2.]],
 
         [[2., 0., 0., 0.],
          [0., 2., 0., 0.],
          [0., 0., 2., 0.],
          [0., 0., 0., 2.]],
 
         [[2., 0., 0., 0.],
          [0., 2., 0., 0.],
          [0., 0., 2., 0.],
          [0., 0., 0., 2.]]],
 
 
        [[[2., 0., 0., 0.],
          [0., 2., 0., 0.],
          [0., 0., 2., 0.],
          [0., 0., 0., 2.]],
 
         [[2., 0., 0., 0.],
          [0., 2., 0., 0.],
          [0., 0., 2., 0.],
          [0., 0., 0., 2.]],
 
         [[2., 0., 0

In [18]:
# n_agents = 2
# time_horizon = 4
# info is a 4-vector, precision is 4x4 matrix

# pose only has 2 factor
# dynamics is time_horizon - 1 factors
factors.dynamics.precision.shape

(2, 3, 4, 4)

In [35]:
var2fac_msgs.dynamics.info.shape

(2, 6, 4)

In [20]:
def update_factor_to_var_messages(
    var2fac_msgs: Var2FacMessages,
    factors: Factors
) -> Fac2VarMessages:
    def batched_update_factor_to_var_messages(agent_var2fac_msgs: Var2FacMessages, factors: Factors) -> Fac2VarMessages:
        poses = agent_var2fac_msgs.poses
        dynamics = agent_var2fac_msgs.dynamics

        updated_poses = factors.poses 
        updated_dynamics = None # poses 
        # inner_dynamics = dynamics[jnp.array([2,1,4,3])]
        # updated_dynamics = jax.tree_util.tree_map(lambda x, y, z: jnp.concatenate((x, y, z)), outer_dynamics[0:1], inner_dynamics, outer_dynamics[1:])
        return Fac2VarMessages(updated_poses,updated_poses)
    return jax.vmap(batched_update_factor_to_var_messages)(var2fac_msgs, factors)
    # outer vmap is over the agents
    # so the batched_function basically takes care of each individual factor graph

updated_fac2_var_msgs = update_factor_to_var_messages(var2fac_msgs, factors)

In [None]:
var2fac_msgs.dynamics.shape

In [None]:
factors.dynamics

In [None]:
jax.vmap(lambda f: DynamicsFactor(f).calculate_likelihood())(factors.dynamics)

In [23]:
updated_fac2_var_msgs.poses[0][0].shape

{'info': (4,), 'precision': (4, 4), 'dims': (4,)}

In [77]:
def update_var_to_factor_messages(
    fac2var_msgs: Fac2VarMessages
) -> Var2FacMessages:
    def batched_update_var_to_factor_messages(agent_fac2var_msgs: Fac2VarMessages):
        poses = agent_fac2var_msgs.poses
        dynamics = agent_fac2var_msgs.dynamics

        updated_poses = dynamics[outer_idx] # jnp.array([0, -1])
        outer_dynamics = poses 
        inner_dynamics = dynamics[jnp.array([2,1,4,3])]
        updated_dynamics = jax.tree_util.tree_map(lambda x, y, z: jnp.concatenate((x, y, z)), outer_dynamics[0:1], inner_dynamics, outer_dynamics[1:])
        return Var2FacMessages(updated_poses, updated_dynamics)

    return jax.vmap(batched_update_var_to_factor_messages)(fac2var_msgs)

updated_var2fac_msgs = update_var_to_factor_messages(var2fac_msgs) # TODO: change to fac2var messages for args
updated_var2fac_msgs.poses.shape, updated_var2fac_msgs.dynamics.shape

{'dims': (Array(1, dtype=int32, weak_type=True), Array(4, dtype=int32, weak_type=True)), 'info': (Array(1, dtype=int32, weak_type=True), Array(4, dtype=int32, weak_type=True)), 'precision': (Array(1, dtype=int32, weak_type=True), Array(4, dtype=int32, weak_type=True), Array(4, dtype=int32, weak_type=True))}
{'dims': (Array(4, dtype=int32, weak_type=True), Array(4, dtype=int32, weak_type=True)), 'info': (Array(4, dtype=int32, weak_type=True), Array(4, dtype=int32, weak_type=True)), 'precision': (Array(4, dtype=int32, weak_type=True), Array(4, dtype=int32, weak_type=True), Array(4, dtype=int32, weak_type=True))}
{'dims': (Array(1, dtype=int32, weak_type=True), Array(4, dtype=int32, weak_type=True)), 'info': (Array(1, dtype=int32, weak_type=True), Array(4, dtype=int32, weak_type=True)), 'precision': (Array(1, dtype=int32, weak_type=True), Array(4, dtype=int32, weak_type=True), Array(4, dtype=int32, weak_type=True))}


({'info': (2, 2, 4), 'precision': (2, 2, 4, 4), 'dims': (2, 2, 4)},
 {'info': (2, 6, 4), 'precision': (2, 6, 4, 4), 'dims': (2, 6, 4)})

In [36]:
var2fac_msgs.dynamics.shape

{'info': (2, 6, 4), 'precision': (2, 6, 4, 4), 'dims': (2, 6, 4)}

In [50]:
var2fac_msgs.dynamics[0].shape

{'info': (6, 4), 'precision': (6, 4, 4), 'dims': (6, 4)}

In [66]:
@dataclass
class Metric:
    score1: jnp.ndarray
    score2: jnp.ndarray
    score3: jnp.ndarray

a = Metric(score1=jnp.array([10,10,10]), score2=jnp.array([20,20,20]), score3=jnp.array([30,30,30]))
b = Metric(score1=jnp.array([10,10,10]), score2=jnp.array([20,20,20]), score3=jnp.array([30,30,30]))

jax.tree_util.tree_map(lambda x, y: jnp.concatenate([x, y]), a, b)

Metric(score1=Array([10, 10, 10, 10, 10, 10], dtype=int32), score2=Array([20, 20, 20, 20, 20, 20], dtype=int32), score3=Array([30, 30, 30, 30, 30, 30], dtype=int32))

In [78]:
a = [1,2,3,4,5]
b = [4,5,6,7,8]

jax.tree_util.tree_map(lambda x, y: x * y, a, b)

[4, 10, 18, 28, 40]