Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Magnitude #2

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CWT/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
__all__ = ['cwt']
__version__ = "0.2.0"
135 changes: 71 additions & 64 deletions cwt.py → CWT/cwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
import numpy as np
import tensorflow as tf


class ContinuousWaveletTransform(object):
"""CWT layer implementation in Tensorflow for GPU acceleration."""
def __init__(self, n_scales, border_crop=0, stride=1, name="cwt"):
def __init__(self, n_scales, border_crop=0, stride=1, name="cwt", output='Complex' ):
"""
Args:
n_scales: (int) Number of scales for the scalogram.
Expand All @@ -26,8 +25,8 @@ def __init__(self, n_scales, border_crop=0, stride=1, name="cwt"):
self.border_crop = border_crop
self.stride = stride
self.name = name
with tf.variable_scope(self.name):
self.real_part, self.imaginary_part = self._build_wavelet_bank()
self.output = output
self.real_part, self.imaginary_part = self._build_wavelet_bank()

def _build_wavelet_bank(self):
"""Needs implementation to compute the real and imaginary parts
Expand All @@ -36,7 +35,8 @@ def _build_wavelet_bank(self):
real_part = None
imaginary_part = None
return real_part, imaginary_part


@tf.function
def __call__(self, inputs):
"""
Computes the CWT with the specified wavelet bank.
Expand All @@ -57,33 +57,37 @@ def __call__(self, inputs):
border_crop = int(self.border_crop / self.stride)
start = border_crop
end = (-border_crop) if (border_crop > 0) else None
with tf.variable_scope(self.name):
# Input has expected shape of [batch_size, time_len, n_channels]
# We first unstack the input channels
inputs_unstacked = tf.unstack(inputs, axis=2)
multi_channel_cwt = []
for j, single_channel in enumerate(inputs_unstacked):
# Reshape input [batch, time_len] -> [batch, 1, time_len, 1]
inputs_expand = tf.expand_dims(single_channel, axis=1)
inputs_expand = tf.expand_dims(inputs_expand, axis=3)
with tf.name_scope('%s_%d' % (self.name, j)):
bank_real = self.real_part
bank_imag = -self.imaginary_part # Conjugation
out_real = tf.nn.conv2d(
input=inputs_expand, filter=bank_real,
strides=[1, 1, self.stride, 1], padding="SAME")
out_imag = tf.nn.conv2d(
input=inputs_expand, filter=bank_imag,
strides=[1, 1, self.stride, 1], padding="SAME")
out_real_crop = out_real[:, :, start:end, :]
out_imag_crop = out_imag[:, :, start:end, :]
out_concat = tf.concat(
[out_real_crop, out_imag_crop], axis=1)
# [batch, 2, time, n_scales]->[batch, time, n_scales, 2]
single_scalogram = tf.transpose(
out_concat, perm=[0, 2, 3, 1])
multi_channel_cwt.append(single_scalogram)
# Input has expected shape of [batch_size, time_len, n_channels]
# We first unstack the input channels
inputs_unstacked = tf.unstack(inputs, axis=2)
multi_channel_cwt = []
for j, single_channel in enumerate(inputs_unstacked):
# Reshape input [batch, time_len] -> [batch, 1, time_len, 1]
inputs_expand = tf.expand_dims(single_channel, axis=1)
inputs_expand = tf.expand_dims(inputs_expand, axis=3)
bank_real = self.real_part
bank_imag = -self.imaginary_part # Conjugation
out_real = tf.nn.conv2d(
input=inputs_expand, filters=bank_real,
strides=[1, 1, self.stride, 1], padding="SAME")
out_imag = tf.nn.conv2d(
input=inputs_expand, filters=bank_imag,
strides=[1, 1, self.stride, 1], padding="SAME")
out_real_crop = out_real[:, :, start:end, :]
out_imag_crop = out_imag[:, :, start:end, :]
out_mag_crop = tf.sqrt(out_real_crop**2 + out_imag_crop**2)

if self.output == 'Magnitude':
out_concat = out_mag_crop
else:
out_concat = tf.concat([out_real_crop, out_imag_crop], axis=1)

# [batch, :, time, n_scales]->[batch, time, n_scales, :]
single_scalogram = tf.transpose(
a=out_concat, perm=[0, 2, 3, 1])
multi_channel_cwt.append(single_scalogram)
# Get all in shape [batch, time_len, n_scales, 2*n_channels]
# or if output='Magnitude [batch, time_len, n_scales, 2*n_channels]
scalograms = tf.concat(multi_channel_cwt, -1)
return scalograms

Expand All @@ -101,6 +105,7 @@ def __init__(
trainable=False,
border_crop=0,
stride=1,
output='Complex' ,
name="cwt"):
"""
Computes the complex morlet wavelets
Expand Down Expand Up @@ -159,6 +164,9 @@ def __init__(
raise ValueError("lower_freq should be lower than upper_freq")
if lower_freq < 0:
raise ValueError("Expected positive lower_freq.")
if output not in ['Complex', 'Magnitude']:
raise ValueError("Expected output to be 'Complex' or 'Magnitude'.")


self.initial_wavelet_width = wavelet_width
self.fs = fs
Expand All @@ -180,39 +188,38 @@ def __init__(
trainable=self.trainable,
name='wavelet_width',
dtype=tf.float32)
super().__init__(n_scales, border_crop, stride, name)
super().__init__(n_scales, border_crop, stride, name, output)

def _build_wavelet_bank(self):
with tf.variable_scope("cmorlet_bank"):
# Generate the wavelets
# We will make a bigger wavelet in case the width grows
# For the size of the wavelet we use the initial width value.
# |t| < truncation_size => |k| < truncation_size * fs
truncation_size = self.scales.max() * np.sqrt(4.5 * self.initial_wavelet_width) * self.fs
one_side = int(self.size_factor * truncation_size)
kernel_size = 2 * one_side + 1
k_array = np.arange(kernel_size, dtype=np.float32) - one_side
t_array = k_array / self.fs # Time units
# Wavelet bank shape: 1, kernel_size, 1, n_scales
wavelet_bank_real = []
wavelet_bank_imag = []
for scale in self.scales:
norm_constant = tf.sqrt(np.pi * self.wavelet_width) * scale * self.fs / 2.0
scaled_t = t_array / scale
exp_term = tf.exp(-(scaled_t ** 2) / self.wavelet_width)
kernel_base = exp_term / norm_constant
kernel_real = kernel_base * np.cos(2 * np.pi * scaled_t)
kernel_imag = kernel_base * np.sin(2 * np.pi * scaled_t)
wavelet_bank_real.append(kernel_real)
wavelet_bank_imag.append(kernel_imag)
# Stack wavelets (shape = kernel_size, n_scales)
wavelet_bank_real = tf.stack(wavelet_bank_real, axis=-1)
wavelet_bank_imag = tf.stack(wavelet_bank_imag, axis=-1)
# Give it proper shape for convolutions
# -> shape: 1, kernel_size, n_scales
wavelet_bank_real = tf.expand_dims(wavelet_bank_real, axis=0)
wavelet_bank_imag = tf.expand_dims(wavelet_bank_imag, axis=0)
# -> shape: 1, kernel_size, 1, n_scales
wavelet_bank_real = tf.expand_dims(wavelet_bank_real, axis=2)
wavelet_bank_imag = tf.expand_dims(wavelet_bank_imag, axis=2)
# Generate the wavelets
# We will make a bigger wavelet in case the width grows
# For the size of the wavelet we use the initial width value.
# |t| < truncation_size => |k| < truncation_size * fs
truncation_size = self.scales.max() * np.sqrt(4.5 * self.initial_wavelet_width) * self.fs
one_side = int(self.size_factor * truncation_size)
kernel_size = 2 * one_side + 1
k_array = np.arange(kernel_size, dtype=np.float32) - one_side
t_array = k_array / self.fs # Time units
# Wavelet bank shape: 1, kernel_size, 1, n_scales
wavelet_bank_real = []
wavelet_bank_imag = []
for scale in self.scales:
norm_constant = tf.sqrt(np.pi * self.wavelet_width) * scale * self.fs / 2.0
scaled_t = t_array / scale
exp_term = tf.exp(-(scaled_t ** 2) / self.wavelet_width)
kernel_base = exp_term / norm_constant
kernel_real = kernel_base * np.cos(2 * np.pi * scaled_t)
kernel_imag = kernel_base * np.sin(2 * np.pi * scaled_t)
wavelet_bank_real.append(kernel_real)
wavelet_bank_imag.append(kernel_imag)
# Stack wavelets (shape = kernel_size, n_scales)
wavelet_bank_real = tf.stack(wavelet_bank_real, axis=-1)
wavelet_bank_imag = tf.stack(wavelet_bank_imag, axis=-1)
# Give it proper shape for convolutions
# -> shape: 1, kernel_size, n_scales
wavelet_bank_real = tf.expand_dims(wavelet_bank_real, axis=0)
wavelet_bank_imag = tf.expand_dims(wavelet_bank_imag, axis=0)
# -> shape: 1, kernel_size, 1, n_scales
wavelet_bank_real = tf.expand_dims(wavelet_bank_real, axis=2)
wavelet_bank_imag = tf.expand_dims(wavelet_bank_imag, axis=2)
return wavelet_bank_real, wavelet_bank_imag
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

A TensorFlow implementation of the Continuous Wavelet Transform obtained via the complex Morlet wavelet. Please see the demo Jupyter Notebook for usage demonstration and more implementation details.

This implementation is aimed to leverage GPU acceleration for the computation of the CWT in TensorFlow models. The morlet's wavelet width can be set as a trainable parameter if you want to adjust it via backprop. Please note that this implementation was made before TensorFlow 2, so you need TensorFlow 1 (i.e. tf 1.x).
This implementation is aimed to leverage GPU acceleration for the computation of the CWT in TensorFlow models. The morlet's wavelet width can be set as a trainable parameter if you want to adjust it via backprop. This implementation now supports TensorFlow 2.

This module was used to obtain the CWT of EEG signals for the RED-CWT model, described in:

Expand Down
611 changes: 1 addition & 610 deletions cwt_demo.ipynb

Large diffs are not rendered by default.

11 changes: 11 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import setuptools

setuptools.setup(
name="CWT",
version=2,
description="A tensorflow 2.0 Continuous Wavelet Transform",
long_description=open('README.md').read(),
packages=['CWT'],
install_requires=['numpy', 'tensorflow'],
python_requires='>=3.6',
)