In [4]:
from typing import Any, Callable, Sequence, Optional

In [116]:
import jax
import jax.numpy as np
import optax
from flax import linen as nn

In [117]:
import jax
import jax.numpy as np
import optax
import haiku as hk

In [208]:
class LSTM(hk.Module):
    def __init__(self, name='lstm'):
        super().__init__(name=name)
        self._w = hk.Linear(4, True, name="w")
        self._u = hk.Linear(4, False, name="u")


    def __call__(self, x):        
        h, c = np.zeros((x.shape[0], 1)), np.zeros((x.shape[0], 1))
        for i in range(x.shape[-1]):
             o, h, c = self._call(x[:, i, None], h, c)
        return o
        
    def _call(self, x_t, h_t, c_t):                        
        iw, gw, fw, ow = np.split(self._w(x_t), indices_or_sections=4, axis=-1)        
        iu, gu, fu, ou = np.split(self._u(h_t), indices_or_sections=4, axis=-1)        
        i = jax.nn.sigmoid(iw + iu)
        f = jax.nn.sigmoid(fw + fu + 1.0)
        g = np.tanh(gw + gu)
        o = jax.nn.sigmoid(ow + ou)
        c = f * c_t + i * g
        h =  o * jnp.tanh(c)
        return o, h, c
    

def _lstm(x):
    module = LSTM()
    return module(x)

In [209]:
model = hk.transform(_lstm)
model = hk.without_apply_rng(model)

x_ = np.array([[1., 2., 3.]])
key = jax.random.PRNGKey(42)

params = model.init(key, x_)

In [210]:
x = jax.random.uniform(key, (10, 3))
model.apply(x=x, params=params)

DeviceArray([[0.51601326],
             [0.6739608 ],
             [0.5244281 ],
             [0.7774125 ],
             [0.75583863],
             [0.51171345],
             [0.5947514 ],
             [0.59904695],
             [0.7415096 ],
             [0.74563086]], dtype=float32)