In [None]:
from importlib import reload
from pathlib import Path
import os
import numpy as np
from scipy.signal import medfilt, medfilt2d, get_window
from scipy.ndimage import median_filter

from obspy.clients.filesystem.sds import Client as SDSClient
from obspy.clients.fdsn import RoutingClient
from obspy.core import UTCDateTime as UTC, read

import matplotlib.pyplot as plt
plt.style.use('tableau-colorblind10')

from data_quality_control import sds_db, base, util, analysis
from data_quality_control.analysis import Analyzer

In [None]:
# Only for display in documentation!
from IPython.core.display import display, HTML 

## Define parameters

In [None]:
# NSLC
nslc_code = "GR.BFO..BHZ"

overlap = 60 #3600
fmin, fmax = (4, 14)
nperseg = 2048
winlen_in_s = 3600
proclen = 24*3600

outdir = 'output'

sds_root = os.path.abspath('../../sample_sds/')
inventory_routing_type = "eida-routing"

# Test trimming

In [None]:
startdate = UTC("2020-12-24")
enddate = UTC("2021-01-15")

In [None]:
reload(analysis)
reload(base)
#reload(util)
lyza = analysis.Analyzer(outdir, nslc_code,
                            fileunit="year")

In [None]:
print(lyza)

We can inquire which files and time ranges are available for
the given code, location and fileunit.

In [None]:
files = lyza.get_available_datafiles()
print(files)

In [None]:
lyza.get_available_timerange()

In [None]:
reload(analysis)
reload(base)
#reload(util)
lyza = analysis.Analyzer(outdir, nslc_code,
                            fileunit="year")

lyza.get_data(startdate, enddate)

In [None]:
lyza

In [None]:
lyza.trim(UTC("2020-12-30"), UTC("2021-01-06"))

In [None]:
lyza

In [None]:
lyza.trim(startdate, enddate)

In [None]:
lyza.plot_spectrogram();

In [None]:
lyza.trim(startdate, enddate, fill_value=np.nan)

lyza.plot_spectrogram();

# Interpolator

### Iterator over PSD-array

In [None]:
psds = lyza.psds.copy()
psds.shape

In [None]:
x = psds[:26,0]
x.shape

In [None]:
f = util.get_overlapping_frames(x, 2, 6)
f.shape

In [None]:
smoothed = []
for psd in psds.T:
    #print(psd.shape)
    f = get_overlapping_frames(psd, 3, 6)
    x = np.nanmedian(f, axis=1)
    smoothed.append(x)
smoothed = np.array(smoothed).T

In [None]:
smoothed.shape

In [None]:
plt.imshow(np.log10(smoothed), aspect="auto")

In [None]:
inc = 3
winlen = 6
smoothed = np.array([np.nanmedian(get_overlapping_frames(x, inc, winlen),axis=1) for x in psds.T]).T

smoothed.shape

In [None]:
plt.imshow(np.log10(smoothed.T), aspect="auto")

In [None]:
plt.imshow(np.log10(psds.T), aspect="auto")

class Interpolator():
    def __init__(self, datadir, nslc_code, fileunit="year"):
        self.stationcode = nslc_code
        
        self.datadir = datadir
        self.fileunit  = fileunit
        self.iter_time = util.TIME_ITERATORS[self.fileunit]
        
        # Get fmtstr of data files
        fmtstr_base, sep, fmtstr_time = util.FNAME_FMTS[self.fileunit].rpartition("_")
        self.fmtstr = (fmtstr_base.format(
                        outdir=self.datadir, **self.nslc_as_dict()) + 
                        sep + fmtstr_time)
        #self.logger = logging.getLogger(module_logger.name+
        #                    '.'+"Analyzer")
        #self.logger.setLevel(logging.DEBUG)
     
    def nslc_as_dict(self):
        d = {k: v for k, v in zip(["network", "station", "location", "channel"], 
                                  self.stationcode.split("."))}
        return d
        
    def interpolate(self):
        files =  [str(f) for f in 
                Path(self.datadir).glob(self.stationcode+"_"+util.FNAME_WILDCARD[self.fileunit]+".hdf5")]
        files.sort()
        
        for f in files:
            data = base.BaseProcessedData().from_file(f)
            print(f)
            print(data)
            print()

In [None]:
lyza.psds.shape

In [None]:
reload(analysis)
reload(base)
reload(util)
class Interpolator(analysis.Analyzer):
    def __init__(self, datadir, nslc_code, fileunit="year"):
        super().__init__(datadir, nslc_code, fileunit)
    
    
    def _get_SECONDS_PER_WINDOW(self, TSTA, TEND):
        """
        Read first file in list to get window size in seconds.
        """
        self.logger.debug("\n\nLooking for window size")
        for tsta, tend in self.iter_time(TSTA, TEND):
            self.get_data(tsta, tend)
            self.logger.debug("Time range to get window size: {} - {}".format(tsta, tend))
            self.logger.info("Expecting window size is {:g}s".format(self.seconds_per_window))
            self.SECONDS_PER_WINDOW = self.seconds_per_window
            break
        
    #@property
    def _set_check_SECONDS_PER_WINDOW(self):
        if not hasattr(self, "SECONDS_PER_WINDOW"):
            self.logger.info("Expecting window size = {:g}s".format(
                self.seconds_per_window))
            self.SECONDS_PER_WINDOW = self.seconds_per_window
        elif self.SECONDS_PER_WINDOW != self.seconds_per_window:
            msg = "Window size changed"
            self.logger.error(msg)
            raise RuntimeError(msg)

            
    def _check_framed_shape(self, x, X, kernel_shift, label=""):
        
        nk, ks = X.shape
        ns = (nk-1)*kernel_shift+ks
        assert ns == x.size, \
            "{:d} of {} timeseries remain".format(x.size-ns, label)
    
    
    
    def _interpolate(self, kernel_size, kernel_shift):
        
        x = self.amplitudes
        X = util.get_overlapping_frames(x, 
                                       kernel_size, kernel_shift)
        
        print(x.size, X.shape)
        self._check_framed_shape(x, X, kernel_shift, "amplitude")
        amplitudes_ = np.nanmedian(X, axis=1)

        x = self.psds[:,0]
        X = util.get_overlapping_frames(x, kernel_size, kernel_shift)
        print(x.size, X.shape)
        self._check_framed_shape(x, X, kernel_shift, "psd")
        PSD_ = np.array([np.nanmedian(
                util.get_overlapping_frames(x, kernel_size, kernel_shift),axis=1) 
                         for x in self.psds.T]).T
        
        
        return amplitudes_, PSD_
    
    
    def iter_times_kernel(self, tsta, tend, kernel_size, kernel_shift):
        """
        Note
        -------
        yielded endtime is starttime of last sample plus window size.
        """
        new_tsta = None
        new_tend = None
        for _tsta, _tend in self.iter_time(tsta, tend):
            _tend = _tend + 24*3600
            if not new_tsta:
                new_tsta = _tsta
                #_tend = _tend + 24*3600
            
            N = int((_tend-new_tsta) / self.SECONDS_PER_WINDOW) 
            n_kernels, n_left = np.divmod(N, kernel_shift)
            #n_left = int(samples_left / self.SECONDS_PER_WINDOW)
            print(N, n_kernels, n_left)
            
            #Nadd = (kernel_size - n_left)
            
            #tend = tend - n_left*self.SECONDS_PER_WINDOW + kernel_size*self.seconds_per_window
            new_tend = _tend + (kernel_size-n_left)*self.SECONDS_PER_WINDOW
            self.logger.debug("Times adjusted to kernel: {} - {}".format(
                new_tsta, new_tend))
            
            yield new_tsta, new_tend
            
            new_tsta = new_tend - kernel_size*self.SECONDS_PER_WINDOW 
            
    
    
    def interpolate(self, kernel_size, kernel_shift=1, outdir="."):
        TSTA, TEND = self.get_available_timerange()
        TSTA = UTC(TSTA.date)
        TEND = UTC(TEND.date)
        print(self.seconds_per_window)
        self._get_SECONDS_PER_WINDOW(TSTA, TEND)
        
        self.logger.debug("\n\nStarting interpolation\n")
        for tsta, tend in self.iter_times_kernel(TSTA, TEND, kernel_size, kernel_shift):
            
            self.logger.debug("Yielded {} - {}".format(tsta, tend))
            #tsta = tsta
            #tend = tend + 24*3600# + (kernel_size-kernel_shift)*self.SECONDS_PER_WINDOW 
            self.logger.info("Interpolating {:} - {}".format(tsta, tend))
            self.logger.debug("Getting data...")
            
            self.get_data(tsta, tend)
            self.trim(tsta, tend, fill_value=np.nan)
            self._set_check_SECONDS_PER_WINDOW()
            self.logger.debug("Interpolating")
            #self.logger.debug("{} - {}".format(self.startdate, self.enddate))
            amplitudes_, psds_ = self._interpolate(kernel_size, kernel_shift)
            self.set_data(amplitudes_, psds_, self.frequency_axis)
            self.set_time(tsta, tend+(kernel_shift-kernel_size)*self.SECONDS_PER_WINDOW)
            self.seconds_per_window = kernel_shift*self.SECONDS_PER_WINDOW
            self._check_shape_vs_time()
            
            self.to_file(outdir)
            self.logger.debug("\n")
            #self.fill_days()
            
            
            

In [None]:
reload(analysis)
#reload(Analyzer)
reload(base)
reload(util)
polly = Interpolator(outdir, nslc_code )

In [None]:
polly.interpolate(6, 5, "output/interpolated/")

In [None]:
ax = axs[0]

In [None]:
reload(base)
fig, axs = plt.subplots(2,1, figsize=(8, 10))

titles = ["raw", "interpolated"]
fpatterns = [ "*202*hdf5", "interpolated/*hdf5",]

for i, fpattern in enumerate(fpatterns):
    try: 
        res = base.BaseProcessedData()
        for fname in Path("output/").glob(fpattern):
            print(fname)
            res.extend_from_file(fname)
    except RuntimeWarning:
        pass
    
    ax = axs[i]
    res.plot_psds(np.log10, ax=ax)
    ax.set_title(titles[i])

In [None]:
lyza = analysis.Analyzer("output/interpolated/", nslc_code)
lyza.get_data(startdate, enddate)
lyza.plot_spectrogram()

# Figuring out how my get_overlapping_frames() work

In [None]:
sdsclient = SDSClient(sds_root)

In [None]:
startdate = UTC("2020-12-25")

In [None]:
endtime = startdate+24*3600
st = sdsclient.get_waveforms(*nslc_code.split("."), startdate, endtime)
st = st.trim(startdate-60, endtime+60, pad=True, fill_value=0)
tr = st[0]
x = np.arange(1,24+1).repeat(72000)
print(x.shape)
print(tr.stats.npts)
tr.data[60*20+1:-60*20] = x

In [None]:
tr.data.size % 24 # 74400

In [None]:
tr.plot(endtime=startdate+600);

In [None]:
procparams = base.ProcessingParameters()

In [None]:
procparams

In [None]:
nf = int(procparams.proclen_seconds/
        procparams.winlen_seconds)

In [None]:
nf

In [None]:
f, taps = util.get_overlapping_tapered_frames(tr, startdate, 24, int(3600*20), 60*20)

In [None]:
f.shape

In [None]:
plt.imshow(f, aspect="auto")

In [None]:
#plt.plot(f[10,60*20+1:-60*20])
plt.plot(f[3,:])
#plt.xlim(-1, 10)

In [None]:
x = np.arange(24).repeat(74400)

In [None]:
x

In [None]:
def get_overlapping_tapered_frames(tr, starttime, nf, winlen_samples,
                           taper_samples):
    sr = tr.stats.sampling_rate
    
    # Samples in window including tapers
    nwin = int(winlen_samples + 2*taper_samples)
    
    # Total number of samples of trace to process
    proclen_samples = int(nf * winlen_samples + 2*taper_samples)
    
    # Cut out the needed data
    x = tr.slice(starttime-taper_samples/sr).data[:proclen_samples]
    
    # Ratio of tapers to total window size
    a =  2*taper_samples / nwin
    win = get_window(('tukey', a), nwin, fftbins=False)
    
    # From obspy.signal.enframe()
    #nx = len(x)
    #nwin = len(win)
    if (len(win) == 1):
        length = win
    else:
        length = nwin
    #nf = int(np.fix((nx - length + winlen_samples) // winlen_samples))
    # f = np.zeros((nf, length))
    indf = winlen_samples * np.arange(nf)
    f = x[np.expand_dims(indf, 1) + 
          np.expand_dims(np.arange(length), 0)]
    print(indf)
    print(length)
    print(np.expand_dims(indf, 1) + 
          np.expand_dims(np.arange(length), 0))
    #f = f * win
    #f[np.any(np.isnan(f), axis=1),:] = np.nan
    #no_win, _ = f.shape
    return f, taper_samples

In [None]:
st

In [None]:
endtime = startdate+24*3600
st = sdsclient.get_waveforms(*nslc_code.split("."), startdate, endtime)
#st = st.trim(startdate-60, endtime+60, pad=True, fill_value=0)
tr = st[0]

In [None]:
sr = 0.01
x = np.float_(np.arange(1,24+1).repeat(int(3600*sr)))
#x = np.insert()
print(x.shape)
print(x.size/24)
plt.plot(x)

In [None]:
tr.data = x
tr.stats.sampling_rate = sr
tr.stats.starttime = startdate
tr = tr.trim(startdate-600, endtime+600, pad=True, fill_value=0.1)
print(tr.stats.npts)

In [None]:
tr.stats.starttime, tr.stats.endtime

In [None]:
tr.plot();

In [None]:
reload(util)
f, taps = util.get_overlapping_tapered_frames(tr, startdate, 24, int(3600*sr), 600*sr)

print(f.shape)

plt.imshow(f, aspect="auto")

In [None]:
reload(util)
nf = 24
winlen_samples = int(3600*sr)
taper_samples = int(600*sr)
proclen_samples = int(nf * winlen_samples + 2*taper_samples)

# Cut out the needed data
x = tr.slice(startdate-taper_samples/sr).data[:proclen_samples]

f = util.get_overlapping_frames(x, int(winlen_samples + 2*taper_samples), 
                                     taper_samples)

print(f.shape)

plt.imshow(f, aspect="auto")

In [None]:
f[-1,:]

In [None]:
f[0, 6:6+36]

In [None]:
x2 = np.append(x, np.zeros(10))

In [None]:
reload(util)

In [None]:
f1 = util.get_overlapping_frames(x2, 18, 36 )

In [None]:
f1.shape

In [None]:
plt.imshow(f1, aspect="auto")

In [None]:
plt.plot(f1[2,:])

In [None]:
x

### Simple test of get_overlapping_frames()

In [None]:
x = np.arange(1,5+1).repeat(3)
print(x.shape)
print(x)
plt.plot(x)

In [None]:
reload(util)

f = util.get_overlapping_frames(x, 3, 3)
print(f.shape)
print(f)
plt.imshow(f, aspect="auto")

In [None]:
nf, fs = f.shape
print((nf-1)*3+fs, x.size)

In [None]:
reload(util)

f = util.get_overlapping_frames(x, 5, 1)
print(f.shape)
print(f)
plt.imshow(f, aspect="auto")

In [None]:
nf, fs = f.shape
print((nf-1)*1+fs, x.size)

In [None]:
reload(util)

f = util.get_overlapping_frames(x, 4, 3)
print(f.shape)
print(f)
plt.imshow(f, aspect="auto")

In [None]:
nf, fs = f.shape
print((nf-1)*3+fs, x.size)