In [None]:
import numpy as np

from scipy.signal import welch, get_window
from obspy.signal.filter import bandpass

from obspy.clients.filesystem.sds import Client
from obspy.clients.fdsn import RoutingClient
from obspy.core import UTCDateTime as UTC
from obspy.signal import util

import h5py

In [None]:
import logging
logger = logging.getLogger('notebook')
logger.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)  # set level
cformatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s',
                            datefmt='%y-%m-%d %H:%M:%S')
ch.setFormatter(cformatter)
logger.addHandler(ch)

In [None]:
import matplotlib.pyplot as plt
plt.style.use('tableau-colorblind10')
#%matplotlib widget

### Processing steps

- **remove_sensitivity**: convert from counts to m/s. Also converts data type from `np.int32` to `np.float64`. Wih the latter you can use `np.nan` with facilitates dealing with data gaps.
- **merge**: if data contains gaps, the stream consists of multiple traces. We want 1 single trace with gaps represented as np.nan
- **trim** with padding: if data don't start/end at requested times, by default, trace is shortened. However, we want traces to be always the exact lenght of the requested interval, so use `pad=True` and `fill_value=np.nan` to fill up the trace.
-

In [None]:

def process(tr, winlen_in_s, nperseg, fmin, fmax):
    """Obsolete"""
    
    # Convert data into matrix (n_slices, npts)
    sr = tr.stats.sampling_rate
    nwin = int(winlen_in_s * sr)
    
    #win = get_window('boxcar', nwin, fftbins=False)
    #matrix, nx, ny  = util.enframe(tr.data, win, win.size)
    
    # Get number of frames
    nf = tr.data.size // nwin
    nsize = nf*nwin
    matrix = np.reshape(tr.data[:nsize], (nf, nwin))
    
    #matrix = matrix[1:-1]
    freq, P = welch(matrix, fs=sr, nperseg=nperseg, axis=1)
    
    for i, r in enumerate(matrix):
        matrix[i,:] = bandpass(r, fmin, fmax, sr)
    
    prctl = np.nanpercentile(matrix, 75, axis=1)
    return prctl, freq, P



def get_data(dataclient, starttime, overlap, proclen,
            network, station, location, channel, inv):
    
    starttime = starttime - overlap
    endtime = starttime + proclen + 2*overlap
    st = dataclient.get_waveforms(network, station, 
                             location, channel,starttime, endtime)
    st.remove_sensitivity(inv)
    st.merge(fill_value=np.nan)
    st.trim(starttime, endtime, pad=True, fill_value=np.nan, 
            nearest_sample=False)
    if len(st) > 1:
        raise RuntimeWarning("More than 1 trace in stream!")
    return st[0]    
    


def iterate_database(dataclient, invclient, network, station, location, 
                     channel, startdate, enddate, 
                    overlap, winlen_in_s, nperseg, fmin, fmax,
                    proclen=24*3600,):
    """
    Iterate day-wise over data in sds data base.
    """
    AMP = []
    PXX = []
    starttime = startdate-overlap
    inv = invclient.get_stations(network=network, station=station, 
                                 location=location, starttime=startdate, 
                                 endtime=enddate, channel=channel, 
                                 level='response')
    while starttime < enddate - overlap:
        logger.debug("%s" % starttime)
        tr = get_data(dataclient, starttime+overlap, overlap, proclen,
                        network, station, location, channel, inv)
        
        # Demean ignoring gaps
        tr.data = tr.data - np.nanmean(tr.data)
        
        # Get some numbers
        sr = tr.stats.sampling_rate
        nf = int(proclen/winlen_in_s)
        #proclen_samples = proclen * sr
        winlen_samples = int(winlen_in_s * sr)
        
        # Spectra
        data = get_adjacent_frames(tr, starttime+overlap, nf, 
                                   winlen_samples)
        freq, P = welch(data, fs=sr, nperseg=nperseg, axis=1)
            
        # Amplitude
        prctl = get_amplitude(tr, starttime+overlap, fmin, fmax,
                              overlap, winlen_samples, nf)
        
        AMP.append(prctl) #amp[1:-1])
        PXX.append(P) # pxx[1:-1,:])

        starttime = starttime + proclen
    
    return np.array(AMP), np.array(PXX), freq


def get_adjacent_frames(tr, starttime, nf, winlen_samples):
    #print(starttime)
    #print(tr)
    ntot = int(nf*winlen_samples)
    data = tr.slice(starttime, endtime=None).data[:ntot]
    return data.reshape((nf, winlen_samples))


def get_overlapping_tapered_frames(tr, starttime, nf, winlen_samples,
                           taper_samples):
    """
    Splits the vector up into (overlapping) Tukey windows.
    
    Frames containing any Nans are set entirely to Nan.
    Loosely based on `obspy.signal.util.enframe`
    """
    sr = tr.stats.sampling_rate
    
    # Samples in window including tapers
    nwin = int(winlen_samples + 2*taper_samples)
    
    # Total number of samples of trace to process
    proclen_samples = int(nf * winlen_samples + 2*taper_samples)
    
    # Cut out the needed data
    x = tr.slice(starttime-taper_samples/sr).data[:proclen_samples]
    
    # Ratio of tapers to total window size
    a =  2*taper_samples / nwin
    win = get_window(('tukey', a), nwin, fftbins=False)
    
    # From obspy.signal.enframe()
    #nx = len(x)
    #nwin = len(win)
    if (len(win) == 1):
        length = win
    else:
        length = nwin
    #nf = int(np.fix((nx - length + winlen_samples) // winlen_samples))
    # f = np.zeros((nf, length))
    indf = winlen_samples * np.arange(nf)
    f = x[np.expand_dims(indf, 1) + 
          np.expand_dims(np.arange(length), 0)]
    f = f * win
    f[np.any(np.isnan(f), axis=1),:] = np.nan
    #no_win, _ = f.shape
    return f, taper_samples



def get_amplitude(tr, starttime, fmin, fmax, overlap, winlen_samples, nf):
    """
    If trace is free of nans, we can simply filter the whole trace at 
    once and use reshape to get the windows.
    If there are nans, the obspy filter function only filters
    the data up to the first nan. Thus you can loose almost the entire
    trace because of a single nan in the early part. To avoid this, we
    check for nans and if there are some, we filter the data per
    window. This means however, we have to ensure again some overlap
    and the windowing will be slower.
    """
    sr = tr.stats.sampling_rate
    
    if np.any(np.isnan(tr.data)):
        logger.info('Found nans in %s' % tr)
        taper_samples = int(overlap*sr)
        
        data, taper_samples = get_overlapping_tapered_frames(tr, 
                            starttime, nf, winlen_samples,
                           taper_samples)
        data = bandpass(data, fmin, fmax, sr)
        data = data[:,taper_samples:-taper_samples]
                                    
    else:
        tr.filter('bandpass', freqmin=fmin, freqmax=fmax)
        data = get_windows(tr, starttime+overlap, nf, winlen_samples)
        
    prctl = np.nanpercentile(data, 75, axis=1)
    return prctl

In [None]:
network = 'GR'
station = 'BFO'
location = ''
channel = 'HHZ'
startdate = UTC("2020-01-01")
enddate = UTC("2020-02-01")
overlap = 300 #3600

fmin, fmax = (4, 14)
nperseg = 2048
winlen_in_s = 3600
proclen = 24*3600

sdsclient = Client('/home/lehr/sds/data')
invclient = RoutingClient('eida-routing')

In [None]:
%%time
AMP, PXX, freq = iterate_database(sdsclient, invclient,
                    network, station, location, channel, startdate, enddate,
                    overlap, winlen_in_s, nperseg, fmin, fmax, proclen=proclen)

In [None]:
inv = invclient.get_stations(network=network, station=station, 
                                 location=location, starttime=startdate, 
                                 endtime=enddate, channel=channel, 
                                 level='response')

In [None]:
starttime = UTC("2020-336")
overlap = 600
tr = get_data(sdsclient, starttime, overlap, proclen,
            network, station, location, channel, inv)

In [None]:
tr

In [None]:
np.where(np.isnan(tr.data))

In [None]:
tr.plot(show=False)

In [None]:
tr.data = tr.data - np.nanmean(tr.data)
# Get some numbers
sr = tr.stats.sampling_rate
nf = int(proclen/winlen_in_s)
#proclen_samples = proclen * sr
winlen_samples = int(winlen_in_s * sr)

In [None]:
# Spectra
data = get_windows(tr, starttime+overlap, nf, winlen_samples)

In [None]:
plt.imshow(data, aspect='auto')

In [None]:
scale = 1e6
for i, row in enumerate(data):
    plt.plot(row*scale + i)

In [None]:
x = data[:,250000]
plt.plot(x, 'o-')
print(x)

In [None]:
np.any(np.isnan(data), axis=1)

In [None]:
freq, P = welch(data, fs=sr, nperseg=nperseg, axis=1)

In [None]:
plt.imshow(np.log(P), aspect='auto')

In [None]:
i, j = np.where(np.isnan(P))
print(np.unique(i))

In [None]:
prctl = get_amplitude(tr, starttime+overlap, fmin, fmax,
                              overlap, winlen_samples, nf)

In [None]:
plt.plot(prctl)

In [None]:
np.where(np.isnan(prctl))

In [None]:
x = np.random.rand(1000)
x[10] = np.nan
f, p = welch(x)

In [None]:
p

In [None]:
plt.plot(f, p)

In [None]:
for a in [PXX, AMP]:
    print(a.__sizeof__()/1e6)

In [None]:
for i, row in enumerate(AMP):
    #print(row.shape)z
    if row.shape[0] != 24:
        print(UTC("2020-{:03d}".format(i+1)), row.shape)
        #AMP[i] = np.append(row, np.nan)

In [None]:
AMP = np.asarray(AMP)
AMP.shape

In [None]:
plt.matshow(AMP.T)

In [None]:
np.any(np.isnan(AMP))

In [None]:
nanpos_amp = np.where(np.isnan(AMP))

In [None]:
nanhours = nanpos_amp[0]*24 + nanpos_amp[1]

### Check Spectra

In [None]:
PXX = np.array(PXX)

In [None]:
PXX.shape

In [None]:
plt.matshow(np.log(PXX.reshape((366*24, 1025))).T)

In [None]:
plt.matshow(np.log(PXX.reshape((366*24, 1025)))[nanhours,:])

In [None]:
nanpos_amp

In [None]:
np.where(~np.isnan(PXX.reshape((366*24, 1025))[nanhours,:]))

In [None]:
PXX.reshape((366*24, 1025))[nanhours,:][28,:]

In [None]:
nanhours

In [None]:
nanhours[28] // 24

In [None]:
network = 'GR'
station = 'BFO'
location = ''
channel = 'HHZ'
startdate = UTC("2020-108")
enddate = UTC("2020-109")
overlap = 600 #3600

fmin, fmax = (4, 14)
nperseg = 2048
winlen_in_s = 3600
proclen = 24*3600

dataclient = Client('/home/lehr/sds/data')
invclient = RoutingClient('eida-routing')

In [None]:
starttime = startdate-overlap
inv = invclient.get_stations(network=network, station=station, 
                             location=location, starttime=startdate, 
                             endtime=enddate, channel=channel, 
                             level='response')
while starttime < enddate - overlap:
    print(starttime)
    endtime = starttime + proclen + 2*overlap
    st = dataclient.get_waveforms(network, station, 
                             location, channel,starttime, endtime)
    st.remove_sensitivity(inv)
    st.merge(fill_value=np.nan)
    if len(st) > 1:
        raise RuntimeWarning("More than 1 trace in stream!")
    tr = st[0]    

    # Demean ignoring gaps
    tr.data = tr.data - np.nanmean(tr.data)
    
    break
    # Get some numbers
    sr = tr.stats.sampling_rate
    nf = int(proclen/winlen_in_s)
    #proclen_samples = proclen * sr
    winlen_samples = int(winlen_in_s * sr)

    # Spectra
    data = get_windows(tr, starttime+overlap, nf, winlen_samples)
    freq, P = welch(data, fs=sr, nperseg=nperseg, axis=1)

    # Amplitude
    tr.filter('bandpass', freqmin=fmin, freqmax=fmax)
    data = get_windows(tr, starttime+overlap, nf, winlen_samples)
    prctl = np.nanpercentile(data, 75, axis=1)

    #AMP.append(prctl) #amp[1:-1])
    #PXX.append(P) # pxx[1:-1,:])

    starttime = starttime + proclen

In [None]:
tr.plot(show=False, size=(1500, 200), linewidths=0.5)

In [None]:
scale=1e6
plt.figure(figsize=(10, 10))
for i, row in enumerate(get_windows(tr, starttime+overlap, nf, winlen_samples)):
    plt.plot(row*scale + i, lw=0.5)

In [None]:
np.any(np.isnan(tr.data))

In [None]:
np.where(np.isnan(tr.data))

In [None]:
trfilt = tr.copy().filter('bandpass', freqmin=fmin, freqmax=fmax)#.plot(show=False)

In [None]:
trfilt

In [None]:
np.where(np.isnan(trfilt.data))

In [None]:
idx = np.where(np.isnan(PXX))

In [None]:
idx[0].size

In [None]:
plt.plot(freq, PXX[0,:,:].T);
plt.xlim(0, 1)

In [None]:
rPXX = PXX.reshape((366*24, 1025))

In [None]:
plt.plot(freq, rPXX[:24,:].T);
plt.xlim(0, 1)

In [None]:
rPXX[:24,:].shape, PXX[0,:,:].shape

In [None]:
np.all(np.isclose(rPXX[:24,:], PXX[0,:,:], equal_nan=True))
                      

In [None]:
rPXX[:24,:] - PXX[0,:,:]

### Create HDF5

In [None]:
fname = "data/{}.{}.{}.{}_{}-{}.hdf5".format(network, station, location, channel, startdate.date, enddate.date)
print(fname)

In [None]:
with h5py.File(fname, "w") as fout:
    for k, v in zip(['PXX', 'AMP', 'freq'],
                    [PXX, AMP, freq]):
        fout.create_dataset(k, data=v)
    

In [None]:
%ls -lh data/*.hdf5

# Check missing data

In [None]:
np.where(np.isnan(AMP))

In [None]:
startdate = UTC("2020-009")
enddate = UTC("2020-010")
st = sdsclient.get_waveforms(network, station, 
                                 location, channel,startdate, enddate)

In [None]:
inv = invclient.get_stations(network=network, station=station, 
                                 location=location, starttime=startdate, 
                                 endtime=enddate, channel=channel, 
                                 level='response')

In [None]:
st.remove_sensitivity(inv)

In [None]:
st.plot(show=False)

In [None]:
st.merge(fill_value=np.nan)

In [None]:
st.trim(startdate, enddate, pad=True, fill_value=np.nan)

In [None]:
st.plot(show=False)

In [None]:
tr = st[0].copy()
amp, freq, pxx = process(tr, winlen_in_s, nperseg, fmin, fmax)

In [None]:
amp

In [None]:
amp

# Filter without taper

In [None]:
starttime = UTC("2020-06-01")
endtime = UTC("2020-06-02")
st = sdsclient.get_waveforms(network, station, 
                                 location, channel,starttime, endtime)

In [None]:
inv = invclient.get_stations(network=network, station=station, 
                                 location=location, starttime=starttime, 
                                 endtime=endtime, channel=channel, 
                                 level='response')

In [None]:
st.remove_sensitivity(inv)
st.merge(fill_value=np.nan)
st.trim(starttime, endtime, pad=True, fill_value=np.nan)

In [None]:
tr = st[0].copy()

In [None]:
sr = tr.stats.sampling_rate
nwin = int(winlen_in_s * sr)
nf = tr.data.size // nwin # Get number of frames
nsize = nf*nwin

In [None]:
matrix = np.reshape(tr.data[:nsize], (nf, nwin))

In [None]:
plt.figure(figsize=(10, 12))
scale = 5e6
for i, row in enumerate(matrix):
    plt.plot(scale*row + i)
    
    frow = bandpass(row.copy(), fmin, fmax, sr)
    plt.plot(frow*scale +i, 'k', lw=0.5)
plt.xlim(-100, 1000)

In [None]:
amp, freq, pxx = process(tr, winlen_in_s, nperseg, fmin, fmax)

# Benchmark windowing
`enframe` of obspy.signal is significantly slower than just reshaping. So unless you want overlaps between windows, there is no reason to use enframe.

In [None]:
# %load -n util.enframe
#from scipy import fix
def enframe(x, win, inc, use_obspy=False):
    """
    Splits the vector up into (overlapping) frames beginning at increments
    of inc. Each frame is multiplied by the window win().
    The length of the frames is given by the length of the window win().
    The centre of frame I is x((I-1)*inc+(length(win)+1)/2) for I=1,2,...

    :param x: signal to split in frames
    :param win: window multiplied to each frame, length determines frame length
    :param inc: increment to shift frames, in samples
    :return f: output matrix, each frame occupies one row
    :return length, no_win: length of each frame in samples, number of frames
    """
    nx = len(x)
    nwin = len(win)
    if (nwin == 1):
        length = win
    else:
        # length = next_pow_2(nwin)
        length = nwin
    nf = int(np.fix((nx - length + inc) // inc))
    # f = np.zeros((nf, length))
    indf = inc * np.arange(nf)
    if use_obspy:
        inds = np.arange(length) + 1
        f = x[(np.transpose(np.vstack([indf] * length)) +
           np.vstack([inds] * nf)) - 1]
    else:
        f = x[np.expand_dims(indf, 1) + 
              np.expand_dims(np.arange(length), 0)]
    if (nwin > 1):
        w = np.transpose(win)
        f = f * np.vstack([w] * nf)
    #f = signal.detrend(f, type='constant')
    no_win, _ = f.shape
    return f, length, no_win

In [None]:
network = 'GR'
station = 'BFO'
location = ''
channel = 'HHZ'
startdate = UTC("2020-001")
enddate = UTC("2020-002")
overlap = 0 #3600

fmin, fmax = (4, 14)
nperseg = 2048
winlen_in_s = 3600

dataclient = Client('/home/lehr/sds/data')
invclient = RoutingClient('eida-routing')

In [None]:
inv = invclient.get_stations(network=network, station=station, 
                                 location=location, starttime=startdate, 
                                 endtime=enddate, channel=channel, 
                                 level='response')

In [None]:
starttime = startdate
endtime = starttime + 24*3600+2*overlap
st = dataclient.get_waveforms(network, station, 
                         location, channel,starttime, endtime)
st.remove_sensitivity(inv)
st.merge(fill_value=np.nan)
st.trim(starttime, endtime, pad=True, fill_value=np.nan)

In [None]:
tr = st[0]

In [None]:
%%timeit 
sr = tr.stats.sampling_rate
nwin = int(winlen_in_s * sr)
win = get_window('boxcar', nwin, fftbins=False)

matrix, nx, ny  = enframe(tr.data, win, win.size)

In [None]:
prctl = np.percentile(matrix, 75, axis=1)

plt.plot(prctl)

In [None]:
%%timeit
sr = tr.stats.sampling_rate
nwin = int(winlen_in_s * sr)
nf = tr.data.size // nwin
nsize = nf*nwin
matrix = np.reshape(tr.data[:nsize], (nf,nwin))

In [None]:
prctl = np.percentile(matrix, 75, axis=1)

plt.plot(prctl)

# Using tapered windows

In [None]:
network = 'GR'
station = 'BFO'
location = ''
channel = 'HHZ'
startdate = UTC("2020-002")
enddate = UTC("2020-003")
overlap = 5*60 #3600

fmin, fmax = (4, 14)
nperseg = 2048
winlen_in_s = 3600 + 2*overlap

dataclient = Client('/home/lehr/sds/data')
invclient = RoutingClient('eida-routing')

In [None]:
starttime = startdate - overlap
endtime = starttime + 24*3600+2*overlap
st = dataclient.get_waveforms(network, station, 
                         location, channel,starttime, endtime)
inv = invclient.get_stations(network=network, station=station, 
                                 location=location, starttime=startdate, 
                                 endtime=enddate, channel=channel, 
                                 level='response')

st.remove_sensitivity(inv)
st.merge(fill_value=np.nan)
st.trim(starttime, endtime, pad=True, fill_value=np.nan)

In [None]:
tr = st[0].copy()

In [None]:
#%%timeit 
sr = tr.stats.sampling_rate
winlen_in_s = 3600+600
nwin = int(winlen_in_s * sr)
a = 600 / winlen_in_s
win = get_window(('tukey', a), nwin, fftbins=False)

In [None]:
tax = np.linspace(-5, 65, nwin) 
plt.plot(tax, win)
#plt.xlim(-1, 61)

In [None]:
%%timeit
matrix, nx, ny  = enframe(tr.data, win, int(3600*sr), use_obspy=False)

In [None]:
%%timeit
matrix, nx, ny  = enframe(tr.data, win, int(3600*sr), use_obspy=True)

In [None]:
matrix, nx, ny  = enframe(tr.data, win, int(3600*sr), use_obspy=False)

In [None]:
plt.figure(figsize=(10, 12))
scale = 1e6
for i, row in enumerate(matrix):
    plt.plot(scale*row + i)
    
    #frow = bandpass(row.copy(), fmin, fmax, sr)
    #plt.plot(frow*scale +i, 'k', lw=0.5)
#plt.xlim(-100, 100000)
#plt.xlim(matrix.shape[-1]-100000, None)

In [None]:
plt.figure(figsize=(10, 12))
scale = 1e6
noverlap = int(overlap*sr)
for i, row in enumerate(matrix[:,noverlap:-noverlap]):
    plt.plot(scale*row + i)
    
    #frow = bandpass(row.copy(), fmin, fmax, sr)
    #plt.plot(frow*scale +i, 'k', lw=0.5)
#plt.xlim(-100, 100000)
#plt.xlim(matrix.shape[-1]-100000, None)

In [None]:
x = tr.data.copy()
inc = int(3600*sr)

In [None]:
nx = len(x)
nwin = len(win)
if (nwin == 1):
    length = win
else:
    # length = next_pow_2(nwin)
    length = nwin

In [None]:
length

In [None]:
nf = int(np.fix((nx - length + inc) // inc))

In [None]:
nf

In [None]:
nx - length + inc

In [None]:
# f = np.zeros((nf, length))
indf = inc * np.arange(nf)
print(indf)

In [None]:
inds = np.arange(length) + 1
inds.size, win.size

In [None]:
print(len([indf] * length))
print(([indf] * length)[1])

In [None]:
idx = (np.transpose(np.vstack([indf] * length)) +
       np.vstack([inds] * nf)) - 1

In [None]:
idx.shape

In [None]:
idx

In [None]:
f = x[(np.transpose(np.vstack([indf] * length)) +
       np.vstack([inds] * nf)) - 1]
if (nwin > 1):
    w = np.transpose(win)
    f = f * np.vstack([w] * nf)
#f = signal.detrend(f, type='constant')
no_win, _ = f.shape

In [None]:
length

In [None]:
idx = np.expand_dims(indf, 1) + np.expand_dims(np.arange(length), 0)

In [None]:
x[idx].shape

In [None]:
%%timeit -n100
view = np.lib.stride_tricks.sliding_window_view(x, length)
m = view[::inc,:]*win

In [None]:
x.strides

In [None]:
view = np.lib.stride_tricks.as_strided(x, (24,420000))

In [None]:
view

In [None]:
m = view*win

In [None]:
m

In [None]:
plt.figure(figsize=(10, 12))
scale = 1e6
for i, row in enumerate(m):
    plt.plot(scale*row + i)
    
    #frow = bandpass(row.copy(), fmin, fmax, sr)
    #plt.plot(frow*scale +i, 'k', lw=0.5)
plt.xlim(-100, 100000)

In [None]:
(fview * win)

In [None]:
inc

In [None]:
view[::inc,:].shape

In [None]:

def bandpass_matrix(data, freqmin, freqmax, df, corners=4, 
                    zerophase=False, axis=-1):

    """
    Butterworth-Bandpass Filter.

    Filter data from ``freqmin`` to ``freqmax`` using ``corners``
    corners.
    The filter uses :func:`scipy.signal.iirfilter` (for design)
    and :func:`scipy.signal.sosfilt` (for applying the filter).

    :type data: numpy.ndarray
    :param data: Data to filter.
    :param freqmin: Pass band low corner frequency.
    :param freqmax: Pass band high corner frequency.
    :param df: Sampling rate in Hz.
    :param corners: Filter corners / order.
    :param zerophase: If True, apply filter once forwards and once backwards.
        This results in twice the filter order but zero phase shift in
        the resulting filtered trace.
    :return: Filtered data.
    """
    fe = 0.5 * df
    low = freqmin / fe
    high = freqmax / fe
    # raise for some bad scenarios
    if high - 1.0 > -1e-6:
        msg = ("Selected high corner frequency ({}) of bandpass is at or "
               "above Nyquist ({}).").format(
            freqmax, fe)
        raise ValueError(msg)
    if low > 1:
        msg = "Selected low corner frequency is above Nyquist."
        raise ValueError(msg)
        
    z, p, k = iirfilter(corners, [low, high], btype='band',
                        ftype='butter', output='zpk')
    sos = zpk2sos(z, p, k)
    if zerophase:
        firstpass = sosfilt(sos, data)
        return sosfilt(sos, firstpass[::-1])[::-1]
    else:
        return sosfilt(sos, data)

In [None]:


def get_windows(tr, starttime, nf, winlen_samples,
                           taper_samples=0):
    """
    Split trace data into adjacent or overlapping frames
    depending whether Nans are present.
    
    If Nans are present, overlapping frames are created.
    The overlap is determined by `taper_samples`. Thus if
    you expect Nans in your data and want overlapping frames
    give `taper_samples` > 0
    
    Returns
    ----------
    data : ndarray
        framed data, 
        shape=(nf, winlen_samples+2*taper_samples)
    nontapered : slice
        slice object, that gives the slice of the
        untapered region in `data`.
    """
    
    if np.any(np.isnan(tr.data)):
        logger.info('Found nans in %s' % tr)
        taper_samples = int(taper_samples)
        data, taper_samples = get_overlapping_windows(tr, 
                            starttime, nf, winlen_samples,
                           taper_samples)
        nontapered = slice(taper_samples,-taper_samples,None)
    else:
        data = get_adjacent_windows(tr, starttime, nf, 
                                    winlen_samples)
        nontapered = slice(None,None,None)
    return data, nontapered
    


def get_adjacent_windows(tr, starttime, nf, winlen_samples):
    """
    Reshape vector into frames without overlap
    
    Uses `np.reshape`.
    """
    #print(starttime)
    #print(tr)
    ntot = int(nf*winlen_samples)
    data = tr.slice(starttime, endtime=None).data[:ntot]
    return data.reshape((nf, winlen_samples))


def get_overlapping_windows(tr, starttime, nf, winlen_samples,
                           taper_samples):
    """
    Splits the vector up into (overlapping) frames 
    """
    sr = tr.stats.sampling_rate
    
    # samples in actual window (without tapers)
    #winlen_samples = int(winlen_seconds * sr)
    
    # samples in one taper
    #taper_samples = int(taper_seconds * sr)
    
    # Samples in window including tapers
    nwin = int(winlen_samples + 2*taper_samples)
    
    # Total number of samples of trace to process
    proclen_samples = int(nf * winlen_samples + 2*taper_samples)
    
    # Ratio of tapers to total window size
    a =  2*taper_samples / nwin
    win = get_window(('tukey', a), nwin, fftbins=False)
    
    # Cut out the needed data
    x = tr.slice(starttime-taper_samples/sr).data[:proclen_samples]
    
    # From obspy.signal.enframe()
    nx = len(x)
    nwin = len(win)
    if (nwin == 1):
        length = win
    else:
        # length = next_pow_2(nwin)
        length = nwin
    nf = int(np.fix((nx - length + winlen_samples) // winlen_samples))
    # f = np.zeros((nf, length))
    indf = winlen_samples * np.arange(nf)
    f = x[np.expand_dims(indf, 1) + 
          np.expand_dims(np.arange(length), 0)]
    f = f * win
    f[np.any(np.isnan(f), axis=1),:] = np.nan
    #no_win, _ = f.shape
    return f, taper_samples



def get_amplitude(data, nontapered, fmin, fmax, sr):
    """
    If trace is free of nans, we can simply filter the whole trace at 
    once and use reshape to get the windows.
    If there are nans, the obspy filter function only filters
    the data up to the first nan. Thus you can loose almost the entire
    trace because of a single nan in the early part. To avoid this, we
    check for nans and if there are some, we filter the data per
    window. This means however, we have to ensure again some overlap
    and the windowing will be slower.
    """
    
    data = bandpass(data, fmin, fmax, sr)
    data = data[:,nontapered]
                                    
    else:
        tr.filter('bandpass', freqmin=fmin, freqmax=fmax)
        data = get_windows(tr, starttime+overlap, nf, winlen_samples)
        
    prctl = np.nanpercentile(data, 75, axis=1)
    return prctl