Skip to content

Commit

Permalink
fixup! Allow memory_d > 1 for LMUFFT
Browse files Browse the repository at this point in the history
  • Loading branch information
hunse committed Jun 15, 2021
1 parent 672109f commit 7aeb03a
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions keras_lmu/tests/test_layers.py
Expand Up @@ -63,14 +63,16 @@ def test_multivariate_lmu(rng):


@pytest.mark.parametrize("has_input_kernel", (True, False))
def test_layer_vs_cell(has_input_kernel, rng):
@pytest.mark.parametrize("fft", (True, False))
def test_layer_vs_cell(has_input_kernel, fft, rng):
n_steps = 10
input_d = 32
kwargs = dict(
memory_d=4 if has_input_kernel else input_d,
order=12,
theta=n_steps,
has_input_kernel=has_input_kernel,
memory_to_memory=not fft,
)
hidden_cell = lambda: tf.keras.layers.SimpleRNNCell(units=64)

Expand All @@ -87,14 +89,17 @@ def test_layer_vs_cell(has_input_kernel, rng):
lmu_layer.layer.set_weights(lmu_cell.get_weights())
layer_out = lmu_layer(inp)

assert isinstance(lmu_layer.layer, layers.LMUFFT if fft else tf.keras.layers.RNN)

for w0, w1 in zip(
sorted(lmu_cell.weights, key=lambda w: w.shape.as_list()),
sorted(lmu_layer.weights, key=lambda w: w.shape.as_list()),
):
assert np.allclose(w0.numpy(), w1.numpy())

assert np.allclose(cell_out, lmu_cell(inp))
assert np.allclose(cell_out, layer_out)
atol = 2e-6 if fft else 1e-8
assert np.allclose(cell_out, lmu_cell(inp), atol=atol)
assert np.allclose(cell_out, layer_out, atol=atol)


def test_save_load_weights(rng, tmp_path):
Expand Down Expand Up @@ -423,11 +428,7 @@ def test_fit(fft):

_, acc = model.evaluate(x_test, y_test, verbose=0)

if fft:
assert isinstance(lmu_layer.layer, layers.LMUFFT)
else:
assert isinstance(lmu_layer.layer, tf.keras.layers.RNN)

assert isinstance(lmu_layer.layer, layers.LMUFFT if fft else tf.keras.layers.RNN)
assert acc == 1.0


Expand Down

0 comments on commit 7aeb03a

Please sign in to comment.