Skip to content

Commit

Permalink
API cwt_morlet() defaults: zero_mean=True, output='complex'
Browse files Browse the repository at this point in the history
  • Loading branch information
christianbrodbeck committed Sep 28, 2021
1 parent 7d4abcf commit f97eb23
Showing 1 changed file with 31 additions and 22 deletions.
53 changes: 31 additions & 22 deletions eelbrain/_ndvar.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,36 +410,46 @@ def cross_correlation(in1, in2, name=None):
return NDVar(x_corr, (time,), *op_name(in1, '*', in2, merge_info((in1, in2)), name))


def cwt_morlet(y, freqs, use_fft=True, n_cycles=3.0, zero_mean=False,
out='magnitude', decim=1):
def cwt_morlet(
y: NDVar,
frequencies: Sequence[float],
use_fft: bool = True,
n_cycles: Union[float, Sequence[float]] = 3.0,
zero_mean: bool = True,
output: Literal['complex', 'power', 'phase', 'magnitude'] = 'complex',
decim: int = 1,
) -> NDVar:
"""Time frequency decomposition with Morlet wavelets (mne-python)
Parameters
----------
y : NDVar with time dimension
Signal.
freqs : scalar | array
Frequency/ies of interest. For a scalar, the output will not contain a
y
Input signal.
frequencies
Frequencies of interest. For a scalar, the output will not contain a
frequency dimension.
use_fft : bool
use_fft
Compute convolution with FFT or temporal convolution.
n_cycles: float | array of float
n_cycles
Number of cycles. Fixed number or one per frequency.
zero_mean : bool
zero_mean
Make sure the wavelets are zero mean.
out : 'complex' | 'magnitude'
Format of the data in the returned NDVar.
output
Format of the data in the returned NDVar. Default is the complex eavelet
transform.
decim
Decimate the time axis by this factor.
Returns
-------
tfr : NDVar
tfr
Time frequency decompositions.
"""
if out == 'magnitude':
if output == 'magnitude':
magnitude_out = True
out = 'power'
elif out not in ('complex', 'phase', 'power'):
raise ValueError("out=%r" % (out,))
output = 'power'
elif output not in ('complex', 'phase', 'power'):
raise ValueError(f"{output=}")
else:
magnitude_out = False
dimnames = y.get_dimnames(last='time')
Expand All @@ -449,15 +459,14 @@ def cwt_morlet(y, freqs, use_fft=True, n_cycles=3.0, zero_mean=False,
data_flat = data.reshape((1, shape_outer, data.shape[-1]))
time_dim = dims[-1]
sfreq = 1. / time_dim.tstep
if np.isscalar(freqs):
freqs = [freqs]
if np.isscalar(frequencies):
frequencies = [frequencies]
fdim = None
else:
fdim = Scalar("frequency", freqs, 'Hz')
freqs = fdim.values
fdim = Scalar("frequency", frequencies, 'Hz', '%.0f')
frequencies = fdim.values

x_flat = mne.time_frequency.tfr_array_morlet(
data_flat, sfreq, freqs, n_cycles, zero_mean, use_fft, decim, out)
x_flat = mne.time_frequency.tfr_array_morlet(data_flat, sfreq, frequencies, n_cycles, zero_mean, use_fft, decim, output)

out_shape = list(data.shape)
out_dims = list(dims)
Expand Down

0 comments on commit f97eb23

Please sign in to comment.