Skip to content

Commit

Permalink
Improve speed of LMUFFT on CPU
Browse files Browse the repository at this point in the history
  • Loading branch information
drasmuss committed Jun 29, 2021
1 parent 0202eaf commit 1257a40
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 10 deletions.
1 change: 1 addition & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Release history
- The ``A`` and ``B`` matrices are now stored as constants instead of non-trainable
variables. This can improve the training/inference speed, but it means that saved
weights from previous versions will be incompatible. (`#41`_)
- Improved speed of ``keras_lmu.LMUFFT`` when running on CPU. (`#40`_)

.. _#40: https://github.com/nengo/keras-lmu/pull/40
.. _#47: https://github.com/nengo/keras-lmu/pull/41
Expand Down
62 changes: 58 additions & 4 deletions keras_lmu/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
Core classes for the KerasLMU package.
"""

import os
from functools import partial

import numpy as np
import tensorflow as tf
from tensorflow.python.keras.layers.recurrent import DropoutRNNCellMixin
from tensorflow.python.ops.signal.fft_ops import irfft, rfft


class LMUCell(DropoutRNNCellMixin, tf.keras.layers.Layer):
Expand Down Expand Up @@ -786,7 +790,7 @@ def call(self, inputs, training=None):
if self.dropout:
inputs = tf.keras.layers.Dropout(
self.dropout, noise_shape=(inputs.shape[0], 1) + inputs.shape[2:]
)(inputs)
)(inputs, training=training)

# Apply input encoders
u = (
Expand All @@ -798,15 +802,17 @@ def call(self, inputs, training=None):
# FFT requires shape (batch, memory_d, timesteps)
u = tf.transpose(u, perm=[0, 2, 1])

# Pad sequences to avoid circular convolution
# Perform the FFT
fft_input = tf.signal.rfft(u, fft_length=[2 * seq_len], name="input_pad")
# Pad sequences to avoid circular convolution
fft_input = LMUFFT._parallel_rfft(u, fft_length=[2 * seq_len])

# Elementwise product of FFT (with broadcasting)
result = tf.expand_dims(fft_input, axis=-2) * self.impulse_response

# Inverse FFT
m = tf.signal.irfft(result, fft_length=[2 * seq_len])[..., :seq_len]
m = LMUFFT._parallel_rfft(result, inverse=True, fft_length=[2 * seq_len])[
..., :seq_len
]

m = tf.reshape(m, (-1, self.order * self.memory_d, seq_len))
m = tf.transpose(m, perm=[0, 2, 1])
Expand All @@ -831,6 +837,54 @@ def call(self, inputs, training=None):

return h

@staticmethod
def _parallel_rfft(x, inverse=False, fft_length=None):
"""
Computes FFT in parallel, which results in faster performance on CPU.
See https://github.com/tensorflow/tensorflow/issues/6541#issuecomment-713578892.
Parameters
----------
x : ``tf.Tensor``
The input to the FFT.
inverse: bool
If True, compute the inverse FFT.
fft_length: list of int
Pad/crop the fft axis (last dimension of ``x``) to this length (specified
as a length-1 list for some reason).
Returns
-------
y : ``tf.Tensor``
The (inverse) real-valued FFT of ``x``.
"""

if len(tf.config.get_visible_devices("GPU")) > 0:
# tensorflow's regular fft implementation already runs fast on GPU,
# don't need to do anything special
return (
tf.signal.irfft(x, fft_length=fft_length)
if inverse
else tf.signal.rfft(x, fft_length=fft_length)
)

op = partial(irfft if inverse else rfft, fft_length=fft_length)

x_shape = tf.shape(x)

# use map to parallelize fft calls
y = tf.map_fn(
op,
tf.reshape(x, (-1, x_shape[-1])),
parallel_iterations=os.cpu_count(),
dtype=x.dtype.real_dtype
if inverse
else (tf.complex64 if x.dtype == tf.float32 else tf.complex128),
)

return tf.reshape(y, tf.concat((x_shape[:-1], [-1]), axis=0))

def get_config(self):
"""Return config of layer (for serialization during model saving/loading)."""

Expand Down
35 changes: 29 additions & 6 deletions keras_lmu/tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,9 +358,11 @@ def test_connection_params(fft, hidden_cell):
lmu.memory_d * lmu.order,
lmu.hidden_cell.units,
)
assert y.shape == (
input_shape[0],
lmu.memory_d * lmu.order if hidden_cell is None else lmu.hidden_cell.units,
assert y.shape.is_compatible_with(
(
None if fft else input_shape[0], # fft loses track of static batch shape
lmu.memory_d * lmu.order if hidden_cell is None else lmu.hidden_cell.units,
)
)

lmu_args["input_to_hidden"] = hidden_cell is not None
Expand All @@ -386,9 +388,11 @@ def test_connection_params(fft, hidden_cell):
lmu.memory_d * lmu.order + input_shape[-1],
lmu.hidden_cell.units,
)
assert y.shape == (
input_shape[0],
lmu.memory_d * lmu.order if hidden_cell is None else lmu.hidden_cell.units,
assert y.shape.is_compatible_with(
(
None if fft else input_shape[0], # fft loses track of static batch shape
lmu.memory_d * lmu.order if hidden_cell is None else lmu.hidden_cell.units,
)
)


Expand Down Expand Up @@ -617,3 +621,22 @@ def test_theta_attribute(mode):
cell = layer if mode == "cell" else layer.layer.cell
cell.theta_inv.assign(10)
assert np.allclose(layer.theta, 0.1)


@pytest.mark.parametrize("gpu", (True, False))
def test_parallel_fft(gpu, rng, monkeypatch):
monkeypatch.setattr(
tf.config, "get_visible_devices", lambda *_: ["a_gpu"] if gpu else []
)

x = tf.constant(rng.uniform(-1, 1, size=(32, 20, 5)), dtype=tf.float32)

y0 = tf.signal.rfft(x, fft_length=[10])
y1 = layers.LMUFFT._parallel_rfft(x, fft_length=[10])

assert np.allclose(y0, y1)

x0 = tf.signal.irfft(y0, fft_length=[10])
x1 = layers.LMUFFT._parallel_rfft(y1, fft_length=[10], inverse=True)

assert np.allclose(x0, x1)

0 comments on commit 1257a40

Please sign in to comment.