In [1]:
import jax
import jax.numpy as jnp

In [2]:
state_transition = jnp.eye(4)
delta_t = 0.2
state_transition = state_transition.at[:2,2:].set(jnp.eye(2) * delta_t)
state_transition

Array([[1. , 0. , 0.2, 0. ],
       [0. , 1. , 0. , 0.2],
       [0. , 0. , 1. , 0. ],
       [0. , 0. , 0. , 1. ]], dtype=float32)

In [8]:
current_state1 = jnp.array([5, 0, 0.5, 0]).astype(float)
current_state2 = jnp.array([-5, 0, -0.5, 0]).astype(float)
state = jnp.stack((current_state1, current_state2)).T

state_transition @ state

Array([[ 5.1, -5.1],
       [ 0. ,  0. ],
       [ 0.5, -0.5],
       [ 0. ,  0. ]], dtype=float32)

In [9]:
time_horizon = 6
@jax.jit
def update_init_state(carry: jnp.array, _: int=None):
    carry = state_transition @ carry
    return carry, carry.T

In [10]:
_, all_states = jax.lax.scan(update_init_state, state, length=time_horizon)
all_states = jnp.swapaxes(all_states, 0, 1)
all_states, all_states.shape

(Array([[[ 5.1      ,  0.       ,  0.5      ,  0.       ],
         [ 5.2      ,  0.       ,  0.5      ,  0.       ],
         [ 5.2999997,  0.       ,  0.5      ,  0.       ],
         [ 5.3999996,  0.       ,  0.5      ,  0.       ],
         [ 5.4999995,  0.       ,  0.5      ,  0.       ],
         [ 5.5999994,  0.       ,  0.5      ,  0.       ]],
 
        [[-5.1      ,  0.       , -0.5      ,  0.       ],
         [-5.2      ,  0.       , -0.5      ,  0.       ],
         [-5.2999997,  0.       , -0.5      ,  0.       ],
         [-5.3999996,  0.       , -0.5      ,  0.       ],
         [-5.4999995,  0.       , -0.5      ,  0.       ],
         [-5.5999994,  0.       , -0.5      ,  0.       ]]], dtype=float32),
 (2, 6, 4))

In [18]:
jax.vmap(lambda x: (state_transition @ x.T).T)(all_states)

Array([[[1.8      , 1.8      , 2.       , 2.       ],
        [2.2      , 2.2      , 2.       , 2.       ]],

       [[2.8000002, 2.8000002, 2.       , 2.       ],
        [3.2000003, 3.2000003, 2.       , 2.       ]]], dtype=float32)

In [10]:
time_horizon = 1000
_, all_states = jax.lax.scan(update_init_state, state, jnp.zeros((time_horizon, 2, 4)),time_horizon)
all_states = jnp.swapaxes(all_states, 0, 1)

(2, 1000, 4)