In [1]:
from sys import path
path.append("../")

In [2]:
import jax
import jax.numpy as jnp
from flax.struct import dataclass 

from abc import abstractmethod

In [3]:
from fg import Gaussian

# Inter Robot

In [4]:
x = jnp.array([
    [1,1],
    [2,2],
    [3,3],
    [4,4]
])
states = jnp.stack((x, x + 0.1, x + 0.3))
states, states.shape

(Array([[[1. , 1. ],
         [2. , 2. ],
         [3. , 3. ],
         [4. , 4. ]],
 
        [[1.1, 1.1],
         [2.1, 2.1],
         [3.1, 3.1],
         [4.1, 4.1]],
 
        [[1.3, 1.3],
         [2.3, 2.3],
         [3.3, 3.3],
         [4.3, 4.3]]], dtype=float32, weak_type=True),
 (3, 4, 2))

In [5]:
@dataclass
class InterRobotMapping:
    points: jnp.ndarray
    other_robot: int

In [6]:
def find_closest_robot(pts: jnp.ndarray):
    def find_closest_pt_across_horizon(pt, other_pts):
        jax.debug.print("other pts: {}", other_pts)
        closest_index = jnp.argmin(jnp.linalg.norm(pt - other_pts, axis=1))
        return InterRobotMapping(other_pts[closest_index], closest_index)

    def find_batched_closest_pt(pts, i):
        other_pts = jnp.delete(pts, jnp.array([i]), assume_unique_indices=True, axis=0) # (N - 1, 4, 2)
        return jax.vmap(find_closest_pt_across_horizon, in_axes=(0, 1))(pts[i], other_pts)
    return jax.vmap(find_batched_closest_pt, in_axes=(None, 0))(pts, jnp.arange(pts.shape[0]))
closest_robots = find_closest_robot(states[:,:,0:2]) # ideally, should be # (N - 1, 4, 2)
closest_robots

other pts: [[1.1 1.1]
 [1.3 1.3]]
other pts: [[1.  1. ]
 [1.3 1.3]]
other pts: [[1.  1. ]
 [1.1 1.1]]
other pts: [[2.1 2.1]
 [2.3 2.3]]
other pts: [[2.  2. ]
 [2.3 2.3]]
other pts: [[2.  2. ]
 [2.1 2.1]]
other pts: [[3.1 3.1]
 [3.3 3.3]]
other pts: [[3.  3. ]
 [3.3 3.3]]
other pts: [[3.  3. ]
 [3.1 3.1]]
other pts: [[4.1 4.1]
 [4.3 4.3]]
other pts: [[4.  4. ]
 [4.3 4.3]]
other pts: [[4.  4. ]
 [4.1 4.1]]


InterRobotMapping(points=Array([[[1.1, 1.1],
        [2.1, 2.1],
        [3.1, 3.1],
        [4.1, 4.1]],

       [[1. , 1. ],
        [2. , 2. ],
        [3. , 3. ],
        [4. , 4. ]],

       [[1.1, 1.1],
        [2.1, 2.1],
        [3.1, 3.1],
        [4.1, 4.1]]], dtype=float32, weak_type=True), other_robot=Array([[0, 0, 0, 0],
       [0, 0, 0, 0],
       [1, 1, 1, 1]], dtype=int32))

In [24]:
N_STATES = 4
POSE_NOISE = 1e-15
DYNAMICS_NOISE = 0.005
OBSTACLE_NOISE = 0.005
INTER_ROBOT_NOISE = 0.005

class Factor:
    def __init__(
        self, state: jnp.array, state_precision: jnp.ndarray, dims: jnp.array, linear: bool = True
    ) -> None:
        self._state = state
        self._state_precision = state_precision
        self._linear = linear
        self._dims = dims

    def calculate_likelihood(self) -> Gaussian:
        return Gaussian(
            self._calc_info(self._state, self._state_precision),
            self._calc_precision(self._state, self._state_precision),
            self._dims
        )

    @abstractmethod
    def _calc_measurement(self, state: jnp.ndarray) -> jnp.ndarray:
        pass

    def _calc_info(self, state: jnp.ndarray, precision: jnp.ndarray) -> jnp.ndarray:
        X = state
        if self._linear:
            eta = precision @ (jnp.zeros(N_STATES) - self._calc_measurement(state))
        else:
            J = jax.jacfwd(self._calc_measurement)(state)
            eta = (J.T @ precision) @ (
                (J @ X.reshape((-1,1))) + 0 - self._calc_measurement(state).reshape((-1,1))
            )
        return eta.squeeze()

    def _calc_precision(self, state: jnp.ndarray, precision: jnp.ndarray) -> jnp.ndarray:
        if self._linear:
            return precision
        else:
            J = jax.jacfwd(self._calc_measurement)(state)
            return J.T @ precision @ J

class InterRobotFactor(Factor):
    def __init__(
        self,
        state: jnp.ndarray,
        agent_radius: float,
        critical_distance: float,
        t: jnp.ndarray, #ndarray just to hold time,
        dims: jnp.ndarray,
    ) -> None:
        self._critical_distance = critical_distance
        self._agent_radius = agent_radius
        precision = jnp.pow(t * INTER_ROBOT_NOISE, -2) * jnp.eye(N_STATES)
        super(InterRobotFactor, self).__init__(state, precision, dims, False)

    def _calc_measurement(self, state: jnp.ndarray):
        current_state = state[0:4]
        other_state = state[4:]
        dist = self._calc_dist(current_state, other_state)
        measurement = jax.lax.select(
            dist < self._critical_distance, jnp.full((4,), 1.0 - dist / self._critical_distance), jnp.zeros((4,)) 
        )
        return measurement
    
    def _calc_dist(self, state: jnp.array, other_state: jnp.array):
        return jnp.linalg.norm(state[0:2] - other_state[0:2])

ir_factor =  InterRobotFactor(jnp.array([1.0, 2.0, 0.0, 0.0, 1.0, 3.0, 0.0, 0.0]), 0.2, 1.5, jnp.array([1.]), jnp.array([1., 1., 1., 1., 100., 100., 100., 100.]))
ir_factor.calculate_likelihood()

Gaussian(info=Array([      0.   , -106666.664,       0.   ,       0.   ,       0.   ,
        106666.664,       0.   ,       0.   ], dtype=float32), precision=Array([[     0.  ,      0.  ,      0.  ,      0.  ,      0.  ,      0.  ,
             0.  ,      0.  ],
       [     0.  ,  71111.12,      0.  ,      0.  ,      0.  , -71111.12,
             0.  ,      0.  ],
       [     0.  ,      0.  ,      0.  ,      0.  ,      0.  ,      0.  ,
             0.  ,      0.  ],
       [     0.  ,      0.  ,      0.  ,      0.  ,      0.  ,      0.  ,
             0.  ,      0.  ],
       [     0.  ,      0.  ,      0.  ,      0.  ,      0.  ,      0.  ,
             0.  ,      0.  ],
       [     0.  , -71111.12,      0.  ,      0.  ,      0.  ,  71111.12,
             0.  ,      0.  ],
       [     0.  ,      0.  ,      0.  ,      0.  ,      0.  ,      0.  ,
             0.  ,      0.  ],
       [     0.  ,      0.  ,      0.  ,      0.  ,      0.  ,      0.  ,
             0.  ,      0.  ]], 

# Obstacle

In [5]:
obstacle = jnp.array([
    [4.2, 4.2],
    [1.0, 1.0]
])

def find_closest_obstacle(states, obstacles):
    def find_closest_obstacle_for_pt(state_t, obstacles):
        closest_obstacle_idx = jnp.argmin(jnp.linalg.norm(state_t - obstacles, axis=1))
        return obstacles[closest_obstacle_idx]
    def batch_find_closest_obstacle(agent_states, obstacles):
        return jax.vmap(find_closest_obstacle_for_pt, in_axes=(0, None))(agent_states, obstacles)
    return jax.vmap(batch_find_closest_obstacle, in_axes=(0, None))(states, obstacles)
closest_obstacle = find_closest_obstacle(states, obstacle)
closest_obstacle, closest_obstacle.shape

(Array([[[1. , 1. ],
         [1. , 1. ],
         [4.2, 4.2],
         [4.2, 4.2]],
 
        [[1. , 1. ],
         [1. , 1. ],
         [4.2, 4.2],
         [4.2, 4.2]],
 
        [[1. , 1. ],
         [1. , 1. ],
         [4.2, 4.2],
         [4.2, 4.2]]], dtype=float32),
 (3, 4, 2))