# Processing DAS data using Multiprocessing: Seven Trees Aftershock
## SEP: June  2023
### Thomas Cullison, 1st Year Geophysics 

<br><br>

In [None]:
import io
import datetime
import h5py
import multiprocess # <-- Look! (extension of multiprocessing package; confusing? indeed.)

import numpy as np

from scipy import signal
from google.cloud import storage
from time import time
from os import cpu_count
from ctypes import c_float, c_int64  #used for delcaring shared memory

start_time = time() #start timing of processing

## Get Data from the Cloud: Function Defs

In [None]:
def gcs_download_to_local_disk(bucket,lpath,fname,i):
    """
       Function name explains it all
    """
    client = storage.Client()
    bucket = client.get_bucket(bucket)
    blobname = fname + '_min' + str(i+1).zfill(2) + '.npz' 
    blob = bucket.get_blob(blobname)
    dfname = path+fname + '_min' + str(i+1).zfill(2) + '.npz' 
    print(f'downloading to: {dfname}\n')
    blob.download_to_filename(dfname)
    
    
def parallel_load_npz(path,fname,i):
    """
       Load DAS data from local disk
    """
    global s_rdata
    global s_tdata
    
    dfname = path+fname + '_min' + str(i+1).zfill(2) + '.npz' 
    print(f'reading: {dfname}\n')
    dt_data = np.load(dfname)
    s_rdata[:,i,:] = dt_data['data'][:,:]
    s_tdata[i,:] = dt_data['time'][:]

## Setup List of Files to Read: (a priori Knowledge Req.)

In [None]:
buckname = 'sep-allow-others' #kind of like the head/main directory -- Leave ALONE

nfiles = 10 # Leave this ALONE

## Start Dask Distributed Cluster: (10 Threads at Most, One-per-file)

In [None]:
ncore = cpu_count() #Only this part is needed for Thread Pool (later)
print(f'ncore: {ncore}')

## Read All Files to Array -- Map to Threads : (Memory in Processes)

In [None]:
%%time

nchan = 48_000 #fancy commas --> underscore for python
ntsmp = 12_000

# Declare buffers for shared memory between processes
buf_rdata = multiprocess.RawArray(c_float,nchan*ntsmp*nfiles)
buf_tdata = multiprocess.RawArray(c_int64,ntsmp*nfiles)

# "Wrap" above into ndarrays
s_rdata = np.frombuffer(buf_rdata,dtype=np.float32).reshape((nchan, nfiles, ntsmp))
s_tdata = np.frombuffer(buf_tdata,dtype=np.int64).reshape((nfiles,ntsmp))

### Download files to local disk

In [None]:
%%time

path = './data/test/' # put YOUR local path here
fname = 'oct_7trees_aftershock_das' # leading prefix. leave this ALONE 

# setup a list of tuples as args to mapped function
blist = [buckname for i in range(nfiles)] # <-- Note "nfiles" not "ncore"
lpath = [path for i in range(nfiles)] 
lfname = [fname for i in range(nfiles)] 
lidxs = [i for i in range(nfiles)] 
pargs = list(zip(blist,lpath,lfname,lidxs))

# Parallel part
with multiprocess.Pool(processes=nfiles) as pool:
    pool.starmap(gcs_download_to_local_disk, pargs) 
    

### Read compressed numpy arrays into memory

In [None]:
%%time 

# lpath:  Same as above
# lfnam:  Same as above
# lidx:   Same as above
pargs = list(zip(lpath,lfname,lidxs))

# Parallel part
with multiprocess.Pool(processes=nfiles) as pool:
    pool.starmap(parallel_load_npz, pargs) # Will copy memory (via return) to Host/Notebook

## Gather Arrays to Notebook

## Clean-up Cluster Memory

## Scale-down Cluster to One Thread: (Not the Same as Notebook Thread)

## Concatenate Arrays Over Time Axis: (Notebook Thread)

**We only need to "reshape" the arrays here. Note, it's important to appropriately declare the correct shape and use the correct indicies above when we fetched the data for each file from the cloud.**

In [None]:
%%time

# Declare buffers shared memory of the output of bandpass filter
tbuf_rdata = multiprocess.RawArray(c_float,nchan*ntsmp*nfiles)
tbuf_tdata = multiprocess.RawArray(c_int64,ntsmp*nfiles)
test_buf = multiprocess.RawArray(c_float,nchan*ntsmp*nfiles)

# "Wrap" above into ndarrays
cc_rdata = np.frombuffer(tbuf_rdata,dtype=np.float32).reshape((nchan, nfiles*ntsmp))
cc_tdata = np.frombuffer(tbuf_tdata,dtype=np.int64).reshape((nfiles*ntsmp))


# Reshape (concatenate the files for each minute into 10 mintes of coninous data)
s_rdata = None # the buffer is still in global memory
s_tdata = None # the buffer is still in global memory
s_rdata = np.frombuffer(buf_rdata,dtype=np.float32).reshape((nchan, nfiles*ntsmp))
s_tdata = np.frombuffer(buf_tdata,dtype=np.int64).reshape((nfiles*ntsmp))


# To make it easy to follow, we switch to these variable names (only once is out_tr needed)
raw_data = s_rdata.copy()
in_tr = s_rdata
out_tr = cc_rdata

## Begin Processing: Function Defs

**Need to make significant changes to these functions. No returns; start and end indices; global/shared memory**

In [None]:
# No JIT'ing!
#@numba.njit(cache=True, fastmath=True, nogil=True, parallel=True)
def remove_median_xchannel(start,end):
    global in_tr #shared memory
    for i in range(start,end):
        in_tr[:, i] -= np.median(in_tr[:, i])


# CANNOT jit ANY of the functions below

def detrend_all_traces(start,end):
    global in_tr #shared memory
    signal.detrend(in_tr[start:end],type='constant',overwrite_data=True)
    signal.detrend(in_tr[start:end],type='linear',overwrite_data=True)
        

def bandpass_butter_all_traces(start, end, fs=None, b0=None, bN=None, order=5):
    global in_tr #shared memory
    global out_tr #shared memory
    sos = signal.butter(order, (b0,bN), 'bandpass', fs=fs, output='sos')
    out_tr[start:end,:] = signal.sosfiltfilt(sos, in_tr[start:end])[:]


#FIXME: maybe this is why there is a difference between all?
def silly_zero_data(d2z,i):
    d2z[i] = 0
    

# swap views for global/shared memory arraay
def swap_views():
    global in_tr
    global out_tr
    tmp = in_tr
    in_tr = out_tr
    out_tr = tmp
    tmp = None

## Scale-up Cluster for Data Processing: (All Cores)

## Scatter Concatenated Data to All Cores (from Notebook to Cluster)

## Detrend Data Per Channel: Multiple Channels Per Thread (Scheduler Decides)

In [None]:
%%time


# We want a total of "ncore" chuncks (range of channels)
chan_chnk = nchan//ncore

# start and end indices into shared memory
slst = [i for i in range(0,nchan,chan_chnk)]
elst = [i for i in range(chan_chnk,nchan+1,chan_chnk)]
pargs = pargs = list(zip(slst,elst))


# Parallel part
with multiprocess.Pool(processes=ncore) as pool:
    pool.starmap(detrend_all_traces, pargs)
    # output is stored in in_tr shared-array
    

## Bandpass Filter Per Channel: (Same as for Detrend)

In [None]:
%%time

# Bandpass parameters
bl = 0.025
#br = 5.0
#bl = 0.075
br = 1.5
fs = 200

# start and end indices into shared memory
bl_lst = [bl for i in range(ncore)]
br_lst = [br for i in range(ncore)]
fs_lst = [fs for i in range(ncore)]
pargs = pargs = list(zip(slst,elst,fs_lst,bl_lst,br_lst))


# Parallel part
with multiprocess.Pool(processes=ncore) as pool:
    pool.starmap(bandpass_butter_all_traces, pargs)
    # output is stored in out_tr shared-array
    
    
swap_views() # like pointer swapping (for shared arrays)

## Decimate Per Channel: (Slightly Faster than Serial)

In [None]:
%%time 

ss = 30
ohz = 200
in_tr = in_tr[:,::ss] # easy peasy

## Gather Processed Data: (From cluster to Notebook)

## Release Cluster and Scheduler: (and All Related Resources, i.e. Memory, Cores, etc.)

## X-Channel Median Removal Per Time-Sample: (Serial)

In [None]:
%%time


# concatenated and decimated num time samps
# again, find chunck size for "ncore" chunks (but in time)
dntsmp = (ntsmp*nfiles)//ss 
time_chnk = dntsmp//ncore

# start and end indices into shared memory
slst = [i for i in range(0,dntsmp,time_chnk)]
elst = [i for i in range(time_chnk,dntsmp+1,time_chnk)]
pargs = pargs = list(zip(slst,elst))


# Parallel part
with multiprocess.Pool(processes=ncore) as pool:
    pool.starmap(remove_median_xchannel, pargs)
    # output is stored in in_tr shared-array
    

## Total Processing Time

In [None]:
print('Done Processing')
runtime = time() - start_time #start is in first cell
print(f'runtime: {datetime.timedelta(seconds=runtime)}')

## Define function for plotting: (Serial)

In [None]:
def plot_seven_trees_data(data,times,pclip=.95,fig_size=(9,10)):

    import matplotlib.pyplot as plt

    eqdate = datetime.datetime.utcfromtimestamp(times[0]//1000000)
    start_c = 22_000
    end_c = 28_000
    #start_c = 23_225 # start spool?
    #end_c = 23_300   # end   spool?
    bounds = (0,nfiles*60,end_c,start_c)


    vclip = (1-pclip)*np.abs(data[start_c:end_c+1,:]).max()


    plt.figure(figsize=fig_size)
    plt.imshow(data[start_c:end_c+1,:], aspect='auto', interpolation='none', cmap='gray', vmin=-vclip, vmax=vclip, extent=bounds)
    #plt.title('DAS for Seven Trees 1st-Aftershock, 3.1 EQ @' + str(eqdate) )
    #plt.title('DAS Data: ' + str(eqdate))
    plt.title('Recorded: ' + str(eqdate))
    plt.xlabel('seconds from: ' + str(eqdate.time()))
    plt.ylabel('channel')
    
    outfig = plt.gcf()
    
    return plt,outfig,eqdate

## 2D Plot of the Processed DAS Data: (Serial)

In [None]:
%matplotlib inline
#%%time

pclip = .99
#pclip = .0
print(f'pclip: {pclip}')

plt, fig, eqdate = plot_seven_trees_data(in_tr,s_tdata,pclip=pclip)
plt.show()