Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug with hidden_to_memory and no hidden_cell #26

Merged
merged 3 commits into from Nov 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Expand Up @@ -5,3 +5,9 @@ repos:
rev: 20.8b0
hooks:
- id: black
files: \.py$
- repo: https://github.com/pycqa/isort
rev: 5.6.4
hooks:
- id: isort
files: \.py$
7 changes: 7 additions & 0 deletions CHANGES.rst
Expand Up @@ -22,6 +22,13 @@ Release history
0.3.1 (unreleased)
==================

**Changed**

- Raise a validation error if ``hidden_to_memory`` or ``input_to_hidden`` are True
when ``hidden_cell=None``. (`#26`_)

.. _#26: https://github.com/nengo/keras-lmu/pull/26


0.3.0 (November 6, 2020)
========================
Expand Down
7 changes: 1 addition & 6 deletions keras_lmu/__init__.py
@@ -1,11 +1,6 @@
"""KerasLMU provides a package for deep learning with Legendre Memory Units."""

from .layers import (
LMUCell,
LMUFFT,
LMU,
)

from .layers import LMU, LMUFFT, LMUCell
from .version import version as __version__

__copyright__ = "2019-2020, Applied Brain Research"
Expand Down
23 changes: 12 additions & 11 deletions keras_lmu/layers.py
Expand Up @@ -3,8 +3,8 @@
"""

import numpy as np
from scipy.signal import cont2discrete
import tensorflow as tf
from scipy.signal import cont2discrete
from tensorflow.python.keras.layers.recurrent import DropoutRNNCellMixin


Expand Down Expand Up @@ -101,11 +101,11 @@ def __init__(
self.B = None

if self.hidden_cell is None:
# if input_to_hidden=True then we can't determine the output size
# until build time
self.hidden_output_size = (
None if input_to_hidden else self.memory_d * self.order
)
for conn in ("hidden_to_memory", "input_to_hidden"):
if getattr(self, conn):
raise ValueError(f"{conn} must be False if hidden_cell is None")
gsmalik marked this conversation as resolved.
Show resolved Hide resolved

self.hidden_output_size = self.memory_d * self.order
self.hidden_state_size = []
elif hasattr(self.hidden_cell, "state_size"):
self.hidden_output_size = self.hidden_cell.output_size
Expand Down Expand Up @@ -142,10 +142,6 @@ def build(self, input_shape):

super().build(input_shape)

if self.input_to_hidden and self.hidden_cell is None:
self.hidden_output_size = self.memory_d * self.order + input_shape[-1]
self.output_size = self.hidden_output_size

enc_d = input_shape[-1]
if self.hidden_to_memory:
enc_d += self.hidden_output_size
Expand Down Expand Up @@ -526,7 +522,12 @@ def __init__(
if memory_d != 1:
# TODO: we can support this by reusing the same impulse response
# for each dimension
raise NotImplementedError("Multi-dimensional memory not supported")
raise NotImplementedError(
"Multi-dimensional memory not supported in LMUFFT"
)

if input_to_hidden and hidden_cell is None:
raise ValueError("input_to_hidden must be False if hidden_cell is None")

self.memory_d = memory_d
self.order = order
Expand Down
88 changes: 40 additions & 48 deletions keras_lmu/tests/test_layers.py
Expand Up @@ -202,12 +202,20 @@ def test_fft(return_sequences, hidden_cell, rng):
assert np.allclose(rnn_out, fft_out, atol=2e-6)


def test_fft_errors():
def test_validation_errors():
fft_layer = layers.LMUFFT(1, 2, 3, None)

with pytest.raises(ValueError, match="temporal axis be fully specified"):
fft_layer(tf.keras.Input((None, 32)))

with pytest.raises(ValueError, match="hidden_to_memory must be False"):
layers.LMUCell(1, 2, 3, None, hidden_to_memory=True)

with pytest.raises(ValueError, match="input_to_hidden must be False"):
layers.LMUCell(1, 2, 3, None, input_to_hidden=True)

with pytest.raises(ValueError, match="input_to_hidden must be False"):
layers.LMUFFT(1, 2, 3, None, input_to_hidden=True)


@pytest.mark.parametrize(
"hidden_to_memory, memory_to_memory, memory_d",
Expand All @@ -218,7 +226,7 @@ def test_fft_auto_swap(hidden_to_memory, memory_to_memory, memory_d):
memory_d,
2,
3,
None,
tf.keras.layers.Dense(5),
hidden_to_memory=hidden_to_memory,
memory_to_memory=memory_to_memory,
)
Expand Down Expand Up @@ -269,49 +277,64 @@ def test_hidden_types(hidden_cell, fft, rng, seed):


@pytest.mark.parametrize("fft", (True, False))
def test_connection_params(fft):
@pytest.mark.parametrize("hidden_cell", (None, tf.keras.layers.Dense))
def test_connection_params(fft, hidden_cell):
input_shape = (32, 7 if fft else None, 6)

x = tf.keras.Input(batch_shape=input_shape)

lmu_args = dict(
memory_d=1,
order=3,
theta=4,
hidden_cell=tf.keras.layers.Dense(units=5),
hidden_cell=hidden_cell if hidden_cell is None else hidden_cell(units=5),
gsmalik marked this conversation as resolved.
Show resolved Hide resolved
input_to_hidden=False,
)
if not fft:
lmu_args["hidden_to_memory"] = False
lmu_args["memory_to_memory"] = False
gsmalik marked this conversation as resolved.
Show resolved Hide resolved

lmu = layers.LMUCell(**lmu_args) if not fft else layers.LMUFFT(**lmu_args)
lmu.build(input_shape)
y = lmu(x) if fft else tf.keras.layers.RNN(lmu)(x)
assert lmu.kernel.shape == (input_shape[-1], lmu.memory_d)
if not fft:
assert lmu.recurrent_kernel is None
assert lmu.hidden_cell.kernel.shape == (
lmu.memory_d * lmu.order,
lmu.hidden_cell.units,
if hidden_cell is not None:
assert lmu.hidden_cell.kernel.shape == (
lmu.memory_d * lmu.order,
lmu.hidden_cell.units,
)
assert y.shape == (
input_shape[0],
lmu.memory_d * lmu.order if hidden_cell is None else lmu.hidden_cell.units,
)

lmu_args["input_to_hidden"] = True
lmu_args["input_to_hidden"] = hidden_cell is not None
if not fft:
lmu_args["hidden_to_memory"] = True
lmu_args["hidden_to_memory"] = hidden_cell is not None
lmu_args["memory_to_memory"] = True

lmu = layers.LMUCell(**lmu_args) if not fft else layers.LMUFFT(**lmu_args)
lmu.hidden_cell.built = False # so that the kernel will be rebuilt
lmu.build(input_shape)
if hidden_cell is not None:
lmu.hidden_cell.built = False # so that the kernel will be rebuilt
y = lmu(x) if fft else tf.keras.layers.RNN(lmu)(x)
assert lmu.kernel.shape == (
input_shape[-1] + (lmu.hidden_cell.units if not fft else 0),
input_shape[-1] + (0 if fft or hidden_cell is None else lmu.hidden_cell.units),
lmu.memory_d,
)
if not fft:
assert lmu.recurrent_kernel.shape == (
lmu.order * lmu.memory_d,
lmu.memory_d,
)
assert lmu.hidden_cell.kernel.shape == (
lmu.memory_d * lmu.order + input_shape[-1],
lmu.hidden_cell.units,
if hidden_cell is not None:
assert lmu.hidden_cell.kernel.shape == (
lmu.memory_d * lmu.order + input_shape[-1],
lmu.hidden_cell.units,
)
assert y.shape == (
input_shape[0],
lmu.memory_d * lmu.order if hidden_cell is None else lmu.hidden_cell.units,
)


Expand Down Expand Up @@ -341,34 +364,3 @@ def test_dropout(dropout, recurrent_dropout, fft):
y0 = lmu(np.ones((32, 10, 64)), training=False).numpy()
y1 = lmu(np.ones((32, 10, 64)), training=False).numpy()
assert np.allclose(y0, y1)


@pytest.mark.parametrize(
"hidden_cell",
(tf.keras.layers.SimpleRNNCell(units=10), tf.keras.layers.Dense(units=10), None),
)
def test_skip_connection(rng, hidden_cell):
memory_d = 4
order = 16
n_steps = 10
input_d = 32

inp = tf.keras.Input(shape=(n_steps, input_d))

lmu = layers.LMUCell(
memory_d=memory_d,
order=order,
theta=n_steps,
hidden_cell=hidden_cell,
input_to_hidden=True,
)
assert lmu.output_size == (None if hidden_cell is None else 10)

out = tf.keras.layers.RNN(lmu)(inp)

output_size = (
(memory_d * order + input_d) if hidden_cell is None else hidden_cell.units
)
assert out.shape[-1] == output_size
assert lmu.hidden_output_size == output_size
assert lmu.output_size == output_size
4 changes: 4 additions & 0 deletions pyproject.toml
Expand Up @@ -5,3 +5,7 @@ requires = ["setuptools", "wheel"]

[tool.black]
target-version = ['py35', 'py36', 'py37', 'py38']

[tool.isort]
profile = "black"
src_paths = ["keras_lmu"]