# Processing DAS data using DASK: Seven Trees Aftershock
## SEP: June  2023
### Thomas Cullison, 1st Year Geophysics 

<br><br>

In [None]:
import io
import datetime
import h5py
import numba
#import ipycytoscape #work in progress "delayed function graph"

import numpy as np
#import dask.dataframe as dd  #work in progress
#import dask.array as da      #work in progress

from scipy import signal
from google.cloud import storage
from dask.distributed import Client, wait
from time import time
from os import cpu_count

start_time = time() #start timing of processing

## Get Data from the Cloud: Function Defs

In [None]:
def gcs_download_to_local_disk(pargs,bucket=None,lpath=None):
    """
       Function name explains it all
    """
    fname, i = pargs
    client = storage.Client()
    bucket = client.get_bucket(bucket)
    blobname = fname + '_min' + str(i+1).zfill(2) + '.npz' 
    blob = bucket.get_blob(blobname)
    dfname = path+fname + '_min' + str(i+1).zfill(2) + '.npz' 
    blob.download_to_filename(dfname)
    
    
def parallel_load_npz(pargs,path=None):
    """
       Load DAS data from local disk
    """
    fname, i = pargs
    
    dfname = path+fname + '_min' + str(i+1).zfill(2) + '.npz' 
    dt_data = np.load(dfname)
    
    return dt_data['data'], dt_data['time']

## Setup List of Files to Read: (a priori Knowledge Req.)

In [None]:
buckname = 'sep-allow-others' #kind of like the head/main directory -- Leave ALONE

nfiles = 10 # Leave this ALONE

## Start Dask Distributed Cluster: (10 Threads at Most, One-per-file)

In [None]:
ncore = cpu_count()
nwork = min(nfiles,ncore) 

client = Client(n_workers=nwork,processes=True,threads_per_worker=1)
#                   Be sure this is set ----^

# show dash board link
client

## Read All Files to Array -- Map to Threads : (Memory in Cluster)

### Download files to local disk

In [None]:
%%time

path = './data/test/' # put YOUR local path here
fname = 'oct_7trees_aftershock_das' # leading prefix. leave this ALONE 

lfname = [fname for i in range(nfiles)] 
lidxs = [i for i in range(nfiles)] 
pargs = list(zip(lfname,lidxs))
    
arrs = client.map(gcs_download_to_local_disk,pargs,bucket=buckname,lpath=path,pure=False)
junk_futures = wait(arrs)

### Read compressed numpy arrays into memory

In [None]:
%%time 

# pargs: Same as above

# Parallel part
#with multiprocess.Pool(processes=nfiles) as pool:
    #pool.starmap(parallel_load_npz, pargs) # Will copy memory (via return) to Host/Notebook
    
arrs = client.map(parallel_load_npz,pargs,path=path,pure=False)
junk_futures = wait(arrs)

## Gather Arrays to Notebook

In [None]:
%%time

gathered_data = client.gather(arrs,direct=True)
client.who_has()

## Clean-up Cluster Memory

In [None]:
%%time

for t in arrs:
    client.cancel(t)
client.cancel(arrs)

## Scale-down Cluster to One Thread: (Not the Same as Notebook Thread)

In [None]:
%%time

client.cluster.scale(1)
client.who_has()

## Concatenate Arrays Over Time Axis: (Notebook Thread)

In [None]:
%%time

tup_list = list(map(list, zip(*gathered_data)))
rdlist = tup_list[0]
rtlist = tup_list[1]

In [None]:
%%time

rdata = np.concatenate(rdlist,axis=1)
tdata = np.concatenate(rtlist)

del rdlist[:]
del rdlist
del rtlist[:]
del rtlist

## Begin Processing: Function Defs

In [None]:
@numba.njit(cache=True, fastmath=True, nogil=True, parallel=True)
def remove_median_xchannel(orig_tr):
    rmed_traces = orig_tr.copy()
    for it in numba.prange(orig_tr.shape[-1]):
        rmed_traces[:, it] -= np.median(orig_tr[:, it])
    return rmed_traces


# CANNOT jit ANY of the functions below

def detrend_all_traces(orig_tr):
    det_traces = orig_tr.copy()
    for i in range(det_traces.shape[0]): # did this because scipy mem managment is not good enough
        signal.detrend(det_traces[i],type='constant',overwrite_data=True)
        signal.detrend(det_traces[i],type='linear',overwrite_data=True)
    return det_traces


def detrend_single_trace(orig_tr):
    det_const = signal.detrend(orig_tr,type='constant')
    det_trace = signal.detrend(det_const,type='linear')
    del det_const
    return det_trace


    
def bandpass_butter_single_trace(trace, fs=None, b0=None, bN=None, order=5):
    sos = signal.butter(order, (b0,bN), 'bandpass', fs=fs, output='sos')
    bp_trace = signal.sosfiltfilt(sos, trace)
    return bp_trace



def bandpass_butter_all_traces(orig_tr, fs=None, b0=None, bN=None, order=5):
    bp_traces = np.zeros_like(orig_tr)
    sos = signal.butter(order, (b0,bN), 'bandpass', fs=fs, output='sos')
    for i in range(len(orig_tr)):
        bp_traces[i,:] = signal.sosfiltfilt(sos, orig_tr[i])
    return bp_traces


def silly_decimate_single_trace(orig_tr,q=2):
    return orig_tr[::q]
    #return orig_tr[::q].copy() 

## Scale-up Cluster for Data Processing: (All Cores)

In [None]:
%%time
client.cluster.scale(ncore)

## Scatter Concatenated Data to All Cores (from Notebook to Cluster)

In [None]:
%%time
future = client.scatter(list(rdata))
junk = wait(future)

## Detrend Data Per Channel: Multiple Channels Per Thread (Scheduler Decides)

In [None]:
%%time

det_data = client.map(detrend_single_trace,future,pure=False)
jink = wait(det_data) ## returns a "reciept"

rdata_dtype = rdata.dtype #save for gathering
del rdata #clean-up Notebook Memory

## Bandpass Filter Per Channel: (Same as for Detrend)

In [None]:
%%time

bl = 0.025
br = 5.0
fs = 200

bp_data = client.map(bandpass_butter_single_trace,det_data,fs=fs,b0=bl,bN=br,pure=False)
junk = wait(bp_data)

## Decimate Per Channel: (Slightly Faster than Serial)

In [None]:
%%time 

ss = 4
bp_data = client.map(silly_decimate_single_trace,bp_data,q=ss,pure=False)
junk = wait(bp_data)

## Gather Processed Data: (From cluster to Notebook)

In [None]:
%%time

dec_bp_data = np.asarray(client.gather(bp_data,direct=True),dtype=rdata_dtype)
#                          ^               #NOTE: ------------^
#                          |
# --- LOOK ----------------  # for some reason this is slightly faster than two lines of code


# NOTE: np.vstack() has ~same RUNTIME as np.asarray()

## Release Cluster and Scheduler: (and All Related Resources, i.e. Memory, Cores, etc.)

In [None]:
%%time
client.shutdown()
client.close()

## X-Channel Median Removal Per Time-Sample: (Numba-Parallel)

In [None]:
%%time

proc_data = remove_median_xchannel(dec_bp_data)

del dec_bp_data

## Total Processing Time

In [None]:
print('Done Processing')
runtime = time() - start_time #start is in first cell
print(f'runtime: {datetime.timedelta(seconds=runtime)}')

## Define function for plotting: (Serial)

In [None]:
def plot_seven_trees_data(data,times,pclip=.95,fig_size=(9,10)):

    import matplotlib.pyplot as plt

    eqdate = datetime.datetime.utcfromtimestamp(times[0]//1000000)
    start_c = 23000
    end_c = 35000 
    bounds = (0,nfiles*60,end_c,start_c)


    vclip = (1-pclip)*np.abs(data[start_c:end_c+1,:]).max()


    plt.figure(figsize=fig_size)
    plt.imshow(data[start_c:end_c+1,:], aspect='auto', interpolation='none', cmap='gray', vmin=-vclip, vmax=vclip, extent=bounds)
    plt.title('DAS for Seven Trees 1st-Aftershock, 3.1 EQ @' + str(eqdate) )
    plt.xlabel('seconds from: ' + str(eqdate.time()))
    plt.ylabel('channel')
    
    return plt

## 2D Plot of the Processed DAS Data: (Serial)

In [None]:
%%time

pclip = .99
print(f'pclip: {pclip}')

plt = plot_seven_trees_data(proc_data,tdata,pclip=pclip)
plt.show()