# Redue/refactor Project.py

### Step 0

Load packages

In [None]:
#load all packages
import datetime
import pickle
import copy
import os

from sys import argv
from pathlib import Path

import numpy as np
import pandas as pd
import pyvista as pv
import matplotlib.pyplot as plt 
from matplotlib.colors import Normalize


from pyaspect.project import *
from pyaspect.model.gridmod3d import gridmod3d as gm
from pyaspect.model.bbox import bbox as bb
from pyaspect.model.gm3d_utils import *
from pyaspect.moment_tensor import MomentTensor
from pyaspect.specfemio.headers import *
#from pyaspect.specfemio.write import *
from pyaspect.specfemio.read import *
from pyaspect.specfemio.utils import *


import pyaspect.events.gevents as gevents
import pyaspect.events.gstations as gstations
from pyaspect.events.munge.knmi import correct_station_depths as csd_f
import pyaspect.events.mtensors as mtensors
from obspy.imaging.beachball import beach
from obspy import UTCDateTime
import shapefile as sf

### Step 1 

Extract the ndarray of the subsampled, smoothed NAM model and instantiate a new GriddedModel3D object for QC'ing

In [None]:
data_in_dir  = 'data/output/'
data_out_dir = data_in_dir
!ls {data_in_dir}
!ls data/groningen

### Step 6 

Decompress the ndarray of the sliced, subsampled, smoothed NAM model and instantiate a new GriddedModel3D object for QC'ing

In [None]:
# set filename then used it to decompress model
ifqn = f'{data_out_dir}/vsliced_subsmp_smth_nam_2017_vp_vs_rho_Q_model_dx100_dy100_dz100_maxdepth5850_sig250.npz'
vslice_gm3d, other_pars = decompress_gm3d_from_file(ifqn)

print()
print('decompressed gridded model\n:',vslice_gm3d) 
print()
print('other parameters:\n',other_pars)
print()

# WARNING: this will unpack all other_pars, if you overwrite a variable of the samename as val(key), then you 
#          may not notice, and this may cause large headaches.  I use it because I am aware of it.
'''
for key in other_pars:
    locals()[key] = other_pars[key]  #this is more advanced python than I think is reasonable for most 
sig_meters = sig
''';

# another way to get these varibles is just use the accessor functions for the gridmod3d.  We need them later.
xmin = other_pars['xmin']
dx   = other_pars['dx']
nx   = other_pars['nx']
ymin = other_pars['ymin']
dy   = other_pars['dy']
ny   = other_pars['ny']
zmin = other_pars['zmin']
dz   = other_pars['dz']
nz   = other_pars['nz']
sig_meters = other_pars['sig']  # this variable is used later
print('sig_meters:',sig_meters)

In [None]:
# Create the spatial reference
grid = pv.UniformGrid()

# Set the grid dimensions: shape + 1 because we want to inject our values on
#   the CELL data
nam_dims = list(vslice_gm3d.get_npoints())
nam_origin = [0,0,-vslice_gm3d.get_gorigin()[2]]
#nam_origin = list(vslice_gm3d.get_gorigin())
#nam_origin[2] *= -1
nam_origin = tuple(nam_origin)
nam_spacing = list(vslice_gm3d.get_deltas())
nam_spacing[2] *=-1
nam_spacing = tuple(nam_spacing)
print('nam_dims:',nam_dims)
print('nam_origin:',nam_origin)
print('nam_spacing:',nam_spacing)

# Edit the spatial reference
grid.dimensions = np.array(nam_dims) + 1
grid.origin = nam_origin  # The bottom left corner of the data set
grid.spacing = nam_spacing  # These are the cell sizes along each axis
nam_pvalues = vslice_gm3d.getNPArray()[0]
print('pvalues.shape:',nam_pvalues.shape)

# Add the data values to the cell data
grid.cell_arrays["values"] = nam_pvalues.flatten(order="F")  # Flatten the array!

# Now plot the grid!
cmap = plt.cm.jet
#grid.plot(show_edges=True,cmap=cmap)
grid.plot(cmap=cmap,opacity=1.0)


In [None]:
slices = grid.slice_orthogonal()

#slices.plot(show_edges=True,cmap=cmap)
slices.plot(cmap=cmap)

## create random virtual source (to specfem stations, but using reciprocity -- sources)

In [None]:
#coords = vslice_gm3d.getGlobalCoordsPointsXYZ()
coords = vslice_gm3d.getLocalCoordsPointsXYZ()
coords[:,2] = -coords[:,2]

xc = np.unique(coords.T[0,:])
yc = np.unique(coords.T[1,:])
zc = np.unique(coords.T[2,:])


#n_rand_p = 1000

n_rand_p = 3
np.random.seed(n_rand_p) #nothing special about using n_rand_p just want reproducible random

#stay away from the edges of the model for derivatives 
# and to avoid boundary effects
xy_pad = 500 

lrx = np.min(xc) + xy_pad
lry = np.min(yc) + xy_pad
lrz = -3400.0

hrx = np.max(xc) - xy_pad
hry = np.max(yc) - xy_pad
hrz = -2600.0

srx = hrx - lrx
sry = hry - lry
srz = hrz - lrz

r_xyz_list = []
for i in range(n_rand_p):
    rx = lrx + srx*np.random.rand()
    ry = lry + sry*np.random.rand()
    rz = lrz + srz*np.random.rand()
    r_xyz_list.append([rx,ry,rz])
    
r_xyz = np.array(r_xyz_list)
    

#r_xyz = np.vstack(np.meshgrid(rx,ry,rz)).reshape(3,-1).T
print('r_xyz:\n',r_xyz)


In [None]:
pv_rpoints = pv.wrap(r_xyz)
p = pv.Plotter()
slices = grid.slice_orthogonal()
#p.add_mesh(slices,cmap=cmap,opacity=0.50)
p.add_mesh(slices,cmap=cmap,opacity=1)
p.add_mesh(grid,cmap=cmap,opacity=0.50)
p.add_mesh(pv_rpoints, render_points_as_spheres=True, point_size=5,opacity=1.0)

p.show()

## Make Moment Tensors and CMTSolutionHeaders for each tensor

In [None]:
# this is the path to the project dir on the cluster
my_proj_dir = '/scratch/seismology/tcullison/test_mesh/FWD_Batch_Src_Test'

magnitude = np.pi
strike = [30,45,90] # just making three to test
dip = [30,30,60]
rake = [330,190,20]

l_mt = []
for i in range(len(strike)):
    l_mt.append(MomentTensor(mw=magnitude,strike=strike[i],dip=dip[i],rake=rake[i]))

assert len(l_mt) == len(r_xyz)

for mt in l_mt:
    print(mt)
    
l_cmt_src = []
for i in range(len(r_xyz)):
    cmt_h = CMTSolutionHeader(date=datetime.datetime.now(),
                              ename=f'Event-{str(i).zfill(4)}',
                              tshift=0.0,
                              hdur=0.0,
                              lat_yc=r_xyz[i,1],
                              lon_xc=r_xyz[i,0],
                              depth=-r_xyz[i,2],
                              mt=l_mt[i],
                              eid=i,
                              sid=0)
    l_cmt_src.append(cmt_h)
    
print()
for cmt in l_cmt_src:
    print(cmt)

## Make Corresponding "Virtual" Recievers (including cross membors for derivatives) for the CMT's

In [None]:
m_delta = 50.0 # distance between cross stations for derivatives
assert m_delta < xy_pad #see cells above this is padding
#l_grp_vrecs = make_grouped_half_cross_reciprocal_station_headers_from_cmt_list(l_cmt_src,m_delta)
l_grp_vrecs = make_grouped_cross_reciprocal_station_headers_from_cmt_list(l_cmt_src,m_delta)

ig = 0
for grp in l_grp_vrecs:
    print(f'***** Group: {ig} *****\n')
    ir = 0
    for gvrec in grp:
        print(f'*** vrec: {ir} ***\n{gvrec}')
        ir += 1
    ig += 1

print(len(flatten_grouped_headers(l_grp_vrecs)))
    

## Plot Virtual Receiver Groups

In [None]:
all_g_xyz = get_xyz_coords_from_station_list(flatten_grouped_headers(l_grp_vrecs))
all_g_xyz[:,2] *= -1 #pyview z-up positive and oposize sign of standard geophysics 
pv_all_points = pv.wrap(all_g_xyz)
p = pv.Plotter()
p.add_mesh(grid,cmap=cmap,opacity=0.5)
p.add_mesh(slices,cmap=cmap,opacity=1.0)
p.add_mesh(pv_all_points, render_points_as_spheres=True, point_size=5,opacity=1.0)
p.show()

## Get receiver/station coordinates created from a different notebook

In [None]:
# unpickle the Bounding box (from a different notebook)

#ifqn  = data_out_dir + 'bbox_nvl' + str(int(nvl)) + '_nvb' + str(int(nvb))
#ifqn += '_xsft' + str(xshift) + '_ysft' + str(yshift) + '.pickle'
ifqn = data_out_dir + 'bbox_nvl152_nvb197_xsft4400_ysft19100.pickle'
f = open(ifqn, 'rb')
sgf_bbox = pickle.load(f)
f.close()
print()
print('Unpickled Bounding:\n',sgf_bbox)

In [None]:
# unpickle the events if needed (again from a different notebook)
ifqn = data_out_dir + 'bbox_groning_events.pickle'
f = open(ifqn, 'rb')
bbox_events = pickle.load(f)
f.close()
print()
print('Unpickled Events:\n',bbox_events)

In [None]:
#Read moment tensors
mt_in_file  = 'data/groningen/events/event_moments.csv' 
!ls {mt_in_file}
bbox_gf_mts = mtensors(mt_in_file)

# get event catalog of the events (ObsPy catalog)
bbox_event_cat = copy.deepcopy(bbox_events.getIncCatalog())

# This is a bit hokey, but it works. Here we update the
# event time from the moment tensor CSV file with thouse
# from the event catalog
bbox_gf_mts.update_utcdatetime(bbox_event_cat)

'''
#for imt in range(len(bbox_gf_mts)):
#    print("Moment-Tensor %d:/n" %(imt),bbox_gf_mts[imt])
'''

# Create a dictionary that maps moment tensors to events
bbox_emap,bbox_mt_cat,bbox_mts = bbox_gf_mts.get_intersect_map_events_mts(bbox_event_cat)
bbox_e2mt_keys = bbox_emap.keys()

# Print a comparison of events to moment tensors
for key in bbox_e2mt_keys:
    print('UTC: event[%d][Date] = %s' %(key,bbox_mt_cat[key].origins[0].time))
    print('UTC:    MT[%d][Date] = %s' %(key,bbox_emap[key]['Date']))
    print('Mag: event[%d][Date] = %s' %(key,bbox_mt_cat[key].magnitudes[0].mag))
    print('Mag:    MT[%d][Date] = %s' %(key,bbox_emap[key]['ML']))
    print()

#replace moment-tensors with only those that intersect with the events in the BoundingBox
bbox_gf_mts.replace_moment_tensors_from_map(bbox_emap)
    
# add mt_catalog to bbox_events
bbox_events.mergeMomentTensorsCatalog(bbox_mt_cat)
merged_bbox_event_cat = bbox_events.getIncCatalog()
print('bbox_event_cat:\n', bbox_event_cat)
print()
print('merged_bbox_event_cat:\n', merged_bbox_event_cat)
print()
print('bbox_mt_cat:\n', bbox_mt_cat)
print()
print('bbox_mt_df:\n', bbox_gf_mts)

In [None]:
ifqn = data_out_dir + 'bbox_groning_stations.pickle'

print('Unpickling Station Traces')
f = open(ifqn, 'rb')
bbox_straces = pickle.load(f)
f.close()

print('Stations:\n',type(bbox_straces))

In [None]:
# read shapefiles
shape_in_files  = 'data/groningen/shapefile/Groningen_field' 

gf_shape = sf.Reader(shape_in_files)
print('Groningen Field shape:',gf_shape)

#get coordinates for the Shape-File
s = gf_shape.shape(0)
shape_xy = np.asarray(s.points)

In [None]:
# This is kind of hokey, but it works for now.
# Some of the stations depths do not follow the 
# 50, 100, 150, 200 meter depths -- possibly because
# the boreholes are slanted. To correct for this,
# a hard coded "patch/update" is applied. See the
# code for details and update values.
#from gnam.events.munge.knmi import correct_station_depths as csd_f
bbox_straces.correct_stations(csd_f)

bbox_bb_diam = 1500  #size of the beachball for plotting. I had to play with this parameter
bbox_cmt_bballs = bbox_gf_mts.get_cmt_beachballs(diam=bbox_bb_diam,fc='black')

bbox_mt_coords = bbox_events.getIncCoords()

#get event and borhole keys used for indexing
ekeys = bbox_straces.getEventKeys()
bkeys = bbox_straces.getBoreholeKeys()

#Plot seuence of events with stations 
#for ie in ekeys:
for i in range(1):
    ie = ekeys[i]
    # coordinates for stations that are in the bounding box
    xy3 = bbox_straces.getIncStationCoords(ie,bkeys[0]) #station code G##3
    xy4 = bbox_straces.getIncStationCoords(ie,bkeys[1]) #station code G##4
    
    # coordinates for stations that are G-stations but outside the bounding box
    ex_xy3 = bbox_straces.getExcStationCoords(ie,bkeys[0]) #station code G##3
    ex_xy4 = bbox_straces.getExcStationCoords(ie,bkeys[1]) #station code G##4
    
    # coordinates for stations that are inside the bounding box but there is no data
    er_xy3 = bbox_straces.getErrStationCoords(ie,bkeys[0]) #station code G##3
    er_xy4 = bbox_straces.getErrStationCoords(ie,bkeys[1]) #station code G##4

    fig, ax = plt.subplots(1,figsize=(8,8))
    fig.gca().set_aspect('equal', adjustable='box')
    
    #Groningen Field Shape-File
    ax.scatter(shape_xy[:,0],shape_xy[:,1],s=1,c='black',zorder=0)
    
    #Bounding Box
    ax.plot(sgf_bbox.getCLoop()[:,0],sgf_bbox.getCLoop()[:,1],c='green',zorder=1)
    
    #Events (reuse event coordinates from cell further above)
    ax.scatter(bbox_mt_coords[ie,0],bbox_mt_coords[ie,1],s=90,c='red',marker='*',zorder=5)
    beach = bbox_cmt_bballs[ie]  #this creates a plot collection for the beachball points
    beach.set_zorder(3)
    ax.add_collection(beach)
    
    #Included stations
    ax.scatter(xy3[:,0],xy3[:,1],s=50,c='blue',marker='v',zorder=3)
    ax.scatter(xy4[:,0],xy4[:,1],s=100,c='gray',marker='o',zorder=2)
    
    #Excluded stations
    ax.scatter(ex_xy3[:,0],ex_xy3[:,1],s=80,c='lightgray',marker='1',zorder=4)
    ax.scatter(ex_xy4[:,0],ex_xy4[:,1],s=100,c='lightgray',marker='2',zorder=3)
    
    #Stations without data
    ax.scatter(er_xy3[:,0],er_xy3[:,1],s=50,c='yellow',marker='v',zorder=4)
    ax.scatter(er_xy4[:,0],er_xy4[:,1],s=100,c='gray',marker='o',zorder=3)
    
    origin_time = bbox_events[ie].origins[0].time
    mag = bbox_events[ie].magnitudes[0].mag
    title_str = 'Event-%d, Origin Time: %s, Magnitude: %1.2f' %(ie,str(origin_time),mag)
    ax.set_title(title_str)
    plt.show()
    


In [None]:
fig1, ax1 = plt.subplots(1,figsize=(8,8))
fig1.gca().set_aspect('equal', adjustable='box')
    
#Groningen Field Shape-File
ax1.scatter(shape_xy[:,0],shape_xy[:,1],s=1,c='black',zorder=0)

#Bounding Box
ax1.plot(sgf_bbox.getCLoop()[:,0],sgf_bbox.getCLoop()[:,1],c='green',zorder=1)

#Events (reuse event coordinates from cell further above)
ax1.scatter(bbox_mt_coords[ie,0],bbox_mt_coords[ie,1],s=90,c='red',marker='*',zorder=5)
    
all_xy4 = np.concatenate((xy4,er_xy4),axis=0)
#all_xy4 = xy4

ax1.scatter(all_xy4[:,0],all_xy4[:,1],s=100,c='gray',marker='o',zorder=2)
ax1.scatter(er_xy4[:,0],er_xy4[:,1],s=100,c='yellow',marker='x',zorder=2)

title_str = 'Event-%d, Origin Time: %s, Magnitude: %1.2f' %(ie,str(origin_time),mag)
ax1.set_title(title_str)
plt.show()

## Make random virtual sources

In [None]:
coords = vslice_gm3d.getLocalCoordsPointsXY()

x_orig = vslice_gm3d.get_gorigin()[0]
y_orig = vslice_gm3d.get_gorigin()[1]

clip_xy = all_xy4[9:13]
print(clip_xy)

s_xyz = np.zeros((len(clip_xy),3))
s_xyz[:,0] = clip_xy[:,0] - x_orig
s_xyz[:,1] = clip_xy[:,1] - y_orig
s_xyz[:,2] = -200

print(s_xyz)

## Plot virtual sources (red) with virtual receivers (white)

In [None]:
pv_spoints = pv.wrap(s_xyz)
p = pv.Plotter()
#p.add_mesh(slices,cmap=cmap,opacity=0.50)
p.add_mesh(grid,cmap=cmap,opacity=0.3)
p.add_mesh(pv_spoints, render_points_as_spheres=True, point_size=8,opacity=1,color='red')
#p.add_mesh(pv_rpoints, render_points_as_spheres=True, point_size=5,opacity=0.5)
p.add_mesh(all_g_xyz, render_points_as_spheres=True, point_size=5,opacity=0.5)
p.show()

## Make StationHeaders (real recievers) 

In [None]:
l_real_recs = []
for i in range(len(s_xyz)):
    
    tr_bname = 'tr'
    new_r = StationHeader(name=tr_bname,
                          network='NL', #FIXME
                          lon_xc=s_xyz[i,0],
                          lat_yc=s_xyz[i,1],
                          depth=-s_xyz[i,2], #specfem z-down is positive
                          elevation=0.0,
                          trid=i)
    l_real_recs.append(new_r)
    
for rec in l_real_recs:
    print(rec)


## Make ForceSolutionHeaders for the above virtual sources (including force-triplets for calculation derivatives)

In [None]:
l_grp_vsrcs = make_grouped_reciprocal_force_solution_triplet_headers_from_rec_list(l_real_recs)

## Make replicates of each virtual receiver list: one for each force-triplet

In [None]:
l_grp_vrecs_by_vsrcs = make_replicated_reciprocal_station_headers_from_src_triplet_list(l_grp_vsrcs,
                                                                                          l_grp_vrecs)

## Plot virtual sources (red) and virtual receivers (white) FROM headers

In [None]:
grp_s_xyz = get_unique_xyz_coords_from_solution_list(flatten_grouped_headers(l_grp_vsrcs))
grp_s_xyz[:,2] *= -1 #pyvista z-up is positive

flat_recs = flatten_grouped_headers(flatten_grouped_headers(l_grp_vrecs_by_vsrcs))
grp_r_xyz = get_unique_xyz_coords_from_station_list(flat_recs)
grp_r_xyz[:,2] *= -1 #pyvista z-up is positive

print(len(grp_s_xyz))
print(len(grp_r_xyz))

pv_spoints = pv.wrap(grp_s_xyz)
pv_rpoints = pv.wrap(grp_r_xyz)

p = pv.Plotter()
p.add_mesh(slices,cmap=cmap,opacity=0.50)
p.add_mesh(grid,cmap=cmap,opacity=0.3)
p.add_mesh(pv_spoints, render_points_as_spheres=True, point_size=8,opacity=1,color='red')
p.add_mesh(pv_rpoints, render_points_as_spheres=True, point_size=5,opacity=0.5)
p.show()

## Make replicates of each "real" receiver list: for each CMT source

In [None]:
l_grp_recs_by_srcs = make_replicated_station_headers_from_src_list(l_cmt_src,l_real_recs)


for i in range(len(l_cmt_src)):
    print(f'***** SRC Records for Source: {i} *****\n')
    for j in range(len(l_real_recs)):
        print(f'*** REC Header for Receiver: {j} ***\n{l_grp_recs_by_srcs[i][j]}')
    

## Plot "real" sources (red) and virtual receivers (white) FROM headers

In [None]:
grp_s_xyz = get_unique_xyz_coords_from_solution_list(l_cmt_src)
grp_s_xyz[:,2] *= -1 #pyvista z-up is positive

flat_recs = flatten_grouped_headers(l_grp_recs_by_srcs) #real!
grp_r_xyz = get_unique_xyz_coords_from_station_list(flat_recs)
grp_r_xyz[:,2] *= -1 #pyvista z-up is positive

print(len(grp_s_xyz))
print(len(grp_r_xyz))

pv_spoints = pv.wrap(grp_s_xyz)
pv_rpoints = pv.wrap(grp_r_xyz)

p = pv.Plotter()
p.add_mesh(slices,cmap=cmap,opacity=0.50)
p.add_mesh(grid,cmap=cmap,opacity=0.3)
p.add_mesh(pv_spoints, render_points_as_spheres=True, point_size=12,opacity=1,color='red')
p.add_mesh(pv_rpoints, render_points_as_spheres=True, point_size=8,opacity=0.5)
p.show()

## Make reciprical RecordHeader

In [None]:
from pyaspect.specfemio.utils import station_auto_data_fname_id
from pyaspect.specfemio.write import _write_header

print(len(flatten_grouped_headers(l_grp_vsrcs.copy())))
print(len(flatten_grouped_headers(flatten_grouped_headers(l_grp_vrecs_by_vsrcs.copy()))))
print('nrec_per_src*nsrc:',21*12)

l_flat_vsrcs = flatten_grouped_headers(l_grp_vsrcs)
l_flat_vrecs = flatten_grouped_headers(flatten_grouped_headers(l_grp_vrecs_by_vsrcs))

vrecord_h = RecordHeader(name='Reciprocal-Record',solutions_h=l_flat_vsrcs,stations_h=l_flat_vrecs)

vrec_fqp = os.path.join(data_out_dir,'simple_record_h')
_write_header(vrec_fqp,vrecord_h)

!ls -l {vrec_fqp}

print(vrecord_h)
print('\n***************************************************\n')
'''
#print(vrecord_h.get_event_nsolutions(1))

idx = pd.IndexSlice
slice_rec_h = vrecord_h.copy()
slice_rec_h.solutions_df.reset_index(inplace=True)
#slice_rec_h.stations_df.reset_index(inplace=True)
slice_rec_h.stations_df
rec_df = slice_rec_h.stations_df
#print(f'rec_df = {rec_df}')
for index, src in slice_rec_h.solutions_df.iterrows():
    print(f'**** src.sid = {src.sid} ****************\n')
    print(f'**** src.eid = {src.eid} ****************\n')
    #print(rec_df[rec_df['proj_id'] == src.proj_id])
    #print(f'index = {index}')
    #new_rec_df = rec_df[(rec_df['proj_id'] == src.proj_id) & (rec_df['eid'] == src.eid) & (rec_df['sid'] == src.sid)]
    for index,rec in rec_df.loc[idx[src.proj_id,src.eid,src.sid],:].reset_index().iterrows():
        print(rec)
#slice_rec_h.stations_df.loc[idx[0,0,:,:],'data_fqdn'] = '/somewhere/over/the/rainbow'
#print(slice_rec_h.stations_df.loc[idx[0,0,:,:],'data_fqdn'])
''';

print()
#svr = vrecord_h[0,0,0,0]
svr = vrecord_h[::5,::5,::5,::7]
print('svr:',svr)
print('\n***************************************************\n')
svr.solutions_df['proj_id'] = 1
svr.stations_df['proj_id'] = 1
old_idx = svr.stations_df.index
svr.stations_df.reset_index(inplace=True)
for index,row in svr.stations_df.iterrows():
    #svr.stations_df.loc[index,'data_fqdn'] = f'../SYN/s{row.sid}.t{str(row.trid).zfill(6)}'
    svr.stations_df.loc[index,'data_fqdn'] = f'../SYN/{station_auto_data_fname_id(row)}'

print()
svr.stations_df.set_index(old_idx,inplace=True)
print('proj svr:',svr)

#print('index:',svr.index)
#print('shape:',svr.index.shape)

'''
svr_idx_names = svr.index.names
print('names:\n',svr_idx_names)
print()
print('orig:\n',svr)
svr.reset_index(inplace=True)
print()
print('reset:\n',svr)
svr['proj_id'] = 1
print()
print('new proj_id:\n',svr)
svr.set_index(svr_idx_names)
print()
print('new idx:\n',svr)
''';

#assert False

## Redo Make Project Code

In [None]:
from pyaspect.specfemio.write import write_solution
from pyaspect.specfemio.write import write_stations
from pyaspect.specfemio.write import _write_header
from pyaspect.specfemio.utils import station_auto_data_fname_id
from pyaspect.specfemio.headers import SolutionHeader
from pyaspect.specfemio.headers import StationHeader

def df_to_header_list(df,HeaderCls):
    return  [HeaderCls.from_series(row) for index, row in df.iterrows()] 

def write_record(rdir_fqp,
                 record_h,
                 fname='event_record',
                 write_record_h=True,
                 write_h=False,
                 auto_name=False,
                 auto_network=False):
    
    
    data_fqp = os.path.join(rdir_fqp,'DATA')
    syn_fqp  = os.path.join(rdir_fqp,'SYN')
    rel_syn_fqp = os.path.relpath(syn_fqp,syn_fqp)
    rel_syn_data_fqp = os.path.relpath(syn_fqp,data_fqp)
    
    record_h.reset_midx()
    data_h = record_h.copy()
    
    
    src_df = record_h.solutions_df
    SrcHeader = record_h.solution_cls
    
    rec_df = record_h.stations_df
    data_rec_df = data_h.stations_df
    RecHeader = record_h.station_cls
    
    mk_sym_link = True # first <CMT|FORCE>SOLUTION and STATIONS files get symlink
    for sidx, src in src_df.iterrows():
        
        #solution = src_htype.from_series(src)
        solution = SrcHeader.from_series(src)
        write_solution(data_fqp,
                       solution,
                       postfix=f'e{src.eid}s{src.sid}',
                       write_h=write_h,
                       mk_sym_link=mk_sym_link)
        
        for ridx, rec in rec_df.loc[rec_df['sid'] == src.sid].iterrows():
            rec_df.loc[ridx,'data_fqdn'] = os.path.join(rel_syn_fqp,station_auto_data_fname_id(rec))
            data_rec_df.loc[ridx,'data_fqdn'] = os.path.join(rel_syn_data_fqp,station_auto_data_fname_id(rec))
            
        
        #get list of dictionaries
        l_stations = df_to_header_list(data_rec_df,RecHeader)
        write_stations(data_fqp,
                       l_stations,
                       fname=f'STATIONS.e{src.eid}s{src.sid}',
                       write_h=write_h,
                       auto_name=auto_name,
                       auto_network=auto_network,
                       mk_sym_link=mk_sym_link)
        
        
        mk_sym_link = False #only write for first src
        
    # write record header in run####/SYN
    record_h.set_default_midx()
    syn_record_fqp = _get_header_path(syn_fqp,fname)
    _write_header(syn_record_fqp,record_h)
        
        
    #write record header in run####/DATA
    data_h.set_default_midx()
    if write_record_h:
        data_record_fqp = _get_header_path(data_fqp,fname)
        _write_header(data_record_fqp,data_h)
        
            
            


In [None]:
from pyaspect.specfemio.utils import make_record_headers
from pyaspect.specfemio.utils import _mk_symlink
from pyaspect.specfemio.utils import _copy_recursive_dir
from pyaspect.specfemio.utils import _join_path_fname
from pyaspect.specfemio.utils import _get_header_path
from pyaspect.specfemio.utils import station_auto_data_fname_id
from pyaspect.parfile import change_multiple_parameters_in_lines
from pyaspect.parfile import readlines
from pyaspect.parfile import writelines




MAX_SPEC_SRC = int(9999) # see SPECFEM3D_Cartesian manual

# list of directories need for every event
common_dir_struct = {'DATA': {},
                     'OUTPUT_FILES' : {'DATABASES_MPI':{}},
                     'SYN': {},
                     'FILT_SYN': {} }

# extra common dirs for fwi
common_fwi_dir_struct = {'SEM': {},
                         'OBS': {},
                         'FILT_OBS': {} }

# list of directories only needed for the primary run0001 dir
primary_dir_struct= {'INPUT_GRADIENT': {},
                    'INPUT_KERNELS': {},
                    'INPUT_MODEL': {},
                    'OUTPUT_MODEL': {},
                    'OUTPUT_SUM': {},
                    'SMOOTH': {},
                    'COMBINE': {},
                    'topo': {} }


def _make_dirs(fqdn,access_rights=0o755):
    if os.path.isdir(fqdn):
        raise OSError(f'The directory {fqdn} has already been created')
    try:
        os.makedirs(fqdn, access_rights)
    except OSError:
        print(f'Creation of the directory {fqdn} failed')
        return OSError
    
    
def _recursive_proj_dirs(dl,pdir,access_rights=0o755):

    if len(dl.keys()) == 0:
        return
    else:
        for dl_key in dl.keys():
            new_dir = os.path.join(pdir, dl_key)
            _make_dirs(new_dir,access_rights=0o755)
            _recursive_proj_dirs(dl[dl_key],new_dir)


def _make_proj_dir(proj_root_fqp,
                   proj_base_name,
                   pyutils_fqp=None,
                   script_fqp=None):

        projdir_fqp = os.path.join(proj_root_fqp, proj_base_name)
        _make_dirs(projdir_fqp)
        
        # create project level symlinks 
        if pyutils_fqp != None:
            lname = 'pyutils'
            src = pyutils_fqp
            dst = os.path.join(projdir_fqp, lname)
            _mk_symlink(src,dst)

        if script_fqp != None:
            lname = 'scriptutils'
            src = script_fqp
            dst = os.path.join(projdir_fqp, lname)
            _mk_symlink(src,dst)
        
        return projdir_fqp
            
def _make_run_dir(irdir,
                  projdir_fqp,
                  spec_bin_fqp,
                  spec_utils_fqp,
                  par_lines,
                  dir_struct,
                  record_h):

    rdir_name = 'run' + str(irdir+1).zfill(4)
    rundir_fqp = os.path.join(projdir_fqp, rdir_name)
    _make_dirs(rundir_fqp)

    # make sim links for each event dir 
    # (related to the computational node(s) filesytem
    lname = 'bin'
    src = spec_bin_fqp
    dst = os.path.join(rundir_fqp, lname)
    _mk_symlink(src,dst)

    lname = 'utils'
    src = spec_utils_fqp
    dst = os.path.join(rundir_fqp, lname)
    _mk_symlink(src,dst)

    # make subdirectorieds for each event
    _recursive_proj_dirs(common_dir_struct,rundir_fqp)

    #write Par_files in DATA dirs
    ddir_fqp = os.path.join(rundir_fqp, 'DATA')
    out_par_fqp  = os.path.join(ddir_fqp, 'Par_file')
    writelines(out_par_fqp,par_lines)
    
    #write Headers and Record
    write_record(rundir_fqp,
                 record_h,
                 fname='record',
                 write_record_h=True,
                 write_h=False,
                 auto_name=True,
                 auto_network=True)


def setup_mesh_dir(proj_fqp, mesh_fqp, copy_mesh=False):

    mesh_dst_fqp = os.path.join(proj_fqp, 'MESH-default')
    if copy_mesh:
        _copy_recursive_dir(mesh_fqp,mesh_dst_fqp)
    else:
        _mk_symlink(mesh_fqp,mesh_dst_fqp)
    

def make_fwd_project_dir(proj_base_name,
                         proj_root_fqp,
                         parfile_fqp,
                         mesh_fqp,
                         spec_fqp,
                         pyutils_fqp,
                         script_fqp,
                         proj_record_h,
                         sub_proj_name=None,
                         batch_srcs=False,
                         copy_mesh=False,
                         max_event_rdirs=MAX_SPEC_SRC,
                         verbose=False):
    
    
    if not isinstance(proj_record_h,RecordHeader):
        raise TypeError('arg: \'record_h\' must be a RecordHeader type')

    if not isinstance(proj_base_name,str):
        raise TypeError('proj_base_name must be a str type')

    if not isinstance(proj_root_fqp,str):
        raise TypeError('proj_root_fqp must be a str type')

    if not isinstance(parfile_fqp,str):
        raise TypeError('parfile_fqp must be a str type')

    if not isinstance(mesh_fqp,str):
        raise TypeError('mesh_fqp must be a str type')

    if not isinstance(spec_fqp,str):
        raise TypeError('spec_fqp must be a str type')

    if not isinstance(pyutils_fqp,str):
        raise TypeError('pyutils_fqp must be a str type')

    if not isinstance(script_fqp,str):
        raise TypeError('script_fqp must be a str type')
        
    
    ########################################################################
    #
    # setup project structure parameters
    #
    ########################################################################
    
    nevents = proj_record_h.nevents
    nsrc = proj_record_h.nsrc
    
    #setup nbatch
    nbatch = 1
    if batch_srcs:
        nbatch = nsrc
            
       
    # calculate number of rundirs (events) per subproject/project
    l_nrundirs = [nevents]
    if not batch_srcs:
        l_nrundirs[0] = nevents*nsrc
        
    # calculate if project or subprojects and addjust rundirs
    if max_event_rdirs < l_nrundirs[0]:
        old_nrundirs = l_nrundirs[0]
        ndiv = l_nrundirs[0]//max_event_rdirs
        l_nrundirs[0] = max_event_rdirs
        for i in range(1,ndiv):
            l_nrundirs.append(max_event_rdirs)
        rem_rundirs = old_nrundirs%max_event_rdirs 
        if rem_rundirs != 0:
            l_nrundirs.append(rem_rundirs)
            
    # actual number of subproject dirs
    nprojdirs = len(l_nrundirs)
    
    # info
    if verbose:
        print(f'nevents: {nevents}')
        print(f'nsrc: {nsrc}')
        print(f'nbatch: {nbatch}')
        print(f'l_nrundirs: {l_nrundirs}')
        print(f'total_rundirs: {sum(l_nrundirs)}')
        print(f'nprojdirs: {nprojdirs}')
        
        
    # setup paths to specfem binarys and utils/tools
    spec_bin_fqp   = os.path.join(spec_fqp, 'bin')
    spec_utils_fqp   = os.path.join(spec_fqp, 'utils')
    
    
    # Read input Par_file stub 
    par_lines = readlines(parfile_fqp)
    
    # Setup output Par_files based on user input Par_file stub
    par_keys = ['SIMULATION_TYPE',
                'SAVE_FORWARD',
                'USE_FORCE_POINT_SOURCE',
                'MODEL',
                'SAVE_MESH_FILES',
                'USE_BINARY_FOR_SEISMOGRAMS']
    
    #set solution type
    use_force_src = ForceSolutionHeader == proj_record_h.solution_cls
    keys_vals_dict = dict(zip(par_keys,[1,False,use_force_src,'gll',False,True]))
    par_lines = change_multiple_parameters_in_lines(par_lines,keys_vals_dict)
    
    
    ########################################################################
    #
    # If more than one proj_dir is needed then make a main proj_dir
    # for the sub_proj_dirs.  
    #
    ########################################################################
    
    #Setup sub_dir pars here
    if sub_proj_name == None:
        sub_proj_name = 'Sub_' + proj_base_name
        
    sub_proj_root_fqp = proj_root_fqp
    sub_pdir_name = proj_base_name
    
    
    # make the main_dir if needed
    if 1 < nprojdirs:
        sub_proj_root_fqp = _make_proj_dir(proj_root_fqp,
                                           proj_base_name)
        
        # make or copy MESH-default to the main project dir
        setup_mesh_dir(sub_proj_root_fqp, mesh_fqp, copy_mesh=copy_mesh)
        print(f'yes sub projects Main Dir:{sub_proj_root_fqp}')
        
    if verbose: print(f'sub_proj_root_fqp: {sub_proj_root_fqp}')
        
    
    accum_src = 0
    for ipdir in range(nprojdirs):
        
        # Make subdir or main project dir depending on number of proj_dirs
        sub_projdir_fqp = sub_proj_root_fqp
        
        if nprojdirs != 1:
            sub_pdir_name = sub_proj_name + '_' + str(ipdir+1).zfill(4)
            
        sub_projdir_fqp = _make_proj_dir(sub_projdir_fqp,
                                         sub_pdir_name,
                                         pyutils_fqp=pyutils_fqp,
                                         script_fqp=script_fqp,)
        
        
        if nprojdirs == 1:
            #no sub projects so copy or make symlink (user par)
            setup_mesh_dir(sub_projdir_fqp, mesh_fqp, copy_mesh=copy_mesh)
            print(f'no sub projects Main Dir:{sub_projdir_fqp}')
        else: 
            #multiple subproject: only make symlink for subproj dirs
            sym_mesh_fqp = os.path.join(sub_proj_root_fqp,'MESH-default')
            rel_mesh_fqp = os.path.relpath(sub_proj_root_fqp,sym_mesh_fqp)
            rel_mesh_fqp = os.path.join(rel_mesh_fqp,'MESH-default')
            print(f'sub_projdir_fqp:{sub_projdir_fqp}')
            print(f'rel_mesh_fqp:{rel_mesh_fqp}')
            setup_mesh_dir(sub_projdir_fqp, rel_mesh_fqp, copy_mesh=False)
        
        
        ####################################################################
        #
        # Make all the run dirs per proj/sub_proj dirs
        #
        ####################################################################
        for irdir in range(l_nrundirs[ipdir]):
            
            #FIXME: "ie" is not Correct! need list[ipdir][irdir] returns ie
            ie   = accum_src//nsrc
            isrc = accum_src%nsrc
            rdir_record_h = proj_record_h[ie,isrc:isrc+nbatch,:,:]
            
            rdir_record_h.solutions_df['proj_id'] = ipdir
            rdir_record_h.stations_df['proj_id']  = ipdir
            
            _make_run_dir(irdir,
                          sub_projdir_fqp,
                          spec_bin_fqp,
                          spec_utils_fqp,
                          par_lines,
                          common_dir_struct,
                          rdir_record_h)
            
            accum_src += nbatch
            
            ####################################################################
            #   
            # Setup and set the SPECFEM station data paths and file prefix names
            #
            ####################################################################

            #setup paths for changing data_fqdn's in stations
            rdir_name = 'run' + str(irdir+1).zfill(4)
            rundir_fqp = os.path.join(sub_projdir_fqp, rdir_name)
            syn_fqp = os.path.join(rundir_fqp, 'SYN')
            rel_syn_fqp = os.path.relpath(syn_fqp,sub_proj_root_fqp)
            

            # reset the pandas.Multiindex to allow eid and sid filtering
            proj_record_h.reset_midx()
            rec_df = proj_record_h.stations_df

            # loop over all stations and set data_fqdn
            ibool = (rec_df['eid'] == ie) & ((rec_df['sid'] >= isrc) & (rec_df['sid'] < isrc+nbatch))
            for ridx, rec in rec_df[ibool].iterrows():
                new_data_fqdn = os.path.join(rel_syn_fqp,station_auto_data_fname_id(rec))
                rec_df.loc[ridx,"data_fqdn"] = new_data_fqdn
                

            # set index back to origanl
            proj_record_h.set_default_midx()
            
    ###############################################
    #
    # write record header for the main project
    #
    ###############################################
    proj_record_fqp = _get_header_path(sub_proj_root_fqp,'project_record')
    if nprojdirs == 1:
        single_proj_fqp = os.path.join(sub_proj_root_fqp,proj_base_name)
        proj_record_fqp = _get_header_path(single_proj_fqp,'project_record')
    _write_header(proj_record_fqp,proj_record_h)
                

In [None]:
test_proj_name = 'TestNewMKProject'
test_proj_root_fqp =  os.path.join(data_out_dir, 'tmp/TestProjects/NewMKProj')
test_parfile_fqp =  os.path.join(data_out_dir, 'Par_file')
test_mesh_fqp = '/scratch/seismology/tcullison/test_mesh/MESH-default_batch_force_src'
#test_mesh_fqp = '/Users/seismac/Documents/Work/Bench/ForkGnam/pyaspect/notebooks/data/output/tmp/TestProjects/MESH-default_norot_xsft4400_ysft19100_quart_100m_min_vs_sig250m'
test_spec_fqp = '/quanta1/home/tcullison/DevGPU_specfem3d'
test_pyutils_fqp = '/quanta1/home/tcullison/myscripts/python/specfem/pyutils'
test_script_fqp = '/quanta1/home/tcullison/myscripts/specfem'

#print(len(flatten_grouped_headers(l_grp_vsrcs.copy())))
#print(len(flatten_grouped_headers(flatten_grouped_headers(l_grp_vrecs_by_vsrcs.copy()))))
#print('nrec_per_src*nsrc:',21*12)

l_flat_vsrcs = flatten_grouped_headers(l_grp_vsrcs)
l_flat_vrecs = flatten_grouped_headers(flatten_grouped_headers(l_grp_vrecs_by_vsrcs))

vrecord_h = RecordHeader(name='Reciprocal-Record',solutions_h=l_flat_vsrcs,stations_h=l_flat_vrecs)

#print(vrecord_h)
src_df = vrecord_h.solutions_df
rec_df = vrecord_h.stations_df

s_nsid = list(src_df.index.get_level_values('sid').unique())
r_nsid = list(rec_df.index.get_level_values('sid').unique())

#print(f'Are Same: {s_nsid == r_nsid}')

d = {'one':1, 'two':2, 'three': 3}
k = list(d.keys())

isin = 'three' in k
#print(f'is in: {isin}')

test_proj_record_h = vrecord_h.copy()

make_fwd_project_dir(test_proj_name,
                     test_proj_root_fqp,
                     test_parfile_fqp,
                     test_mesh_fqp,
                     test_spec_fqp,
                     test_pyutils_fqp,
                     test_script_fqp,
                     test_proj_record_h,
                     batch_srcs=False,
                     verbose=True,
                     copy_mesh=False,
                     max_event_rdirs=5)
                     #max_event_rdirs=MAX_SPEC_SRC)
        

print()
print('ls:')
!ls {test_proj_root_fqp}
print('ls:')
!ls {test_proj_root_fqp}/*/*
'''
print('rm:')
!rm -rf {test_proj_root_fqp}/*
print('ls:')
!ls {test_proj_root_fqp}
''';
assert False

## Make reciprocal project

In [None]:
assert False
from pyaspect.specfemio.utils import _join_path_fname
test_proj_name = 'ReciprocalTestProject'
test_proj_fqp =  os.path.join(data_out_dir, 'tmp/TestProjects')
test_parfile_fqp =  os.path.join(data_out_dir, 'Par_file')
test_mesh_fqp = '/scratch/seismology/tcullison/test_mesh/MESH-default_batch_force_src'
test_spec_fqp = '/quanta1/home/tcullison/DevGPU_specfem3d'
test_pyutils_fqp = '/quanta1/home/tcullison/myscripts/python/specfem/pyutils'
test_script_fqp = '/quanta1/home/tcullison/myscripts/specfem'

from pyaspect.project import make_project
#make_records(l_src=l_grp_solutions_h,l_rec=src_grouped_stations)
make_project(test_proj_name,
             test_proj_fqp,
             test_parfile_fqp,
             test_mesh_fqp,
             test_spec_fqp,
             test_pyutils_fqp,
             test_script_fqp,
             l_grp_vsrcs,
             l_grp_vrecs_by_vsrcs,
             copy_mesh=False)

!ls -ltrh {os.path.join(test_proj_fqp,test_proj_name)}

## Make "real" project

In [None]:
from pyaspect.specfemio.utils import _join_path_fname
test_proj_name = 'ForwardTestProject'
test_proj_fqp =  os.path.join(data_out_dir, 'tmp/TestProjects')
test_parfile_fqp =  os.path.join(data_out_dir, 'Par_file')
test_mesh_fqp = '/scratch/seismology/tcullison/test_mesh/MESH-default_batch_force_src'
test_spec_fqp = '/quanta1/home/tcullison/DevGPU_specfem3d'
test_pyutils_fqp = '/quanta1/home/tcullison/myscripts/python/specfem/pyutils'
test_script_fqp = '/quanta1/home/tcullison/myscripts/specfem'

from pyaspect.project import make_project
#make_records(l_src=l_grp_solutions_h,l_rec=src_grouped_stations)
make_project(test_proj_name,
             test_proj_fqp,
             test_parfile_fqp,
             test_mesh_fqp,
             test_spec_fqp,
             test_pyutils_fqp,
             test_script_fqp,
             l_cmt_src,
             l_grp_recs_by_srcs,
             copy_mesh=False)

!ls -ltrh {os.path.join(test_proj_fqp,test_proj_name)}