In [1]:
import jax
import jax.numpy as jnp

In [2]:
state_transition = jnp.eye(4)
delta_t = 0.2
state_transition = state_transition.at[:2,2:].set(jnp.eye(2) * delta_t)
state_transition

Array([[1. , 0. , 0.2, 0. ],
       [0. , 1. , 0. , 0.2],
       [0. , 0. , 1. , 0. ],
       [0. , 0. , 0. , 1. ]], dtype=float32)

In [3]:
current_state1 = jnp.array([1, 1, 2, 2]).astype(float)
current_state2 = jnp.array([2, 2, 2, 2]).astype(float)
state = jnp.stack((current_state1, current_state2)).T

state_transition @ state

Array([[1.4, 2.4],
       [1.4, 2.4],
       [2. , 2. ],
       [2. , 2. ]], dtype=float32)

In [4]:
time_horizon = 10
@jax.jit
def update_state(carry: jnp.array, x: int):
    # global state_transition
    # state_transition = jnp.eye(4)
    # delta_t = 0.2
    # state_transition = state_transition.at[:2,2:].set(jnp.eye(2) * delta_t)
    carry = state_transition @ carry
    # state_transition = state_transition.at[:,:].add(1)
    x = x.at[:,:].set(carry.T)

    return carry, x

In [5]:
jax.debug.print("{}", state_transition)

[[1.  0.  0.2 0. ]
 [0.  1.  0.  0.2]
 [0.  0.  1.  0. ]
 [0.  0.  0.  1. ]]


In [6]:
_, all_states = jax.lax.scan(update_state, state, jnp.zeros((time_horizon, 2, 4)),time_horizon)
all_states = jnp.swapaxes(all_states, 0, 1)
all_states

Array([[[1.4      , 1.4      , 2.       , 2.       ],
        [1.8      , 1.8      , 2.       , 2.       ],
        [2.2      , 2.2      , 2.       , 2.       ],
        [2.6000001, 2.6000001, 2.       , 2.       ],
        [3.0000002, 3.0000002, 2.       , 2.       ],
        [3.4000003, 3.4000003, 2.       , 2.       ],
        [3.8000004, 3.8000004, 2.       , 2.       ],
        [4.2000003, 4.2000003, 2.       , 2.       ],
        [4.6000004, 4.6000004, 2.       , 2.       ],
        [5.0000005, 5.0000005, 2.       , 2.       ]],

       [[2.4      , 2.4      , 2.       , 2.       ],
        [2.8000002, 2.8000002, 2.       , 2.       ],
        [3.2000003, 3.2000003, 2.       , 2.       ],
        [3.6000004, 3.6000004, 2.       , 2.       ],
        [4.0000005, 4.0000005, 2.       , 2.       ],
        [4.4000006, 4.4000006, 2.       , 2.       ],
        [4.8000007, 4.8000007, 2.       , 2.       ],
        [5.200001 , 5.200001 , 2.       , 2.       ],
        [5.600001 , 5.6000

In [7]:
time_horizon = 1000
_, all_states = jax.lax.scan(update_state, state, jnp.zeros((time_horizon, 2, 4)),time_horizon)
all_states = jnp.swapaxes(all_states, 0, 1)