Skip to content

Commit

Permalink
Merge 804f208 into a41612d
Browse files Browse the repository at this point in the history
  • Loading branch information
Duncan Macleod committed Feb 21, 2018
2 parents a41612d + 804f208 commit f3cec73
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 121 deletions.
42 changes: 29 additions & 13 deletions gwpy/signal/fft/lal.py
Expand Up @@ -31,6 +31,7 @@
import numpy

from ...frequencyseries import FrequencySeries
from ..window import canonical_name
from .utils import scale_timeseries_unit
from . import registry as fft_registry

Expand Down Expand Up @@ -86,7 +87,7 @@ def generate_fft_plan(length, level=None, dtype='float64', forward=True):
return LAL_FFTPLANS[key]


def generate_window(length, window=('kaiser', 24), dtype='float64'):
def generate_window(length, window=None, dtype='float64'):
"""Generate a time-domain window for use in a LAL FFT
Parameters
Expand All @@ -110,6 +111,9 @@ def generate_window(length, window=('kaiser', 24), dtype='float64'):
import lal
from ...utils.lal import LAL_TYPE_STR_FROM_NUMPY

if window is None:
window = ('kaiser', 24)

# generate key for caching window
laltype = LAL_TYPE_STR_FROM_NUMPY[numpy.dtype(dtype).type]
key = (length, str(window), laltype)
Expand All @@ -121,17 +125,31 @@ def generate_window(length, window=('kaiser', 24), dtype='float64'):
except KeyError:
# parse window as name and arguments, e.g. ('kaiser', 24)
if isinstance(window, (list, tuple)):
args = window[1:]
window = str(window[0])
window, beta = window
else:
args = []
window = window.title() if window.islower() else window
beta = 0
window = canonical_name(window)
# create window
create = getattr(lal, 'Create%s%sWindow' % (window, laltype))
LAL_WINDOWS[key] = create(length, *args)
create = getattr(lal, 'CreateNamed{}Window'.format(laltype))
LAL_WINDOWS[key] = create(window, beta, length)
return LAL_WINDOWS[key]


def window_from_array(array):
"""Convert a `numpy.ndarray` into a LAL `Window` object
"""
import lal
from ...utils.lal import LAL_TYPE_STR_FROM_NUMPY
laltype = LAL_TYPE_STR_FROM_NUMPY[array.dtype.type]

# create sequence
seq = getattr(lal, 'Create{}Sequence'.format(laltype))(array.size)
seq.data = array

# create window from sequence
return getattr(lal, 'Create{}WindowFromSequence'.format(laltype))(seq)


# -- spectrumm methods ------------------------------------------------------

def _lal_spectrum(timeseries, segmentlength, noverlap=None, method='welch',
Expand All @@ -152,10 +170,10 @@ def _lal_spectrum(timeseries, segmentlength, noverlap=None, method='welch',
noverlap : `int`
number of samples to overlap between segments, defaults to 50%.
window : `tuple`, `str`, optional
window parameters to apply to timeseries prior to FFT
window : `lal.REAL8Window`, optional
window to apply to timeseries prior to FFT
plan : `REAL8FFTPlan`, optional
plan : `lal.REAL8FFTPlan`, optional
LAL FFT plan to use when generating average spectrum
Returns
Expand All @@ -174,9 +192,7 @@ def _lal_spectrum(timeseries, segmentlength, noverlap=None, method='welch',
# get window
if window is None:
window = generate_window(segmentlength, dtype=timeseries.dtype)
elif isinstance(window, (tuple, str)):
window = generate_window(segmentlength, window=window,
dtype=timeseries.dtype)

# get FFT plan
if plan is None:
plan = generate_fft_plan(segmentlength, dtype=timeseries.dtype)
Expand Down
66 changes: 35 additions & 31 deletions gwpy/signal/fft/ui.py
Expand Up @@ -70,7 +70,7 @@ def seconds_to_samples(x, rate):
return int((Quantity(x, 's') * rate).decompose().value)


def normalize_fft_params(series, kwargs=None):
def normalize_fft_params(series, kwargs=None, library=None):
"""Normalize a set of FFT parameters for processing
This method reads the ``fftlength`` and ``overlap`` keyword arguments
Expand All @@ -89,6 +89,10 @@ def normalize_fft_params(series, kwargs=None):
kwargs : `dict`
the dict of keyword arguments passed by the user
library: `str`, optional
the name of the library that provides the FFT methods, e.g.
'scipy'
Examples
--------
>>> from numpy.random import normal
Expand All @@ -104,24 +108,44 @@ def normalize_fft_params(series, kwargs=None):
samp = series.sample_rate
fftlength = kwargs.pop('fftlength', None)
overlap = kwargs.pop('overlap', None)
window = kwargs.pop('window', None)

# get canonical window name
if isinstance(kwargs.get('window', None), str):
kwargs['window'] = canonical_name(kwargs['window'])
if isinstance(window, str):
window = canonical_name(window)

# fftlength -> nfft
if fftlength is None:
fftlength = series.duration
nfft = seconds_to_samples(fftlength, samp)

# overlap -> noverlap
if overlap is None and isinstance(kwargs.get('window', None), str):
noverlap = recommended_overlap(kwargs['window'], nfft)
if overlap is None and isinstance(window, str):
noverlap = recommended_overlap(window, nfft)
elif overlap is None:
noverlap = 0
else:
noverlap = seconds_to_samples(overlap, samp)

# create window
if library == 'lal' and isinstance(window, numpy.ndarray):
from .lal import window_from_array
window = window_from_array(window)
elif library == 'lal':
from .lal import generate_window
window = generate_window(nfft, window=window, dtype=series.dtype)
elif isinstance(window, (str, tuple)):
window = get_window(window, nfft)

# allow FFT methods to use their own defaults
if window is not None:
kwargs['window'] = window

# create FFT plan for LAL
if library == 'lal' and kwargs.get('plan', None):
from .lal import generate_fft_plan
kwargs['plan'] = generate_fft_plan(nfft, dtype=series.dtype)

kwargs.update({
'nfft': nfft,
'noverlap': noverlap,
Expand All @@ -141,8 +165,9 @@ def wrapped_func(series, method_func, *args, **kwargs):
else:
data = series

# extract parameters in seconds, setting recommended default overlap
normalize_fft_params(data, kwargs)
# normalise FFT parmeters for all libraries
library = method_func.__module__.rsplit('.', 1)[-1]
normalize_fft_params(data, kwargs=kwargs, library=library)

return func(series, method_func, *args, **kwargs)

Expand Down Expand Up @@ -217,10 +242,10 @@ def average_spectrogram(timeseries, method_func, stride, *args, **kwargs):
epoch = timeseries.t0.value
nstride = seconds_to_samples(stride, timeseries.sample_rate)
kwargs['fftlength'] = kwargs.pop('fftlength', stride) or stride
normalize_fft_params(timeseries, kwargs)
normalize_fft_params(timeseries, kwargs=kwargs,
library=method_func.__module__.rsplit('.', 1)[-1])
nfft = kwargs['nfft']
noverlap = kwargs['noverlap']
window = kwargs.pop('window', None)

# sanity check parameters
if nstride > timeseries.size:
Expand All @@ -231,21 +256,6 @@ def average_spectrogram(timeseries, method_func, stride, *args, **kwargs):
if noverlap >= nfft:
raise ValueError("overlap must be less than fftlength")

# generate windows and FFT plans up-front
if method_func.__module__.endswith('.lal'):
from .lal import (generate_fft_plan, generate_window)
if isinstance(window, (str, tuple)):
window = generate_window(nfft, window=window,
dtype=timeseries.dtype)
if kwargs.get('plan', None) is None:
kwargs['plan'] = generate_fft_plan(nfft, dtype=timeseries.dtype)
else:
if isinstance(window, (str, tuple)):
window = get_window(window, nfft)
# don't operate on None, let the method_func work out its own defaults
if window is not None:
kwargs['window'] = window

# set up single process Spectrogram method
def _psd(series):
"""Calculate a single PSD for a spectrogram
Expand Down Expand Up @@ -293,18 +303,12 @@ def spectrogram(timeseries, method_func, **kwargs):
if noverlap >= nfft:
raise ValueError("overlap must be less than fftlength")

# get window once (if given)
window = kwargs.pop('window', None) or 'hann'
if isinstance(window, (str, tuple)):
window = get_window(window, nfft)

# set up single process Spectrogram method
def _psd(series):
"""Calculate a single PSD for a spectrogram
"""
try:
return method_func(series, nfft=nfft, window=window,
**kwargs)[1]
return method_func(series, nfft=nfft, **kwargs)[1]
except Exception as exc: # pylint: disable=broad-except
if nproc == 1:
raise
Expand Down

0 comments on commit f3cec73

Please sign in to comment.