In [None]:
import xarray as xr
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
import pandas as pd
from glob import glob
import scipy.signal as signal

In [None]:

def filter_timeseries(ts, order, cutoff, btype='lowpass', fs=1, **kwargs):
    """Given an array, this function apply a butterworth (high/low pass) 
    filter of the given order and cutoff frequency.
    For example:
    If 'ts' is a timeseries of daily samples, filter_timeseries(ts,3,1/20)
    will return the series without the 20 days or less variability using an
    order 3 butterworth filter. 
    In the same way, filter_timeseries(ts,3,1/20, btype='highpass') will
    return the series with only the 20 days or less variability.

    Args:
        ts (array_like): timeseries or 1D array to filter
        order (int): _description_
        cutoff (array_like): Single float for lowpass or highpass filters, 
        arraylike for bandpass filters.
        btype (str, optional): The type of filter. Defaults to 'lowpass'.
        fs (int): Sampling frequency. Defaults to 1
        **kwargs are passed to scipy.signal.filtfilt

    Returns:
        output (array): Filtered array
    """
    mask = np.isnan(ts)
    nans = np.ones(len(ts))*np.nan
    if mask.sum()==len(ts):
        return nans
    else:
        b, a = signal.butter(order, cutoff, btype=btype, fs=fs)
        filt=signal.filtfilt(b, a, ts[~mask], **kwargs)
        output=np.ones(len(ts))*np.nan
        output[np.where(~mask)] = filt
        return output
    
def filter_xarray(data, dim, order, cutoff, btype='lowpass', parallel=False, fs=1):
    """Given a 3d DataArray, with time and spatial coordinates, this function apply
    the 1D function filter_timeseries along the time dimension, filter the complete
    xarray data.

    Args:
        data (XDataArray): data
        dim (str): name of the time dimension
        order (int): butterworth filter order
        cutoff (array_like): if float, the cutoff frequency, if array must be the
                            [min,max] frequencys for the bandpass filter.
        btype (str, optional): {lowpass,highpass,bandpass}. Defaults to 'lowpass'.
        parallel (bool, optional): If parallelize with dask. Defaults to False.
        fs (int, optional): Sampling frequency. Defaults to 1.

    Returns:
        XDataArray: filtered data
    """
    if parallel:
        dask='parallelized'
    else:
        dask='forbidden'
    filt = xr.apply_ufunc(filter_timeseries, data, order, cutoff, btype, fs,
                          input_core_dims=[[dim],[],[],[],[]],
                          output_core_dims=[[dim]],
                          exclude_dims=set((dim,)),
                          keep_attrs=True,
                          vectorize=True, dask=dask)
    filt[dim] = data[dim]
    return filt



In [None]:
tropical_glorys_clim       = xr.open_dataset('data/GLORYS12V1/HOVMOLLERS/CLIMATOLOGY/glorys_tropical.nc')
coastnorth_glorys_clim     = xr.open_dataset('data/GLORYS12V1/HOVMOLLERS/CLIMATOLOGY/glorys_coastnorth.nc')
coastsouth_glorys_clim     = xr.open_dataset('data/GLORYS12V1/HOVMOLLERS/CLIMATOLOGY/glorys_coastsouth.nc')

tropical_reforecast_clim   = xr.open_dataset('data/S2S/CLIMATOLOGY/reforecasts_tropical_clim.nc')
coastnorth_reforecast_clim = xr.open_dataset('data/S2S/CLIMATOLOGY/reforecasts_coastnorth_clim.nc')
coastsouth_reforecast_clim = xr.open_dataset('data/S2S/CLIMATOLOGY/reforecasts_coastsouth_clim.nc')

In [None]:
tropical_glorys       = xr.open_dataset('data/GLORYS12V1/HOVMOLLERS/glorys_tropical.nc').convert_calendar('noleap', dim='time')
tropical_reforecast   = xr.open_dataset('data/S2S/HOVMOLLERS/reforecasts_tropical.nc').convert_calendar('noleap', dim='inittime')

coastnorth_glorys     = xr.open_dataset('data/GLORYS12V1/HOVMOLLERS/glorys_coastnorth.nc').convert_calendar('noleap', dim='time')
coastnorth_reforecast = xr.open_dataset('data/S2S/HOVMOLLERS/reforecasts_coastnorth.nc').convert_calendar('noleap', dim='inittime')

coastsouth_glorys     = xr.open_dataset('data/GLORYS12V1/HOVMOLLERS/glorys_coastsouth.nc').convert_calendar('noleap', dim='time')
coastsouth_reforecast = xr.open_dataset('data/S2S/HOVMOLLERS/reforecasts_coastsouth.nc').convert_calendar('noleap', dim='inittime')

---

##### check forecast for some study cases

In [None]:
north_glorys = xr.concat([tropical_glorys.zos.groupby('time.dayofyear')-tropical_glorys_clim.zos,
                          coastnorth_glorys.zos.groupby('time.dayofyear')-coastnorth_glorys_clim.zos
                         ],'index').convert_calendar('gregorian')

south_glorys = xr.concat([tropical_glorys.zos.groupby('time.dayofyear')-tropical_glorys_clim.zos,
                          coastsouth_glorys.zos.groupby('time.dayofyear')-coastsouth_glorys_clim.zos
                         ],'index').convert_calendar('gregorian')

In [None]:
north_reforecast = xr.concat([tropical_reforecast.zos.groupby('inittime.dayofyear')-tropical_reforecast_clim.zos,
                          coastnorth_reforecast.zos.groupby('inittime.dayofyear')-coastnorth_reforecast_clim.zos
                         ],'index').convert_calendar('gregorian', dim='inittime')

south_reforecast = xr.concat([tropical_reforecast.zos.groupby('inittime.dayofyear')-tropical_reforecast_clim.zos,
                          coastsouth_reforecast.zos.groupby('inittime.dayofyear')-coastsouth_reforecast_clim.zos
                         ],'index').convert_calendar('gregorian', dim='inittime')

In [None]:
itime = '2008-04-20'
ftime = (pd.to_datetime(itime)+pd.Timedelta(days=46)).strftime('%F')

In [None]:
coastsouth_reforecast.swap_dims({'index':'lat'}).sel(lat=-35,method='nearest')

In [None]:
p1 = south_glorys.sel(index=1133)
p1['time'] = p1.time-pd.Timedelta(hours=12)
p2 = south_reforecast.sel(index=78).drop_duplicates('inittime')#.rename({'inittime':'time'})#.interp(inittime=p1.time.values)
p3 = []
for lead in p2.leadtime.values:
    p = p2.sel(leadtime=lead)
    p['inittime'] = p.inittime+pd.Timedelta(days=lead)
    p = p.rename({'inittime':'time'})
    p = p.interp(time=p1.time.values)
    p = dict(p.convert_calendar('noleap').groupby('time.month'))
    p3.append(p)


p1 = dict(p1.convert_calendar('noleap').groupby('time.month'))

corr = np.empty((12,46))
for dayofyear in range(12):
    for lead in range(46):
        corr[dayofyear,lead] = xr.corr(p1[dayofyear+1],p3[lead][dayofyear+1],'time')
plt.figure(figsize=(15,5))
plt.contourf(corr.T**2, vmin=0, vmax=1, cmap='nipy_spectral', levels=np.linspace(0,1,100))
plt.colorbar()


In [None]:
p1 = south_glorys.sel(index=893)
p1['time'] = p1.time-pd.Timedelta(hours=12)
p2 = south_reforecast.sel(index=66).drop_duplicates('inittime')#.rename({'inittime':'time'})#.interp(inittime=p1.time.values)
p3 = []
for lead in p2.leadtime.values:
    p = p2.sel(leadtime=lead)
    p['inittime'] = p.inittime+pd.Timedelta(days=lead)
    p = p.rename({'inittime':'time'})
    p = p.interp(time=p1.time.values)
    p = dict(p.convert_calendar('noleap').groupby('time.month'))
    p3.append(p)


p1 = dict(p1.convert_calendar('noleap').groupby('time.month'))

corr = np.empty((12,46))
for dayofyear in range(12):
    for lead in range(46):
        corr[dayofyear,lead] = xr.corr(p1[dayofyear+1],p3[lead][dayofyear+1],'time')
plt.figure(figsize=(15,5))
plt.contourf(corr.T**2, vmin=0, vmax=1, cmap='nipy_spectral', levels=np.linspace(0,1,100))
plt.colorbar()


In [None]:
p1 = south_glorys.sel(index=1073)
p1['time'] = p1.time-pd.Timedelta(hours=12)
p2 = south_reforecast.sel(index=77).drop_duplicates('inittime')#.rename({'inittime':'time'})#.interp(inittime=p1.time.values)
p3 = []
for lead in p2.leadtime.values:
    p = p2.sel(leadtime=lead)
    p['inittime'] = p.inittime+pd.Timedelta(days=lead)
    p = p.rename({'inittime':'time'})
    p = p.interp(time=p1.time.values)
    p = dict(p.convert_calendar('noleap').groupby('time.month'))
    p3.append(p)

# del p, p2
p1 = dict(p1.convert_calendar('noleap').groupby('time.month'))

corr = np.empty((12,46))
for dayofyear in range(12):
    for lead in range(46):
        corr[dayofyear,lead] = xr.corr(p1[dayofyear+1],p3[lead][dayofyear+1],'time')
plt.figure(figsize=(15,5))
plt.contourf(corr.T**2, vmin=0, vmax=1, cmap='nipy_spectral', levels=np.linspace(0,1,100))
plt.colorbar()
