# Legendre Memory Unit (LMU)

The Legendre Memory Unit ([LMU; Voelker et al.,
2019](http://compneuro.uwaterloo.ca/files/publications/voelker.2019.lmu.pdf)) is a
dynamical system that maintains a rolling window of continuous-time input history. It is
implementable on neuromorphic hardware using Nengo. Here we show how it can be easily
implemented using the `gyrus.lti` operator, and allows the user to easily apply it not
only to scalars (as in the original paper), but to arbitrary input ndarrays (i.e.,
tensors) using spiking neurons.

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import nengo
from nengo_gui.ipython import InlineGUI

import numpy as np

import seaborn as sns

import gyrus

In [None]:
def lmu(theta, order):
    """(A, B) linear system for the LMU of given length and dimensionality."""
    # Voelker et al., 2019, equation 1
    Q = np.arange(order, dtype=np.float64)
    R = (2 * Q + 1)[:, None] / theta
    j, i = np.meshgrid(Q, Q)
    A = np.where(i < j, -1, (-1.0) ** (i - j + 1)) * R
    B = (-1.0) ** Q[:, None] * R
    return A, B

In [None]:
AB = lmu(theta=1.0, order=8)

ens_kwargs = dict(
    n_neurons=250,
    neuron_type=nengo.SpikingRectifiedLinear(),
    max_rates=nengo.dists.Uniform(1e2, 1e3),
)

step = gyrus.stimulus(
    [lambda t, i=i: 1 if t > i else 0 for i in np.linspace(0, 0.1, 20)]
)

x = gyrus.fold(
    [
        step.lti(AB, state=lambda x: x.split().decode(**ens_kwargs).join()),
        step.lti(AB),
    ]
)

out_hat, out_ideal = np.asarray(x.run(1.1))
assert out_hat.shape == out_ideal.shape
print(out_hat.shape)  # (shifted step functions, time steps, lmu dimensions)

In [None]:
colors = sns.color_palette("mako", n_colors=out_hat.shape[0])

plt.figure(figsize=(16, 4))
plt.title("Legendre Memory Unit (LMU) - Shifted Step Responses")
for j in range(out_hat.shape[0]):
    for i in range(out_hat.shape[2]):
        color = colors[j]
        plt.plot(out_hat[j, :, i], color=color, alpha=0.0)
        plt.plot(out_ideal[j, :, i], color=color)  # , linestyle='--')
plt.xlabel("Time-step")
plt.show()

In [None]:
with nengo.Network() as model:
    x.make()

InlineGUI(model)