Skip to content

elisejiuqizhang/jax-lmu

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

72 Commits
 
 
 
 
 
 
 
 

Repository files navigation

jax-lmu

Unofficial implementation of Legendre Memory Unit (LMU) using Jax (out of interest and it's a work in progress!😅)

Original paper: Legendre Memory Units: Continuous-Time Representation in Recurrent Neural Networks; Repo of the orginal paper used Keras;

Another paper is parallelizing LMU training and its LMU PyTorch implementation (which is mostly what I am basing my adaptation on).

Models

Original LMU and its training/evaluation are named jax_lmu.py and jax_lmu_run_exp.py (can also run from the jax_lmu_mnist.ipynb file);

The parallelized version of LMU and its training/evaluation are named jax_lmufft.py and jax_lmufft_run_exp.py(can also run from the jax_lmu_parallel_mnist.ipynb file).

Environment Setup

Minimum Requirements

Python version is 3.9.15. Required packages installed using a mixture of conda and pip:

conda install cudatoolkit=11.3
conda install cudnn=8.2.1
pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install flax==0.6.2

Experiments on MNIST

MNIST dataset will be loaded using TensorFlow.

pip install tqdm
pip install tensorflow tensorflow_datasets

About

Legendre Memory Unit (LMU) implementation using Jax

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published