In [None]:
#!/usr/bin/env python
# coding: utf-8

import os, sys
import logging

import numpy as np
import h5netcdf

import xarray as xr
import xarray.ufuncs as xu
import xrft
import pandas as pd
from scipy.signal import convolve2d, detrend

from matplotlib import pyplot as plt

import cartopy.crs as ccrs
import cartopy

plt.rc("figure", figsize=(12,10))
plt.rc("font", size=14)

from dask.distributed import Client, LocalCluster
#
# Initialisation d'un cluster de 32 coeurs
cluster = LocalCluster(processes=False, n_workers=1, threads_per_worker=4, silence_logs='error')
client = Client(cluster)
client

In [None]:
path = "/home/durand/Documents/OLR/"
# path = ""

indir_data = path + 'Anomaly/'
outdir_TF = path + 'TF2D/'

indir_data = '/cnrm/tropics/commun/DATACOMMUN/WAVE/NO_SAVE/DATA/RAW_ANOMALY/OLR/'
outdir_TF = '/cnrm/tropics/commun/DATACOMMUN/WAVE/NO_SAVE/DATA/TF2D/OLR/'

var = 'OLR'
prefix = 'TF2D'

addDay = 180
spd = 1

In [None]:
def split_hann_taper(series_length, fraction):
    """Implements `split cosine bell` taper of length series_length where only fraction of points are tapered (combined on both ends).
    
    This returns a function that tapers to zero on the ends. To taper to the mean of a series X:
    XTAPER = (X - X.mean())*series_taper + X.mean()
    """
    npts = int(np.rint(fraction * series_length))  # total size of taper
    taper = np.hanning(npts)
    series_taper = np.ones(series_length)
    series_taper[0:npts//2+1] = taper[0:npts//2+1]
    series_taper[-npts//2+1:] = taper[npts//2+1:]
    return series_taper, taper

def smooth_wavefreq(data, kern=None, nsmooth=None, freq_ax=None, freq_name=None):
    """Apply a convolution of (data,kern) nsmooth times.
       The convolution is applied separately to the positive and negative frequencies.
       Either the name (freq_name: str) or axis index (freq_ax: int) of frequency is required, with the name preferred.
    """
    assert isinstance(data, xr.DataArray)
    if kern is None:
        kern = simple_smooth_kernel()
    if nsmooth is None:
        nsmooth = 20
    if freq_name is not None:
        axnum = list(data.dims).index(freq_name)
        nzero =  data.sizes[freq_name] // 2 # <-- THIS IS SUPPOSED TO BE THE INDEX AT FREQ==0.0
    elif freq_ax is not None:
        axnum = freq_ax
        nzero = data.shape[freq_ax] // 2
    else:
        raise ValueError("smooth_wavefreq needs to know how to find frequency dimension.")
    smth1pass = convolvePosNeg(data, kern, axnum, nzero) # this is a custom function to skip 0-frequency (mean)
    # note: the convolution is strictly 2D and the boundary condition is symmetric --> if kernel is normalized, preserves the sum.
    smth1pass = xr.DataArray(smth1pass, dims=data.dims, coords=data.coords) # ~copy_metadata
    # repeat smoothing many times:
    smthNpass = smth1pass.values.copy()
    for i in range(nsmooth):
        smthNpass = convolvePosNeg(smthNpass, kern, axnum, nzero)
    return xr.DataArray(smthNpass, dims=data.dims, coords=data.coords)

def simple_smooth_kernel():
    """Provide a very simple smoothing kernel."""
    kern = np.array([[0, 1, 0],[1, 2, 1],[0, 1, 0]])
    return kern / kern.sum()

def convolvePosNeg(arr, k, dim, boundary_index):
    """Apply convolution of (arr, k) excluding data at boundary_index in dimension dim.
    
    arr: numpy ndarray of data
    k: numpy ndarray, same dimension as arr, this should be the kernel
    dim: integer indicating the axis of arr to split
    boundary_index: integer indicating the position to split dim
    
    Split array along dim at boundary_index;
    perform convolution on each sub-array;
    reconstruct output array from the two subarrays;
    the values of output at boundary_index of dim will be same as input.
    
    `convolve2d` is `scipy.signal.convolve2d()`
    """
    # arr: numpy ndarray
    oarr = arr.copy()  # maybe not good to make a fresh copy every time?
    # first pass is [0 : boundary_index)
    slc1 = [slice(None)] * arr.ndim
    slc1[dim] = slice(None, boundary_index)
    arr1 = arr[tuple(slc1)]
    ans1 = convolve2d(arr1, k, boundary='symm', mode='same')
    # second pass is [boundary_index+1, end]
    slc2 = [slice(None)] * arr.ndim
    slc2[dim] = slice(boundary_index+1,None)
    arr2 = arr[tuple(slc2)]
    ans2 = convolve2d(arr2, k, boundary='symm', mode='same')
    # fill in the output array
    oarr[tuple(slc1)] = ans1
    oarr[tuple(slc2)] = ans2
    return oarr

In [None]:
def  decompose2SymAsym(arr):
    """Mimic NCL function to decompose into symmetric and asymmetric parts.
    
    arr: xarra DataArray
    return: DataArray with symmetric in SH, asymmetric in NH
    Note:
        This function produces indistinguishable results from NCL version.
    """
#     lat_dim = arr.dims.index('lat')
#     print(lat_dim)
    # flag to follow NCL convention and put symmetric component in SH 
    # & asymmetric in NH
    # method: use flip to reverse latitude, put in DataArray for coords, use loc/isel
    # to assign to negative/positive latitudes (exact equator is left alone)
    _arr = arr.reindex(lat=arr.lat[::-1])
    _arr['lat'] = arr.lat
    data_sym = 0.5*(arr + _arr)
    data_asy = 0.5*(arr - _arr)
    data_sym = xr.DataArray(data_sym, dims=arr.dims, coords=arr.coords)
    data_asy = xr.DataArray(data_asy, dims=arr.dims, coords=arr.coords)
    out = arr.copy()  # might not be best to copy, but is safe        
    out.loc[{'lat':arr['lat'][arr['lat']<0]}] = data_sym.isel(lat=data_sym.lat<0)
    out.loc[{'lat':arr['lat'][arr['lat']>0]}] = data_asy.isel(lat=data_asy.lat>0)
    return out

def genDispersionCurves(nWaveType=6, nPlanetaryWave=50, rlat=0, Ahe=[50, 25, 12]):
    """
    Function to derive the shallow water dispersion curves. Closely follows NCL version.
    input:
        nWaveType : integer, number of wave types to do
        nPlanetaryWave: integer
        rlat: latitude in radians (just one latitude, usually 0.0)
        Ahe: [50.,25.,12.] equivalent depths
              ==> defines parameter: nEquivDepth ; integer, number of equivalent depths to do == len(Ahe)
    returns: tuple of size 2
        Afreq: Frequency, shape is (nWaveType, nEquivDepth, nPlanetaryWave)
        Apzwn: Zonal savenumber, shape is (nWaveType, nEquivDepth, nPlanetaryWave)
        
    notes:
        The outputs contain both symmetric and antisymmetric waves. In the case of 
        nWaveType == 6:
        0,1,2 are (ASYMMETRIC) "MRG", "IG", "EIG" (mixed rossby gravity, inertial gravity, equatorial inertial gravity)
        3,4,5 are (SYMMETRIC) "Kelvin", "ER", "IG" (Kelvin, equatorial rossby, inertial gravity)
    """
    nEquivDepth = len(Ahe) # this was an input originally, but I don't know why.
    pi    = np.pi
    radius = 6.37122e06    # [m]   average radius of earth
    g     = 9.80665        # [m/s] gravity at 45 deg lat used by the WMO
    omega = 7.292e-05      # [1/s] earth's angular vel
    # U     = 0.0   # NOT USED, so Commented
    # Un    = 0.0   # since Un = U*T/L  # NOT USED, so Commented
    ll    = 2.*pi*radius*np.cos(np.abs(rlat))
    Beta  = 2.*omega*np.cos(np.abs(rlat))/radius
    fillval = 1e20
    
    # NOTE: original code used a variable called del,
    #       I just replace that with `dell` because `del` is a python keyword.

    # Initialize the output arrays
    Afreq = np.empty((nWaveType, nEquivDepth, nPlanetaryWave))
    Apzwn = np.empty((nWaveType, nEquivDepth, nPlanetaryWave))

    for ww in range(1, nWaveType+1):
        for ed, he in enumerate(Ahe):
            # this loops through the specified equivalent depths
            # ed provides index to fill in output array, while
            # he is the current equivalent depth
            # T = 1./np.sqrt(Beta)*(g*he)**(0.25) This is close to pre-factor of the dispersion relation, but is not used.
            c = np.sqrt(g * he)  # phase speed   
            L = np.sqrt(c/Beta)  # was: (g*he)**(0.25)/np.sqrt(Beta), this is Rossby radius of deformation        

            for wn in range(1, nPlanetaryWave+1):
                s  = -20.*(wn-1)*2./(nPlanetaryWave-1) + 20.
                k  = 2.0 * pi * s / ll
                kn = k * L 

                # Anti-symmetric curves  
                if (ww == 1):       # MRG wave
                    if (k < 0):
                        dell  = np.sqrt(1.0 + (4.0 * Beta)/(k**2 * c))
                        deif = k * c * (0.5 - 0.5 * dell)
                    
                    if (k == 0):
                        deif = np.sqrt(c * Beta)
                    
                    if (k > 0):
                        deif = fillval
                    
                
                if (ww == 2):       # n=0 IG wave
                    if (k < 0):
                        deif = fillval
                    
                    if (k == 0):
                        deif = np.sqrt( c * Beta)
                    
                    if (k > 0):
                        dell  = np.sqrt(1.+(4.0*Beta)/(k**2 * c))
                        deif = k * c * (0.5 + 0.5 * dell)
                    
                
                if (ww == 3):       # n=2 IG wave
                    n=2.
                    dell  = (Beta*c)
                    deif = np.sqrt((2.*n+1.)*dell + (g*he) * k**2)
                    # do some corrections to the above calculated frequency.......
                    for i in range(1,5+1):
                        deif = np.sqrt((2.*n+1.)*dell + (g*he) * k**2 + g*he*Beta*k/deif)
                    
    
                # symmetric curves
                if (ww == 4):       # n=1 ER wave
                    n=1.
                    if (k < 0.0):
                        dell  = (Beta/c)*(2.*n+1.)
                        deif = -Beta*k/(k**2 + dell)
                    else:
                        deif = fillval
                    
                if (ww == 5):       # Kelvin wave
                    deif = k*c

                if (ww == 6):       # n=1 IG wave
                    n=1.
                    dell  = (Beta*c)
                    deif = np.sqrt((2. * n+1.) * dell + (g*he)*k**2)
                    # do some corrections to the above calculated frequency
                    for i in range(1,5+1):
                        deif = np.sqrt((2.*n+1.)*dell + (g*he)*k**2 + g*he*Beta*k/deif)
                
                eif  = deif  # + k*U since  U=0.0
                P    = 2.*pi/(eif*24.*60.*60.)  #  => PERIOD
                # dps  = deif/k  # Does not seem to be used.
                # R    = L #<-- this seemed unnecessary, I just changed R to L in Rdeg
                # Rdeg = (180.*L)/(pi*6.37e6) # And it doesn't get used.
            
                Apzwn[ww-1,ed-1,wn-1] = s
                if (deif != fillval):
                    # P = 2.*pi/(eif*24.*60.*60.) # not sure why we would re-calculate now
                    Afreq[ww-1,ed-1,wn-1] = 1./P
                else:
                    Afreq[ww-1,ed-1,wn-1] = fillval
    return  Afreq, Apzwn

In [None]:
year = np.arange(2004,2007)

filenames = np.arange(2001,2020)
datasets = []
for f in filenames:
    ds = xr.open_mfdataset(indir_data + 'anom_OLR_daily_brut_'+str(f)+'.nc', chunks={'lat': 1})
    ds = ds.sel(lat = slice(-15.1,15.1))
    datasets.append(ds)
    
ds = xr.concat(datasets, dim='time', coords='minimal', compat='override')
ds_NH = ds.where(ds.lat > 0, 0)
ds_SH = ds.where(ds.lat < 0, 0)


In [None]:
segsize = 96*spd
noverlap = 65
_ds = ds_NH['OLR_ano']

x_roll = _ds.rolling(time=segsize, min_periods=segsize)  # WK99 use 96-day window
x_win = x_roll.construct("segments", stride=noverlap).dropna("time")  

In [None]:
taper,tt = split_hann_taper(segsize, 0.1)  # try to replicate NCL's
x_wintap = x_win*taper 
x_wintap = x_win

In [None]:
x_wintap = x_wintap.chunk({"segments" : -1, "lat": 1})

tcwvhat = xrft.fft(x_wintap, detrend='linear',
            dim=['segments','lon'], true_phase=False, true_amplitude=True)
tcwvhat

In [None]:
tcwvhat_s = (tcwvhat*xu.conj(tcwvhat)).real
tcwvhat_s

In [None]:
# whatbhat = whatbhat.set_index(wavenumber= ['freq_lon'])
wavenumber = np.zeros(tcwvhat.freq_lon.size)
for i in range( tcwvhat.freq_lon.size):
    j= - int(360/2) + i
    wavenumber[i]= tcwvhat.freq_lon[int(360/2)+j]*360 + 1
tcwvhat_s['freq_lon'] = wavenumber
z_final = tcwvhat_s

z_final['freq_segments'] = spd * tcwvhat_s['freq_segments']
z_final = z_final.mean(dim='time').sum(dim='lat').load()
z_final

In [None]:
z_final = xr.DataArray(np.flip(z_final.values , axis = 1),
                       dims=("freq_lon","freq_segments"), 
                       coords={
                             "freq_lon":z_final.freq_lon,
                             "freq_segments":z_final.freq_segments})
z_final.name = 'power'

# z_final.loc[{'freq_segments':0}] = np.nan 
# z2avg = z_final.mean(dim='component')

### Spectre pour l'hemispère Nord

In [None]:
z_final = xu.log10(z_final) 
z_final.loc[{'freq_segments':0}] = np.nan 
z_final.plot.contourf(x='freq_lon', y = 'freq_segments', ylim = [-0,0.8], xlim =[-15,15], vmin = 7, vmax = 11, levels = 21, cmap='Spectral_r')
plt.grid()

In [None]:
#Cell to create dispersion curve
swfreq,swwn = genDispersionCurves()
# swfreq.shape # -->(6, 3, 50)
swf = np.where(swfreq == 1e20, np.nan, swfreq)
swk = np.where(swwn == 1e20, np.nan, swwn)

In [None]:
background = smooth_wavefreq(z_final, kern=simple_smooth_kernel(), nsmooth=50, freq_name='freq_segments')
background
test = background
test.loc[{'freq_segments':0}] = np.nan

_tcwvhat = xu.log10(test) 

_tcwvhat.plot.contourf(x='freq_lon', y = 'freq_segments', ylim = [-0,0.5], xlim =[-15,15], levels = 20,
                       cmap='Spectral_r')
plt.grid()

In [None]:
# separate components
z2 = z_final

# normalize
nspec = z2 / background 
nspec['freq_lon'] = wavenumber


In [None]:
plt.figure()
nspec.plot.contourf(x='freq_lon', y = 'freq_segments', ylim = [-0,0.8], xlim =[-15,15], levels = 21, vmin=0.95, vmax = 1.04,
                       cmap='Spectral_r')
for ii in range(0,3):
    plt.plot(swk[ii, 0,:], swf[ii,0,:], color='grey')
    plt.plot(swk[ii, 1,:], swf[ii,1,:], color='grey')
    plt.plot(swk[ii, 2,:], swf[ii,2,:], color='grey')
for ii in range(3,6):
    plt.plot(swk[ii, 0,:], swf[ii,0,:], color='grey', linestyle='dashed')
    plt.plot(swk[ii, 1,:], swf[ii,1,:], color='grey', linestyle='dashed')
    plt.plot(swk[ii, 2,:], swf[ii,2,:], color='grey', linestyle='dashed')
plt.grid()
plt.show()

In [None]:
plt.figure()
nspec_asy.plot.contourf(x='freq_lon', y = 'freq_segments', levels = 20,xlim =[-15,15], ylim = [-0,0.8], vmax= 1.5,
                       cmap='Spectral_r')
for ii in range(0,3):
    plt.plot(swk[ii, 0,:], swf[ii,0,:], color='grey')
    plt.plot(swk[ii, 1,:], swf[ii,1,:], color='grey')
    plt.plot(swk[ii, 2,:], swf[ii,2,:], color='grey')
plt.grid()
plt.show()

In [None]:
nspec_tot = nspec_sym + nspec_asy

In [None]:
plt.figure()
nspec_tot.plot.contourf(x='freq_lon', y = 'freq_segments', levels = 20,xlim =[-15,15], ylim = [-0,0.8], vmax= 3,
                       cmap='Spectral_r')
for ii in range(0,3):
    plt.plot(swk[ii, 0,:], swf[ii,0,:], color='grey')
    plt.plot(swk[ii, 1,:], swf[ii,1,:], color='grey')
    plt.plot(swk[ii, 2,:], swf[ii,2,:], color='grey')
for ii in range(3,6):
    plt.plot(swk[ii, 0,:], swf[ii,0,:], color='grey', linestyle='dashed')
    plt.plot(swk[ii, 1,:], swf[ii,1,:], color='grey', linestyle='dashed')
    plt.plot(swk[ii, 2,:], swf[ii,2,:], color='grey', linestyle='dashed')
plt.grid()
plt.show()