##### ---
# Unit15: Ambient Noise Tomography

This notebook has the activities of the Course **ProSeisSN**. It deals with time series processing using a passive seismic dataset using [ObsPy](https://docs.obspy.org/).

#### Dependencies: Obspy, Numpy, Matplotlib
#### Reset the Jupyter notebook in order to run it again, press:
***Kernel*** -> ***Restart & Clear Output***

In [None]:
"""
====================== Leads to Colab ======================
1)
!git clone https://github.com/jandyr/ProSeisSN
!cd ProSeisSN
2)
import subprocess

try:
subprocess.check_call(['install', 'obspy'])
print(f"obspy installed successfully using conda.")
return True
except subprocess.CalledProcessError:
print(f"Failed to install obspy using conda. Trying pip.")
    try:
      subprocess.check_call(['pip', 'install', 'obspy'])
      print(f"'obspy' installed successfully using pip.")
      return True
    except subprocess.CalledProcessError:
      print(f"Failed to install {package_name} using both pip and apt-get.")
      return False
"""
#!pip install pycwt
#import pycwt
!pip install disba mpi_master_slave
!pip install pylops
#!pip install obspy
!pip install import-ipynb
import import_ipynb

In [None]:
#------ Import OS Libraries
import sys
import os

#------ Work with the directory structure to include auxiliary codes
print('\n Local directory ==> ', os.getcwd())
print('  - Contents: ', os.listdir(), '\n')

path = os.path.abspath(os.path.join('..'))
if path not in sys.path:
    sys.path.append(path+"/CodePy")

#%run ../CodePy/ImpMod.ipynb
%run ../CodePy/ImpMod
#------ Alter default matplotlib rcParams
from matplotlib import rcParams
import matplotlib.dates as dates
# Change the defaults of the runtime configuration settings in the global variable matplotlib.rcParams
plt.rcParams['figure.figsize'] = 9, 5
#plt.rcParams['lines.linewidth'] = 0.5
plt.rcParams["figure.subplot.hspace"] = (.9)
plt.rcParams['figure.dpi'] = 100
#------ Magic commands
%matplotlib inline
%matplotlib widget
#%pylab notebook
%config Completer.use_jedi = False
%load_ext autoreload
%autoreload 2

---
## The 84S Dataset
- **The context**
  1) The seismic source is ambient noise generated by the energy transfer from the winds on features on the snow surface and the Cryosphere-I.
  2) Two 66m gathers with 12 geophones each. Each trace has a time window 𝛿𝑡 = 8s and a sampling rate 𝛿𝑓 = 2000Hz,
resulting in 16000 data points/trace. We focus here on the first gather.
  3) Geophone location along profile due S from Crios-I: D1 = 200m, D2 = 40m, G1 = G2 = 66m
  
 │Crios│ D1 ├─────── G1 ──────────| D2 ├─────── G2 ──────────| 
 
            o o o o o o o o o o o o<-->o o o o o o o o o o o o 

  4) Files 4600 and 4608 were collected in a day with a strong breeze, wind force 6 on the Beaufort Scale, with gusts reaching 12𝑚/𝑠. The data is notched at 60.62Hz as well as its even harmonics up to 606.2Hz.

<div style="text-align: center;">
<img src="./84S.png" width="800">
</div>

In [None]:
"""
====================== READ PHONES LOCATIONS ======================
"""
#------ Read the phones cartesian locations
#--- Reads the CSV file with (x, y)m locations
loc84S = u.RGloc('../Data/'+'84S_loc.dat')
#------ Read the phones geographic locations
#--- Reads the CSV file with (lat,lon). All (0., 0.)!
gloc84S = u.RGloc('../Data/'+'84S_gloc.dat')
#
#------ Plot gather in cartesian
p.pgather(loc84S[:,1], loc84S[:,2], loc84S[:,0], coord='cartesian')

In [None]:
"""
====================== READ THE SEISMIC DATA LOCALLY ======================
File hints: 4600, 4608

"""
#------ Read the seismic data
ent = str(np.random.choice(np.arange(3700, 3811, 10)))
ent = input(f'   Enter a file number in [4600 or 4608], rtn=4600:\n') or '4600'
ent = ent.rstrip().split(' ')
print(f">> Read with data file {ent}")
ent = '../Data/84S/'+ent[0]+'.dat'
#------- Read the data file as a SEG2 object.
st     = read(ent)
#
#------- Print stream information
dummy = float(st[-1].stats.seg2.RECEIVER_LOCATION)
print(f">> Gather acquired on {st[0].stats.starttime}, has {int(st[0].stats.npts)} data points.")
"""
================= Create a new stream from the SEG2 stream ======================
                         Retain a gather copy
"""
#------ Create a new stream from the SEG2 stream.
#       1) Adds coordinates to gather. Stores a copy in gather0
#       2) Gather baricenter = bcenter.
gather, bcenter = u.creastrm(st, gloc84S, hght = 2000., surv = '84S')
gather0 = gather.copy()
#
#--- Phone choice
phone = None

---
## Data processing
- Filter data
- Display the seismogram

In [None]:
#
"""
================= Filter data and look at the frequency contents ======================
                    Create a new stream from the SEG2 stream
"""
#
#------- Remove mean and trend + filter the stream
#--- Filter parameters: change them as you wish.
MTparam = [ 1,   1,    'bp',  10.,   40.,   0,    0]
# └─────> [dtr, line, ftype, Fmin, Fmax, taper, gain]
#                                          └─> data will be windowed at trace normalization and spectral whitening
ent = str(MTparam[3]) + ' ' + str(MTparam[4])
ent = input(f'\n Enter filter min and max frequencies (dflt = {MTparam[3]}, {MTparam[4]})') or ent
ent = ent.rstrip().split(' ')
MTparam[3], MTparam[4] = [float(dummy) for dummy in ent]
#
gather = u.otrstr(gather, MTparam)
#
#------- Check frequency contents to accept preprocessing
#--- Pick up a random phone/trace
phone = phone if phone is not None else np.random.randint(1, len(gather)+1)
print(f' Random phone {phone} ')
#--- Go to trace instead of phone: trace = phone -1
phone = phone - 1
#--- Relative time: nummpy array
time = gather[phone].times(type="relative")
#--- Plot Trace+Spectrogram
p.Pspect(time, gather[phone])
#
#------- Once filtering is accepted create a new backup for gather
ent = input(f' Run this cell again (rtn= No, else plot Spectrogram)?: ') or False
if not ent:
    gather0 = gather.copy()
    print(f' A new stream backup was created.')
else:
    gather = gather0.copy()

### Down-sample the data
- Down-sample the data to number of pints compatible with the upper limit of the bandpass filter
- Reduce computational costs

In [None]:
import math
#
gather = gather0.copy()
"""
================= Downsample stream by an integer factor ======================
"""
print(f'\n>> Phone {phone+1} has {gather[phone].stats.npts} data points with a sampling rate of {gather[phone].stats.sampling_rate}Hz,')
#--- Divisors of sampling rate > upper frequency bound of the band-pass filter above
dummy =  u.divisors(int(gather[phone].stats.sampling_rate), MTparam[4])
#--- Find the decimation/resampling factors
factor = [math.trunc(gather[phone].stats.sampling_rate / num) for num in dummy]
factor = [factor for factor in factor if factor < 16]
#
dummy =  [int(gather[phone].stats.sampling_rate/factor) for factor in factor]
print(f'    this sampling rate can be lowered to the following integer values {dummy}Hz')
print(f'    representing decimation factors of............................... {factor},')
dummy = [math.trunc(gather[phone].stats.npts / factor) for factor in factor]
print(f'    with a total of points of ....................................... {dummy},')
#
print(f'\n>> If you need a higher decimation factor than {factor[0]}:')
print(f'    1) Set a value from the above, and accept the running;')
print(f'    2) Run the cell again with a new factor to your goal.')
#
print(f'\n>> Note that data is already band pass filtered in the range [{MTparam[3]}, {MTparam[4]}]Hz.')
print(f'    Be sure decimation is above the upper limit')
#
ent = input(f'\n<< Enter a new sampling rate from the above list:')
ent = float( ent.rstrip().split(' ')[0] )
#--- Check on Fmax
#MTparam[4] = MTparam[4] if MTparam[4] <= ent else ent
#
#------- Check on decimation factor. // is a floor division = integer floor.
factor = int(gather[phone].stats.sampling_rate / ent)
if gather[phone].stats.npts % factor != 0: raise ValueError("Decimation factor is not an integer.")
"""
Decimate or resample
1) Use resample instead of decimate if factor >16, as automatic filter design becomes unstable.
2) After decimation every n-th sample remains in the trace.
3) Prior to decimation it is applyed a lowpass filter to prevente aliasing artifacts.
4) If decimation factor is too large the FFT will fold itself in two, with a middle spike within its frequency range.
"""
print(factor)
if factor <= 15:
    gather.decimate(factor=factor)                        #Uses: no_filter=False, strict_length=False
    print(f'\n>> Decimation factor is {factor} -> use decimate.')
else:
    gather.resample(ent, window='hann', no_filter=False)  #Uses: strict_length=False
    print(f'\n>> Decimation factor is {factor} -> use resample.')
#
print(f'>> Phone {phone+1} has now {gather[phone].stats.npts} data points with a new sampling rate of {gather[phone].stats.sampling_rate}Hz.')
print(f'    Resampled data is low-pass filtered to prevent aliasing, with a decimation factor of  {factor}.')
#
#------- Check frequency contents of a trace to accept downsampling
#--- Relative time: nummpy array
time = gather[phone].times(type="relative")
#--- Plot Trace+Spectrogram
p.Pspect(time, gather[phone])
#
#------- Once filtering is accepted create a new backup for gather
ent = input(f' Run this cell again (rtn= No)?: ') or False
if not ent:
    gather0 = gather.copy()
    print(f' A new stream backup was created.')
else:
    gather = gather0.copy()

In [None]:
"""
====================== Plot Seismogram ======================
"""
gather.plot(type='section',
            scale=1.3, alpha=.7,
            orientation='vertical')

---
### Normalization and Spectral whitening

In [None]:
"""
====================== Trace normalization and spectral whitening ======================
1) Cross-correlations are estimated in receiver pairs, where one is concidered as a source:
-> BEFORE CROSS-CORRELATION IT IS NECESSARY TO TAKE THE CONJUGATE OF THE SOURCE TRACE
                            trS = np.conjugate(trS)
    SO THE SOURCE SPETRUM IS CONSISTENT WITH THE RECEIVER.
2) The O/P of 'noise_processing', tr0W, will be next_fast_len(len(FtrTplt)) greater than len(FtrTplt)
3) We are not discarding the first element of the FFT output, or its DC energy as we have already detrended and demeaned the data.
"""
#
#------ Initialization
st = Stream()
nwGather = np.zeros((len(gather[0]), len(gather)))
#
#------ Loop through traces:
#       1) Append each processed spectrum as a new column into a new 2-D array
#       2) Create a new stream with the inverse FFT.
#
for i, t in enumerate(gather):
#--- Build the processed freq. domain data: nwGather[data, trace]
    tr0W = u.noise_processing(t.data, t.stats.delta, MTparam[3], MTparam[4], 
                              smooth_N = 100, time_norm = 'one_bit', freq_norm = 'rma')
    nwGather[:, i] = tr0W
#--- Inverse fourier transform of the whitened spectra (not tapered)
    trw = np.fft.ifft(tr0W)
#--- Discard small complex values left-overs after the inverse fourier transform
    trw = np.real(trw)
#--- Build the trace. Distance is along cable, NOT from Crios-I
    trw = Trace(data=trw)
    trw.stats.sampling_rate = t.stats.sampling_rate
    trw.stats.npts          = len(trw)
    trw.stats.station       = t.stats.station
    trw.stats.starttime     = t.stats.starttime
    trw.stats.network       = t.stats.network   
    trw.stats.channel       = "HHZ"                     # Assign channel code
    trw.stats.location      = "0"                       # Assign location code
    trw.stats.distance      = float(t.stats.distance) - float(gather[0].stats.distance)
#
    st += trw
#
#------ Control
print(f"\n>> Created an array nwGather with processed traces, with shape {np.shape(nwGather)}. Each column is a processed trace.")
#
print(f"\n>> Created a new processed gather with {len(st)} traces, with {int(st[0].stats.npts)} data points.")
print(f"    From {st[0].stats.distance} to {st[-1].stats.distance}m. The sampling rate is {st[0].stats.sampling_rate}Hz.")
print(f"    The new time window is {st[0].stats.endtime - st[-1].stats.starttime}s.")
print(f"\n>> Plot of phone {phone+1}:")
#--- Whitened trace has more points than the original one! Adjust time here.
dummy = st[phone].times(type="relative")
dummy = np.linspace(dummy[0], dummy[-1], num = len(st[phone]))
#--- Freq. axis (HZ). f=0 is not discarded here
fNy = st[phone].stats.sampling_rate / 2.         # Nyquist
freq = np.linspace(0, fNy, np.shape(nwGather)[0])
#--- Plot
p.pltTrSp( st[phone].times(type="relative"), st[phone], freq, abs(nwGather[:,phone]),
              x1label='s', y1label='Ampl.', y1log=False, clr1 = 'k', 
              x2label='Hz', y2label='Spec.', y2log=True, clr2 = 'r' )

## Ambient Noise Cross-correlation
- Given two seismometers, $u_1$ and $u_2$, on the surface, will record ground motion as a function of time. Over long periods of time, the cross-correlation of ground motions is
$$C_{1,2}\left(\tau\right)=\int u_{1}\left(t\right)\,u_{2}\left(t+\tau\right)dt$$

- Phone 1 will be considered as a source for all other phones as Crios-I is due N from it.

In [None]:
"""
====================== Cross correlation on phone pairs ======================
"""
print(f'\n>> Phone spatial sampling is {loc84S[1, 1] - loc84S[0, 1]}m,')
print(f'    while the gather is {loc84S[-1, 1] - loc84S[0, 1]}m long.')
#
#------ Change distances making phone 1 to be at 0m, instead of 200m; the source for all other phones.
#       Remove the useless 3rd column
loc84S[:,1] = loc84S[:,1] - loc84S[0,1]
loc84S = loc84S[:, :2]
#
#------ Phone 1 is the source for all other phones in a shot gather. Therefore it is necessary to
#        take the conjugate, to multiply it by the corresponding of the receivers'.
nwGather[:, 0] = np.conjugate(nwGather[:, 0])
#
#------ Datetime object list to be used in correlation
dataS_t = [st[0].stats.starttime]
#
#------ Enter max lag
ent = st[0].stats.endtime - st[0].stats.starttime
ent = input(f'\n<< Enter max lag time <={round(ent,1)} (rtn = {int(ent)}s):') or str(int(ent))
ent = ent.rstrip().split(' ')
maxlag = float(ent[0])
#
#------ Loop trhough the phone pairs from phone 1 to 12. Keep phone 1 as the source -> 11 xcorr.
#--- Initialize. The O/P from correlate has 1 fewer point than the I/P traces.
xcorr  = np.zeros((nwGather.shape[0] - 1, nwGather.shape[1] - 1))
tstamp = xcorr.copy(); nxcorr = xcorr.copy()
#--- The source
tr0W = nwGather[:,0]
#--- Iterate from 0 to the number of columns-1. NB: t_corr is a timestamp object
for i in range(1, nwGather.shape[1]):                         #nwGather.shape[1] - 1
    tr1W = nwGather[:, i]
#    corr| Tstamp| #segments          source|receiver|time sampling     |Datetime list
    tdata, t_corr, n_corr = u.correlate(tr0W, tr1W,   st[0].stats.delta, dataS_t,
                                        MTparam[3], MTparam[4], maxlag, method = 'xcorr')
    xcorr[:,  i - 1] = tdata
    tstamp[:, i - 1] = t_corr
    nxcorr[:, i - 1] = n_corr
#
#------- A sequential time for the x-correlation
tvec = np.linspace(-maxlag, maxlag, xcorr.shape[0])    #maxlag+dt
#
#------- Plot correlation between phone 1 and another
print(f'\n>> Plot correlation between phones 1 and {phone+1}, with a distance of {loc84S[phone, 1]}m')
#
fig, ax = plt.subplots(figsize=(12,5))
ax.plot(tvec,xcorr[:, phone-1])
ax.set_xlabel("Time [s]")
ax.set_ylabel("Amplitude")
#ax.set_title('Cross-Correlation Function Between %s and %s'%(ssta,rsta))
plt.show()

In [None]:
# Helper to retrieve last value
ent = input(f'\n<< copy(rtn=y))?') or 'y'
if ent == 'y': 
    xcorr  = Sxcorr.copy()
    tvec   = Stvec.copy()
    nxcorr = Snxcorr.copy()
else:
    Sxcorr  = xcorr.copy()
    Stvec   = tvec.copy()
    Snxcorr = nxcorr.copy()


In [None]:
"""
====================== Construct a cross correlation signal ======================
"""
#
#------- Choose cross correlation part
ent = input(f'\n<< Acausal (a) or causal (c) of {maxlag}s x-corr(rtn=c))?') or 'c'
ent = ent.rstrip().split(' ')[0]
if ent not in ['a', 'c']: raise ValueError(f"'{ent}' not allowed")
if ent == 'a': raise ValueError(f"'{ent}' not implemented yet")
#
#------- Go through all traces, slicing apropriately. 
i = np.where(tvec >= 0)[0] if ent == 'c' else np.where(tvec < 0)[0]
xcorr  = xcorr[i,:]
tvec   = tvec[i]
nxcorr = nxcorr[i,:]
#
print(f'>> Work with the {ent[0]}-part with len={len(xcorr)}')
#
#------ Produce a new gather. Initialization
xst = Stream()
#--- Loop through traces
for i in range(0, xcorr.shape[1]):
    t                     = Trace(data=xcorr[:, i])
    t.stats.sampling_rate = np.absolute(tvec[1] - tvec[0])    #caters for both a and c parts
    t.stats.npts          = len(tvec)
    t.stats.station       = str(i)
    t.stats.starttime     = st[0].stats.starttime
    t.stats.network       = st[0].stats.network 
    t.stats.channel       = "HHZ"                     # Assign channel code
    t.stats.location      = "0"                       # Assign location code
    t.stats.distance      = st[i+1].stats.distance
#--- Taper ends with 0.05
    t.taper(type = 'hann', max_percentage = 0.025)
#--- Add trace to stream
    xst += t
#
#------ Plot Seismogram
xst.plot(type='section',
        scale=1.3, alpha=.7,
        orientation='vertical')

In [None]:
"""
====================== Work with cross correlation signal ======================
"""
#
#------ Dispersion for frequency axis: 101 values
#                   |   fmim  |    fmax   |nfreq|
fdisp = np.linspace(MTparam[3], MTparam[4], 401)
#
#------ Augment traces
# Nfft = int(next_fast_len(int(xcorr.shape[0])))
Nfft = int(xcorr.shape[0])
#
#------ 
#                      |---------- Nyquist ----------|
freq = np.linspace(0., xst[0].stats.sampling_rate / 2.)
Fxcoor = np.zeros((Nfft//2+1, xcorr.shape[1]),dtype=xcorr.dtype)
vf     = np.zeros((Nfft//2+1, xcorr.shape[1]),dtype=xcorr.dtype)
#
#------ Velocity dispersion -> lvdispsp  100.0 1000.0 401
#                    |vmim| vmax| nvel|
vdispsp = np.linspace(100., 1000., 401)

for i in range(0, xcorr.shape[1]):
    Fxcoor[:, i] = np.fft.rfft(xcorr[:, i])
#    Fxcoor[:, i] = scipy.fftpack.fft(xcorr[:, i], Nfft)
#--- Interpolate velocities
    vf[:, i]     = np.interp(Fxcoor[:, i], fdisp, vdispsp)
#
#------ nt, nr = 11 200
#      xcorr.T size `nx x nt`
nt, nr = xcorr.shape
dt = tvec[1] - tvec[0]
#--- len(t) = 200 -> tvec[;,i]; f= 0.0 14.072142857142751 100
f = sp.fftpack.fftfreq(nt, dt)[:nt//2+1]

#df = f[1] - f[0]
#fmax_idx = int(MTparam[4]//df)      #fmax_idx 281
fmax_idx = len(f)


#
#------ Phase velocity range -> C  100.0 997.75 400

c = np.arange(vdispsp[0], vdispsp[-1], vdispsp[1]-vdispsp[0])

#
#------ x 0.0 60.0 11
dx = xst[0].stats.distance
x = np.linspace(0.0, (nr-1)*dx, nr)
#
#------ Fx spectrum -> U (11, 100)
U = sp.fftpack.fft(xcorr.T)[:, :nt//2]

#
#------ Dispersion panel
disp = np.zeros((len(c), fmax_idx))

for fi in range(fmax_idx-1):    #range(fmax_idx)
    for ci in range(len(c)):
        k = 2.0*np.pi*f[fi]/(c[ci])
        disp[ci, fi] = np.abs(
                np.dot(dx * np.exp(1.0j*k*x), U[:, fi]/np.abs(U[:, fi])))


fig, axs = plt.subplots(1, 2, figsize=(10, 4), gridspec_kw={'width_ratios': [1, 2]})
axs[0].imshow(xcorr, cmap='gray', vmin=-.1, vmax=.1, extent=[x[0], x[-1], tvec[-1], tvec[0]])
axs[0].axis('tight')
axs[0].set_xlabel('Offset [m]')
axs[0].set_ylabel('Time [s]')
axs[0].set_title('Shot gather')
axs[1].imshow(disp, origin='lower', extent=(f[0], f[-1], vdispsp[0], vdispsp[-1]))
axs[1].plot(f, vf, 'w', lw=4)  #vf*1e3
axs[1].axis('tight')
axs[1].set_xlim(0, 45)
axs[1].set_ylabel('Velocity [m/s]')
axs[1].set_xlabel('Frequency [Hz]')
axs[1].set_title('Dispersion panel');
    


---


In [None]:
import pylops
import warnings
from pylops.signalprocessing import FFT2D
from functools import partial
from scipy.optimize import minimize, Bounds
from disba import PhaseDispersion
import scipy as sp

"""
inversion.py
"""
def fun(x, nlayers, t, vdispobs, dc=0.005):
    r"""Surface wave inversion misfit function

    Create a dispersion curve from the input parameters 
    (thicknesses and shear wave velocities) and compute
    misfit

    Parameters
    ----------
    x : :obj:`numpy.ndarray`
        Set of thicknesses and shear wave velocities
    t : :obj:`numpy.ndarray`
        Period
    vdispobs : :obj:`numpy.ndarray`
        Observed dispersion curve
    dc : :obj:`float`
        Phase velocity increment for root finding.
 
    Returns
    -------
    loss : :obj:`float`
        Loss function

    """
    # Create model
    thick = x[:nlayers]
    vs = x[nlayers:]
    vp = vs * 4
    rho = 1. * np.ones(nlayers)
    model = np.vstack([thick, vp, vs, rho]).T
    
    # Compute the Rayleigh-wave modal dispersion curves
    pd = PhaseDispersion(*model.T, dc=dc)
    cpr = pd(t, mode=0, wave="rayleigh") 
    vdisp = cpr[1]

    return np.linalg.norm(vdisp-vdispobs)

"""
dispersionspectra.py
"""
def parkdispersion(data, dx, dt, cmin, cmax, dc, fmax):
    """Dispersion panel
    
    Calculate dispersion curves using the method of
    Park et al. 1998
    
    Parameters
    ----------
    data : :obj:`numpy.ndarray`
        Data of size `nx x nt`
    dx : :obj:`float`
        Spatial sampling
    dt : :obj:`float`
        Time sampling
    cmin : :obj:`float`
        Minimum velocity
    cmax : :obj:`float`
        Maximum velocity
    dc : :obj:`float`
        Velocity sampling
    fmax : :obj:`float`
        Maximum frequency
        
    Returns
    -------
    f : :obj:`numpy.ndarray`
        Frequency axis
    c : :obj:`numpy.ndarray`
        Velocity axis`
    disp : :obj:`numpy.ndarray`
        Dispersion panel of size `nc x nf`
    """
    nr, nt = data.shape
    
    # Axes
    t = np.linspace(0.0, nt*dt, nt)

    f = sp.fftpack.fftfreq(nt, dt)[:nt//2]
    df = f[1] - f[0]
    fmax_idx = int(fmax//df)

    c = np.arange(cmin, cmax, dc)  # set phase velocity range
    x = np.linspace(0.0, (nr-1)*dx, nr)

    # Fx spectrum
    U = sp.fftpack.fft(data)[:, :nt//2]
    
    # Dispersion panel
    disp = np.zeros((len(c), fmax_idx))
    for fi in range(fmax_idx):
        for ci in range(len(c)):
            k = 2.0*np.pi*f[fi]/(c[ci])
            disp[ci, fi] = np.abs(
                np.dot(dx * np.exp(1.0j*k*x), U[:, fi]/np.abs(U[:, fi])))

    return f[:fmax_idx], c, disp
"""
surfacewaves.py
"""
def _tcrop(t):
    """Crop time axis with even number of samples"""
    if len(t) % 2 == 0:
        t = t[:-1]
        warnings.warn("one sample removed from time axis...")
    return t


def ormsby(t, f=(5.0, 10.0, 45.0, 50.0), taper=None):
    r"""Ormsby wavelet

    Create a Ormsby wavelet given time axis ``t`` and frequency range
    defined by four frequencies which parametrize a trapezoidal shape in
    the frequency spectrum.

    Parameters
    ----------
    t : :obj:`numpy.ndarray`
        Time axis (positive part including zero sample)
    f : :obj:`tuple`, optional
        Frequency range
    taper : :obj:`func`, optional
        Taper to apply to wavelet (must be a function that
        takes the size of the window as input

    Returns
    -------
    w : :obj:`numpy.ndarray`
        Wavelet
    t : :obj:`numpy.ndarray`
        Symmetric time axis
    wcenter : :obj:`int`
        Index of center of wavelet

    """
    def numerator(f, t):
        """The numerator of the Ormsby wavelet"""
        return (np.sinc(f * t) ** 2) * ((np.pi * f) ** 2)

    t = _tcrop(t)
    t = np.concatenate((np.flipud(-t[1:]), t), axis=0)
    f1, f2, f3, f4 = f

    pf43 = (np.pi * f4) - (np.pi * f3)
    pf21 = (np.pi * f2) - (np.pi * f1)
    w = (
        (numerator(f4, t) / pf43)
        - (numerator(f3, t) / pf43)
        - (numerator(f2, t) / pf21)
        + (numerator(f1, t) / pf21)
    )
    w = w / np.amax(w)
    wcenter = np.argmax(np.abs(w))

    # apply taper
    if taper is not None:
        w *= taper(len(t))

    return w, t, wcenter

def surfacewavedata(nt, dt, nx, dx, nfft, fdisp, vdisp, wav):
    r"""Surface wave data
    Synthetise surface wave only seismic data from dispersion relation
    Parameters
    ----------
    nt : :obj:`int`         number of time samples
    dt : :obj:`float`         time sampling 
    nx : :obj:`int`         number of spatial samples
    dx : :obj:`float`         spatial sampling
    nx : :obj:`int`        number of fft samples
    fdisp : :obj:`int`         frequency axis of dispersion relation
    vdisp : :obj:`int`         velocity axis of dispersion relation in km/s
    wav : :obj:`numpy.ndarray`         source wavelet
    Returns
    -------
    dshift : :obj:`numpy.ndarray`
    """
    # Axes and gridded phase velocity
    t, x = np.arange(nt)*dt, np.arange(nx)*dx
    f = np.fft.rfftfreq(nfft, dt)
    vf = np.interp(f, fdisp, vdisp)

    # Create data
    data = np.outer(wav, np.ones(nx)).T
    D = np.fft.rfft(data, n=nfft, axis=1)

    # Define and apply freq-dependant shifts
    shifts = np.outer(x, 1e-3/vf)
    shifts = np.exp(-1j * 2 * np.pi * f[np.newaxis, :] * shifts)

    Dshift = D * shifts
    dshift = np.fft.irfft(Dshift, n=nfft, axis=1)[:, :nt]

    return dshift, f, vf


## Dispersion of Surface waves
Velicity dispersion panels can be seen as a FK transform stretched over the K axis. There are different approaches to compute dispersion panels; here we usie the Park method.

1) Consider a a parametrized layered medium, where each layer is defined by the following 4 parameters: $\left(\Delta z,v_{p},\rho\right)$.
2) In a layered medium, the elastic wave equation can be turned ino an eigenvalue problem whose solution is the surface wave dispersion curve. This is referred in the literature as Dunkin's matrix algorithm.
3) Parameters
| Material | Thickness(km) | $V_P$(km/s) | $V_S$(km/s) | Density(g/cm3) |
| :-: | -: | -: | -: | -: |
| Firn | 0.1 | 2.5-3 | 1.2–2 | 0.7 |
| Ice | 2 | 3.8 | 2 | 0.9 |
| Granite | $\infty$ | 4500 | 3500 | 2.8 |


In [None]:
"""
====================== A dispersion curve for the above model ======================
"""
# 
#------ Define thickness(km), Vp(km/s), Vs(km/s), density(g/cm3)
thick = np.array([0.1, 2, 10])
vs = np.array([1.5, 2, 3.5])
true_model = np.vstack([thick, vs*4, vs, np.ones(3)]).T

# Frequency axis
fdisp = np.linspace(2, 50, 101)

# Periods (must be sorted starting with low periods)
period = np.flipud(1/fdisp)

# Rayleigh-wave fundamental model dispersion curve 
pd = PhaseDispersion(*true_model.T)
cpr = pd(period, mode=0, wave="rayleigh") 
vdisp = np.flipud(cpr[1]) # flip it back to show it as function of f instead of period

plt.figure(figsize=(8, 4))
plt.plot(fdisp, vdisp, 'k', lw=4)
plt.xlabel('f [hz]')
plt.ylabel('v [km/s]')
plt.title('Rayleigh-wave dispersion curve');


### Synthetic seismogram of surface-wave only shot gather


In [None]:
#------  Axes
nt = 600 # number of time samples
dt = 0.008 # time sampling in s
nx = 201 # number of spatial samples
dx = 2 # spatial sampling in m
nfft = 2**10
t, x = np.arange(nt)*dt, np.arange(nx)*dx
# Wavelet
wav = ormsby(t[:nt//2+1], f=[2, 4, 38, 40], taper=np.hanning)[0][:-1]
wav = np.roll(np.fft.ifftshift(wav), 20) # apply small shift to make it causal

dshift, f, vf = surfacewavedata(nt, dt, nx, dx, nfft, fdisp, vdisp, wav)

In [None]:
# Convert from f-kx to f-velocity
nvel = 401
vlims = (1000, 2000)
vdispsp = np.linspace(vlims[0], vlims[1], nvel)
Ddispsp = parkdispersion(dshift, dx, dt, vlims[0], vlims[1], vdispsp[1]-vdispsp[0], f[-1])[2]
#
fig, axs = plt.subplots(1, 2, figsize=(8, 4), gridspec_kw={'width_ratios': [1, 2]})
axs[0].imshow(dshift.T, cmap='gray', vmin=-.1, vmax=1, extent=[x[0], x[-1], t[-1], t[0]])
axs[0].axis('tight')
axs[0].set_xlabel('Offset [m]')
axs[0].set_ylabel('Time [s]')
axs[0].set_title('Shot gather')
axs[1].imshow(Ddispsp, origin='lower', extent=(f[0], f[-1], vdispsp[0], vdispsp[-1]))
axs[1].plot(f, vf*1e3, 'w', lw=4)
axs[1].axis('tight')
axs[1].set_xlim(0, 45)
axs[1].set_ylabel('Velocity [m/s]')
axs[1].set_xlabel('Frequency [Hz]')
axs[1].set_title('Dispersion panel');