In [1]:
# Picker
# Pick waves and pass arrivals to the assoc_feed app

In [2]:
import numpy as np
import pandas as pd
import os
import json
import requests

from obspy import UTCDateTime
from flask import Flask, request, jsonify
from scipy.signal import lfilter, butter, decimate, hann, find_peaks

import scipy.stats as stats
import tensorflow as tf
import matplotlib.pyplot as plt

In [3]:
# Import configuration paramaters for pipeline
import pconf

In [4]:
%matplotlib inline
#%matplotlib widget

In [5]:
# App variables
app = Flask(__name__)
app_name = 'picker'
napp = 'assoc_feed'
model_par = {v[0]:v[1] for v in [pair.split(':') for pair in pconf.picker_model[:-3].split('|')]}

In [6]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

In [7]:
def get_custom_loss(trim):
    def inner_mse(y_true, y_pred):
        error = y_true[:,trim:-trim] - y_pred[:,trim:-trim]
        sqr_error = K.square(error)
        mean_sqr_error = K.mean(sqr_error)

        return mean_sqr_error
    return inner_mse

In [8]:
def load_my_model(name, location, trim):
    my_loss = get_custom_loss(trim)
    return tf.keras.models.load_model(os.path.join(location, name), 
                                      custom_objects={
                                          'inner_mse':my_loss,
                                      })

In [10]:
picker = load_my_model(pconf.picker_model, pconf.model_folder, model_par['left_trim'])
print(pconf.picker_model)
picker.summary()

f:64|k:20|d:2x4x16x256|s:2|bs:100|o_len:1800|w_len:1800|n_phs:4|shift:0|left_trim:500|right_trim:1000|r_smp:20|f_low:0.5|f_hig:8|c_len:60|c_buf:0|c_shp:gauss|c_amp:1|noise:0.1|mixed:0.1|lr:0.0005|pat:10|time:1627374037.h5
Model: "encoder_model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_7 (InputLayer)            [(None, None, 3)]    0                                            
__________________________________________________________________________________________________
conv1d_72 (Conv1D)              (None, None, 64)     3904        input_7[0][0]                    
__________________________________________________________________________________________________
activation_36 (Activation)      (None, None, 64)     0           conv1d_72[0][0]                  
______________________________________________________________

In [11]:
# Filters and stuff
def DAT_normalize(X):
    X = X - np.expand_dims(np.mean(X,1),1)
    X = X / np.expand_dims(np.expand_dims(np.abs(X).max(1).max(1), 1), 1)
    return X

def butter_bandpass(lowcut, highcut, fs, order=8):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq

    b, a = butter(order, [low, high], btype='band')
    return b, a

def DAT_filter(X, pdict, order=3):
    lowcut = pdict['f_low']
    highcut = pdict['f_hig']
    fs = pdict['r_smp']
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    return lfilter(b, a, X, axis=1) 

def DAT_taper(X, taper_percentage=.05):
    npts = X.shape[1]
    taper_len = int(npts * taper_percentage)
    taper_sides = hann(2 * taper_len + 1)
    taper = np.hstack((taper_sides[:taper_len], np.ones(npts - taper_len)))
    return X * np.reshape(taper,(1,-1,1))

In [12]:
def showme(wave, picks, pdict, sp=0):    
    b, a = butter_bandpass(.5, 2, 20, order=8)
    lp_wave = lfilter(b, a, wave[sp], axis=1) 
    
    b, a = butter_bandpass(2, 8, 20, order=8)
    hp_wave = lfilter(b, a, wave[sp], axis=1) 
    
    sta = pdict['metadata'][sp]['sta']
    start = UTCDateTime(pdict['metadata'][sp]['start'])
    end = start + pdict['window_size']

    fig, axs = plt.subplots(5, figsize=(15,12), sharex=True)
    fig.suptitle(f'{sta}: {start} - {end}')
    axs[0].plot(wave[sp])
    axs[0].title.set_text('Y - bandpass filtered 0.5-8 Hz')
    axs[1].plot(hp_wave)
    axs[1].title.set_text('hp_Y - bandpass filtered 2-8 Hz')
    axs[2].plot(lp_wave)
    axs[2].title.set_text('lp_Y - bandpass filtered 0.5-2 Hz')
    axs[3].plot(picks[sp])
    axs[3].title.set_text('Y_prime - raw model output')
    axs[4].plot(clean_y(picks, pdict)[sp])
    axs[4].title.set_text('~Y_prime - correlated output')
    
    for i in range(len(axs)):
        axs[i].axvspan(0, pdict['trim'] * pdict['r_smp'], facecolor='k', alpha = 0.2)
        axs[i].axvspan(len(wave[sp]) - pdict['trim'] * pdict['r_smp'], len(wave[sp]), facecolor='k', alpha = 0.2)
    
    arrivals = get_arrivals(picks, pdict)
    pn = arrivals[sp]['pn']
    pg = arrivals[sp]['pg']
    sn = arrivals[sp]['sn']
    lg = arrivals[sp]['lg']
    
    for ph_name, ph_list in zip(['Pn', 'Pg', 'Sn', 'Lg'], [pn, pg, sn, lg]):
        for a in ph_list:
            for i in range(len(axs)):
                axs[i].axvline(x=a, c='k')
                i>2 or axs[i].text(a,0,ph_name)
        
    axs[3].set_ylim((0,1))
    axs[4].set_ylim((0,1))
    
    fig.tight_layout()
    plt.show()

In [13]:
def clean_y(characteristic_function, pdict):
    # Replicate the original characteristic function
    x = stats.norm.pdf(np.linspace(-3,3,pdict['c_len']*int(pdict['r_smp'])))
    x = (x - x.min()) / (x - x.min()).max()
    # If the two signals are the same, the max value will be the sum of the squares
    # Use this value for normalizing the cross correlated array
    max_corr = np.sum(x**2)
    clean_characteristic_function = []
    for p in range(characteristic_function.shape[0]):
        pn = np.correlate(characteristic_function[p,:,0], x, mode='same')/max_corr
        pg = np.correlate(characteristic_function[p,:,1], x, mode='same')/max_corr
        sn = np.correlate(characteristic_function[p,:,2], x, mode='same')/max_corr
        lg = np.correlate(characteristic_function[p,:,3], x, mode='same')/max_corr
        clean_characteristic_function.append(np.vstack((pn.T, pg.T, sn.T, lg.T)).T)
#     pn = np.correlate(picks[:,:,0], x, mode='same')/max_corr
#     pg = np.correlate(picks[:,:,1], x, mode='same')/max_corr
#     sn = np.correlate(picks[:,:,2], x, mode='same')/max_corr
#     lg = np.correlate(picks[:,:,3], x, mode='same')/max_corr
    return np.array(clean_characteristic_function)

In [14]:
def get_arrivals(characteristic_function, pdict):
    # characteristic_function: characteristic_function.shape = (m, n, 4)
    # pdict: parameters dict
    
    # Determine trusted range
    low_trusted_limit = pdict['trim'] * pdict['r_smp']
    high_trusted_limit = characteristic_function.shape[1] - low_trusted_limit
    
    # Smooth picks array
    clean_characteristic_function = clean_y(characteristic_function, pdict)
    
    # Zero out all values below the threshold.
    clean_characteristic_function[clean_characteristic_function < pdict['dt']] = 0
    
    arrivals = []
    # Find peaks in the picks arrays
    for p in range(clean_characteristic_function.shape[0]):
        pn,_ = find_peaks(clean_characteristic_function[p,low_trusted_limit:high_trusted_limit,0])
        pg,_ = find_peaks(clean_characteristic_function[p,low_trusted_limit:high_trusted_limit,1])
        sn,_ = find_peaks(clean_characteristic_function[p,low_trusted_limit:high_trusted_limit,2])
        lg,_ = find_peaks(clean_characteristic_function[p,low_trusted_limit:high_trusted_limit,3])
        
        # Because the low trusted limit was spliced out it needs to be added again for plotting and time calculation
        picks = {
            'pn': pn + low_trusted_limit,
            'pg': pg + low_trusted_limit,
            'sn': sn + low_trusted_limit,
            'lg': lg + low_trusted_limit
        }
        arrivals.append(picks)
    return np.array(arrivals)

In [15]:
def build_catalog(characteristic_function, pdict):
    
    arrivals = get_arrivals(characteristic_function, pdict)
    
    try:
        with open(pdict['cat'], "r") as cat:
            arid = int(cat.readlines()[-1].split(',')[0])+1
    except:
        arid = 0
        pass
    with open(pdict['cat'], "a+") as cat:
        for i, a in enumerate(arrivals):
            for p, t in zip(a.keys(), a.values()):
                for idx in t:
                    cat.write(f"{arid},{p},{str(UTCDateTime(pdict['metadata'][i]['start']) + ((idx+1)/pdict['r_smp']))},{pdict['metadata'][i]['sta']},{pdict['metadata'][i]['st_lat']},{pdict['metadata'][i]['st_lon']}\n")
                    arid+=1
    return True

In [16]:
# Test up method
@app.route('/')
def apitest():
    return f'{app_name} is working'

In [17]:
# This is the URI that process the received requests
@app.route(f'/{app_name}', methods=['POST'])
def process_request():
    # Get json from request
    pdict = request.json
    # Extract the data in the X key (this removes the data so it is not passed back and forth)
    # Alternativelly the data may be just read instead of popped but it will grow the json
    X = np.array(pdict.pop('X'))
    # Evaluate the model to get the picks using the sample
    characteristic_function = picker.predict(X)
    # Attach picks to the json to send back
    pdict['picks'] = characteristic_function.tolist()
    for w in range(X.shape[0]):
        showme(X, characteristic_function, pdict, w)
    build_catalog(characteristic_function, pdict)
    return pdict

In [18]:
def forward(json_data):
    response = requests.post(f'http://{pconf.host}:{pconf.apps[napp]}/{napp}', json = json_data, headers = pconf.head)
    return response

In [None]:
if __name__ == '__main__':
    app.run(pconf.host, debug=False, port=pconf.apps[app_name])

 * Serving Flask app "__main__" (lazy loading)
 * Environment: production
[2m   Use a production WSGI server instead.[0m
 * Debug mode: off


 * Running on http://127.0.0.1:6003/ (Press CTRL+C to quit)
