# Processing DAS data using mpi4py: Seven Trees Aftershock
## SEP: June  2023
### Thomas Cullison, 1st Year Geophysics 

<br><br>

In [None]:
# Host-side (serial thread)
import datetime
import numba
import ipyparallel

import numpy as np

from time import time
from os import cpu_count

start_time = time() #start timing of processing

## Start MPI Local Cluster: (All 30 Threads)

**This is a bit different then for the Serial and DASK notebooks. We need to start cluster before we use the "parallel" magic.**

In [None]:
ncore = cpu_count()

# attach to a running cluster
cluster = ipyparallel.Client(profile='mpi',n=ncore)
print('profile:', cluster.profile)
print("IDs:", list(cluster.ids)) # Print process id numbers

**We also need to import packages for the parallel environment/processes.**

In [None]:
%%px 
# ^-- Look!  (ipyparallel magic)


# MPI process-side imports
import io
import datetime
import h5py
import numba

import numpy as np

from mpi4py import MPI     # <-- Look!
from scipy import signal
from google.cloud import storage

In [None]:
%%px
# ^-- Look again!


# init
# MPI.Init(): this is done by they ipyparallel.Client() part. DO NOT call MPI.Init with parallel magic.

# get WORLD_COMM info
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

if rank == 0:
    print(f'root: num-process: {comm.size}')
    
print(f'my rank: {rank}')

## Get Data from the Cloud: Function Defs

In [None]:
%%px 

def gcs_download_to_local_disk(pargs,bucket=None,lpath=None):
    """
       Function name explains it all
    """
    fname, i = pargs
    client = storage.Client()
    bucket = client.get_bucket(bucket)
    blobname = fname + '_min' + str(i+1).zfill(2) + '.npz' 
    blob = bucket.get_blob(blobname)
    dfname = path+fname + '_min' + str(i+1).zfill(2) + '.npz' 
    blob.download_to_filename(dfname)
    
    
def parallel_load_npz(pargs,path=None):
    """
       Load DAS data from local disk
    """
    fname, i = pargs
    
    dfname = path+fname + '_min' + str(i+1).zfill(2) + '.npz' 
    dt_data = np.load(dfname)
    
    return dt_data['data'], dt_data['time']

## Setup List of Files to Read: (a priori Knowledge Req.)

In [None]:
%%px

buckname = 'sep-allow-others' #kind of like the head/main directory -- Leave ALONE

nfiles = 10 # Leave this ALONE

## Read All Files to Array -- Map to Threads : (Memory in Cluster)

#### First we need to setup MPI Comm's

In [None]:
%%px

# Assign a "Color": those who read files (ranks 0-9), and those who don't
key = rank
color = 0
if nfiles-1 < rank:
    color = 1
file_comm = comm.Split(color,key)

if nfiles-1 < rank:
    file_comm = MPI.COMM_NULL  # dummy comm
    
# init comm ranks and size    
fc_rank = -1
fc_size = -1

#get comm info if reading file
if file_comm != MPI.COMM_NULL:
    fc_rank = file_comm.Get_rank()
    fc_size = file_comm.Get_size()
    print('FILE_COMM')
    print(f'\nw-comm:  rank,   size    = {rank},{size}\nfc-comm: fc_rank,fc_size = {fc_rank},{fc_size}')
    
comm.Barrier() #pedantic WORLD_COMM
    
if file_comm == MPI.COMM_NULL:
    print('NULL_COMM')
    print(f'\nw-comm:  rank,   size    = {rank},{size}\nfc-comm: fc_rank,fc_size = {fc_rank},{fc_size}')
    
comm.Barrier() #pedantic WORLD_COMM

### Now, participating MPI-processes get the data from the cloud and save to disk

In [None]:
%%px


path = './data/test/' # put YOUR local path here
fname = 'oct_7trees_aftershock_das' # leading prefix. leave this ALONE 

pargs = (fname,rank)

if file_comm != MPI.COMM_NULL: 
    gcs_download_to_local_disk(pargs,bucket=buckname,lpath=path)
    
comm.Barrier()


# Watch Below!

### Read compressed numpy arrays into memory

In [None]:
%%px


rdata = None  #non-participating ranks stay as "None"
rtime = None

# pargs: same as above

if file_comm != MPI.COMM_NULL: 
    rdata, rtime = parallel_load_npz(pargs,path=path)
    print(f'rank-{rank}, rdata.shape: {rdata.shape}')
    print(f'rank-{rank}, rdata.dtype: {rdata.dtype}')
    
comm.Barrier()


# Watch Below!

## Gather Arrays to ROOT Process (scatterV might be more prudent) 

In [None]:
%%px


gath_rdata = None
gath_rtime = None

if rank == 0:
    gath_rdata = np.empty((nfiles, rdata.shape[0], rdata.shape[1]), dtype=rdata.dtype)
    gath_rtime = np.empty((nfiles, rtime.shape[0]), dtype=rtime.dtype)

if file_comm != MPI.COMM_NULL: 
    file_comm.Gather(rdata,gath_rdata,root=0)
    file_comm.Gather(rtime,gath_rtime,root=0)
    
comm.Barrier()

## Clean-up Cluster Memory (data on 10 participating ranks)

In [None]:
%%px


if file_comm != MPI.COMM_NULL: 
    del rdata
    rdata = None
    del rtime
    rtime = None

## Scale-down Cluster to One Thread: (Not the Same as Notebook Thread)

## Concatenate Arrays Over Time Axis: (Notebook Thread)

In [None]:
%%px


cc_rdata = None
cc_rtime = None
if rank == 0:
    cc_rdata = np.hstack(gath_rdata)
    cc_rtime = np.hstack(gath_rtime) #UTC times hsed for plotting, etc.

## Begin Processing: Function Defs

In [None]:
# LOOK!  (no magic, will use numba from Host/Notebook process)

@numba.njit(cache=True, fastmath=True, nogil=True, parallel=True)
def remove_median_xchannel(orig_tr):
    rmed_traces = orig_tr.copy()
    for it in numba.prange(orig_tr.shape[-1]):
        rmed_traces[:, it] -= np.median(orig_tr[:, it])
    return rmed_traces

In [None]:
%%px
# ^-- LOOK! (these tasks done by MPI processes)

# CANNOT jit ANY of the functions below

def detrend_all_traces(orig_tr):
    det_traces = orig_tr.copy()
    for i in range(det_traces.shape[0]): # did this because scipy mem managment is not good enough
        signal.detrend(det_traces[i],type='constant',overwrite_data=True)
        signal.detrend(det_traces[i],type='linear',overwrite_data=True)
    return det_traces


def detrend_single_trace(orig_tr):
    det_const = signal.detrend(orig_tr,type='constant')
    det_trace = signal.detrend(det_const,type='linear')
    del det_const
    return det_trace


    
def bandpass_butter_single_trace(trace, fs=None, b0=None, bN=None, order=5):
    sos = signal.butter(order, (b0,bN), 'bandpass', fs=fs, output='sos')
    bp_trace = signal.sosfiltfilt(sos, trace)
    return bp_trace



def bandpass_butter_all_traces(orig_tr, fs=None, b0=None, bN=None, order=5):
    bp_traces = np.zeros_like(orig_tr)
    sos = signal.butter(order, (b0,bN), 'bandpass', fs=fs, output='sos')
    for i in range(len(orig_tr)):
        bp_traces[i,:] = signal.sosfiltfilt(sos, orig_tr[i])
    return bp_traces



def silly_decimate_single_trace(orig_tr,q=2):
    return orig_tr[::q]
    #return orig_tr[::q].copy() 

## Scale-up Cluster for Data Processing: Not Used with MPI

## Scatter Concatenated Data to All Cores (from Notebook to Cluster)

In [None]:
%%px


# Setup metadata that ONLY root has. Needed for scattering data
d_shape = None
buf_dtype = None
if rank == 0:
    d_shape = (cc_rdata.shape[0]//size,cc_rdata.shape[1]) #ndarray.shape
    buf_dtype = cc_rdata.dtype  #float type
    
    
# Broadcast Meta data
d_shape = comm.bcast(d_shape,root=0)
buf_dtype = comm.bcast(buf_dtype,root=0)


# Setup receiving buffers data
sendbuf = None  
sc_rdata = np.zeros(d_shape,dtype=buf_dtype)


# Only root sends data
if rank == 0:
    sendbuf = cc_rdata.reshape((size,d_shape[0],d_shape[1])) #IMPORTANT: reshape NO splitting

comm.Barrier()
    
# Scatter the data --> all ranks get their respective chunks as ndarrays
comm.Scatter(sendbuf,sc_rdata,root=0)

if rank == 0:
    del cc_rdata
    cc_rdata = None
    sendbuf = None

comm.Barrier()

## Detrend Data Per Channel: CHUNKS of Channels Per MPI Process

In [None]:
%%px

#FIXME: "_traces"
det_data = detrend_all_traces(sc_rdata)

del sc_rdata  #each rank cleans up their data


# Watch Below!

## Bandpass Filter Per Channel: (Same as for Detrend)

In [None]:
%%px


bl = 0.025
br = 5.0
fs = 200

bp_data = bandpass_butter_all_traces(det_data,fs=fs,b0=bl,bN=br)

del det_data

## Decimate Per Channel: (Same as Bandpass)

In [None]:
%%px


#FIXME: global before MPI Init?
ss = 4
dec_bp_data = bp_data[:,::ss].copy()

del bp_data

## Gather Processed Data: From cluster to Notebook (work-around)

#### First, gather data to root, then we will copy to disk so that the "Host/Notebook" thread can read (this is a "poor-man's" work-around).

In [None]:
%%px


# All processes need the variable
g_dec_data = None


# But, ONLY root gets this data (initalize ndarray)
if rank == 0:
    g_dec_data = np.empty((size, dec_bp_data.shape[0], dec_bp_data.shape[1]), dtype=dec_bp_data.dtype)

    
# Gather data to root
comm.Gather(dec_bp_data,g_dec_data,root=0)


# Wait for gather to finish
comm.Barrier()


# Reshape (new VIEW)
if rank == 0:
    g_dec_data = np.vstack(g_dec_data)
    

comm.Barrier()

#### Next, have root write the gathered data to disk.

In [None]:
%%px

# Set file mode
amode = MPI.MODE_WRONLY|MPI.MODE_CREATE


# Create file descriptors (all ranks)
f_d = MPI.File.Open(comm, './data/processed_traces.np', amode)
f_t = MPI.File.Open(comm, './data/times_4_traces.np', amode)


# Only root writes file
if rank == 0:
    wbuf_d = g_dec_data.flatten()
    wbuf_t = cc_rtime.flatten() ## Remember this from above? It has our UTC times
    f_d.Write(wbuf_d)
    f_t.Write(wbuf_t)
    

# Wait for write to finish
comm.Barrier()

# close file descriptors (all ranks)
f_d.Close()
f_t.Close()


# clean up data
del g_dec_data

## Release Cluster: (and All Related Resources, i.e. Memory, Cores, etc.)

In [None]:
%%px

MPI.Finalize()

In [None]:
# Host/Notebook

cluster.shutdown()
cluster.close()

## X-Channel Median Removal Per Time-Sample: Numba Parallel (shared memory)

#### We need to first read the MPI processed data from disk (Serial). 
**We might get better performance by using MPI_All_to_All() to avoid this workaround by transposing the data within the MPI ranks. Another alternative is to communicate through an ipyparallel client-view (still trying to figure this out)**

In [None]:
%%time

traces_from_disk = np.fromfile('./data/processed_traces.np', dtype=np.float32)
tdata = np.fromfile('./data/times_4_traces.np', dtype=np.int64)
dec_bp_data = traces_from_disk.reshape((48000,30000)) # 30,000 = 120,000/ss above (sorta cheating)

#### Now, we remove the xchannel median (Parallel). Note, some of the runtime is due to the JIT compile.

In [None]:
%%time

proc_data = remove_median_xchannel(dec_bp_data) #func() defined above in a "Host" cell not "ipyparallel" cell

del dec_bp_data

## Total Processing Time

In [None]:
print('Done Processing')
runtime = time() - start_time #start is in first cell
print(f'runtime: {datetime.timedelta(seconds=runtime)}')

## Define function for plotting: (Serial)

In [None]:
def plot_seven_trees_data(data,times,pclip=.95,fig_size=(9,10)):

    import matplotlib.pyplot as plt

    eqdate = datetime.datetime.utcfromtimestamp(times[0]//1000000)
    start_c = 23000
    end_c = 35000 
    bounds = (0,nfiles*60,end_c,start_c)


    vclip = (1-pclip)*np.abs(data[start_c:end_c+1,:]).max()


    plt.figure(figsize=fig_size)
    plt.imshow(data[start_c:end_c+1,:], aspect='auto', interpolation='none', cmap='gray', vmin=-vclip, vmax=vclip, extent=bounds)
    plt.title('DAS for Seven Trees 1st-Aftershock, 3.1 EQ @' + str(eqdate) )
    plt.xlabel('seconds from: ' + str(eqdate.time()))
    plt.ylabel('channel')
    
    return plt

## 2D Plot of the Processed DAS Data: (Serial)

In [None]:
# Small Hack for Plotting (var is created inside MPI cluster)
nfiles = 10

In [None]:
%%time

pclip = .99
print(f'pclip: {pclip}')

plt = plot_seven_trees_data(proc_data,tdata,pclip=pclip)
plt.show()