Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added memory_d>1 functionality #40

Merged
merged 3 commits into from Jun 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGES.rst
Expand Up @@ -22,6 +22,15 @@ Release history
0.3.2 (unreleased)
==================

**Added**

- Setting ``kernel_initializer=None`` now removes the dense input kernel. (`#40`_)
- The ``keras_lmu.LMUFFT`` layer now supports ``memory_d > 1``. ``keras_lmu.LMU`` now
uses this implementation for all values of ``memory_d`` when feedforward conditions
are satisfied (no hidden-to-memory or memory-to-memory connections,
and the sequence length is not ``None``). (`#40`_)

.. _#40: https://github.com/nengo/keras-lmu/pull/40

0.3.1 (November 16, 2020)
=========================
Expand Down
86 changes: 53 additions & 33 deletions keras_lmu/layers.py
Expand Up @@ -48,7 +48,8 @@ class to create a recurrent Keras layer to process the whole sequence. Calling
If True, connect the input directly to the hidden component (in addition to
the connection from the memory component) (default False).
kernel_initializer : ``tf.initializers.Initializer``
Initializer for weights from input to memory/hidden component.
Initializer for weights from input to memory/hidden component. If ``None``,
no weights will be used, and the input size must match the memory/hidden size.
recurrent_initializer : ``tf.initializers.Initializer``
Initializer for ``memory_to_memory`` weights (if that connection is enabled).
dropout : float
Expand Down Expand Up @@ -146,11 +147,19 @@ def build(self, input_shape):
if self.hidden_to_memory:
enc_d += self.hidden_output_size

self.kernel = self.add_weight(
name="kernel",
shape=(enc_d, self.memory_d),
initializer=self.kernel_initializer,
)
if self.kernel_initializer is not None:
self.kernel = self.add_weight(
name="kernel",
shape=(enc_d, self.memory_d),
initializer=self.kernel_initializer,
)
else:
self.kernel = None
if enc_d != self.memory_d:
raise ValueError(
f"For LMUCells with no input kernel, the input dimension ({enc_d})"
f" must equal `memory_d` ({self.memory_d})."
)

if self.memory_to_memory:
self.recurrent_kernel = self.add_weight(
Expand Down Expand Up @@ -207,7 +216,7 @@ def call(self, inputs, states, training=None):
u_in = tf.concat((inputs, h[0]), axis=1) if self.hidden_to_memory else inputs
if self.dropout > 0:
u_in *= self.get_dropout_mask_for_cell(u_in, training)
u = tf.matmul(u_in, self.kernel)
u = u_in if self.kernel is None else tf.matmul(u_in, self.kernel)

if self.memory_to_memory:
if self.recurrent_dropout > 0:
Expand Down Expand Up @@ -328,7 +337,8 @@ class LMU(tf.keras.layers.Layer):
If True, connect the input directly to the hidden component (in addition to
the connection from the memory component) (default False).
kernel_initializer : ``tf.initializers.Initializer``
Initializer for weights from input to memory/hidden component.
Initializer for weights from input to memory/hidden component. If ``None``,
no weights will be used, and the input size must match the memory/hidden size.
recurrent_initializer : ``tf.initializers.Initializer``
Initializer for ``memory_to_memory`` weights (if that connection is enabled).
dropout : float
Expand Down Expand Up @@ -398,7 +408,6 @@ def build(self, input_shapes):
if (
not self.hidden_to_memory
and not self.memory_to_memory
and self.memory_d == 1
and input_shapes[1] is not None
):
self.layer = LMUFFT(
Expand Down Expand Up @@ -507,7 +516,8 @@ class LMUFFT(tf.keras.layers.Layer):
If True, connect the input directly to the hidden component (in addition to
the connection from the memory component) (default False).
kernel_initializer : ``tf.initializers.Initializer``
Initializer for weights from input to memory/hidden component.
Initializer for weights from input to memory/hidden component. If ``None``,
no weights will be used, and the input size must match the memory/hidden size.
dropout : float
Dropout rate on input connections.
return_sequences : bool, optional
Expand All @@ -529,13 +539,6 @@ def __init__(
):
super().__init__(**kwargs)

if memory_d != 1:
# TODO: we can support this by reusing the same impulse response
# for each dimension
raise NotImplementedError(
"Multi-dimensional memory not supported in LMUFFT"
)

if input_to_hidden and hidden_cell is None:
raise ValueError("input_to_hidden must be False if hidden_cell is None")

Expand All @@ -548,17 +551,17 @@ def __init__(
self.dropout = dropout
self.return_sequences = return_sequences

# create a standard LMUCell to generate the impulse response during `build`
self.delay_layer = tf.keras.layers.RNN(
LMUCell(
memory_d=memory_d,
memory_d=1,
order=order,
theta=theta,
hidden_cell=None,
input_to_hidden=False,
hidden_to_memory=False,
memory_to_memory=False,
kernel_initializer="ones",
dropout=0,
kernel_initializer=None,
trainable=False,
),
return_sequences=True,
Expand All @@ -577,26 +580,37 @@ def build(self, input_shape):

super().build(input_shape)

if input_shape[1] is None:
seq_len = input_shape[1]
enc_d = input_shape[-1]

if seq_len is None:
# TODO: we could dynamically run the impulse response for longer if
# needed using stateful=True
raise ValueError(
f"LMUFFT requires that the input shape's temporal axis be fully "
f"specified (got {input_shape[1]})"
f"specified (got {seq_len})"
)

impulse = tf.reshape(tf.eye(input_shape[1], 1), (1, -1, 1))
impulse = tf.reshape(tf.eye(seq_len, 1), (1, -1, 1))

self.impulse_response = tf.signal.rfft(
tf.squeeze(tf.transpose(self.delay_layer(impulse)), axis=-1),
fft_length=[2 * input_shape[1]],
fft_length=[2 * seq_len],
)

self.kernel = self.add_weight(
name="kernel",
shape=(input_shape[-1], self.memory_d),
initializer=self.kernel_initializer,
)
if self.kernel_initializer is not None:
self.kernel = self.add_weight(
name="kernel",
shape=(input_shape[-1], self.memory_d),
initializer=self.kernel_initializer,
)
else:
self.kernel = None
if enc_d != self.memory_d:
raise ValueError(
f"For LMUCells with no input kernel, the input dimension ({enc_d})"
f" must equal `memory_d` ({self.memory_d})."
)

if self.hidden_cell is not None and not self.hidden_cell.built:
hidden_input_d = self.memory_d * self.order
Expand Down Expand Up @@ -627,20 +641,26 @@ def call(self, inputs, training=None):
)(inputs)

# Apply input encoders
u = tf.matmul(inputs, self.kernel, name="input_encoder_mult")
# FFT requires shape (batch, 1, timesteps)
u = (
inputs
if self.kernel is None
else tf.matmul(inputs, self.kernel, name="input_encoder_mult")
)

# FFT requires shape (batch, memory_d, timesteps)
u = tf.transpose(u, perm=[0, 2, 1])

# Pad sequences to avoid circular convolution
# Perform the FFT
fft_input = tf.signal.rfft(u, fft_length=[2 * seq_len], name="input_pad")

# Elementwise product of FFT (broadcasting done automatically)
result = fft_input * self.impulse_response
# Elementwise product of FFT (with broadcasting)
result = tf.expand_dims(fft_input, axis=-2) * self.impulse_response

# Inverse FFT
m = tf.signal.irfft(result, fft_length=[2 * seq_len])[..., :seq_len]

m = tf.reshape(m, (-1, self.order * self.memory_d, seq_len))
m = tf.transpose(m, perm=[0, 2, 1])

# apply hidden cell
Expand Down
76 changes: 50 additions & 26 deletions keras_lmu/tests/test_layers.py
Expand Up @@ -58,45 +58,48 @@ def test_multivariate_lmu(rng):

for i in range(memory_d):
assert np.allclose(
results[0][..., i * order : (i + 1) * order], results[i + 1], atol=1e-6
results[0][..., i * order : (i + 1) * order], results[i + 1], atol=2e-6
)


def test_layer_vs_cell(rng):
memory_d = 4
order = 12
@pytest.mark.parametrize("has_input_kernel", (True, False))
@pytest.mark.parametrize("fft", (True, False))
def test_layer_vs_cell(has_input_kernel, fft, rng):
n_steps = 10
input_d = 32
kwargs = dict(
memory_d=4 if has_input_kernel else input_d,
order=12,
theta=n_steps,
kernel_initializer="glorot_uniform" if has_input_kernel else None,
memory_to_memory=not fft,
)
hidden_cell = lambda: tf.keras.layers.SimpleRNNCell(units=64)

inp = rng.uniform(-1, 1, size=(2, n_steps, input_d))

lmu_cell = tf.keras.layers.RNN(
layers.LMUCell(
memory_d, order, n_steps, tf.keras.layers.SimpleRNNCell(units=64)
),
layers.LMUCell(hidden_cell=hidden_cell(), **kwargs),
return_sequences=True,
)
cell_out = lmu_cell(inp)

lmu_layer = layers.LMU(
memory_d,
order,
n_steps,
tf.keras.layers.SimpleRNNCell(units=64),
return_sequences=True,
)
lmu_layer = layers.LMU(return_sequences=True, hidden_cell=hidden_cell(), **kwargs)
lmu_layer.build(inp.shape)
lmu_layer.layer.set_weights(lmu_cell.get_weights())
layer_out = lmu_layer(inp)

assert isinstance(lmu_layer.layer, layers.LMUFFT if fft else tf.keras.layers.RNN)

for w0, w1 in zip(
sorted(lmu_cell.weights, key=lambda w: w.shape.as_list()),
sorted(lmu_layer.weights, key=lambda w: w.shape.as_list()),
):
assert np.allclose(w0.numpy(), w1.numpy())

assert np.allclose(cell_out, lmu_cell(inp))
assert np.allclose(cell_out, layer_out)
atol = 2e-6 if fft else 1e-8
assert np.allclose(cell_out, lmu_cell(inp), atol=atol)
assert np.allclose(cell_out, layer_out, atol=atol)


def test_save_load_weights(rng, tmp_path):
Expand Down Expand Up @@ -183,18 +186,26 @@ def test_save_load_serialization(mode, tmp_path):

@pytest.mark.parametrize("return_sequences", (True, False))
@pytest.mark.parametrize(
"hidden_cell", (None, tf.keras.layers.Dense(4), tf.keras.layers.SimpleRNNCell(4))
"hidden_cell",
(
lambda: None,
lambda: tf.keras.layers.Dense(4),
lambda: tf.keras.layers.SimpleRNNCell(4),
),
)
def test_fft(return_sequences, hidden_cell, rng):
@pytest.mark.parametrize("memory_d", [1, 4])
def test_fft(return_sequences, hidden_cell, memory_d, rng):
kwargs = dict(memory_d=memory_d, order=2, theta=3, hidden_cell=hidden_cell())

x = rng.uniform(-1, 1, size=(2, 10, 32))

rnn_layer = tf.keras.layers.RNN(
layers.LMUCell(1, 2, 3, hidden_cell),
layers.LMUCell(**kwargs),
return_sequences=return_sequences,
)
rnn_out = rnn_layer(x)

fft_layer = layers.LMUFFT(1, 2, 3, hidden_cell, return_sequences=return_sequences)
fft_layer = layers.LMUFFT(return_sequences=return_sequences, **kwargs)
fft_layer.build(x.shape)
fft_layer.kernel.assign(rnn_layer.cell.kernel)
fft_out = fft_layer(x)
Expand Down Expand Up @@ -239,7 +250,7 @@ def test_fft_auto_swap(hidden_to_memory, memory_to_memory, memory_d, steps):
lmu.build((32, steps, 8))

assert isinstance(lmu.layer, tf.keras.layers.RNN) == (
hidden_to_memory or memory_to_memory or memory_d != 1 or steps is None
hidden_to_memory or memory_to_memory or steps is None
)


Expand Down Expand Up @@ -417,9 +428,22 @@ def test_fit(fft):

_, acc = model.evaluate(x_test, y_test, verbose=0)

if fft:
assert isinstance(lmu_layer.layer, layers.LMUFFT)
else:
assert isinstance(lmu_layer.layer, tf.keras.layers.RNN)

assert isinstance(lmu_layer.layer, layers.LMUFFT if fft else tf.keras.layers.RNN)
assert acc == 1.0


@pytest.mark.parametrize("fft", (True, False))
def test_no_input_kernel_dimension_mismatch(fft):
lmu_layer = layers.LMU(
memory_d=1,
order=4,
theta=4,
hidden_cell=tf.keras.layers.SimpleRNNCell(units=10),
hidden_to_memory=False,
memory_to_memory=not fft,
input_to_hidden=not fft,
kernel_initializer=None,
)

with pytest.raises(ValueError, match="no input kernel"):
lmu_layer(tf.ones((4, 10, 2)))