Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keras Layer compatibility #3

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CWT/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
__all__ = ['cwt']
__version__ = "0.2.0"
145 changes: 75 additions & 70 deletions cwt.py → CWT/cwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@

import numpy as np
import tensorflow as tf
from keras.layers import Layer


class ContinuousWaveletTransform(object):
class ContinuousWaveletTransform(Layer):
"""CWT layer implementation in Tensorflow for GPU acceleration."""
def __init__(self, n_scales, border_crop=0, stride=1, name="cwt"):
def __init__(self, n_scales, border_crop=0, stride=1, outputformat='Complex' ):
"""
Args:
n_scales: (int) Number of scales for the scalogram.
Expand All @@ -20,14 +20,13 @@ def __init__(self, n_scales, border_crop=0, stride=1, name="cwt"):
desired size to remove border effects of the CWT. Default 0.
stride: (int) The stride of the sliding window across the input.
Default is 1.
name: (string) A name for the op. Default "cwt".
"""
super(ContinuousWaveletTransform, self).__init__()
self.n_scales = n_scales
self.border_crop = border_crop
self.stride = stride
self.name = name
with tf.variable_scope(self.name):
self.real_part, self.imaginary_part = self._build_wavelet_bank()
self.outputformat = outputformat
self.real_part, self.imaginary_part = self._build_wavelet_bank()

def _build_wavelet_bank(self):
"""Needs implementation to compute the real and imaginary parts
Expand All @@ -36,8 +35,9 @@ def _build_wavelet_bank(self):
real_part = None
imaginary_part = None
return real_part, imaginary_part

def __call__(self, inputs):

@tf.function
def call(self, inputs):
"""
Computes the CWT with the specified wavelet bank.
If the signal has more than one channel, the CWT is computed for
Expand All @@ -57,33 +57,37 @@ def __call__(self, inputs):
border_crop = int(self.border_crop / self.stride)
start = border_crop
end = (-border_crop) if (border_crop > 0) else None
with tf.variable_scope(self.name):
# Input has expected shape of [batch_size, time_len, n_channels]
# We first unstack the input channels
inputs_unstacked = tf.unstack(inputs, axis=2)
multi_channel_cwt = []
for j, single_channel in enumerate(inputs_unstacked):
# Reshape input [batch, time_len] -> [batch, 1, time_len, 1]
inputs_expand = tf.expand_dims(single_channel, axis=1)
inputs_expand = tf.expand_dims(inputs_expand, axis=3)
with tf.name_scope('%s_%d' % (self.name, j)):
bank_real = self.real_part
bank_imag = -self.imaginary_part # Conjugation
out_real = tf.nn.conv2d(
input=inputs_expand, filter=bank_real,
strides=[1, 1, self.stride, 1], padding="SAME")
out_imag = tf.nn.conv2d(
input=inputs_expand, filter=bank_imag,
strides=[1, 1, self.stride, 1], padding="SAME")
out_real_crop = out_real[:, :, start:end, :]
out_imag_crop = out_imag[:, :, start:end, :]
out_concat = tf.concat(
[out_real_crop, out_imag_crop], axis=1)
# [batch, 2, time, n_scales]->[batch, time, n_scales, 2]
single_scalogram = tf.transpose(
out_concat, perm=[0, 2, 3, 1])
multi_channel_cwt.append(single_scalogram)
# Input has expected shape of [batch_size, time_len, n_channels]
# We first unstack the input channels
inputs_unstacked = tf.unstack(inputs, axis=2)
multi_channel_cwt = []
for j, single_channel in enumerate(inputs_unstacked):
# Reshape input [batch, time_len] -> [batch, 1, time_len, 1]
inputs_expand = tf.expand_dims(single_channel, axis=1)
inputs_expand = tf.expand_dims(inputs_expand, axis=3)
bank_real = self.real_part
bank_imag = -self.imaginary_part # Conjugation
out_real = tf.nn.conv2d(
input=inputs_expand, filters=bank_real,
strides=[1, 1, self.stride, 1], padding="SAME")
out_imag = tf.nn.conv2d(
input=inputs_expand, filters=bank_imag,
strides=[1, 1, self.stride, 1], padding="SAME")
out_real_crop = out_real[:, :, start:end, :]
out_imag_crop = out_imag[:, :, start:end, :]
out_mag_crop = tf.sqrt(out_real_crop**2 + out_imag_crop**2)

if self.outputformat == 'Magnitude':
out_concat = out_mag_crop
else:
out_concat = tf.concat([out_real_crop, out_imag_crop], axis=1)

# [batch, :, time, n_scales]->[batch, time, n_scales, :]
single_scalogram = tf.transpose(
a=out_concat, perm=[0, 2, 3, 1])
multi_channel_cwt.append(single_scalogram)
# Get all in shape [batch, time_len, n_scales, 2*n_channels]
# or if output='Magnitude [batch, time_len, n_scales, 2*n_channels]
scalograms = tf.concat(multi_channel_cwt, -1)
return scalograms

Expand All @@ -101,7 +105,7 @@ def __init__(
trainable=False,
border_crop=0,
stride=1,
name="cwt"):
output='Complex'):
"""
Computes the complex morlet wavelets

Expand Down Expand Up @@ -151,14 +155,16 @@ def __init__(
desired size to remove border effects of the CWT. Default 0.
stride: (int) The stride of the sliding window across the input.
Default is 1.
name: (string) A name for the op. Default "cwt".
"""

# Checking
if lower_freq > upper_freq:
raise ValueError("lower_freq should be lower than upper_freq")
if lower_freq < 0:
raise ValueError("Expected positive lower_freq.")
if output not in ['Complex', 'Magnitude']:
raise ValueError("Expected output to be 'Complex' or 'Magnitude'.")


self.initial_wavelet_width = wavelet_width
self.fs = fs
Expand All @@ -180,39 +186,38 @@ def __init__(
trainable=self.trainable,
name='wavelet_width',
dtype=tf.float32)
super().__init__(n_scales, border_crop, stride, name)
super().__init__(n_scales, border_crop, stride, output)

def _build_wavelet_bank(self):
with tf.variable_scope("cmorlet_bank"):
# Generate the wavelets
# We will make a bigger wavelet in case the width grows
# For the size of the wavelet we use the initial width value.
# |t| < truncation_size => |k| < truncation_size * fs
truncation_size = self.scales.max() * np.sqrt(4.5 * self.initial_wavelet_width) * self.fs
one_side = int(self.size_factor * truncation_size)
kernel_size = 2 * one_side + 1
k_array = np.arange(kernel_size, dtype=np.float32) - one_side
t_array = k_array / self.fs # Time units
# Wavelet bank shape: 1, kernel_size, 1, n_scales
wavelet_bank_real = []
wavelet_bank_imag = []
for scale in self.scales:
norm_constant = tf.sqrt(np.pi * self.wavelet_width) * scale * self.fs / 2.0
scaled_t = t_array / scale
exp_term = tf.exp(-(scaled_t ** 2) / self.wavelet_width)
kernel_base = exp_term / norm_constant
kernel_real = kernel_base * np.cos(2 * np.pi * scaled_t)
kernel_imag = kernel_base * np.sin(2 * np.pi * scaled_t)
wavelet_bank_real.append(kernel_real)
wavelet_bank_imag.append(kernel_imag)
# Stack wavelets (shape = kernel_size, n_scales)
wavelet_bank_real = tf.stack(wavelet_bank_real, axis=-1)
wavelet_bank_imag = tf.stack(wavelet_bank_imag, axis=-1)
# Give it proper shape for convolutions
# -> shape: 1, kernel_size, n_scales
wavelet_bank_real = tf.expand_dims(wavelet_bank_real, axis=0)
wavelet_bank_imag = tf.expand_dims(wavelet_bank_imag, axis=0)
# -> shape: 1, kernel_size, 1, n_scales
wavelet_bank_real = tf.expand_dims(wavelet_bank_real, axis=2)
wavelet_bank_imag = tf.expand_dims(wavelet_bank_imag, axis=2)
# Generate the wavelets
# We will make a bigger wavelet in case the width grows
# For the size of the wavelet we use the initial width value.
# |t| < truncation_size => |k| < truncation_size * fs
truncation_size = self.scales.max() * np.sqrt(4.5 * self.initial_wavelet_width) * self.fs
one_side = int(self.size_factor * truncation_size)
kernel_size = 2 * one_side + 1
k_array = np.arange(kernel_size, dtype=np.float32) - one_side
t_array = k_array / self.fs # Time units
# Wavelet bank shape: 1, kernel_size, 1, n_scales
wavelet_bank_real = []
wavelet_bank_imag = []
for scale in self.scales:
norm_constant = tf.sqrt(np.pi * self.wavelet_width) * scale * self.fs / 2.0
scaled_t = t_array / scale
exp_term = tf.exp(-(scaled_t ** 2) / self.wavelet_width)
kernel_base = exp_term / norm_constant
kernel_real = kernel_base * np.cos(2 * np.pi * scaled_t)
kernel_imag = kernel_base * np.sin(2 * np.pi * scaled_t)
wavelet_bank_real.append(kernel_real)
wavelet_bank_imag.append(kernel_imag)
# Stack wavelets (shape = kernel_size, n_scales)
wavelet_bank_real = tf.stack(wavelet_bank_real, axis=-1)
wavelet_bank_imag = tf.stack(wavelet_bank_imag, axis=-1)
# Give it proper shape for convolutions
# -> shape: 1, kernel_size, n_scales
wavelet_bank_real = tf.expand_dims(wavelet_bank_real, axis=0)
wavelet_bank_imag = tf.expand_dims(wavelet_bank_imag, axis=0)
# -> shape: 1, kernel_size, 1, n_scales
wavelet_bank_real = tf.expand_dims(wavelet_bank_real, axis=2)
wavelet_bank_imag = tf.expand_dims(wavelet_bank_imag, axis=2)
return wavelet_bank_real, wavelet_bank_imag
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

A TensorFlow implementation of the Continuous Wavelet Transform obtained via the complex Morlet wavelet. Please see the demo Jupyter Notebook for usage demonstration and more implementation details.

This implementation is aimed to leverage GPU acceleration for the computation of the CWT in TensorFlow models. The morlet's wavelet width can be set as a trainable parameter if you want to adjust it via backprop. Please note that this implementation was made before TensorFlow 2, so you need TensorFlow 1 (i.e. tf 1.x).
This implementation is aimed to leverage GPU acceleration for the computation of the CWT in TensorFlow models. The morlet's wavelet width can be set as a trainable parameter if you want to adjust it via backprop. This implementation now supports TensorFlow 2.

This module was used to obtain the CWT of EEG signals for the RED-CWT model, described in:

Expand Down
611 changes: 1 addition & 610 deletions cwt_demo.ipynb

Large diffs are not rendered by default.

11 changes: 11 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import setuptools

setuptools.setup(
name="CWT",
version=2,
description="A tensorflow 2.0 Continuous Wavelet Transform",
long_description=open('README.md').read(),
packages=['CWT'],
install_requires=['numpy', 'tensorflow'],
python_requires='>=3.6',
)