In [4]:
import jax
import jax.numpy as jnp
from flax.struct import dataclass

In [5]:
@dataclass
class Gaussian:
    info: jnp.array
    precision: jnp.array

    def __add__(self, other_gaussian: "Gaussian"):
        return Gaussian(self.info + other_gaussian.info, self.precision + other_gaussian.precision)
    
    @staticmethod
    def identity():
        return Gaussian(jnp.zeros(4), jnp.eye(4))

def add_one_to_info(g: Gaussian) -> Gaussian:
    other_gaussian = Gaussian(jnp.ones((4,)), jnp.zeros((4,4)))
    return g + other_gaussian

def add_one_to_precision(g: Gaussian) -> Gaussian:
    other_gaussian = Gaussian(jnp.zeros((4,)), jnp.ones((4,4)))
    return g + other_gaussian

batch_add_info = jax.vmap(add_one_to_info)
batch_add_precision = jax.vmap(add_one_to_precision)

In [6]:
g = Gaussian(jnp.ones((2,4)), jnp.stack((jnp.eye(4), jnp.eye(4))))
print(batch_add_info(g))
print(batch_add_precision(g))

Gaussian(info=Array([[2., 2., 2., 2.],
       [2., 2., 2., 2.]], dtype=float32), precision=Array([[[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]],

       [[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]]], dtype=float32))
Gaussian(info=Array([[1., 1., 1., 1.],
       [1., 1., 1., 1.]], dtype=float32), precision=Array([[[2., 1., 1., 1.],
        [1., 2., 1., 1.],
        [1., 1., 2., 1.],
        [1., 1., 1., 2.]],

       [[2., 1., 1., 1.],
        [1., 2., 1., 1.],
        [1., 1., 2., 1.],
        [1., 1., 1., 2.]]], dtype=float32))


In [7]:
def add_info_by_n(n: int) -> Gaussian:
    def iterate_add_info(carry, x):
        # carry is the current state, x is current iteration from arange()
        carry = batch_add_info(carry)
        carry = batch_add_precision(carry)
        return carry, x
    h, stacked = jax.lax.scan(iterate_add_info, g, jnp.arange(1,n))
    return h, stacked

add_info_by_n(10)

(Gaussian(info=Array([[10., 10., 10., 10.],
        [10., 10., 10., 10.]], dtype=float32), precision=Array([[[10.,  9.,  9.,  9.],
         [ 9., 10.,  9.,  9.],
         [ 9.,  9., 10.,  9.],
         [ 9.,  9.,  9., 10.]],
 
        [[10.,  9.,  9.,  9.],
         [ 9., 10.,  9.,  9.],
         [ 9.,  9., 10.,  9.],
         [ 9.,  9.,  9., 10.]]], dtype=float32)),
 Array([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32))

In [8]:
factor_to_var_msgs = {"x1": Gaussian.identity(), "x2": Gaussian.identity(), "x3": Gaussian(jnp.ones(4) * 4, jnp.eye(4) * 5)}
factors = ["f1", "f2", "f3"]

In [9]:
def update_factor_to_var_msgs(factor_to_var_msgs):
    updated_msg = factor_to_var_msgs["x1"] + factor_to_var_msgs["x2"]
    return updated_msg

jax.tree.map(update_factor_to_var_msgs, factors)

TypeError: string indices must be integers

In [10]:
jnp.zeros((1,3)) - jnp.zeros((3,2,3))

Array([[[0., 0., 0.],
        [0., 0., 0.]],

       [[0., 0., 0.],
        [0., 0., 0.]],

       [[0., 0., 0.],
        [0., 0., 0.]]], dtype=float32)

In [6]:
data = {'1': 1, '2': 2, '3': 3}
state = jnp.zeros((4,))

def regular_func(x):
    return x + data['1']

state = regular_func(state)
data['1'] += 1
state = regular_func(state)
data['1'] += 1
state = regular_func(state)
state

Array([6., 6., 6., 6.], dtype=float32)

In [16]:
from copy import deepcopy

data = {'1': 1, '2': 2, '3': 3}
state = jnp.zeros((4,))

@jax.jit
def regular_func(x, data, y):
    return x + data['1'] + y

state = regular_func(state, data, 6)
state
# data['1'] += 1
# state = regular_func(state, data)
# data['1'] += 1
# state = regular_func(state, data)
# state

Array([7., 7., 7., 7.], dtype=float32)