LSTM that returns sequences and stateful. If that's not necessary, just use [nn.scan](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/_autosummary/flax.linen.scan.html). Also, see below for learable initialization.

In [1]:
import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
from flax.training import train_state
import logging
import numpy as np
import optax
from typing import Any

In [2]:
logging.basicConfig(level=logging.INFO)

In [3]:
@flax.struct.dataclass
class Carry:
  carry1: Any
  carry2: Any

class TrainState(train_state.TrainState):
  """Carry around the carry for stateful LSTM."""
  carry: Carry

In [4]:
class SequenceLSTMCell(nn.Module):
    """LSTMCell that can return sequences."""

    features: int
    return_sequences: bool

    @nn.compact
    def __call__(self, carry, inputs):
        sequence_length = inputs.shape[1]

        outputs = []
        lstm_cell = nn.LSTMCell(features=self.features)
        for t in range(sequence_length):
            carry, output = lstm_cell(carry, inputs[:, t, :])
            if self.return_sequences:
                outputs.append(output)

        if self.return_sequences:
            # stack along seq dim
            return carry, jnp.stack(outputs, axis=1)
        else:
            return carry, output

In [5]:

class Model(nn.Module):
    features: list[int]
    
    def setup(self):
        self.lstm1 = SequenceLSTMCell(features=self.features[0], 
                                      return_sequences=True, name="sequence_lstm")
        self.lstm2 = SequenceLSTMCell(features=self.features[1], 
                                      return_sequences=False, name="lstm")
        self.dense = nn.Dense(features=1, name="dense")

    def __call__(self, carry, x):
        carry1, carry2 = carry.carry1, carry.carry2
        carry1, x = self.lstm1(carry=carry1, inputs=x)
        carry2, x = self.lstm2(carry=carry2, inputs=x)
        x = self.dense(x)
        return Carry(carry1, carry2), x
    
    def initialize_carry(self):
        carry1 = (jnp.zeros((self.features[0],)), jnp.zeros((self.features[0],)))
        carry2 = (jnp.zeros((self.features[0],)), jnp.zeros((self.features[0],)))
        return Carry(carry1, carry2)
    
# If you don't need stateful LSTM:
#
# class Model(nn.Module):
#     features: list[int]
#   
#     def setup(self):
#         self.lstm1 = SequenceLSTMCell(features=self.features[0], 
#                                       return_sequences=True, name="sequence_lstm")
#         self.lstm2 = SequenceLSTMCell(features=self.features[1], 
#                                       return_sequences=False, name="lstm")
#         self.dense = nn.Dense(features=1, name="dense")
#
#         # Initialize the carry as a learnable parameter.
#         # These are now part of params that's learned just like W, b in Dense
#         # are learned.
#         self.carry1 = self.param('carry1_lstm1', nn.initializers.zeros, 
#                                  (self.features[0],))
#         self.carry2 = self.param('carry2_lstm2', nn.initializers.zeros, 
#                                  (self.features[1],))
#
#     def __call__(self, x):
#         carry1 = self.carry1
#         carry2 = self.carry2
#         carry1, x = self.lstm1(carry=carry1, inputs=x)
#         carry2, x = self.lstm2(carry=carry2, inputs=x)
#         x = self.dense(x)
#         return x


In [6]:

@jax.jit
def apply_model(state, X, y):
    """Computes gradients, loss and accuracy for a single batch."""
    
    def mean_squared_error(y, yhat):
        return jnp.mean((y - yhat)**2)

    def compute_loss_fn(params, carry):
        carry, yhat = state.apply_fn({"params": params}, carry, X)
        loss = mean_squared_error(y, yhat)
        return loss, carry

    carry, yhat = state.apply_fn({"params": state.params}, state.carry, X)
    grad_fn = jax.value_and_grad(compute_loss_fn, has_aux=True)
    (loss, carry), grads = grad_fn(state.params, state.carry)
    return carry, loss, grads

In [7]:
@jax.jit
def update_model(state, grads):
    return state.apply_gradients(grads=grads)

In [8]:
def train_epoch(state, dataset_fn):
    epoch_loss = []
    for i, (X, y) in enumerate(dataset_fn()):
        carry, loss, grads = apply_model(state, X, y)
        state = update_model(state, grads)
        state = state.replace(carry=carry)
        epoch_loss.append(loss)
    train_loss = np.mean(epoch_loss)
    return state, train_loss

In [9]:
batch_size = 10
sequence_length = 3

def datagen():
    x = jnp.arange(300)
    y = jnp.sin(x * 0.01)
    y = y + jnp.roll(y, shift=1) - 2 * jnp.roll(y, shift=1)
    y = y + jax.random.normal(jax.random.PRNGKey(0), y.shape) * 0.1
    
    for i in range(0, len(x) - sequence_length - batch_size + 1):
        X_batch = [x[i + j:i + j + sequence_length][:, np.newaxis] for j in range(batch_size)]
        y_batch = [y[i + j + sequence_length] for j in range(batch_size)]
        yield jnp.stack(X_batch), jnp.stack(y_batch)[:, np.newaxis]


In [10]:

model = Model(features=[5, 3])
init_rng = jax.random.PRNGKey(42)
carry = model.initialize_carry()

params = model.init(init_rng,
                    carry=carry, 
                    x=jnp.ones((batch_size, sequence_length, 1)))['params']

state = TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=optax.adam(learning_rate=1e-3),
    carry=carry)

for i in range(10):
    state, train_loss = train_epoch(state, datagen)
    # Across batches, carry is not kept.
    state = state.replace(carry=model.initialize_carry())
    logging.info(f"Epoch: {i:4d}, train_loss: {train_loss:.4f}")
    

INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
INFO:jax._src.xla_bridge:Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.


Metal device set to: Apple M1 Max

systemMemory: 64.00 GB
maxCacheSize: 24.00 GB



"builtin.module"() ({
  "func.func"() ({
  ^bb0(%arg0: tensor<5x5xf32>):
    %0 = "mhlo.constant"() {value = dense<-1> : tensor<5x5xi32>} : () -> tensor<5x5xsi32>
    %1 = "mhlo.constant"() {value = dense<0.000000e+00> : tensor<5x5xf32>} : () -> tensor<5x5xf32>
    %2:2 = "mhlo.custom_call"(%arg0) {api_version = 1 : i32, backend_config = "", call_target_name = "Qr", called_computations = [], has_side_effect = false} : (tensor<5x5xf32>) -> (tensor<5x5xf32>, tensor<5xf32>)
    %3 = "mhlo.custom_call"(%2#0, %2#1) {api_version = 1 : i32, backend_config = "", call_target_name = "ProductOfElementaryHouseholderReflectors", called_computations = [], has_side_effect = false} : (tensor<5x5xf32>, tensor<5xf32>) -> tensor<5x5xf32>
    %4 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xsi32>
    %5 = "mhlo.broadcast_in_dim"(%4) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<5xsi32>) -> tensor<5x5xsi32>
    %6 = "mhlo.add"(%5, %0) : (tensor<5x5xsi32>, tensor<5x5xsi32>) -> te

XlaRuntimeError: UNKNOWN: /var/folders/09/9zmlsg756kxfcbx6l3tdhf600000gn/T/ipykernel_99180/857170341.py:14:28: error: failed to legalize operation 'mhlo.custom_call'
/var/folders/09/9zmlsg756kxfcbx6l3tdhf600000gn/T/ipykernel_99180/857170341.py:14:28: note: see current operation: %4:2 = "mhlo.custom_call"(%arg0) {api_version = 1 : i32, backend_config = "", call_target_name = "Qr", called_computations = [], has_side_effect = false} : (tensor<5x5xf32>) -> (tensor<5x5xf32>, tensor<5xf32>)
