# Tutorial demonstrating verification of v1 1000hPa temp against jra55

#### Import pyLatte package

In [1]:
from pylatte import utils
from pylatte import skill

#### Currently, the following packages are required to load the data - this process will be replaced by the CAFE cookbook

In [2]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
# import warnings    
# warnings.filterwarnings("ignore")
from ipywidgets import FloatProgress

# Jupyter specific -----
%matplotlib inline

# A note about the methodology of pyLatte
The pyLatte package is constructed around the xarray Python package. This is particularly useful for verification computations which require large numbers of samples (different model runs) to converge. 

The approach here is to generate very large xarray objects that reference all data required for the verification, but do not store the data in memory. Operations are performed on these xarray objects out-of-memory. When it is necessary to perform a compute (e.g. to produce a plot), this is distributed over multiple processors using the dask Python package.

### Important:
To simplify comparison between different data-sets, the functions in pyLatte assume that the datetime value given for a particular interval corresponds to the START of that interval. E.g., for monthly frequency data `time = ['2002-01-01, 2002-02-01, 2002-03-01,...]`. The function `utils.trunc_time()` can be used to truncate timedeltas to the start value of a period corresponding to a provided frequency.

# Initialise dask (currently not working on vm31)

In [3]:
# import dask
# import distributed
# client = distributed.Client(local_dir='/tmp/squ027-dask-worker-space', n_workers=4)
# client

# Construct xarray objects for forecasts and observations

(The CAFE cookbook should replace many of these code blocks)

At the moment, the user must specify the frequency of the data they are choosing to load. This is necessary for keeping track of the correct units of lead time when time information is formatted as lead time/initial date rather than datatime. `pandas` provides a function `pandas.infer_freq()` for determining the most likely frequency from a datetime array. However, this function is fairly limited because of the difficulties in defining '1 month'. For example, `pandas.infer_freq()` is unable to determine the frequency of the ocean_month data, which has time values, e.g., `time = ['2003-01-16T12:00:00', '2003-02-15T00:00:00', '2003-03-16T12:00:00'...]`. Thus, for now, we leave it up to the user to correctly input the data frequency

In [95]:
# Location of forecast data -----
fcst_folder = '/OSM/CBR/OA_DCFP/data/model_output/CAFE/forecasts/v1/'
fcst_filename = 'atmos_daily*'
fcst_variable = 'temp'
fcst_freq = 'D' # e.g. 'A', '3M', 'M', '7D', '6H'...

# Location of observation data -----
obsv_folder = '/OSM/CBR/OA_DCFP/data/observations/jra55/isobaric/011_tmp/cat/'
obsv_filename = 'jra.55.tmp.1000.1958010100_2016123118.nc'
obsv_variable = 'TMP_GDS0_ISBL'
obsv_freq = 'D'

In [88]:
# Initial dates to include (takes approximately 1 min 30 sec per date) -----
init_dates = pd.date_range('2003-1','2003-12' , freq='1MS')

# Ensembles to include -----
ensembles = range(1,12)

# Forecast length -----
FCST_LENGTH = 2 # years
# lead_times = utils.get_lead_times(FCST_LENGTH, resample_freq)

In [89]:
# Resampling details -----
resample_freq = 'M'
resample_method = 'mean'

### Construct forecasts xarray object
Note, dask has a known bug that manifests when trying to concatentate data containing timedelta64 arrays (see https://github.com/pydata/xarray/issues/1952 for further details). For example, try to concatenate the following two Datasets:

`In : path = '/OSM/CBR/OA_DCFP/data/model_output/CAFE/forecasts/v1/yr2002/mn7/'`

`In : ens5 = xr.open_mfdataset(path + 'OUTPUT.5/atmos_daily*.nc', autoclose=True)`

`In : ens6 = xr.open_mfdataset(path + 'OUTPUT.6/atmos_daily*.nc', autoclose=True)`

`In : xr.concat([ens5, ens6],'ensemble')`

`Out : TypeError: invalid type promotion`

The error here is actually caused by the variables `average_DT` and `time_bounds`, which are timedelta64 arrays. However, I still do not fully unstand the bug: concatenation of `ens4` and `ens5`, for example, works fine, even though `ens4` also contains the timedelta64 variables `average_DT` and `time_bounds`. Regardless, because of this bug, it is not possible currently to create an xarray Dataset object containing all model variables. Instead, only the variable of interest (i.e. `fcst_variable` and `obsv_variable`) are retained in the concatenated xarray object.

In [90]:
# Instantiate progress bar -----
f = FloatProgress(min=0, max=len(init_dates)*len(ensembles), description='Loading...') 
display(f)

# Loop over initial dates -----
fcst_list = []
for init_date in init_dates:
    year = init_date.year
    month = init_date.month
    
    # Loop over ensembles -----
    ens_list = []
    for ensemble in ensembles:
        # Signal to increment the progress bar -----
        f.value += 1 
        
        # Stack ensembles into a list -----
        path = fcst_folder + '/yr' + str(year) + '/mn' + str(month) + \
               '/OUTPUT.' + str(ensemble) + '/' + fcst_filename + '.nc'
        ens_list.append(xr.open_mfdataset(path, autoclose=True)[fcst_variable])
        
    # Concatenate ensembles -----
    ens_object = xr.concat(ens_list, dim='ensemble')
    ens_object['ensemble'] = ensembles
    
    # Stack concatenated ensembles into a list for each initial date -----
    fcst_list.append(ens_object)

# Keep track of the lead time for each initialization -----
n_lead_time = [len(x.time) for x in fcst_list]

# Keep track of initial dates truncated to freqeuncy of runs -----
start_dates = np.array([utils.trunc_time(x.time.values[0], fcst_freq) for x in fcst_list])

# Concatenate initial dates -----
da_fcst = xr.concat(fcst_list, dim='init_date').rename({'time' : 'lead_time'})
da_fcst['init_date'] = start_dates
da_fcst['lead_time'] = range(len(da_fcst.lead_time))
da_fcst['lead_time'].attrs['units'] = fcst_freq

# Rechunk for chunksizes of at least 1,000,000 elements -----
da_fcst = utils.prune(da_fcst.chunk(chunks={'ensemble' : len(da_fcst.ensemble), 
                                            'lead_time' : len(da_fcst.lead_time)}).squeeze())

#### Truncate the forecast lead times at 2 years
The January and July forecasts are run for 5 years rather than 2 years. The xarray concatenation above can deal with this, but fills the shorter forecasts with nans for lead times longer than 2 years. Let's get rid of some of these nans by truncating the forecasts at the lead time corresponding to the longest 2 year forecast.

In [91]:
max_increments = FCST_LENGTH * 366
n_trunc = max([i for i in n_lead_time if i <= max_increments])
da_fcst = da_fcst.isel(lead_time=range(n_trunc))

### Construct observations xarray object

In [96]:
# Instantiate progress bar -----
f = FloatProgress(min=0, max=1, description='Loading...') 
display(f)

# JRA temperature fields are only save in a time-concatenated form -----
path = obsv_folder + obsv_filename
da_obsv = (xr.open_mfdataset(path, autoclose=True)[obsv_variable]).rename(fcst_variable) \
                                                               .rename({'initial_time0_hours' : 'time',
                                                                        'g0_lon_3' : 'lon',
                                                                        'g0_lat_2' : 'lat'})
    
# Truncate jra frequency to forecast frequency -----
da_obsv['time'] = utils.trunc_time(da_obsv['time'].values, da_fcst.lead_time.attrs['units'])

# Stack by initial date to match forecast structure -----
da_obsv = utils.stack_by_init_date(da_obsv,da_fcst.init_date.values,n_trunc)
f.value += 1

# Average over forecast dimension if it is exists -----
if 'forecast_time1' in da_obsv.coords:
    da_obsv[obsv_variable] = da_obsv[obsv_variable].mean(dim='forecast_time1')

# Rechunk for chunksizes of at least 1,000,000 elements -----
da_obsv = utils.prune(da_obsv.chunk(chunks={'init_date' : len(da_obsv.init_date)}).squeeze())

### Resample forecast and observations to desired frequency
The data are currently stored in lead time/initial date format. The easiest and fastest way to perform the resampling is to to use `xr.resample()` which requires that time information be in datetime format (note, this process in wrapped in the pyLatte `utils.resample()` function). Thus it is necessary to first convert from the lead time/initial date format to a datetime format, then resample the data, then convert back to the lead time/initial date format. The `utils.leadtime_to_datetime()` and `utils.datetime_to_leadtime()` functions enable these types of operation

In [97]:
def month_delta(date_in, delta, trunc_to_start=False):
    """ Increments provided datetime64 array by delta months """
    
    date_mod = pd.Timestamp(date_in)
    
    m, y = (date_mod.month + delta) % 12, date_mod.year + ((date_mod.month) + delta - 1) // 12
    if not m: m = 12
    d = min(date_mod.day, [31,
        29 if y % 4 == 0 and not y % 400 == 0 else 28,31,30,31,30,31,31,30,31,30,31][m - 1])
    
    if trunc_to_start:
        date_out = utils.trunc_time(np.datetime64(date_mod.replace(day=d,month=m, year=y)),'M')
    else:
        date_out = np.datetime64(date_mod.replace(day=d,month=m, year=y))
    return date_out

In [98]:
def year_delta(date_in, delta, trunc_to_start=False):
    """ Increments provided datetime64 array by delta years """
    
    date_mod = month_delta(date_in, 12 * delta)
    
    if trunc_to_start:
        date_out = utils.trunc_time(date_mod,'Y')
    else: date_out = date_mod
        
    return date_out

In [103]:
def leadtime_to_datetime(data, lead_time_name='lead_time', init_date_name='init_date'):
    """ Converts time information from lead time/initial date dimension pair to single datetime dimension """
    
    init_date = data[init_date_name].values[0]
    lead_times = list(map(int,data[lead_time_name].values))
    freq = data[lead_time_name].attrs['units']

    # Deal with special cases of monthly and yearly frequencies -----
    if 'M' in freq:
        # Check if multiple months are specified
        if len(freq) > 1:
            num_months = int(freq.replace("M", ""))
        else: num_months = 1
            
        datetimes = np.array([month_delta(init_date, num_months * ix) for ix in lead_times])
    elif 'A' in freq:
        # Check if multiple months are specified
        if len(freq) > 1:
            num_years = int(freq.replace("A", ""))
        else: num_years = 1
            
        datetimes = np.array([year_delta(init_date, num_years * ix) for ix in lead_times])
    else:
        datetimes = (pd.date_range(init_date, periods=len(lead_times), freq=freq)).values
    
    data = data.drop(init_date_name)
    data = data.rename({lead_time_name : 'time'})
    data['time'] = datetimes
    
    return utils.prune(data)

In [335]:
temp = da_fcst.isel(init_date=[0],lead_time=range(0,731,2))
temp.lead_time.attrs['units'] = '2D'

data = leadtime_to_datetime(temp)
resample_freq='5D'
how='sum'
input_freq = None

#def downsample_complete(data, resample_freq, how, input_freq=None):

dates = data.time.values

# Try to infer input frequency -----
if input_freq == None:
    input_freq = utils.infer_freq(dates)
if input_freq == None:
    raise ValueError('Unable to infer input frequency. Please specify this explicity.')

# Split frequencies into numbers and strings -----
incr_string = ''.join([i for i in resample_freq if i.isdigit()])
resample_incr = [int(incr_string) if incr_string else 1][0]
resample_type = ''.join([i for i in resample_freq if not i.isdigit()])

incr_string = ''.join([i for i in input_freq if  i.isdigit()])
input_incr = [int(incr_string) if incr_string else 1][0]
input_type = ''.join([i for i in input_freq if not i.isdigit()])

# Construct dummy date array to determine complete number of increments in each resample bin -----
if 'M' in resample_type: # Deal with special case of months
    start = month_delta(dates[0],-resample_incr)
    end = month_delta(dates[-1],resample_incr)
    
    # Ensure dummy_dates align with dates in overlap region -----
    left_chunk = (pd.date_range(start, dates[0], freq = input_freq)).values
    left_chunk_aligned = left_chunk + (dates[0] - left_chunk[-1])
    right_chunk_aligned = (pd.date_range(dates[-1], end, freq = input_freq)).values
    dummy_dates = np.concatenate([left_chunk_aligned, dates[1:-1], right_chunk_aligned])
    
elif ('A' in resample_type) | ('Y' in resample_type): # Deal with special case of years
    start = year_delta(dates[0],-resample_incr)
    end = year_delta(dates[-1],resample_incr)
    
    # Ensure dummy_dates align with dates in overlap region -----
    left_chunk = (pd.date_range(start, dates[0], freq = input_freq)).values
    left_chunk_aligned = left_chunk + (dates[0] - left_chunk[-1])
    right_chunk_aligned = (pd.date_range(dates[-1], end, freq = input_freq)).values
    dummy_dates = np.concatenate([left_chunk_aligned, dates[1:-1], right_chunk_aligned])
    
else:
    start = dates[0] - pd.Timedelta(resample_incr, unit = resample_type)
    end = dates[-1] + pd.Timedelta(resample_incr, unit = resample_type)
    
    # Ensure dummy_dates align with dates in overlap region -----
    left_chunk = (pd.date_range(start, dates[0], freq = input_freq)).values
    left_chunk_aligned = left_chunk + (dates[0] - left_chunk[-1])
    right_chunk_aligned = (pd.date_range(dates[-1], end, freq = input_freq)).values
    dummy_dates = np.concatenate([left_chunk_aligned, dates[1:-1], right_chunk_aligned])
    print(dates)
    print(dummy_dates)
    
# Package dummy date array as xarray object and resample -----
dummy = xr.DataArray(np.zeros(dummy_dates.shape), coords=[dummy_dates], dims='time')
dummy_sampled = dummy.resample(time=resample_freq)
data_sampled = data.resample(time=resample_freq)

# Find and compare number of increments in each dummy bin and data bin -----
dummy_incr = [len(dummy_bin.time) for name, dummy_bin in dummy_sampled][1:-1]
data_incr = [len(data_bin.time) for name, data_bin in data_sampled]
data_bins = [name for name, data_bin in data_sampled]
keep = [dum == dat for (dum, dat) in zip(dummy_incr, data_incr)]
print(dummy_incr)
print(data_incr)

# Perform resampling according to specified method -----
if how == 'mean':
    data_resampled = data_sampled.mean(dim='time',keep_attrs=True)
elif how == 'sum':
    data_resampled = data_sampled.sum(dim='time',keep_attrs=True)
else:
    raise ValueError(f'Unrecognised "how" method: {how}. Please feel free to add methods.')

# Strangely, xarray.resample().how() adds an additional time step to the beginning or end of 
# the data (depending on whether the resample frequency is a start or end frequency) when the 
# time interval of the data being resampled is wholly divisible by the resampling frequency.
# Data at this time step are all nans. Let's make sure we only keep output time steps that 
# exist in the xarray.core.resample.DataArrayResample object -----
data_resampled = data_resampled.sel(time=slice(str(data_bins[0]), str(data_bins[-1])))

# Only keep resample bins that are complete -----
data_resampled = data_resampled.sel(time=data_resampled.time[keep])

print(data_resampled)
print(keep)



['2003-01-01T00:00:00.000000000' '2003-01-03T00:00:00.000000000'
 '2003-01-05T00:00:00.000000000' '2003-01-07T00:00:00.000000000'
 '2003-01-09T00:00:00.000000000' '2003-01-11T00:00:00.000000000'
 '2003-01-13T00:00:00.000000000' '2003-01-15T00:00:00.000000000'
 '2003-01-17T00:00:00.000000000' '2003-01-19T00:00:00.000000000'
 '2003-01-21T00:00:00.000000000' '2003-01-23T00:00:00.000000000'
 '2003-01-25T00:00:00.000000000' '2003-01-27T00:00:00.000000000'
 '2003-01-29T00:00:00.000000000' '2003-01-31T00:00:00.000000000'
 '2003-02-02T00:00:00.000000000' '2003-02-04T00:00:00.000000000'
 '2003-02-06T00:00:00.000000000' '2003-02-08T00:00:00.000000000'
 '2003-02-10T00:00:00.000000000' '2003-02-12T00:00:00.000000000'
 '2003-02-14T00:00:00.000000000' '2003-02-16T00:00:00.000000000'
 '2003-02-18T00:00:00.000000000' '2003-02-20T00:00:00.000000000'
 '2003-02-22T00:00:00.000000000' '2003-02-24T00:00:00.000000000'
 '2003-02-26T00:00:00.000000000' '2003-02-28T00:00:00.000000000'
 '2003-03-02T00:00:00.000

  label=label, base=base)


[2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3]
[3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 1]


IndexError: Boolean array size 146 is used to index array with shape (147,).

In [247]:
first = data_resampled.isel(time=[0], pfull = [-1], ensemble = [0]).squeeze().compute()
last = data_resampled.isel(time=[12], pfull = [-1], ensemble = [0]).squeeze().compute()


In [250]:
last

<xarray.DataArray 'temp' (lat: 90, lon: 144)>
array([[ nan,  nan,  nan, ...,  nan,  nan,  nan],
       [ nan,  nan,  nan, ...,  nan,  nan,  nan],
       [ nan,  nan,  nan, ...,  nan,  nan,  nan],
       ..., 
       [ nan,  nan,  nan, ...,  nan,  nan,  nan],
       [ nan,  nan,  nan, ...,  nan,  nan,  nan],
       [ nan,  nan,  nan, ...,  nan,  nan,  nan]])
Coordinates:
    time      datetime64[ns] 2005-01-01
  * lon       (lon) float64 1.25 3.75 6.25 8.75 11.25 13.75 16.25 18.75 ...
  * lat       (lat) float64 -89.49 -87.98 -85.96 -83.93 -81.91 -79.89 -77.87 ...
    pfull     float64 996.1
    ensemble  int64 1
Attributes:
    long_name:      temperature
    units:          deg_K
    valid_range:    [ 100.  350.]
    cell_methods:   time: mean
    time_avg_info:  average_T1,average_T2,average_DT

In [321]:
pd.date_range(start='1/1/2000', end='2/1/2000', freq = '2D', closed = 'left').values

array(['2000-01-01T00:00:00.000000000', '2000-01-03T00:00:00.000000000',
       '2000-01-05T00:00:00.000000000', '2000-01-07T00:00:00.000000000',
       '2000-01-09T00:00:00.000000000', '2000-01-11T00:00:00.000000000',
       '2000-01-13T00:00:00.000000000', '2000-01-15T00:00:00.000000000',
       '2000-01-17T00:00:00.000000000', '2000-01-19T00:00:00.000000000',
       '2000-01-21T00:00:00.000000000', '2000-01-23T00:00:00.000000000',
       '2000-01-25T00:00:00.000000000', '2000-01-27T00:00:00.000000000',
       '2000-01-29T00:00:00.000000000', '2000-01-31T00:00:00.000000000'], dtype='datetime64[ns]')

In [None]:
resamp = lambda data, freq, how: utils.datetime_to_leadtime(
                                   resample(
                                       utils.leadtime_to_datetime(data), freq, how))

# Skill metrics for probabilistic forecasts

#### E.g. for temperature at 1000hPa averaged over Australia

In [None]:
with utils.timer():
    # Region of interest -----
    region = (-38.0, -11.0, 113.0 , 153.0) # (lat_min,lat_max,lon_min,lon_max)

    da_fcst = utils.calc_boxavg_latlon(ds_fcst['temp']
                                       .sel(pfull=1000,method='nearest',drop=True), region).compute()-273.15
    da_obsv = utils.calc_boxavg_latlon(ds_obsv['temp']
                                       .squeeze().drop('lv_ISBL1'), region).compute()-273.15

### Plot one initialization date

In [None]:
da_fcst

In [None]:
da_obsv

In [None]:
fig1 = plt.figure(figsize=(10,5))

ax = fig1.add_axes([0.1, 0.1, 0.8, 0.8])
ax.grid()
ax.plot(da_fcst['lead_time'],da_fcst.isel(init_date=[0]).squeeze())
ax.plot(da_obsv['lead_time'],da_obsv.isel(init_date=[0]).squeeze(),'k-',linewidth=2)
ax.set_xlabel('lead time')
ax.set_ylabel('average temp [K]');

## Rank histogram

#### Rank the data and compute histograms as a function of lead time

In [None]:
with utils.timer():
    rank_histogram = skill.compute_rank_histogram(da_fcst, da_obsv, indep_dims='init_date')

In [None]:
with utils.timer():
    ncol = 4; nrow = int(np.ceil(len(lead_times)/ncol));
    fig, axs = plt.subplots(figsize=(15,15), nrows=nrow, ncols=ncol);

    for idx,ax in enumerate(axs.reshape(-1)): 
        ax.grid()
        ax.bar(rank_histogram.bins,rank_histogram.isel(lead_time=idx, drop=True))
        ax.set_ylim(0,rank_histogram.max())
        ax.text(10.3,0.85*rank_histogram.max(),'mn '+str(idx+1))

        if idx % ncol == 0:
            ax.set_ylabel('count')

        if idx / ncol >= nrow - 1:
            ax.set_xlabel('bins')

#### Rank the data and compute histograms for all lead times

In [None]:
with utils.timer():
    rank_histogram = skill.compute_rank_histogram(da_fcst, da_obsv, indep_dims=('init_date','lead_time'))

In [None]:
with utils.timer():
    fig1 = plt.figure(figsize=(8,3))

    ax1 = fig1.add_axes([0.1, 0.1, 0.8, 0.8])
    ax1.grid()
    ax1.bar(rank_histogram.bins,rank_histogram)
    ax1.set_xlabel('bins')
    ax1.set_ylabel('count');

## (Continuous) ranked probability score

In [None]:
with utils.timer():
    # Specify bins for computation of cdf -----
    bins = np.linspace(0,40,100)

    # Compute ranked probability score -----
    rps = skill.compute_rps(da_fcst, da_obsv, bins=bins, indep_dims='init_date', ensemble_dim='ensemble')

In [None]:
with utils.timer():
    fig1 = plt.figure(figsize=(8,4))

    ax = fig1.add_axes([0.1, 0.1, 0.8, 0.8])
    ax.grid()
    ax.plot(rps['lead_time'],rps,linewidth=2)
    ax.set_xlabel('Lead time [months]')
    ax.set_ylabel('Ranked probability score');

# Skill metrics for continuous variables

## Additive bias error

In [None]:
with utils.timer():
    mean_additive_bias = skill.compute_mean_additive_bias(da_fcst, da_obsv, 
                                                          indep_dims='init_date', ensemble_dim='ensemble')

## Multiplicative bias error

In [None]:
with utils.timer():
    mean_multiplicative_bias = skill.compute_mean_multiplicative_bias(da_fcst, da_obsv, 
                                                                      indep_dims='init_date', ensemble_dim='ensemble')

## Mean absolute error

In [None]:
with utils.timer():
    mean_absolute_error = skill.compute_mean_absolute_error(da_fcst, da_obsv, 
                                                            indep_dims='init_date', ensemble_dim='ensemble')

## Mean squared error

In [None]:
with utils.timer():
    mean_squared_error = skill.compute_mean_squared_error(da_fcst, da_obsv, 
                                                          indep_dims='init_date', ensemble_dim='ensemble')

## Root mean squared error

In [None]:
with utils.timer():
    rms_error = skill.compute_rms_error(da_fcst, da_obsv, 
                                        indep_dims='init_date', ensemble_dim='ensemble')

#### Plot as a function of lead time

In [None]:
with utils.timer():
    fig1 = plt.figure(figsize=(8,4))

    ax = fig1.add_axes([0.1, 0.1, 0.8, 0.8])
    ax.grid()
    ax.plot(mean_additive_bias['lead_time'],mean_additive_bias,linewidth=2)
    ax.plot(mean_multiplicative_bias['lead_time'],mean_multiplicative_bias,linewidth=2)
    ax.plot(mean_absolute_error['lead_time'],mean_absolute_error,linewidth=2)
    # ax.plot(mean_squared_error['lead_time'],mean_squared_error,linewidth=2)
    ax.plot(rms_error['lead_time'],rms_error,linewidth=2)
    ax.set_xlabel('Lead time [months]')
    ax.set_ylabel('Error');
    ax.legend();

# Close dask client

In [None]:
# with utils.timer():
#     client.close()