In [1]:
from sys import path
path.append("../")

In [2]:
import jax
import jax.numpy as jnp
from flax.struct import dataclass 

from abc import abstractmethod

In [3]:
from fg import Gaussian

# Inter Robot

In [4]:
x = jnp.array([
    [1,1],
    [2,2],
    [3,3],
    [4,4]
])
states = jnp.stack((x, x + 0.1, x + 0.3))
vels = jnp.zeros_like(states)
states = jnp.concat((states, vels), axis=-1)
states, states.shape

(Array([[[1. , 1. , 0. , 0. ],
         [2. , 2. , 0. , 0. ],
         [3. , 3. , 0. , 0. ],
         [4. , 4. , 0. , 0. ]],
 
        [[1.1, 1.1, 0. , 0. ],
         [2.1, 2.1, 0. , 0. ],
         [3.1, 3.1, 0. , 0. ],
         [4.1, 4.1, 0. , 0. ]],
 
        [[1.3, 1.3, 0. , 0. ],
         [2.3, 2.3, 0. , 0. ],
         [3.3, 3.3, 0. , 0. ],
         [4.3, 4.3, 0. , 0. ]]], dtype=float32, weak_type=True),
 (3, 4, 4))

In [5]:
def find_closest_robot(states: jnp.ndarray):
    def find_closest_robot_across_horizon(robot, other_robots):
        closest_index = jnp.argmin(jnp.linalg.norm(robot[0:2] - other_robots[:,0:2], axis=1))
        return closest_index

    def find_batched_closest_robot(batch_states, i):
        modified_states = batch_states.at[i].set(jnp.inf)
        return jax.vmap(find_closest_robot_across_horizon, in_axes=(0, 1))(batch_states[i], modified_states)
    return jax.vmap(find_batched_closest_robot, in_axes=(None, 0))(states, jnp.arange(states.shape[0]))
closest_robots = find_closest_robot(states) # ideally, should be # (N - 1, 4, 2)
closest_robots

Array([[1, 1, 1, 1],
       [0, 0, 0, 0],
       [1, 1, 1, 1]], dtype=int32)

In [6]:
N_STATES = 4
POSE_NOISE = 1e-15
DYNAMICS_NOISE = 0.005
OBSTACLE_NOISE = 0.005
INTER_ROBOT_NOISE = 0.005

class Factor:
    def __init__(
        self, state: jnp.array, state_precision: jnp.ndarray, dims: jnp.array, linear: bool = True
    ) -> None:
        self._state = state
        self._state_precision = state_precision
        self._linear = linear
        self._dims = dims

    def calculate_likelihood(self) -> Gaussian:
        return Gaussian(
            self._calc_info(self._state, self._state_precision),
            self._calc_precision(self._state, self._state_precision),
            self._dims
        )

    @abstractmethod
    def _calc_measurement(self, state: jnp.ndarray) -> jnp.ndarray:
        pass

    def _calc_info(self, state: jnp.ndarray, precision: jnp.ndarray) -> jnp.ndarray:
        X = state
        if self._linear:
            eta = precision @ (jnp.zeros(N_STATES) - self._calc_measurement(state))
        else:
            J = jax.jacfwd(self._calc_measurement)(state)
            eta = (J.T @ precision) @ (
                (J @ X.reshape((-1,1))) + 0 - self._calc_measurement(state).reshape((-1,1))
            )
        return eta.squeeze()

    def _calc_precision(self, state: jnp.ndarray, precision: jnp.ndarray) -> jnp.ndarray:
        if self._linear:
            return precision
        else:
            J = jax.jacfwd(self._calc_measurement)(state)
            return J.T @ precision @ J

class InterRobotFactor:
    def __init__(
        self,
        state: jnp.ndarray,
        critical_distance: float,
        t: jnp.ndarray, #ndarray just to hold time,
        dims: jnp.ndarray,
    ) -> None:
        self._crit_distance = critical_distance
        self._z_precision = 100

        self._dist = self._calc_dist(state[0:4], state[4:])
        dx, dy = (state[0] - state[4])/self._dist, (state[1] - state[5])/self._dist
        self._J = jnp.array([[-dx/self._crit_distance, -dy/self._crit_distance, 0, 0,
                              dx/self._crit_distance, dy/self._crit_distance, 0, 0]])

        self._state = state
        # self._precision = jnp.pow(t * INTER_ROBOT_NOISE, -2) * jnp.eye(N_STATES)
        self._state_precision = self._z_precision * jnp.eye(1)
        self._dims = dims
    
    def calculate_likelihood(self) -> Gaussian:
        return Gaussian(
            self._calc_info(self._state, self._state_precision),
            self._calc_precision(self._state, self._state_precision),
            self._dims
        )
 
    def _calc_dist(self, state: jnp.array, other_state: jnp.array):
        return jnp.linalg.norm(state[0:2] - other_state[0:2]) 
    
    def _calc_info(self, state: jnp.ndarray, state_precision: jnp.ndarray) -> jnp.ndarray:
        def safe_fn():
            return (state_precision @ state[jnp.newaxis,:]).squeeze()
        def unsafe_fn():
            return (self._J.T @ state_precision @ (self._J @ state[:,jnp.newaxis] - self._calc_measurement(state))).squeeze()
        info = jax.lax.select(self._dist >= self._crit_distance, safe_fn(), unsafe_fn())
        return info
        
    def _calc_precision(self, state: jnp.ndarray, state_precision: jnp.ndarray) -> jnp.ndarray:
        unsafe_precision = self._J.T @ state_precision @ self._J 
        unsafe_precision = unsafe_precision.at[2:4,2:4].set(jnp.eye(2)).at[6:,6:].set(jnp.eye(2))

        # Update A
        unsafe_precision = unsafe_precision.at[:2,:2].set(unsafe_precision[0, 0])

        # Update B
        unsafe_precision = unsafe_precision.at[4:6,4:6].set(unsafe_precision[4,4])

        # Update C
        unsafe_precision = unsafe_precision.at[0:2, 4:6].set(unsafe_precision[0, 4])

        # Update D
        unsafe_precision = unsafe_precision.at[4:6, 0:2].set(unsafe_precision[4, 0])

        precision = jax.lax.select(self._dist >= self._crit_distance, jnp.eye(8), unsafe_precision)
        return precision @ precision.T

    def _calc_measurement(self, state: jnp.ndarray):
        current_state = state[0:4]
        other_state = state[4:]
        dist = self._calc_dist(current_state, other_state)
        measurement = jax.lax.select(
            dist < self._crit_distance, 1.0 - dist / self._crit_distance, 0.
        )
        return measurement

In [7]:
ir_factor =  InterRobotFactor(jnp.array([-5.0, 0., 0.5, 0.0, 5.0, 0.0, -0.5, 0.0]), 1.5, jnp.array([1.]), jnp.array([1., 1., 1., 1., 100., 100., 100., 100.]))
g = ir_factor.calculate_likelihood()
g

Gaussian(info=Array([-500.,    0.,   50.,    0.,  500.,    0.,  -50.,    0.], dtype=float32), precision=Array([[1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1.]], dtype=float32), dims=Array([  1.,   1.,   1.,   1., 100., 100., 100., 100.], dtype=float32))

In [8]:
g = g.marginalize(jnp.array([1., 1., 1., 1.]))
g

Gaussian(info=Array([500.,   0., -50.,   0.], dtype=float32), precision=Array([[1., 0., 0., 0.],
       [0., 1., 0., 0.],
       [0., 0., 1., 0.],
       [0., 0., 0., 1.]], dtype=float32), dims=Array([100., 100., 100., 100.], dtype=float32))

In [9]:
jnp.linalg.inv(g.precision)

Array([[1., 0., 0., 0.],
       [0., 1., 0., 0.],
       [0., 0., 1., 0.],
       [0., 0., 0., 1.]], dtype=float32)

In [10]:
ir_factor =  InterRobotFactor(jnp.array([-0.1, 0., 0.1, 0.0, 0.1, 0.0, -0.1, 0.0]), 1.5, jnp.array([1.]), jnp.array([1., 1., 1., 1., 100., 100., 100., 100.]))
g = ir_factor.calculate_likelihood()
g

Gaussian(info=Array([-66.66667,   0.     ,  -0.     ,  -0.     ,  66.66667,  -0.     ,
        -0.     ,  -0.     ], dtype=float32), precision=Array([[ 7.901237e+03,  7.901237e+03,  0.000000e+00,  0.000000e+00,
        -7.901237e+03, -7.901237e+03,  0.000000e+00,  0.000000e+00],
       [ 7.901237e+03,  7.901237e+03,  0.000000e+00,  0.000000e+00,
        -7.901237e+03, -7.901237e+03,  0.000000e+00,  0.000000e+00],
       [ 0.000000e+00,  0.000000e+00,  1.000000e+00,  0.000000e+00,
         0.000000e+00,  0.000000e+00,  0.000000e+00,  0.000000e+00],
       [ 0.000000e+00,  0.000000e+00,  0.000000e+00,  1.000000e+00,
         0.000000e+00,  0.000000e+00,  0.000000e+00,  0.000000e+00],
       [-7.901237e+03, -7.901237e+03,  0.000000e+00,  0.000000e+00,
         7.901237e+03,  7.901237e+03,  0.000000e+00,  0.000000e+00],
       [-7.901237e+03, -7.901237e+03,  0.000000e+00,  0.000000e+00,
         7.901237e+03,  7.901237e+03,  0.000000e+00,  0.000000e+00],
       [ 0.000000e+00,  0.000000e+0

In [11]:
g = g.marginalize(jnp.array([1., 1., 1., 1.]))
g

Gaussian(info=Array([66.66667, -0.     , -0.     , -0.     ], dtype=float32), precision=Array([[0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 1., 0.],
       [0., 0., 0., 1.]], dtype=float32), dims=Array([100., 100., 100., 100.], dtype=float32))

In [12]:
jnp.linalg.inv(g.precision)

Array([[nan, nan, nan, nan],
       [nan, nan, nan, nan],
       [ 0.,  0.,  1.,  0.],
       [ 0.,  0.,  0.,  1.]], dtype=float32)

# Obstacle

In [13]:
obstacle = jnp.array([
    [4.2, 4.2],
    [3.0, 3.0],
    [1.0, 1.0]
])

def find_closest_obstacle(states, obstacles):
    def find_closest_obstacle_for_pt(state_t, obstacles):
        closest_obstacle_idx = jnp.argmin(jnp.linalg.norm(state_t[0:2] - obstacles, axis=1))
        return obstacles[closest_obstacle_idx]
    def batch_find_closest_obstacle(agent_states, obstacles):
        return jax.vmap(find_closest_obstacle_for_pt, in_axes=(0, None))(agent_states, obstacles)
    return jax.vmap(batch_find_closest_obstacle, in_axes=(0, None))(states, obstacles)
closest_obstacle = find_closest_obstacle(states, obstacle)
closest_obstacle, closest_obstacle.shape

(Array([[[1. , 1. ],
         [3. , 3. ],
         [3. , 3. ],
         [4.2, 4.2]],
 
        [[1. , 1. ],
         [3. , 3. ],
         [3. , 3. ],
         [4.2, 4.2]],
 
        [[1. , 1. ],
         [3. , 3. ],
         [3. , 3. ],
         [4.2, 4.2]]], dtype=float32),
 (3, 4, 2))

In [14]:
states

Array([[[1. , 1. , 0. , 0. ],
        [2. , 2. , 0. , 0. ],
        [3. , 3. , 0. , 0. ],
        [4. , 4. , 0. , 0. ]],

       [[1.1, 1.1, 0. , 0. ],
        [2.1, 2.1, 0. , 0. ],
        [3.1, 3.1, 0. , 0. ],
        [4.1, 4.1, 0. , 0. ]],

       [[1.3, 1.3, 0. , 0. ],
        [2.3, 2.3, 0. , 0. ],
        [3.3, 3.3, 0. , 0. ],
        [4.3, 4.3, 0. , 0. ]]], dtype=float32, weak_type=True)

In [15]:
@dataclass
class Gaussian:
    info: jnp.ndarray
    precision: jnp.ndarray
    dims: jnp.ndarray 

    @property
    def shape(self):
        return {
            "info": self.info.shape,
            "precision": self.precision.shape,
            "dims": self.dims.shape
        }
 
    @property
    def mean(self) -> jnp.ndarray:
        return jnp.linalg.inv(self.precision) @ self.info
    
    @property
    def covariance(self) -> jnp.ndarray:
        return jnp.linalg.inv(self.precision)
    
    @staticmethod
    def identity(variable: int) -> jnp.ndarray:
        dims = jnp.array([variable, variable, variable, variable])
        return Gaussian(jnp.zeros(4), jnp.eye(4), dims)
    
    def __call__(self, x: jnp.ndarray) -> float:
        return 1/2 * x.T @ self.precision @ x - self.info.T @ x

    def __getitem__(self, index) -> "Gaussian":
        return Gaussian(self.info[index], self.precision[index], self.dims[index])

    def __mul__(self, other: 'Gaussian') -> 'Gaussian':
        if other is None:
            return self.copy()
        
        dims = self.dims.copy()
        other_dims = other.dims.copy()
        other_unique_val = jnp.unique(other_dims, size=other_dims.shape[0]//4)[0]
        
        if dims.shape[0] == other_dims.shape[0] and dims[0] == other_dims[0]:
            idxs_self = idxs_other = jnp.arange(len(dims))
        elif dims.shape[0] == 8:
            idxs_self = jnp.arange(len(dims), dtype=int)
            idxs_other = jnp.where(dims == other_unique_val, size=4)[0]
        else:
            idxs_self = jnp.arange(len(dims), dtype=int)
            idxs_other = jnp.arange(len(dims), len(dims) + len(other_dims)) # jax.lax.select((dims == other_dims).sum() == dims.shape[0], jnp.arange(len(dims)), jnp.arange(len(dims), len(dims) + len(other_dims)))
            dims = jnp.concatenate((dims, other_dims))

            # if idxs_other[-1] != idxs_self[-1]:
            #     dims = jnp.concatenate((dims, other_dims))
            
        
        # Extend self matrix
        prec_self = jnp.zeros((len(dims), len(dims)))
        info_self = jnp.zeros((len(dims), 1))
        
        prec_self = prec_self.at[jnp.ix_(idxs_self, idxs_self)].set(self.precision)
        info_self = info_self.at[jnp.ix_(idxs_self, jnp.zeros(1, dtype=int))].set(self.info.reshape(len(self.dims), 1))
        # Extend other matrix
        prec_other = jnp.zeros((len(dims), len(dims)))
        info_other = jnp.zeros((len(dims), 1))
        
        prec_other = prec_other.at[jnp.ix_(idxs_other, idxs_other)].set(other.precision)
        info_other = info_other.at[jnp.ix_(idxs_other, jnp.zeros((1), dtype=int))].set(other.info.reshape(len(other.dims), 1))
        # Add
        prec = prec_other + prec_self
        info = info_other + info_self
        return Gaussian(info, prec, dims.astype(float))
    
    def mul(self, other: 'Gaussian', combine_dims: bool) -> 'Gaussian':
        if other is None:
            return self.copy()
        
        dims = self.dims.copy()
        other_dims = other.dims.copy()
        other_unique_val = jnp.unique(other_dims, size=other_dims.shape[0]//4)[0]
        
        if dims.shape[0] == 8:
            idxs_self = jnp.arange(len(dims), dtype=int)
            idxs_other = jnp.where(dims == other_unique_val, size=4)[0]
        else:
            idxs_self = jnp.arange(len(dims), dtype=int)
            idxs_other = jax.lax.select((dims == other_dims).sum() == dims.shape[0], jnp.arange(len(dims)), jnp.arange(len(dims), len(dims) + len(other_dims)))
            
            if combine_dims:
                dims = jnp.concatenate((dims, other_dims))

        # Extend self matrix
        prec_self = jnp.zeros((len(dims), len(dims)))
        info_self = jnp.zeros((len(dims), 1))
        
        prec_self = prec_self.at[jnp.ix_(idxs_self, idxs_self)].set(self.precision)
        info_self = info_self.at[jnp.ix_(idxs_self, jnp.zeros(1, dtype=int))].set(self.info.reshape(len(self.dims), 1))
        # Extend other matrix
        prec_other = jnp.zeros((len(dims), len(dims)))
        info_other = jnp.zeros((len(dims), 1))
        
        prec_other = prec_other.at[jnp.ix_(idxs_other, idxs_other)].set(other.precision)
        info_other = info_other.at[jnp.ix_(idxs_other, jnp.zeros((1), dtype=int))].set(other.info.reshape(len(other.dims), 1))
        # Add
        prec = prec_other + prec_self
        info = info_other + info_self
        return Gaussian(info.squeeze(), prec, dims.astype(float))
    
    def marginalize(self, dims_to_remove) -> "Gaussian":
        info, prec = self.info.reshape(-1, 1), self.precision
        
        axis_a = jnp.where(self.dims != dims_to_remove[0], size=4)[0]
        axis_b = jax.lax.select(axis_a[0] == 0, jnp.arange(4, 8), jnp.arange(4))

        info_a = info[jnp.ix_(axis_a, jnp.zeros((1,), dtype=int))]
        prec_aa = prec[jnp.ix_(axis_a, axis_a)]
        info_b = info[jnp.ix_(axis_b, jnp.zeros((1, ), dtype=int))]
        prec_ab = prec[jnp.ix_(axis_a, axis_b)]
        prec_ba = prec[jnp.ix_(axis_b, axis_a)]
        prec_bb = prec[jnp.ix_(axis_b, axis_b)]

        # diag_values = jnp.diag(prec_bb)
        # prec_bb = jnp.fill_diagonal(prec_bb, jnp.where(diag_values == 0, diag_values[0], diag_values), inplace=False)

        prec_bb_inv = jnp.linalg.inv(prec_bb)
        info_ = (info_a - prec_ab @ prec_bb_inv @ info_b).squeeze()
        prec_ = (prec_aa - prec_ab @ prec_bb_inv @ prec_ba).squeeze()

        new_dims = jnp.full(4, self.dims[axis_a[0]])

        info = jax.lax.select(jnp.logical_not(jnp.logical_and(jnp.all(jnp.isfinite(info_)), jnp.all(jnp.isfinite(prec_)))), jnp.zeros_like(info_), info_)
        prec = jax.lax.select(jnp.logical_not(jnp.logical_and(jnp.all(jnp.isfinite(info_)), jnp.all(jnp.isfinite(prec_)))), jnp.zeros_like(prec_), prec_)

        return Gaussian(info, prec, new_dims)

In [33]:
class ObstacleFactor:
    def __init__(
        self, state: jnp.ndarray, closest_obstacle: jnp.ndarray, crit_distance: float, agent_radius: float, dims: jnp.ndarray
    ) -> None:
        self._state = state
        self._closest_obstacle = closest_obstacle
        self._crit_distance = crit_distance
        self._agent_radius = agent_radius
        self._state_precision = OBSTACLE_NOISE ** (-2) * jnp.eye(1)
        self._dims = dims
        self._gap_multiplier = 1

        dist, dx, dy = self._calc_dist(state, closest_obstacle) 
        def safe_fn():
            return jnp.zeros((1,4))
        def unsafe_fn():
            return jnp.array([[-dx/crit_distance, -dy/crit_distance, 0, 0]])
        self._J = jax.lax.select(dist >= self._crit_distance, safe_fn(), unsafe_fn())
    
    def calculate_likelihood(self) -> Gaussian:
        return Gaussian(
            self._calc_info(self._state, self._state_precision),
            self._calc_precision(self._state_precision),
            self._dims
        )

    def _calc_info(self, state: jnp.ndarray, state_precision: jnp.ndarray):
        info = self._gap_multiplier * (self._J.T @ state_precision @ (self._J @ state[:,jnp.newaxis] - self._calc_measurement(state))).squeeze()
        return info

    def _calc_precision(self, state_precision: jnp.ndarray):
        precision = self._J.T @ state_precision @ self._J
        return precision

    def _calc_measurement(self, state):
        dist = self._calc_dist(state, self._closest_obstacle)[0]
        return jax.lax.select(dist < self._agent_radius, 1 - dist / self._agent_radius, 0.)

    def _calc_dist(self, state, other_state):
        dist = jnp.linalg.norm(state[0:2] - other_state[0:2]) - self._agent_radius
        return dist, state[0] - other_state[0], state[1] - other_state[1]

In [34]:
obs_factor = ObstacleFactor(jnp.array([1.0, 2.0, 0.0, 0.0]), jnp.array([1.0, 3.0, 0.0, 0.0]), 1.5, 0.2, jnp.array([1.0, 1.0, 1.0, 1.0]))
obs_factor.calculate_likelihood()

0.8
0.0 -1.0
[1. 2.] [1. 3.]
0.8


Gaussian(info=Array([   -0.  , 35555.56,     0.  ,     0.  ], dtype=float32), precision=Array([[    0.  ,    -0.  ,    -0.  ,    -0.  ],
       [   -0.  , 17777.78,     0.  ,     0.  ],
       [   -0.  ,     0.  ,     0.  ,     0.  ],
       [   -0.  ,     0.  ,     0.  ,     0.  ]], dtype=float32), dims=Array([1., 1., 1., 1.], dtype=float32))

# Random

In [None]:
x = jnp.array([[1,2,3,4,5,6], [7,8,9,10,11,12]])
y = x * 2
x + y

Array([[ 3,  6,  9, 12, 15, 18],
       [21, 24, 27, 30, 33, 36]], dtype=int32)

In [None]:
x = jnp.arange(32).reshape((8,4))
x[1:-1:2] + x[2:-1:2]

Array([[12, 14, 16, 18],
       [28, 30, 32, 34],
       [44, 46, 48, 50]], dtype=int32)

In [None]:
x[1:-1:2], x[2:-1:2]

(Array([[ 4,  5,  6,  7],
        [12, 13, 14, 15],
        [20, 21, 22, 23]], dtype=int32),
 Array([[ 8,  9, 10, 11],
        [16, 17, 18, 19],
        [24, 25, 26, 27]], dtype=int32))

In [None]:
x = jnp.eye(8)
A = x[0:2,0:2]
B = x[0:2,4:6]
C = x[4:6,0:2]
D = x[4:6,4:6]
x = x.at[0:2,0:2].set(jnp.where(A == 0, 1.0, A))
x = x.at[0:2,4:6].set(jnp.where(B == 0, 1.0, B))
x = x.at[4:6,0:2].set(jnp.where(C == 0, 1.0, C))
x = x.at[4:6,4:6].set(jnp.where(D == 0, 1.0, D))
x, jnp.linalg.inv(x)

(Array([[1., 1., 0., 0., 1., 1., 0., 0.],
        [1., 1., 0., 0., 1., 1., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 1., 1., 0., 0.],
        [1., 1., 0., 0., 1., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.]], dtype=float32),
 Array([[nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.]], dtype=float32))