In [1]:
import tensorflow_probability.substrates.jax as tfp  # jax TFP
import jax.random as jr
tfd = tfp.distributions

In [3]:
delta = tfd.Deterministic(loc=5)

rng = jr.PRNGKey(0)
for i in range(3):
    rng, seed = jr.split(rng, 2)
    print(repr(delta.sample(seed=seed)))

DeviceArray(5., dtype=float32)
DeviceArray(5., dtype=float32)
DeviceArray(5., dtype=float32)


In [4]:
from ssm.base import SSM

In [7]:
class DeterministicRNN(SSM):
    def __init__(self, rnn_type="gru"):
        pass

    def dynamics_distribution(self, state):
        tfd.Deterministic(rnn_output)

In [8]:
import flax.linen as nn
import jax.numpy as np

In [45]:
# NOTE: no batching yet

num_timesteps, obs_dim = 100, 10
hidden_dim = 3

seed = jr.PRNGKey(0)
rngs = jr.split(seed, 3)

# generate some data
data = jr.normal(rngs[0], shape=(num_timesteps, obs_dim))

# initialize a carry with no batch dim
carry = nn.LSTMCell().initialize_carry(rngs[1], batch_dims=tuple(), size=hidden_dim)

In [46]:
# initialize our LSTM params using the carry and data
initial_params = nn.LSTMCell().init(rngs[2], carry=carry, inputs=data)

In [47]:
# in functional formula, pass params of LSTM into this "apply" function
last_carry, hs = nn.LSTMCell().apply(initial_params, carry, data)

In [48]:
data.shape

(100, 10)

In [56]:
print(data[0].shape)
carry, h = nn.LSTMCell().apply(initial_params, carry, data[0])
print(h.shape)

(10,)
(3,)


In [17]:
print(last_carry[0].shape, last_carry[1].shape)
print(hs.shape)

(100, 3) (100, 3)
(100, 3)


In [18]:
bool(np.all(carry[1] == h))

True

In [35]:
carry[1].shape

(3,)

In [36]:
carry, hs = nn.LSTMCell().apply(initial_params, carry, data[0])

In [41]:
import jax
scan_f = lambda carry, xs: nn.LSTMCell().apply(initial_params, carry, xs)
last_carry, hs = jax.lax.scan(scan_f, carry, data)

# NOTE: no batching yet

#### Reminder: keep in mind that matrices might be transposed from what you would expect (i.e. in TF)

In [61]:
def run_lstm(params, initial_state, data_over_time):
    scan_f = lambda carry, xs: nn.LSTMCell().apply(params, carry, xs)
    last_carry, hs = jax.lax.scan(scan_f, initial_state, data_over_time)
    return hs


num_timesteps, obs_dim = 100, 10
hidden_dim = 3

seed = jr.PRNGKey(0)
rngs = jr.split(seed, 3)

# generate some data
data = jr.normal(rngs[0], shape=(num_timesteps, obs_dim))

# initialize a carry with no batch dim
initial_state = nn.LSTMCell().initialize_carry(rngs[1], batch_dims=tuple(), size=hidden_dim)

# initialize our LSTM params using the carry and data
initial_params = nn.LSTMCell().init(rngs[2], carry=initial_state, inputs=data)

hs = run_lstm(initial_params, initial_state, data)
print(hs.shape)

(100, 3)


# Now we batch

In [67]:
from jax import vmap

batch_dim, num_timesteps, obs_dim = 5, 100, 10
hidden_dim = 3

seed = jr.PRNGKey(0)
rngs = jr.split(seed, 3)

# generate some data
batched_data = jr.normal(rngs[0], shape=(batch_dim, num_timesteps, obs_dim))

# initialize a carry with no batch dim
initial_state = nn.LSTMCell().initialize_carry(rngs[1], batch_dims=tuple(), size=hidden_dim)

# initialize our LSTM params using the carry and data
initial_params = nn.LSTMCell().init(rngs[2], carry=initial_state, inputs=batched_data)

hs = vmap(run_lstm, in_axes=(None, None, 0))(initial_params, initial_state, batched_data)
print(hs.shape)

(5, 100, 3)
