#### Import pyLatte package

In [1]:
from pylatte import utils
from pylatte import skill
from pylatte import indices

#### Currently, the following packages are required to load the data - this process will be replaced by the CAFE cookbook

In [2]:
import numpy as np
import pandas as pd
import xarray as xr
import glob

#### Import some plotting packages and widgets

In [3]:
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

import warnings    
warnings.filterwarnings("ignore")

# Jupyter specific -----
from ipywidgets import FloatProgress
%matplotlib inline

# A note about the methodology of pyLatte
The pyLatte package is constructed around the xarray Python package. This is particularly useful for verification computations which require large numbers of samples (different model runs) to converge. 

The approach here is to generate very large xarray objects that reference all data required for the verification, but do not store the data in memory. Operations are performed on these xarray objects out-of-memory. When it is necessary to perform a compute (e.g. to produce a plot), this is distributed over multiple processors using the dask Python package.

# Initialise dask (currently not working on vm31)

In [4]:
# import dask
# import distributed
# client = distributed.Client(local_dir='/tmp/squ027-dask-worker-space', n_workers=4, ip='*')
# client

# Construct xarray objects for forecasts and observations
(The CAFE cookbook will replace these code blocks)

In [5]:
# Location of forecast data -----
fcst_folder = '/OSM/CBR/OA_DCFP/data/model_output/CAFE/forecasts/v1/'
fcst_filename = 'ocean_daily*'
fcst_variable = 'sst'

# Location of observation data -----
obsv_folder = '/OSM/CBR/OA_DCFP/data/observations/sst/'
obsv_filename = 'HadISST_sst.nc'
obsv_variable = 'sst'

In [6]:
# Initial dates to include (takes approximately 1 min 30 sec per date) -----
init_dates = pd.date_range('2013-4','2016-4' , freq='1MS')

# Ensembles to include -----
ensembles = range(1,12)

# Forecast length -----
FCST_LENGTH = 2 # years

In [7]:
# Resampling details -----
resample_freq = 'MS'

### Construct forecasts xarray object
Note, dask has a known bug that manifests when trying to concatentate data containing timedelta64 arrays (see https://github.com/pydata/xarray/issues/1952 for further details). For example, try to concatenate the following two Datasets:

`In : path = '/OSM/CBR/OA_DCFP/data/model_output/CAFE/forecasts/v1/yr2002/mn7/'`

`In : ens5 = xr.open_mfdataset(path + 'OUTPUT.5/atmos_daily*.nc', autoclose=True)`

`In : ens6 = xr.open_mfdataset(path + 'OUTPUT.6/atmos_daily*.nc', autoclose=True)`

`In : xr.concat([ens5, ens6],'ensemble')`

`Out : TypeError: invalid type promotion`

The error here is actually caused by the variables `average_DT` and `time_bounds`, which are timedelta64 arrays. However, I still do not fully unstand the bug: concatenation of `ens4` and `ens5`, for example, works fine, even though `ens4` also contains the timedelta64 variables `average_DT` and `time_bounds`. Regardless, because of this bug, it is not possible currently to create an xarray Dataset object containing all model variables. Instead, only the variable of interest (i.e. `fcst_variable` and `obsv_variable`) are retained in the concatenated xarray object.

In [None]:
# Instantiate progress bar -----
f = FloatProgress(min=0, max=len(init_dates)*len(ensembles), description='Loading...') 
display(f)

# Loop over initial dates -----
fcst_list = []
for init_date in init_dates:
    year = init_date.year
    month = init_date.month
    
    # Loop over ensembles -----
    ens_list = []
    for ensemble in ensembles:
        # Signal to increment the progress bar -----
        f.value += 1 
        
        
        path = fcst_folder + '/yr' + str(year) + '/mn' + str(month) + \
               '/OUTPUT.' + str(ensemble) + '/' + fcst_filename + '.nc'
            
        # xr.open_mfdataset() is slow - manually concatenate in time -----
        files = glob.glob(path)
        datasets = []
        for file in files:
            dataset = xr.open_dataset(file, autoclose=True)[fcst_variable]
            datasets.append(dataset)
        dataset = xr.concat(datasets, dim='time', coords='all').sortby('time')
        
        # Stack ensembles into a list -----
        ens_list.append(dataset.resample(time=resample_freq) \
                               .mean(dim='time'))
        
    # Concatenate ensembles -----
    ens_object = xr.concat(ens_list, dim='ensemble')
    ens_object['ensemble'] = ensembles
    
    # Stack concatenated ensembles into a list for each initial date -----                       
    fcst_list.append(utils.datetime_to_leadtime(ens_object))

# Keep track of the lead time for each initialization -----
n_lead_time = [len(x.lead_time) for x in fcst_list]

# Concatenate initial dates -----
da_fcst = xr.concat(fcst_list, dim='init_date')

# Rechunk for chunksizes of at least 1,000,000 elements -----
da_fcst = utils.prune(da_fcst.chunk(chunks={'ensemble' : len(da_fcst.ensemble), 
                                            'lead_time' : len(da_fcst.lead_time)}).squeeze())

#### Truncate the forecast lead times at 2 years
The January and July forecasts are run for 5 years rather than 2 years. The xarray concatenation above can deal with this, but fills the shorter forecasts with nans for lead times longer than 2 years. Let's get rid of some of these nans by truncating the forecasts at the lead time corresponding to the longest 2 year forecast.

In [None]:
max_increments = FCST_LENGTH * 12
n_trunc = max([i for i in n_lead_time if i <= max_increments])
da_fcst = da_fcst.isel(lead_time=range(n_trunc))

### Construct observations xarray object

In [None]:
# Instantiate progress bar -----
f = FloatProgress(min=0, max=1, description='Loading...') 
display(f)

# Dates are referenced to 0000-01-01 - xr.open_mfdataset cannot deal with this -----
year_shift = 1970 
dataset = xr.open_dataset(obsv_folder + obsv_filename, autoclose=True)[obsv_variable] \
            .rename({'latitude':'lat','longitude':'lon'})
# time_units = 'days since ' + str(year_shift) + '-01-01'

# decoded_time = xr.coding.times.decode_cf_datetime(dataset.time,time_units)
# shifted_time = np.array([np.datetime64(time - relativedelta(years=year_shift)).astype('datetime64[ns]') 
#                          for time in decoded_time])
# dataset.coords['time'] = ('time', shifted_time, {'long_name' : 'time', 'decoded_using' : time_units })

# Resample to monthly frequency -----
da_obsv_raw = dataset.resample(time=resample_freq).mean(dim='time')

# Stack by initial date to match forecast structure -----
da_obsv = utils.stack_by_init_date(da_obsv_raw,da_fcst.init_date.values,n_trunc)
f.value += 1

# Rechunk for chunksizes of at least 1,000,000 elements -----
da_obsv = utils.prune(da_obsv.chunk(chunks={'init_date' : len(da_obsv.init_date)}).squeeze())

# SOI

In [None]:
import importlib
skill = importlib.reload(skill)

### Observations
#### Get appropriate years

In [None]:
da_obsv_raw = da_obsv_raw.isel(time=range(663,708))

#### Load climatology

In [None]:
da_obsv_clim = utils.load_mean_climatology('jra_1958-2016', 'slp', freq='MS')

#### Anomalize

In [None]:
da_obsv_raw_anom = utils.anomalize(da_obsv_raw, da_obsv_clim)

In [None]:
anomalize = lambda data, clim: utils.datetime_to_leadtime(
                                   utils.anomalize(
                                       utils.leadtime_to_datetime(data),clim))
da_obsv_anom = da_obsv.groupby('init_date').apply(anomalize, clim=da_obsv_clim)

#### Compute and smooth SOI

In [None]:
soi_obsv_raw = indices.compute_soi(da_obsv_raw_anom, std_dim='time') \
                      .rolling(time=3, min_periods=3) \
                      .mean(dim='time') \
                      .compute()
            
soi_obsv = indices.compute_soi(da_obsv_anom, std_dim='lead_time') \
                  .rolling(lead_time=3, min_periods=3) \
                  .mean(dim='lead_time') \
                  .compute()

### Forecasts
#### Load climatology

In [None]:
da_fcst_clim = utils.load_mean_climatology('cafe_fcst_v1_atmos_2003-2021', 'slp', freq='MS') * 100

#### Anomalize

In [None]:
anomalize = lambda data, clim: utils.datetime_to_leadtime(
                                   utils.anomalize(
                                       utils.leadtime_to_datetime(data),clim))

In [None]:
da_fcst_anom = (100 * da_fcst).groupby('init_date').apply(anomalize, clim=da_fcst_clim)

#### Compute and smooth SOI

In [None]:
soi_fcst = indices.compute_soi(da_fcst_anom, std_dim='lead_time') \
                  .rolling(lead_time=3, min_periods=3) \
                  .mean(dim='lead_time') \
                  .compute()

### Forecasts at different lead times

In [None]:
def plot_each(data,ax,col,lw,fade=False):
    if 'ensemble' in data.dims:
        # Plot ensembles
        sp = utils.leadtime_to_datetime(data)
        ax.plot(sp['time'],sp.transpose(),color=(0.9,0.9,0.9))
        
        ts = utils.leadtime_to_datetime(data).mean(dim='ensemble')
    else:
        ts = utils.leadtime_to_datetime(data)
    x = ts['time'].values
    y = ts.values
    
    # Your colouring array
    if fade is True:
        T = np.linspace(0.5,1,np.size(x))
    else:
        T = np.linspace(0,1,np.size(x))

    # Segement plot and colour depending on T
    s = 1 # Segment length
    for i in range(0,len(x)-s,s):
        if col == 'b':
            ax.plot(x[i:i+s+1], y[i:i+s+1], color=(T[i],T[i],1), linewidth=lw)
        elif col == 'k':
            ax.plot(x[i:i+s+1], y[i:i+s+1], color=(T[i],T[i],T[i]), linewidth=lw)
        elif col == 'r':
            ax.plot(x[i:i+s+1], y[i:i+s+1], color=(1,T[i],T[i]), linewidth=lw)
        elif col == 'g':
            ax.plot(x[i:i+s+1], y[i:i+s+1], color=(T[i],1,T[i]), linewidth=lw)
    return ts

fig1 = plt.figure(figsize=(10,5))

ax = fig1.add_axes([0.1, 0.1, 0.8, 0.8])
ax.grid()

lines = soi_obsv.sel(init_date=slice('2013-04', '2016-04')) \
                .groupby('init_date') \
                .apply(plot_each,ax=ax,col='r',lw=1,fade=True)
ax.plot(soi_obsv_raw['time'],soi_obsv_raw,'r-',linewidth=3)

lines = soi_fcst.sel(init_date='2015-11') \
                .groupby('init_date') \
                .apply(plot_each,ax=ax,col='b',lw=3)
        
ax.set_xlabel('date')
ax.set_ylabel('SOI');

### Skill at different lead times

In [None]:
corrcoef_fcst = skill.compute_Pearson_corrcoef(soi_fcst, soi_obsv, 
                                               over_dims=['lead_time'],subtract_local_mean=False) \
                                               .mean('ensemble')
    
rms_error = skill.compute_rms_error(soi_fcst, soi_obsv, over_dims=['ensemble','lead_time'])

In [None]:
fig1 = plt.figure(figsize=(10,5))

ax1 = fig1.add_axes([0.1, 0.9, 0.8, 0.4])
ax1.grid()

ax1.plot(soi_obsv_raw['time'],soi_obsv_raw,'r-',linewidth=3)
#ax1.set_xlabel('date')
ax1.set_ylabel('SOI');
ax1.set_xlim(left='2013-06',right='2016-04')

ax2 = fig1.add_axes([0.1, 0.1, 0.8, 0.7])
ax2.grid()
ax2.plot(corrcoef_fcst['init_date'],corrcoef_fcst,'k-',linewidth=3)
ax2.set_xlabel('date')
ax2.set_xlim(left='2013-06',right='2016-04')

ax2.plot(rms_error['init_date'],rms_error,'b-',linewidth=3)
ax2.set_xlabel('date')
ax2.set_xlim(left='2013-06',right='2016-04');

ax2.legend();

# Close dask client

In [None]:
# with utils.timer():
#     client.close()