### General

In [4]:
from scipy import integrate

### Geo

In [1]:
def haversine_array(latG, lonG, lat_ref,lon_ref):

    DHS = []
    for i in range(0,len(lonG)):
        DHS.append(hs.haversine( (latG[i], lonG[i]), (lat_ref,lon_ref) ))
     
    DHS = np.array(DHS)
    return DHS

In [2]:
def km_to_lat_lon(delta_km, ref_lat):
    """
    Convertir une distance en km en delta latitude et delta longitude.
    delta_km : distance en kilomètres
    ref_lat : latitude de référence en degrés
    """
    # Conversion de la distance en delta latitude
    delta_lat = delta_km / 111.32  # Un degré de latitude ~ 111.32 km

    # Conversion de la distance en delta longitude (dépend de la latitude)
    delta_lon = delta_km / (111.32 * np.cos(np.radians(ref_lat)))
    return delta_lat, delta_lon

### Cones

In [None]:
import math

def iw_cones(N2,D,lat,omega):
    s0 = N2.shape[0]
    s1 = N2.shape[1]
    #print(s1)
    
    A  = N2 * np.nan
    B  = N2 * np.nan
    C  = N2 * np.nan
    DHA = N2 * np.nan 
    DHB = N2 * np.nan 
    for i in range(0,s1):
        w = omega[i]
        for j in range(0,s0):
            
            fcor = gsw.geostrophy.f(lat[i])
            f2 = fcor**2
            w2 = w**2
            n2 = np.abs(N2[j,i])
            #
            x = np.sqrt( ( n2 - w2) / (w2 - f2) )
            alph = np.arctan(x)
            #
            #x2 = np.sqrt( ( w2 - f2) / (n2 - w2) )
            #alph = np.arctan(x2)
            
            C[j,i] = np.degrees(alph)
            A[j,i] = np.degrees(alph)
            B[j,i] = 90-A[j,i]
            beta = np.radians(B[j,i] )
            
            # tan (alpha) = y / x
            # y = 1
            # x = 
            DHB[j,i] = 1 / np.tan(beta)
            DHA[j,i] = 1 / np.tan(alph)

    DH = DHB
    CH_up = N2 * np.nan
    CH_dw = N2 * np.nan
    for i in range(0,s1):
        x = DH[:,i]
        f = np.where(np.isfinite(x)==True)[0]
        x = x[f]
        CH_up[f,i] = np.cumsum(x[::-1])[::-1]
        CH_dw[f,i] = np.cumsum(x)
        
        ff = np.where(np.isfinite(D[:,i])==True)[0]
        CH_up[ff,i] = interpholes(D[f,i],CH_up[f,i],D[ff,i])
        CH_dw[ff,i] = interpholes(D[f,i],CH_dw[f,i],D[ff,i])
    #print('done')
    return A, B, C, DH, CH_up, CH_dw

In [None]:
def iw_cones_fz(N2,D,lat,omega):
    s0 = N2.shape[0]
    s1 = N2.shape[1]
    A  = N2 * np.nan
    B  = N2 * np.nan
    C  = N2 * np.nan
    DHA = N2 * np.nan 
    DHB = N2 * np.nan 
    for i in range(0,s1):
        for j in range(0,s0):
            w = omega[j,i]
            fcor = gsw.geostrophy.f(lat[i])
            f2 = fcor**2
            w2 = w**2
            n2 = np.abs(N2[j,i])
            #
            x = np.sqrt( ( n2 - w2) / (w2 - f2) )
            alph = np.arctan(x)
            #
            #x2 = np.sqrt( ( w2 - f2) / (n2 - w2) )
            #alph = np.arctan(x2)
            
            C[j,i] = np.degrees(alph)
            A[j,i] = np.degrees(alph)
            B[j,i] = 90-A[j,i]
            beta = np.radians(B[j,i] )
            
            # tan (alpha) = y / x
            # y = 1
            # x = 
            DHB[j,i] = 1 / np.tan(beta)
            DHA[j,i] = 1 / np.tan(alph)

    DH = DHB
    CH_up = N2 * np.nan
    CH_dw = N2 * np.nan
    for i in range(0,s1):
        x = DH[:,i]
        f = np.where(np.isfinite(x)==True)[0]
        x = x[f]
        CH_up[f,i] = np.cumsum(x[::-1])[::-1]
        CH_dw[f,i] = np.cumsum(x)
        
        
       # ff = np.where(np.isfinite(D[:,i])==True)[0]
       # if len(ff)>2:
       #     CH_up[ff,i] = interpholes(D[f,i],CH_up[f,i],D[ff,i])
       #     CH_dw[ff,i] = interpholes(D[f,i],CH_dw[f,i],D[ff,i])
    #print('done')
    return A, B, C, DH, CH_up, CH_dw

In [None]:
from scipy.interpolate import griddata

def get_cones_lonlat(r,z,Ntheta,lonG,latG):
    #r = 1e-3*c10
    #z = y10
    # Ntheta = 300
    
    
    # Paramètres
    earth_radius_km = 6371  # Rayon de la Terre en kilomètres
    theta = np.linspace(0, 2 * np.pi, Ntheta)

    # Grille polaire
    Theta, Radius = np.meshgrid(theta, r)
    ZZR = np.tile(z[:, np.newaxis], (1, len(theta)))

    # Conversion en coordonnées cartésiennes (X, Y)
    X = Radius * np.cos(Theta)
    Y = Radius * np.sin(Theta)

    # Conversion en latitude et longitude
    ref_lat = lat_ctd[i]
    lat_r = lat_ctd[i] + (Y / earth_radius_km) * (180 / np.pi)
    lon_r = lon_ctd[i] + (X / (earth_radius_km * np.cos(np.radians(ref_lat)))) * (180 / np.pi)

    lonG_grid, latG_grid = np.meshgrid(lonG, latG)

    # Interpolation de Z sur la grille régulière (lon, lat)
    ZZC = griddata((lon_r.ravel(), lat_r.ravel()), ZZR.ravel(), (lonG_grid, latG_grid), method='linear')
   
    return ZZC

In [3]:
def get_r_z_from_iref(iprofile,iref,DH,D,debug):
    x1 = 1e-3*DH[iref::,iprofile]
    y1 = D[iref::,iprofile]
    f1 = np.where(np.isfinite(x1)==True)[0]
    c12 = np.cumsum(x1[f1])
    y12 = y1[f1]

    x2 = 1e-3*DH[0:iref,iprofile][::-1]
    y2 = D[0:iref,iprofile][::-1]
    f2 = np.where(np.isfinite(x2)==True)[0]
    c10 = np.cumsum(x2[f2])
    y10 = y2[f2]

    if debug == True:
        fig , axs =  plt.subplots(1, 1, figsize=(2,3))
        ax = axs
        ax.plot(-c12,y12,lw=1)
        ax.plot(+c12,y12,lw=4)
        ax.plot(0,D[iref,iprofile],'ko')
        ax.plot(-c10,y10,lw=1)
        ax.plot(+c10,y10,lw=4)
        ax.invert_yaxis()
        ax.set_xlabel('Distance (km)')
        ax.set_ylabel('Depth (m)')
        
    return c10,y10,c12,y12

### PSD general

In [5]:
def psd(
    g,
    dx,
    axis=1,
    ffttype="p",
    detrend=True,
    window="hamming",
    tser_window=None,
    tser_overlap=None,
):
    """
    Compute power spectral density.

    Adapted from Jen MacKinnon.

    Parameters
    ----------
    g : array-like
        Real or complex input data of size [M * 1], [1 * M] or [M * N]
    dx : float
        Distance or time between entries in g
    axis : int, optional
        Axis along which to fft: 0 = columns, 1 = rows (default)
    ffttype : str, optional
        Flag for fft type

        - 'p' for periodic boundary conditions = pure fft (default)

        - 'c' or 's' if the input is a sum of cosines or sines only in which
          case the input is extended evenly or oddly before ffting.

        - 't' for time series, or any other non-exactly periodic series in
          which the data should be windowed and filtered before computing the
          periodogram. In this case you may also specify:

          * tser_window: an integer that gives the length of each window
            the series should be broken up into for filtering and
            computing the periodogram.  Default is length(g).

          * tser_overlap: an integer equal to the lengh of points
            overlap between sucessive windows. Default = tser_window/2,
            which means 50% overlap.
    detrend : bool, optional
        Detrend along dim by removing a linear fit using scipy.detrend().
        Defaults to True.
    window : str, optional
        Window type. Default 'hamming'. See scipy.signal.get_window() for
        window options.

    Returns
    -------
    Pcw : array-like
        Clockwise power spectral density
    Pccw : array-like
        Counter-clockwise power spectral density
    Ptot : array-like
        Total = ccw + cw psd one-sided spectra
    omega: array-like
        Vector of wavenumbers. For plotting use frequency f = omega/(2*pi).
    """
    M0 = g.shape
    g = np.atleast_2d(g)
    if detrend:
        g = signal.detrend(g)

    # FFT and welch act on rows by default. If we want to calculate column-wise,
    # simply transpose the input matrix.
    if axis == 0 & len(M0) > 1:
        g = np.transpose(g)
    M = g.shape

    # If sin or cos transform needed, appropriately extend time series
    if ffttype == "s":
        g_ext = np.concatenate(
            (g, -1 * np.flip(g[:, 1 : int(M[1]) - 1], axis=1)), axis=1
        )
        g = g_ext
        M = g.shape
    if ffttype == "c":
        g_ext = np.concatenate((g, np.flip(g[:, 1 : int(M[1]) - 1], axis=1)), axis=1)
        g = g_ext
        M = g.shape

    # Setup frequency vectors
    df = 1 / (M[1] * dx)
    domega = 2 * np.pi * df
    # full frequency output length M
    f_full = np.linspace(0, (M[1] - 1) * df, num=M[1], endpoint=True)
    if np.remainder(M[1], 2) == 0:  # even -> length M/2+1
        f, step = np.linspace(
            0, M[1] * df / 2, num=int(M[1] / 2 + 1), endpoint=True, retstep=True
        )
    else:  # odd -> length (M+1)/2
        f, step = np.linspace(
            0, (M[1] - 1) * df / 2, num=int((M[1] + 1) / 2), endpoint=True, retstep=True
        )
    assert step - df < df / 1000
    omega = 2 * np.pi * f

    # Compute power spectra using fft
    if ffttype in ["p", "c", "s"]:
        P0 = fftpack.fft(g, axis=-1)
        # Normalize by wavenumber in RADIANS
        Pxx = P0 * np.conj(P0) / M[1] / (M[1] * domega)

    if ffttype == "t":
        if tser_window is None:
            # default one segment
            tser_window = np.array(np.fix(M[1]))
        if tser_overlap is None:
            # default 50% overlap
            tser_overlap = np.fix(tser_window / 2)

        f0, Pxx0 = signal.welch(
            g,
            fs=1 / dx,
            axis=-1,
            nperseg=tser_window,
            noverlap=tser_overlap,
            window=window,
            return_onesided=False,
        )

        # One side of f0 is shifted to negative values - shift everything to
        # positive.
        f0 = np.fft.fftshift(f0) + np.absolute(f0).max()

        # Interpolate to f_full frequency vector
        Pxx0 = interp1d(f0, Pxx0, bounds_error=False, axis=-1)(f_full)
        Pxx = Pxx0.copy()
        Pxx = Pxx / 2 / np.pi  # normalize by radial wavenumber/frequency

    # Separate into cw and ccw spectra
    if np.remainder(M[1], 2) == 0:
        # even lengths, divide 0 and nyquist freq between ccw and cw
        Pccw = np.concatenate(
            (
                np.reshape(0.5 * Pxx[:, 0], (-1, 1)),
                Pxx[:, 1 : int((M[1]) / 2)],
                np.reshape(0.5 * Pxx[:, int(M[1] / 2) + 1], (-1, 1)),
            ),
            axis=1,
        )
        Pcw = np.concatenate(
            (
                np.reshape(0.5 * Pxx[:, 0], (-1, 1)),
                np.flip(Pxx[:, int(M[1] / 2 + 1) : M[1]], axis=1),
                np.reshape(0.5 * Pxx[:, int(M[1] / 2) + 2], (-1, 1)),
            ),
            axis=1,
        )
        Ptot = Pccw + Pcw
    else:
        # odd lengths, divide 0 freq between ccw and cw
        Pccw = np.concatenate(
            (np.reshape(0.5 * Pxx[:, 0], (-1, 1)), Pxx[:, 1 : int((M[1] + 1) / 2)]),
            axis=1,
        )
        Pcw = np.concatenate(
            (
                np.reshape(0.5 * Pxx[:, 0], (-1, 1)),
                np.flip(Pxx[:, int(((M[1] + 3) / 2)) - 1 : M[1]], axis=1),
            ),
            axis=1,
        )
        Ptot = Pccw + Pcw

    Ptot = np.squeeze(Ptot)
    Pcw = np.squeeze(Pcw)
    Pccw = np.squeeze(Pccw)

    # transpose back if axis=0
    if axis == 0 & len(M0) > 1:
        Ptot = Ptot.transpose()
        Pcw = Pcw.transpose()
        Pccw = Pccw.transpose()

    return Pcw, Pccw, Ptot, omega

### Psd shear strain

In [None]:
from scipy import signal
from scipy.signal import welch, tukey



def PSDst(NFFT,dz,xi,xiz, percent_windowing,varloss):
    nperseg = NFFT
    noverlap = nperseg/2
    scaling = 'spectrum'
    LFFT = NFFT * dz
    dk = 2*np.pi / LFFT
    #
    m, psd_strain = welch(xi, dk,
        window= signal.windows.hamming(NFFT),
        #window=signal.windows.tukey(NFFT,percent_windowing),
        nperseg=nperseg,
        noverlap=noverlap,
        detrend='linear',
        return_onesided=True,scaling=scaling)
    m = m[1:]
    psd_strain=varloss*psd_strain[1:]
    psd_strain_xi = (psd_strain)*(KZ**2) /dk 
    
    m, psd_strain = welch(xiz, dk,
        window= signal.windows.hamming(NFFT),
        #window=signal.windows.tukey(NFFT,percent_windowing),
        nperseg=nperseg,
        noverlap=noverlap,
        detrend='linear',
        return_onesided=True,scaling=scaling)
    m = m[1:]
    psd_strain_xiz = varloss*psd_strain[1:]/dk 
    


    
    return psd_strain_xi, psd_strain_xiz, m 


In [None]:
from scipy import signal
from scipy.signal import welch, tukey


def PSDsh(NFFT,dz,u,v,uz,vz,KZ, N2seg,percent_windowing,varloss):
    nperseg = NFFT
    noverlap = nperseg/2
    scaling = 'spectrum'
    LFFT = NFFT * dz
    dk = 2*np.pi / LFFT
    
    #
    m, psd_u = welch(u, dk,
        window= signal.windows.hamming(NFFT),
        #window=signal.windows.tukey(NFFT,percent_windowing),
        nperseg=nperseg,
        noverlap=noverlap,
        detrend='linear',
        nfft=None,
        return_onesided=True,scaling=scaling)
    psd_u=varloss*psd_u[1:]
    
    #
    m, psd_v = welch(v, dk,
        window= signal.windows.hamming(NFFT),
        #window=signal.windows.tukey(NFFT,percent_windowing),
        nperseg=nperseg,
        noverlap=noverlap,
        detrend='linear',
        nfft=None,
        return_onesided=True,scaling=scaling)
    psd_v=varloss*psd_v[1:]
    #
    psd_u_psd_v = (psd_u)+(psd_v)   
    
    #------------------------
    m, psd_uz = welch(uz, dk,
        window= signal.windows.hamming(NFFT),
        #window=signal.windows.tukey(NFFT,percent_windowing),
        nperseg=NFFT,
        noverlap=noverlap,
        detrend='linear',
        nfft=None,
        return_onesided=True,scaling=scaling)
    psd_uz=varloss*psd_uz[1:]
    
    #
    m, psd_vz = welch(vz, dk,
        window= signal.windows.hamming(NFFT),
        #window=signal.windows.tukey(NFFT,percent_windowing),
        nperseg=NFFT,
        noverlap=noverlap,
        detrend='linear',
        nfft=None,
        return_onesided=True,scaling=scaling)
    psd_vz=varloss*psd_vz[1:]
    #
    psd_uz_psd_vz = (psd_uz)+(psd_vz)   
        
    psd_uv        =  psd_u_psd_v / dk 
    psd_shear_uv  = ( (psd_u_psd_v* KZ**2)   / N2seg ) / dk
    psd_shear_uz  = (  psd_uz_psd_vz / N2seg ) / dk 
    return psd_uv, psd_shear_uv, psd_shear_uz, m

### Psd integration

In [7]:
def integrate_spe(im1,im2,dk,sat,spe,PRINT):
    imm = np.arange(im1,im2+1,1)
    cumspe = np.cumsum(spe[imm])*dk
    
    
    f = np.where(cumspe <= sat)[0]; #print(f)
    if len(f)<=1: ### ou aucun, ou 0
        ic = im1
    else : ic = im1 + f[-1] # imm[f[-1]]
    ic = np.nanmin([ic,im2])
    #
    imc = np.arange(im1,ic +1,1)

    va_imm = np.trapz(spe[imm],KZ[imm])
    
    if len(imc)==1:
        va_imc = np.sum(spe[imc])*dk
    else: va_imc = np.trapz(spe[imc],KZ[imc])
    
    #PRINT = 0
    if PRINT == True:
        print('im1,im2, va[imm]:',im1,im2,va_imm,'\nic:   ',ic,'  ','\nim1,imc, va[imc]:',im1,ic,va_imc)
    
    return va_imm, imm,   va_imc, imc,   im1,im2,ic,    sat


In [4]:
def integrate_spe_simple(im1,im2,dk,spe,PRINT):
    imm = np.arange(im1,im2+1,1)
    cumspe = np.cumsum(spe[imm])*dk

    if len(imm)==1:
        va_imm = np.sum(spe[imm])*dk
    else: va_imm = np.trapz(spe[imm],KZ[imm])
    
    #PRINT = 0
    if PRINT == True:
        print('im1,im2, va[imm]:',im1,im2,va_imm)
    
    return va_imm, imm, cumspe


### Strain

In [8]:
def calculate_strain(N2_hi,N2_lo,N2M,i1,i2,ip,RHO_hi):
    N2hi = N2_hi[i1:i2,ip]
    N2lo = N2_lo[i1:i2,ip]
    
    N2mea  = np.nanmean(N2hi)
    N2seg  = np.nanmean(N2lo)
    N2tim  = N2M[i1:i2] # optionnal
    xiz_mea  = (N2hi - N2mea)/N2seg
    xiz_tim  = (N2hi - N2tim)/N2seg
    
    polyf  = np.polyfit(np.arange(i1,i2,1), N2hi, 2)
    polyd  = np.poly1d(polyf)
    N2fit  = polyd(np.arange(i1,i2,1))
    xiz_fit  = (N2hi - N2fit)/N2seg
    
    RHOhi = RHO_hi[i1:i2,ip]
    xi = (9.81/1e3)*(RHOhi/N2seg)
                

    return xi, xiz_mea, xiz_fit, xiz_tim, N2seg, N2mea, N2hi, N2lo, N2tim



### Utility

In [None]:
def wavelengthes(LZ):
    LZround = np.round(LZ).astype(int)
    print('Wavelengthes available')
    ui = np.unique(LZround)
    print(-np.sort(-ui))

### GM

In [6]:
def gm_shear_variance(m, iim, N):
    r"""
    GM model shear variance

    Parameters
    ----------
    m : array-like
        Vertical wavenumber vector [rad/m]
    iim : array-like
        Wavenumber integration range, indexer to m
    N : float
        Local buoyancy frequency [s^-1]

    Returns
    -------
    Sgm : float
        GM shear variance normalized by N^2 [1/m^2].
    Pgm : array-like
        GM shear spectrum for wavenumber range `m`.

    Notes
    -----
    Returns GM shear variance normalized by buoyancy frequency by integrating
    `Pgm` over a wavenumber range of the GM shear spectrum  as presented in
    Kunze et al. (2006) :cite:`Kunze2006` eq. 6:

    .. math::

        \frac{_{GM}\left< V_z^2\right>}{\overline{N}^2} = \frac{3 \pi E_0 b j_\ast}{2} \int_{m_\mathrm{min}}^{m_\mathrm{max}} \frac{m^2 dm}{(m + m_\ast)^2}

    with :math:`j_\ast=3`, :math:`E_0=6.3\times10^{-5}`, :math:`N_0=5.2\times
    10^{-3}` rad/s, :math:`b=1300` m, and

    .. math::

        m_\ast = \frac{\overline{N}}{N_0}\frac{\pi j_\ast}{b}

    See also
    --------
    gm_strain_variance : GM strain variance
    """
    N0 = 5.24e-3  # reference buoyancy frequency = 3 cph
    b = 1300  # thermocline scale depth
    jstar = 3
    E0 = 6.3e-5  # GM76 energy level
    Pgm = (
        (3 * np.pi * E0 * b * jstar / 2) * m**2 / (m + jstar * np.pi / b * N / N0) ** 2
    )
    # integrate
    Sgm = np.trapz(y=Pgm[iim], x=m[iim])
    return Sgm, Pgm

In [None]:
def gm_strain_variance(m, iim, N):
    r"""
    GM model strain variance.

    Parameters
    ----------
    m : array-like
        Vertical wavenumber vector [rad/m].
    iim : array-like
        Wavenumber integration range, indexer to `m`.
    N : float
        Local buoyancy frequency [s^-1].

    Returns
    -------
    Sgm : float
        GM strain variance normalized by N^2 [1/m^2].
    Pgm : array-like
        GM strain spectrum for wavenumber range `m`.

    Notes
    -----
    Returns GM strain variance by integrating `Pgm` over a wavenumber range of
    the GM shear spectrum  as presented in Kunze et al. (2006)
    :cite:`Kunze2006` eq. 10:

    .. math::

        \frac{_{GM}\left< V_z^2\right>}{\overline{N}^2} = \frac{\pi E_0 b j_\ast}{2} \int_{m_\mathrm{min}}^{m_\mathrm{max}} \frac{m^2 dm}{(m + m_\ast)^2}

    with :math:`j_\ast=3`, :math:`E_0=6.3\times10^{-5}`, :math:`N_0=5.2\times
    10^{-3}` rad/s, :math:`b=1300` m, and

    .. math::

        m_\ast = \frac{\overline{N}}{N_0}\frac{\pi j_\ast}{b}

    Note that this corresponds to the buoyancy-normalized GM shear variance
    divided by 3.

    See also
    --------
    gm_shear_variance : GM shear variance

    """
    N0 = 5.24e-3  # reference buoyancy frequency = 3 cph
    b = 1300  # thermocline scale depth
    jstar = 3
    E0 = 6.3e-5  # GM energy level
    Pgm = (np.pi * E0 * b * jstar / 2) * m**2 / (m + jstar * np.pi / b * N / N0) ** 2
    # integrate
    Sgm = np.trapz(y=Pgm[iim], x=m[iim])
    return Sgm, Pgm

In [None]:
# PARAM GM
def GM_model(KZ,dk,N2m):
    latref =  30
    e0 = 7e-10
    Elevel = 6.3e-5
    bGM = 1300 # m
    Jstar = 3
    N0 = 5.24e-3
    
    K0 = 5e-6
    kzstar = np.pi*Jstar* (np.sqrt(N2m)  / N0) * (1/bGM) # m-1
    cstGM = (np.pi*Elevel*Jstar*bGM)/2          # m
    terme = cstGM*( KZ**2 )/ (KZ+kzstar)**2     # m    m-2   /   m-2 
    KZ_GM = KZ # m-1
    spectrum_strain_GM = terme # m
    spectrum_shear_GM = 3.*terme 
    return KZ_GM, spectrum_shear_GM, spectrum_strain_GM

### Rw

In [None]:
def omega_to_rw(omega,fcor):
    rw_omega = (omega**2 + fcor**2)/(omega**2 - fcor**2)
    return rw_omega

def rw_to_omega(FCOR,R):
    OMEGA = np.sqrt( ( FCOR**2 +  R * FCOR**2 ) / (R - 1)  )
    return OMEGA

In [11]:
def beam_iw_angle(omega,n1m,fcor):
    x = np.sqrt( (n1m**2 - omega**2) / (omega**2 - fcor**2) )
    alph = np.arctan(x)
    beta = 90-np.degrees(alph)
    return alph, beta

### f(Rw)

In [9]:
def f_rw_shst(rw):
    frw_sh =  ( 3*(rw+1) ) / ( 2*np.sqrt(2)*rw*np.sqrt(rw-1))
    return frw_sh

def f_rw_st(rw):
    frw_st = (rw*(rw+1)) * (1/(6*np.sqrt(2))) * (1/(np.sqrt(rw-1)))
    return frw_st

def h_rw_GHP(rw,lN,l0):
    hrw_GHP = ((1+1/rw)/(4/3)) * (lN/l0) * np.sqrt(2 / (rw -1))
    return hrw_GHP

def h_rw_HI(rw,l0,l1,l2):
    if rw <  9: hrw_HI = ((1+1/rw)/(4/3)) * (l1 / l0) * (rw**(-l2))
    elif rw >= 9: hrw_HI = ((1+1/rw)/(4/3)) * (1 /l0)   * np.sqrt(2 / (rw -1))
    else : hrw_HI = np.nan
    return hrw_HI


def hprime_rw_HI(rw,l0,l1,l2):
    if rw <  9:   hrw_HI = ((rw/3)**2) * ((1+1/rw)/(4/3)) * (l1 / l0) * (rw**(-l2))
    elif rw >= 9: hrw_HI = ((rw/3)**2) * ((1+1/rw)/(4/3)) * (1 /l0)   * np.sqrt(2 / (rw -1))
    else : hrw_HI = np.nan
    return hrw_HI


### Parameterizations

In [None]:
def eps_shearstrain(eps0, Nm, N0, var2, rw, lfN):
    """
    Parameterized turbulent kinetic energy dissipation rate

    Parameters
    ----------
    eps0 : float
        GM reference dissipation rate
    Nm : float
        background stratification in window
    N0 : float
        GM reference stratification
    Ssh : float
        band-integrated observed shear variance normalized by mean stratification
    Sshgm : float
        band-integrated GM shear variance normalized by GM reference stratification
    Rw : float
        shear/strain ratio
    f : float
        Coriolis frequency
    """
    return (
        eps0
        * (Nm**2 / N0**2)
        * var2
        * (f_rw_shst(rw) * lfN)
    )



def eps_strain(eps0, Nm, N0, var2, rw, lfN):
    """
    Parameterized turbulent kinetic energy dissipation rate

    Parameters
    ----------
    eps0 : float
        GM reference dissipation rate
    Nm : float
        background stratification in window
    N0 : float
        GM reference stratification
    Sst : float
        band-integrated observed strain normalized by mean stratification
    Sstgm : float
        band-integrated GM strain variance normalized by GM reference stratification
    Rw : float
        shear/strain ratio
    f : float
        Coriolis frequency
    """
    return (
        eps0
        * (Nm**2 / N0**2)
        * var2
        * (f_rw_st(rw) * lfN)
    )


def diffusivity(eps, N, Gam=0.2):
    r"""
    Calculate vertical diffusivity

    Parameters
    ----------
    eps : array-like
        Turbulent dissipation
    N : array-like
        Buoyancy frequency
    Gam : float, optional
        Mixing efficiency Gamma. Defaults to 0.2 as commonly used.

    Returns
    -------
    kappa : array-like
        Vertical diffusivity

    Notes
    -----
    Calculates vertical diffusivity from turbulent dissipation based on a
    constant mixing efficiency:

    .. math::

        \kappa = \Gamma \frac{\epsilon}{N^2}

    """
    return Gam * eps / N**2

### Spectral corrections (Turnherr, Polzin)

In [10]:
# https://journals.ametsoc.org/view/journals/atot/29/4/jtech-d-11-00158_1.xml

#dz_cel 
#dz_tra
def Traa(KZ,dz_cel,dz_tra):
    T1 = np.sinc(KZ * dz_cel / 2*np.pi)**2
    T2 = np.sinc(KZ * dz_tra / 2*np.pi)**2
    return T1*T2

#dz_fid
def Tfid(KZ,dz_fid):
    T1 = np.sinc(KZ * dz_fid / 2*np.pi)**2
    return T1

# dz_cel
def Tint(KZ,dz_cel):
    T1 = np.sinc(KZ * dz_cel / 2*np.pi)**4
    return T1

# dz_bin
def Tbin(KZ,dz_bin):
    T1 = np.sinc(KZ * dz_bin / 2*np.pi)**2
    return T1

def Ttil(KZ,dz_til):
    T1 = np.sinc(KZ * dz_til / 2*np.pi)**2
    return T1

def Tsup(KZ,dz_sup):
    T1 = np.sinc(KZ * dz_sup / 2*np.pi)**2
    return T1
