# Legendre Memory Units (LMUs) for MNIST classification
Implementation based on:
* the original NIPS 2019 paper on LMU (https://proceedings.neurips.cc/paper/2019/hash/952285b9b7e7a1be5aa7849f32ffff05-Abstract.html);
* and the GitHub repository of the LMU PyTorch implementation (https://github.com/hrshtv/pytorch-lmu)

## Imports

In [1]:
from functools import partial
from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Union

import numpy as np
import jax
from jax import lax, random, numpy as jnp
from jax import grad, jit, vmap, value_and_grad

from flax import linen as nn

from scipy.signal import cont2discrete

import tensorflow_datasets as tfds 

import optax
from flax.training import train_state

2022-12-26 18:14:56.906933: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-26 18:14:56.907014: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
  from .autonotebook import tqdm as notebook_tqdm


## Modules

In [2]:
#---------------------- Cell Base ----------------------#
# The other classes inherit from this class

class RecurrentCellBase(nn.Module):
  """Recurrent cell base class."""

  @staticmethod
  def initialize_carry(batch_dims, size, init_fn=nn.initializers.zeros):
    """Initialize the RNN cell carry.

    Args:
      batch_dims: a tuple providing the shape of the batch dimensions.
      size: the size or number of features of the memory.
      init_fn: initializer function for the carry.
    Returns:
      An initialized carry for the given cell.
    """
    raise NotImplementedError

In [3]:
#---------------------- One LMU Cell ----------------------#
# References: 
# https://flax.readthedocs.io/en/latest/_modules/flax/linen/recurrent.html
# https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.scan.html

class LMUCell(RecurrentCellBase):
    """LMU Cell
    
    Parameters:
        input_size (int) : 
            Size/dimensions of the input vector (x_t)
        hidden_size (int) : 
            Size of the hidden vector (h_t)
        memory_size (int) :
            Size of the memory vector (m_t)
        theta (int) :
            The number of timesteps in the sliding window that is represented using the LTI system  
    
    Other Attributes:
        State Space Matrices:
            A: [memory_size, memory_size]
            B: [memory_size, 1]
        Initializers:
            init_lecun_uni: lecun_normal, for two encoding vectors, e_x and e_h
            init_zero: constant(0), for the encoding vector e_m
            init_xav_norm: xavier_normal, for three kernel matrices, W_x, W_h, W_m
    """

    input_size: int
    hidden_size: int
    memory_size: int
    theta: int

    init_lecun_uni: Callable = nn.initializers.lecun_normal()
    init_zero: Callable = nn.initializers.constant(0.)
    init_xav_norm: Callable = nn.initializers.xavier_normal()

    activation_fn: Callable = nn.tanh
    # activation_fn: Callable = nn.relu

    def setup(self):
        self.A, self.B = self.stateSpaceMatrices()

    @nn.compact
    def __call__(self, carry, x):
        """
        Args:
            carry (tuple): the previous state:
                h: the hidden vector, [batch_size, hidden_size]
                m: the memory vector, [batch_size, memory_size]
                
            x (array): the input vector,[batch_size, input_size]
                
        """

        # Unpack the hidden state and memory
        h, m = carry

        # Eq (7) of the paper, compute the input signal u
        # u: [batch_size, 1]
        u_x = nn.Dense(features=1,
                        use_bias=False,
                        kernel_init=self.init_lecun_uni,
                        bias_init=self.init_zero,
                        param_dtype=jnp.float32)
        u_h = nn.Dense(features=1,
                        use_bias=False,
                        kernel_init=self.init_lecun_uni,
                        bias_init=self.init_zero,
                        param_dtype=jnp.float32)
        u_m = nn.Dense(features=1,
                        use_bias=False,
                        kernel_init=self.init_zero,
                        bias_init=self.init_zero,
                        param_dtype=jnp.float32)
        u = u_x(x) + u_h(h) + u_m(m)


        # Eq (4) of the paper, compute the memory
        # m: [batch_size, memory_size]
        m_m = jnp.matmul(m, self.A) 
        m_u = jnp.matmul(u, self.B.T)
        new_m= m_m + m_u

        # Eq (6) of the paper, compute the hidden state
        # h: [batch_size, hidden_size]
        h_x = nn.Dense(features=self.hidden_size,
                        use_bias=False,
                        kernel_init=self.init_xav_norm,
                        bias_init=self.init_zero,
                        param_dtype=jnp.float32)
        h_h = nn.Dense(features=self.hidden_size,
                        use_bias=False,
                        kernel_init=self.init_xav_norm,
                        bias_init=self.init_zero,
                        param_dtype=jnp.float32)
        h_m = nn.Dense(features=self.hidden_size,
                        use_bias=False,
                        kernel_init=self.init_xav_norm,
                        bias_init=self.init_zero,
                        param_dtype=jnp.float32)
        new_h = self.activation_fn(h_x(x) + h_h(h) + h_m(m))
        return (new_h, new_m), new_h

    @staticmethod
    def initialized_carry(batch_size, hidden_size, memory_size):
        """Initialize the LMU cell carry.
        Args:
            batch_size (int) : 
                Size of the batch
            hidden_size (int) : 
                Size of the hidden vector (h_t)
            memory_size (int) :
                Size of the memory vector (m_t)
        Returns:
            An initialized carry for the given cell.
        """
        h = jnp.zeros((batch_size, hidden_size))
        m = jnp.zeros((batch_size, memory_size))
        return h, m

    

    def stateSpaceMatrices(self):
        """ Returns the discretized state space matrices A and B """
        Q = np.arange(self.memory_size, dtype = np.float32)
        R = (2*Q + 1) / self.theta
        i, j = np.meshgrid(Q, Q, indexing = "ij")

        # Continuous
        A = R * np.where(i < j, -1, (-1.0)**(i - j + 1))
        B = R * ((-1.0)**Q)
        B = B.reshape(-1, 1)
        C = np.ones((1, self.memory_size))
        D = np.zeros((1,))

        # Convert to discrete
        A, B, C, D, dt = cont2discrete(
                                system = (A, B, C, D), 
                                dt = 1.0, 
                                method = "zoh"
                            )
            
        return A, B

In [4]:
#---------------------- LMU Layer ----------------------#
# Reference: https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.scan.html

class LMU(nn.Module):
    """
    The LMU Layer

    Parameters:
        input_size (int) : 
            Size of the input vector (x_t)
        hidden_size (int) : 
            Size of the hidden vector (h_t)
        memory_size (int) :
            Size of the memory vector (m_t)
        theta (int) :
            The number of timesteps in the sliding window that is represented using the LTI system
    
    """

    input_size: int
    hidden_size: int
    memory_size: int
    theta: int

    @nn.compact
    def __call__(self, carry, x):
        """
        Parameters:
            carry (tuple): the previous state, set to None initially
                h: the hidden vector, [batch_size, hidden_size]
                m: the memory vector, [batch_size, memory_size]
            x (array): the input vector, [batch_size, seq_len, input_size]
        """

        LMULayer = nn.scan(LMUCell,
                            variable_broadcast="params",
                            split_rngs={"params": False},
                            in_axes=1,
                            out_axes=1)

        return LMULayer(self.input_size, self.hidden_size, self.memory_size, self.theta)(carry, x)
    
    @staticmethod
    def initialize_carry(batch_size, hidden_size, memory_size):
        return LMUCell.initialized_carry(batch_size, hidden_size, memory_size)

## The Classification Model

In [5]:
#----------------- LMU Model for pMNIST Classification ----------------#
class Model(nn.Module):
    input_size: int
    output_size: int
    hidden_size: int
    memory_size: int
    theta: int

    @nn.compact
    def __call__(self, x):

        # Initialize the carry at the first time step
        # Get initial carry/state (h_0, m_0)
        init_carry = LMU.initialize_carry(x.shape[0], self.hidden_size, self.memory_size)

        # Run the LMU layer
        # h_n: [batch_size, hidden_size]
        (h_n , _),_= LMU(self.input_size, self.hidden_size, self.memory_size, self.theta)(init_carry, x)

        # Run the output layer
        output=nn.Dense(features=self.output_size)(h_n)

        return output

## Hyperparameters

In [6]:
N_x = 1 # dimension of the input, a single pixel
N_t = 784 # number of time steps (sequence length) - here it's 28 * 28 since we are using MNIST and making it 1D
N_h = 256 # dimension of the hidden state
N_m = 256 # dimension of the memory
N_c = 10 # number of classes 
THETA = N_t
N_b = 512 # batch size
N_epochs = 10 # number of epochs


lr=1e-4 # learning rate for adam optimizer

## Load Data

In [7]:
def get_datasets():
    """Load MNIST train and test datasets into memory."""

    ds_builder = tfds.builder('mnist')
    ds_builder.download_and_prepare()
    train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
    val_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))


    train_ds['image'] = jnp.float32(train_ds['image'])
    val_ds['image'] = jnp.float32(val_ds['image'])

    return train_ds, val_ds



train_ds, val_ds = get_datasets()

2022-12-26 18:14:57.968049: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudnn.so.8'; dlerror: libcudnn.so.8: cannot open shared object file: No such file or directory
2022-12-26 18:14:57.968070: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1934] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


## Loss & Metrics

In [8]:
"""Following the Flax example: https://flax.readthedocs.io/en/latest/getting_started.html"""

def cross_entropy_loss(*, logits, labels):
    labels_onehot = jax.nn.one_hot(labels, num_classes=10)
    return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean() 

def compute_metrics(*, logits, labels):
    loss = cross_entropy_loss(logits=logits, labels=labels)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    metrics = {
        'loss': loss,
        'accuracy': accuracy,
    }
    return metrics

## Training

### Utility Functions

In [9]:
def create_train_state(rng, learning_rate=lr):
    model = Model(
            input_size  = N_x,
            output_size = N_c,
            hidden_size = N_h, 
            memory_size = N_m, 
            theta = THETA
        )

    
    params = model.init(rng, jnp.ones((1, N_t, N_x)))['params']
    print("Model initialized.")

    optimizer = optax.adam(learning_rate)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)

In [10]:
def train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch['image'])
        loss = cross_entropy_loss(logits=logits, labels=batch['label'])
        return loss
    grad_fn = value_and_grad(loss_fn, has_aux=False)
    loss, grads = grad_fn(state.params)
    new_state = state.apply_gradients(grads=grads)

    logits = new_state.apply_fn({'params': new_state.params}, batch['image'])
    metrics = compute_metrics(logits=logits, labels=batch['label'])

    return new_state, metrics 

In [11]:
def eval_step(state, batch, seq_len=N_t, input_size=N_x):
    batch['image']=batch['image'].reshape((-1, seq_len, input_size))
    logits = state.apply_fn({'params': state.params}, batch['image'])
    return compute_metrics(logits=logits, labels=batch['label'])

In [12]:
def train_epoch(state, train_ds, epoch, rng, batch_size=N_b, seq_len=N_x, input_size=N_x):

    train_ds_size = len(train_ds['image'])
    steps_per_epoch = train_ds_size // batch_size

    perms = jax.random.permutation(rng, train_ds_size) # get a randomized index array
    perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
    perms = perms.reshape((steps_per_epoch, batch_size)) # index array, where each row is a batch

    batch_metrics = []
    for perm in perms:
        batch = {k: v[perm, ...] for k, v in train_ds.items()} # dict{'image': array, 'label': array}
        batch['image']=batch['image'].reshape((batch_size, -1, input_size)) # reshape to the required input dimensions
        state, metrics = train_step(state, batch)
        batch_metrics.append(metrics)
    
    # compute mean of metrics across each batch in epoch.
    batch_metrics_np = jax.device_get(batch_metrics)
    epoch_metrics_np = {
        k: np.mean([metrics[k] for metrics in batch_metrics_np])
        for k in batch_metrics_np[0]} # jnp.mean does not work on lists

    print('train epoch: %d, loss: %.4f, accuracy: %.2f' % (epoch, epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100))

    return state

In [13]:
def eval_model(state, test_ds):
    metrics = eval_step(state, test_ds)
    metrics = jax.device_get(metrics)
    summary = jax.tree_util.tree_map(lambda x: x.item(), metrics) # map the function over all leaves in metrics
    return summary['loss'], summary['accuracy']

### Training Loop

In [14]:
# Random seed
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

# Initialize model
state=create_train_state(rng, learning_rate=lr)
del init_rng

for epoch in range(N_epochs):
    # Use a separate PRNG key to permute image data during shuffling
    rng, input_rng = jax.random.split(rng)
    state = train_epoch(state, train_ds, epoch, input_rng)
    test_loss, test_accuracy = eval_model(state, val_ds)
    print(' test epoch: %d, loss: %.2f, accuracy: %.2f' % (epoch, test_loss, test_accuracy * 100))

Model initialized.
train epoch: 0, loss: 1.1877, accuracy: 68.73
 test epoch: 0, loss: 0.62, accuracy: 86.53
train epoch: 1, loss: 0.4881, accuracy: 88.41
 test epoch: 1, loss: 0.38, accuracy: 90.42
train epoch: 2, loss: 0.3562, accuracy: 90.78
 test epoch: 2, loss: 0.30, accuracy: 91.88
train epoch: 3, loss: 0.2892, accuracy: 92.30
 test epoch: 3, loss: 0.26, accuracy: 92.98
