# Parallelized Legendre Memory Unit (LMU) on the permuted MNIST dataset
Implementation based on 
* the ICML 2021 paper "Parallelizing Legendre Memory Unit Training" (https://proceedings.mlr.press/v139/chilkuri21a.html)
* and the GitHub repository of the LMU PyTorch implementation (https://github.com/hrshtv/pytorch-lmu)

## Imports

In [1]:
from functools import partial
from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Union

import numpy as np
import jax
from jax import lax, random, numpy as jnp
from jax import grad, jit, vmap, value_and_grad

from flax import linen as nn

from scipy.signal import cont2discrete

import tensorflow_datasets as tfds 

import optax
from flax.training import train_state

2023-01-05 14:59:13.918236: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-01-05 14:59:13.918278: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
  from .autonotebook import tqdm as notebook_tqdm


## Modules

In [2]:
#---------------------- Parallelized LMU ----------------------#
class LMUFFT(nn.Module):
    """ Parallelized LMU Layer
        
        Parameters:
            input_size (int) : 
                Size of the input vector (x_t)
            hidden_size (int) : 
                Size of the hidden vector (h_t)
            memory_size (int) :
                Size of the memory vector (m_t)
            seq_len (int) :
                Size of the sequence length (n)
            theta (int) :
                The number of timesteps in the sliding window that is represented using the LTI system
    
    """

    input_size: int
    hidden_size: int
    memory_size: int
    seq_len: int
    theta: int

    def setup(self):
        """
        A: [memory_size, memory_size]
        B: [memory_size, 1]
        """
        self.A, self.B = self.stateSpaceMatrices() # numpy
        self.H, self.H_fft = self.impulse() # numpy

    @nn.compact
    def __call__(self, x):
        """
        Parameters:
            x: [batch_size, seq_len, input_size];

        Returns:
            h: [batch_size, seq_len, hidden_size]; The parallelized/flattened hidden states of every timestep;
            h_n: [batch_size, hidden_size]; The hidden state of the last timestep;
        """
        batch_size, seq_len, input_size = x.shape

        # Equation 18 of the paper
        u_pre_act = nn.Dense(features=1,
                            use_bias=True)
        u = nn.relu(u_pre_act(x)) # [batch_size, seq_len, 1]

        # Equation 26 of the paper
        fft_input=u.transpose(0, 2, 1) # [batch_size, 1, seq_len]
        fft_u=jnp.fft.rfft(fft_input, n = 2*seq_len, axis = -1) # [batch_size, seq_len, seq_len+1]

        # Element-wise multiplication (uses broadcasting)
        # fft_u:[batch_size, 1, seq_len+1] 
        # self.H_fft: [memory_size, seq_len+1] -> to be expanded in dimension 0
        H_fft=self.H_fft.reshape(1, self.memory_size, self.seq_len+1) # [1, memory_size, seq_len+1]
        # [batch_size, 1, seq_len+1] * [1, memory_size, seq_len+1]
        temp=jnp.multiply(fft_u, H_fft) # [batch_size, memory_size, seq_len+1]

        m=jnp.fft.irfft(temp, n = 2*seq_len, axis = -1) # [batch_size, memory_size, seq_len+1]
        m=m[:, :, :seq_len] # [batch_size, memory_size, seq_len]
        m=m.transpose(0, 2, 1) # [batch_size, seq_len, memory_size]

        # Equation 20 of the paper
        input_h=jnp.concatenate((x, m), axis=-1) # [batch_size, seq_len, input_size + memory_size]
        h_pre_act = nn.Dense(features=self.hidden_size,
                            use_bias=True)

        # h=nn.tanh(h_pre_act(input_h)) # [batch_size, seq_len, hidden_size]
        h=nn.relu(h_pre_act(input_h)) # [batch_size, seq_len, hidden_size]

        h_n=h[:, -1, :] # [batch_size, hidden_size]

        return h, h_n



    # def impulse(self):
    #     def impulse_body(n, carry:Tuple):
    #         """The body function to be used in the fori_loop of the impluse function above"""
    #         H, A_i=carry
    #         H_n=jnp.matmul(A_i, self.B) # [memory_size, 1]
    #         # H.append(H_n)
    #         H = H.at[:, n].set(H_n.reshape(-1))
    #         A_i=jnp.matmul(self.A, A_i)
    #         return (H, A_i)
    #     """ Returns the matrices H and the 1D Fourier transform of H (Equations 23, 26 of the paper) """
    #     H_init=jnp.empty((self.memory_size, self.seq_len)) # [memory_size, seq_len]
    #     A_i_init=jnp.eye(self.memory_size, dtype = np.float32) # [memory_size, memory_size]
    #     val_init = (H_init, A_i_init)
    #     H, A_i=lax.fori_loop(0, self.seq_len, impulse_body, val_init)        
    #     # H=np.concatenate(H_fin, axis=-1) # [memory_size, seq_len]
    #     H_fft=np.fft.rfft(H, n = 2*self.seq_len, axis = -1) # [memory_size, seq_len + 1]
    #     return H, H_fft

    def impulse(self):
        """ Returns the matrices H and the 1D Fourier transform of H (Equations 23, 26 of the paper) """
        H = np.empty((self.memory_size, self.seq_len), dtype = np.float32) # [memory_size, seq_len]
        A_i = np.eye(self.memory_size, dtype = np.float32) # [memory_size, memory_size]
        for n in range(self.seq_len):
            H_n = np.matmul(A_i, self.B) # [memory_size, 1]
            H[:, n]=H_n.reshape(-1)
            A_i = np.matmul(self.A, A_i)
        H_fft = np.fft.rfft(H, n = 2*self.seq_len, axis = -1) # [memory_size, seq_len + 1]
        return H, H_fft

    def stateSpaceMatrices(self):
        """ Returns the discretized state space matrices A and B """
        Q = np.arange(self.memory_size, dtype = np.float32)
        R = (2*Q + 1) / self.theta
        i, j = np.meshgrid(Q, Q, indexing = "ij")

        # Continuous
        A = R * np.where(i < j, -1, (-1.0)**(i - j + 1))
        B = R * ((-1.0)**Q)
        B = B.reshape(-1, 1)
        C = np.ones((1, self.memory_size))
        D = np.zeros((1,))

        # Convert to discrete
        A, B, C, D, dt = cont2discrete(
                                system = (A, B, C, D), 
                                dt = 1.0, 
                                method = "zoh"
                            )
            
        return A, B

## The Classification Model

In [3]:
#---------------------- Parallelized LMU for pMNIST Classification ----------------------#
class Model(nn.Module):
    input_size: int
    output_size: int
    hidden_size: int
    memory_size: int
    theta: int
    seq_len:int

    @nn.compact
    def __call__(self, x):
        _, h_n = LMUFFT(input_size=self.input_size, hidden_size=self.hidden_size, memory_size=self.memory_size, seq_len=self.seq_len, theta=self.theta)(x)
        x = nn.relu(h_n)
        output = nn.Dense(features=self.output_size)(x)
        return output

## Hyperparameters

In [4]:
N_x = 1 # dimension of the input, a single pixel
N_t = 784 # number of time steps (sequence length) - here it's 28 * 28 since we are using MNIST and making it 1D
N_h = 128 # dimension of the hidden state
N_m = 64 # dimension of the memory
N_c = 10 # number of classes 
THETA = N_t
N_b = 200 # batch size
N_epochs = 20 # number of epochs


lr=1e-4 # learning rate for adam optimizer

## Load Data

In [5]:
def get_datasets():
    """Load MNIST train and test datasets into memory."""

    ds_builder = tfds.builder('mnist')
    ds_builder.download_and_prepare()
    train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
    val_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))


    train_ds['image'] = jnp.float32(train_ds['image'])
    val_ds['image'] = jnp.float32(val_ds['image'])

    return train_ds, val_ds


train_ds, val_ds = get_datasets()

2023-01-05 14:59:14.715023: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudnn.so.8'; dlerror: libcudnn.so.8: cannot open shared object file: No such file or directory
2023-01-05 14:59:14.715047: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1934] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


## Loss & Metrics

In [6]:
"""Following the Flax example: https://flax.readthedocs.io/en/latest/getting_started.html"""

def cross_entropy_loss(*, logits, labels):
    labels_onehot = jax.nn.one_hot(labels, num_classes=10)
    return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean() 

def compute_metrics(*, logits, labels):
    loss = cross_entropy_loss(logits=logits, labels=labels)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    metrics = {
        'loss': loss,
        'accuracy': accuracy,
    }
    return metrics

## Training

### Utility Functions

In [7]:
def create_train_state(rng, learning_rate=lr):
    model = Model(
            input_size  = N_x,
            output_size = N_c,
            hidden_size = N_h, 
            memory_size = N_m, 
            seq_len=N_t,
            theta = THETA
        )

    
    params = model.init(rng, jnp.ones((1, N_t, N_x)))['params']
    print("Model initialized.")

    optimizer = optax.adam(learning_rate)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)

In [8]:
def train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch['image'])
        loss = cross_entropy_loss(logits=logits, labels=batch['label'])
        return loss
    grad_fn = value_and_grad(loss_fn, has_aux=False)
    loss, grads = grad_fn(state.params)
    new_state = state.apply_gradients(grads=grads)

    logits = new_state.apply_fn({'params': new_state.params}, batch['image'])
    metrics = compute_metrics(logits=logits, labels=batch['label'])

    return new_state, metrics 

In [9]:
def eval_step(state, batch, seq_len=N_t, input_size=N_x):
    batch['image']=batch['image'].reshape((-1, seq_len, input_size))
    logits = state.apply_fn({'params': state.params}, batch['image'])
    return compute_metrics(logits=logits, labels=batch['label'])


In [10]:
def train_epoch(state, train_ds, epoch, rng, batch_size=N_b, seq_len=N_x, input_size=N_x):

    train_ds_size = len(train_ds['image'])
    steps_per_epoch = train_ds_size // batch_size

    perms = jax.random.permutation(rng, train_ds_size) # get a randomized index array
    perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
    perms = perms.reshape((steps_per_epoch, batch_size)) # index array, where each row is a batch

    batch_metrics = []
    for perm in perms:
        batch = {k: v[perm, ...] for k, v in train_ds.items()} # dict{'image': array, 'label': array}
        batch['image']=batch['image'].reshape((batch_size, -1, input_size)) # reshape to the required input dimensions
        state, metrics = train_step(state, batch)
        batch_metrics.append(metrics)
    
    # compute mean of metrics across each batch in epoch.
    batch_metrics_np = jax.device_get(batch_metrics)
    epoch_metrics_np = {
        k: np.mean([metrics[k] for metrics in batch_metrics_np])
        for k in batch_metrics_np[0]} # jnp.mean does not work on lists

    print('train epoch: %d, loss: %.4f, accuracy: %.2f' % (epoch, epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100))

    return state

In [11]:
def eval_model(state, test_ds):
    metrics = eval_step(state, test_ds)
    metrics = jax.device_get(metrics)
    summary = jax.tree_util.tree_map(lambda x: x.item(), metrics) # map the function over all leaves in metrics
    return summary['loss'], summary['accuracy']

### Training Loop

In [None]:
# Random seed
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

# Initialize model
state=create_train_state(rng, learning_rate=lr)
del init_rng

for epoch in range(N_epochs):
    # Use a separate PRNG key to permute image data during shuffling
    rng, input_rng = jax.random.split(rng)
    state = train_epoch(state, train_ds, epoch, input_rng)
    test_loss, test_accuracy = eval_model(state, val_ds)
    print(' test epoch: %d, loss: %.2f, accuracy: %.2f' % (epoch, test_loss, test_accuracy * 100))