# Tutorial demonstrating verification of v1 1000hPa temp against jra55

#### Import pyLatte package

In [1]:
from pylatte import utils
from pylatte import skill

#### Currently, the following packages are required to load the data - this process will be replaced by the CAFE cookbook

In [34]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import warnings    
warnings.filterwarnings("ignore")
from ipywidgets import FloatProgress

# Jupyter specific -----
%matplotlib inline

#### The pyLatte package is constructed around the xarray Python package. This is particularly useful for verification computations which require large numbers of samples (different model runs) to converge. 

#### The approach here is to generate very large xarray objects that reference all data required for the verification, but do not store the data in memory. Operations are performed on these xarray objects out-of-memory. When it is necessary to perform a compute (e.g. to produce a plot), this is distributed over multiple processors using the dask Python package.

# Initialise dask (currently not working on vm31)

In [3]:
# import dask
# import distributed
# client = distributed.Client(local_dir='/tmp/squ027-dask-worker-space', n_workers=4)
# client

# Construct xarray objects for forecasts and observations

#### (The CAFE cookbook will replace these code blocks)

In [19]:
# Resampling details -----
resample_freq = '1MS' # Must be '1MS' for monthly
resample_method = 'mean'

# Location of forecast data -----
fcst_folder = '/OSM/CBR/OA_DCFP/data/model_output/CAFE/forecasts/v1/'
fcst_filename = 'atmos_daily*'
fcst_variable = 'temp'

# Location of observation data -----
obsv_folder = '/OSM/CBR/OA_DCFP/data/observations/jra55/isobaric/011_tmp/cat/'
obsv_filename = 'jra.55.tmp.1000.1958010100_2016123118.nc'
obsv_variable = 'TMP_GDS0_ISBL'

In [72]:
# Initialization dates (takes approximately 1 min 30 sec per date) -----
init_dates = pd.date_range('2/2002','7/2002' , freq='1MS')  # Must be '1MS' for monthly

# Ensembles to include -----
ensembles = range(1,12)

# Forecast length -----
FCST_LENGTH = 2 # years
no_leap = 2001
n_incr = len(pd.date_range('1/1/' + str(no_leap),
                           '12/1/' + str(no_leap+FCST_LENGTH-1),
                           freq=resample_freq)) # number of lead_time increments
lead_times = range(1,n_incr+1)

### Construct forecasts xarray object
Note, dask has a known bug that manifests when trying to concatentate data containing timedelta64 arrays (see https://github.com/pydata/xarray/issues/1952 for further details). For example, try to concatenate the following two Datasets:

`In : path = '/OSM/CBR/OA_DCFP/data/model_output/CAFE/forecasts/v1/yr2002/mn7/'`

`In : ens5 = xr.open_mfdataset(path + 'OUTPUT.5/atmos_daily*.nc', autoclose=True)`

`In : ens6 = xr.open_mfdataset(path + 'OUTPUT.6/atmos_daily*.nc', autoclose=True)`

`In : xr.concat([ens5, ens6],'ensemble')`

`Out : TypeError: invalid type promotion`

The error here is actually caused by the variables `average_DT` and `time_bounds`, which are timedelta64 arrays. However, I still do not fully unstand the bug: concatenation of `ens4` and `ens5`, for example, works fine, even though `ens4` also contains the timedelta64 variables `average_DT` and `time_bounds`. Regardless, because of this bug, it is not possible currently to create an xarray Dataset object containing all model variables. Instead, only the variable of interest (i.e. `fcst_variable` and `obsv_variable`) are retained in the concatenated xarray object.

In [79]:
# Instantiate progress bar -----
f = FloatProgress(min=0, max=len(init_dates)*len(ensembles), description='Loading...') 
display(f)

# Loop over initial dates -----
fcst_list = []
for init_date in init_dates:
    year = init_date.year
    month = init_date.month
    
    # Loop over ensembles -----
    ens_list = []
    for ensemble in ensembles:
        # Signal to increment the progress bar -----
        f.value += 1 
        
        # Stack ensembles into a list -----
        path = fcst_folder + '/yr' + str(year) + '/mn' + str(month) + \
               '/OUTPUT.' + str(ensemble) + '/' + fcst_filename + '.nc'
        ens_list.append(xr.open_mfdataset(path, autoclose=True)[fcst_variable])
        
    # Concatenate ensembles -----
    ens_object = xr.concat(ens_list, dim='ensemble')
    ens_object['ensemble'] = ensembles
    
    # Stack concatenated ensembles into a list for each initial date -----
    fcst_list.append(ens_object)

# Keep track of the lead time for each initialization -----
n_lead_time = [len(x.time) for x in fcst_list]

# Concatenate initial dates -----
da_fcst = xr.concat(fcst_list, dim='init_date').rename({'time' : 'lead_time'})
da_fcst['init_date'] = init_dates
freq = pd.infer_freq(da_fcst.lead_time.values)
da_fcst['lead_time'] = range(len(da_fcst.lead_time))
da_fcst['lead_time'].attrs['units'] = freq

# Rechunk for chunksizes of at least 1,000,000 elements -----
da_fcst = da_fcst.chunk(chunks={'ensemble' : len(da_fcst.ensemble), 'lead_time' : len(da_fcst.lead_time)})

#### Truncate the forecast lead times at 2 years
The January and July forecasts are run for 5 years rather than 2 years. The xarray concatenation above can deal with this, but fills the shorter forecasts with nans for lead times longer than 2 years. Let's get rid of some of these nans.

In [87]:
n_trunc = max([i for i in n_lead_time if i < 2000])
da_fcst = da_fcst.isel(lead_time=range(n_trunc))

### Construct observations xarray object

In [90]:
def stack_by_init_date(data_in, init_dates, N_lead_steps, init_date_name='init_date', lead_time_name='lead_time'):
    """ 
    Splits provided data array into n chunks beginning at time=init_date[n] and spanning 
    N_lead_steps time increments
    Input Dataset/DataArray must span full range of times required for this operation
    """

    init_list = []
    for init_date in init_dates:
        print(data_in.time)
        print(np.datetime64(init_date))
        start_index = np.where(data_in.time == np.datetime64(init_date))[0].item()
        init_list.append(
                      datetime_to_leadtime(
                          data_in.isel(time=range(start_index, start_index + N_lead_steps))))
    
    data_out = xr.concat(init_list, dim='ensemble')
    
    # # Initialize xarray object for first initialization date -----
    # start_index = np.where(data_in.time == np.datetime64(init_dates[0]))[0].item()
    # data_out = data_in.isel(time=range(start_index, start_index + N_lead_steps))
    # data_out = datetime_to_leadtime(data_out).expand_dims(init_date_name)
    
    # # Loop over remaining initialization dates -----
    # for init_date in init_dates[1:]:
    #     start_index = np.where(data_in.time == np.datetime64(init_date))[0].item()
    #     data_temp = data_in.isel(time=range(start_index, start_index + N_lead_steps))

        # # Concatenate along initialization date dimension/coordinate -----
        # data_temp = datetime_to_leadtime(data_temp)
        # data_out = xr.concat([data_out, data_temp],init_date_name) 
    
    return data_out

In [91]:
# Instantiate progress bar -----
f = FloatProgress(min=0, max=1, description='Loading...') 
display(f)

# JRA temperature fields are only save in a time-concatenated form -----
path = obsv_folder + obsv_filename
da_obsv = (xr.open_mfdataset(path, autoclose=True)[obsv_variable]).rename(fcst_variable) \
                                                               .rename({'initial_time0_hours' : 'time',
                                                                        'g0_lon_3' : 'lon',
                                                                        'g0_lat_2' : 'lat'})
f.value += 1 

# 

da_obsv = stack_by_init_date(da_obsv,init_dates,n_trunc)

<xarray.DataArray 'time' (time: 21550)>
array(['1958-01-01T09:00:00.000000000', '1958-01-02T09:00:00.000000000',
       '1958-01-03T09:00:00.000000000', ..., '2016-12-29T09:00:00.000000000',
       '2016-12-30T09:00:00.000000000', '2016-12-31T09:00:00.000000000'], dtype='datetime64[ns]')
Coordinates:
  * time     (time) datetime64[ns] 1958-01-01T09:00:00 1958-01-02T09:00:00 ...
Attributes:
    standard_name:  time
    long_name:      initial time
    bounds:         initial_time0_hours_bnds
    axis:           T
2002-02-01T00:00:00.000000


ValueError: can only convert an array of size 1 to a Python scalar

In [16]:
# Compute forecast years -----
fcst_years = [pd.to_datetime(init_dates).year]
fcst_year_min = min(fcst_years)[0]
fcst_year_max = max(fcst_years)[-1] + FCST_LENGTH

# Loop over all years in forecasts -----
year_list = []
for year_to_load in range(fcst_year_min+1,fcst_year_max+1):
    path = obsv_folder + obsv_filename + str(year_to_load) + '*.nc'
    print(path)
    year_list.append(xr.open_mfdataset(path, autoclose=True)[obsv_variable])

# Concatenate years -----
ens_object = xr.concat(year_list)

/OSM/CBR/OA_DCFP/data/observations/jra55/isobaric/011_tmp/cat/jra.55.tmp.1000.1958010100_2016123118.nc2003*.nc


OSError: no files to open

In [12]:

    
ds_obsv_all = xr.open_mfdataset(obsv_folder + obsv_filename, autoclose=True)[obsv_variable]

ds_obsv_all = xr.open_mfdataset(obsv_folder + filename + str(fcst_year_min) + '*', 
                                 autoclose=True)
    for year_to_load in range(fcst_year_min+1,fcst_year_max+1):
        ds_temp2 = xr.open_mfdataset(obsv_folder + filename + str(year_to_load) + '*', 
                                     autoclose=True)
        ds_jra = xr.concat([ds_jra, ds_temp2],'initial_time0_hours')

    # Standardize naming -----
    ds_jra = ds_jra.rename({'initial_time0_hours':'time',
                                      'g0_lon_3':'lon',
                                      'g0_lat_2':'lat',
                                      'TPRAT_GDS0_SFC_ave3h':'precip'})

KeyError: 'temp'

In [None]:
ds_obsv_all = xr.open_mfdataset(obsv_folder + obsv_filename, autoclose=True)[obsv_variable]

# Standardize naming -----
ds_temp1 = ds_temp1.rename({'initial_time0_hours':'time',
                                  'g0_lon_3':'lon',
                                  'g0_lat_2':'lat',
                                  obsv_variable : fcst_variable})

# Resample to desired frequency -----
ds_temp1 = ds_temp1.resample(freq=resample_freq, dim='time', how=resample_method)

# ===============================================
# Stack to resemble ds_forecast coordinates -----
# ===============================================
# Initialize xarray object for first lead_time -----
start_index = np.where(ds_temp1.time == np.datetime64(init_dates[0]))[0].item()
ds_obsv = ds_temp1.isel(time=range(start_index, start_index+len(lead_times)))
ds_obsv.coords['init_date'] = init_dates[0]
ds_obsv = ds_obsv.expand_dims('init_date')
ds_obsv = ds_obsv.rename({'time' : 'lead_time'})
ds_obsv['lead_time'] = lead_times

# Loop over remaining lead_time -----
for init_date in init_dates[1:]:
    start_index = np.where(ds_temp1.time == np.datetime64(init_date))[0].item()
    ds_temp3 = ds_temp1.isel(time=range(start_index, start_index+len(lead_times)))

    # Concatenate along 'lead_time' dimension/coordinate -----
    ds_temp3 = ds_temp3.rename({'time' : 'lead_time'})
    ds_temp3['lead_time'] = lead_times
    ds_temp3.coords['init_date'] = init_date
    ds_obsv = xr.concat([ds_obsv, ds_temp3],'init_date') 

#### Rechunk

In [None]:
with utils.timer():
    # Rechunk for chunksizes of at least 1,000,000 elements -----
    ds_obsv = ds_obsv.chunk(chunks={'init_date' : len(init_dates)})

# Skill metrics for probabilistic forecasts

#### E.g. for temperature at 1000hPa averaged over Australia

In [None]:
with utils.timer():
    # Region of interest -----
    region = (-38.0, -11.0, 113.0 , 153.0) # (lat_min,lat_max,lon_min,lon_max)

    da_fcst = utils.calc_boxavg_latlon(ds_fcst['temp']
                                       .sel(pfull=1000,method='nearest',drop=True), region).compute()-273.15
    da_obsv = utils.calc_boxavg_latlon(ds_obsv['temp']
                                       .squeeze().drop('lv_ISBL1'), region).compute()-273.15

### Plot one initialization date

In [None]:
da_fcst

In [None]:
da_obsv

In [None]:
fig1 = plt.figure(figsize=(10,5))

ax = fig1.add_axes([0.1, 0.1, 0.8, 0.8])
ax.grid()
ax.plot(da_fcst['lead_time'],da_fcst.isel(init_date=[0]).squeeze())
ax.plot(da_obsv['lead_time'],da_obsv.isel(init_date=[0]).squeeze(),'k-',linewidth=2)
ax.set_xlabel('lead time')
ax.set_ylabel('average temp [K]');

## Rank histogram

#### Rank the data and compute histograms as a function of lead time

In [None]:
with utils.timer():
    rank_histogram = skill.compute_rank_histogram(da_fcst, da_obsv, indep_dims='init_date')

In [None]:
with utils.timer():
    ncol = 4; nrow = int(np.ceil(len(lead_times)/ncol));
    fig, axs = plt.subplots(figsize=(15,15), nrows=nrow, ncols=ncol);

    for idx,ax in enumerate(axs.reshape(-1)): 
        ax.grid()
        ax.bar(rank_histogram.bins,rank_histogram.isel(lead_time=idx, drop=True))
        ax.set_ylim(0,rank_histogram.max())
        ax.text(10.3,0.85*rank_histogram.max(),'mn '+str(idx+1))

        if idx % ncol == 0:
            ax.set_ylabel('count')

        if idx / ncol >= nrow - 1:
            ax.set_xlabel('bins')

#### Rank the data and compute histograms for all lead times

In [None]:
with utils.timer():
    rank_histogram = skill.compute_rank_histogram(da_fcst, da_obsv, indep_dims=('init_date','lead_time'))

In [None]:
with utils.timer():
    fig1 = plt.figure(figsize=(8,3))

    ax1 = fig1.add_axes([0.1, 0.1, 0.8, 0.8])
    ax1.grid()
    ax1.bar(rank_histogram.bins,rank_histogram)
    ax1.set_xlabel('bins')
    ax1.set_ylabel('count');

## (Continuous) ranked probability score

In [None]:
with utils.timer():
    # Specify bins for computation of cdf -----
    bins = np.linspace(0,40,100)

    # Compute ranked probability score -----
    rps = skill.compute_rps(da_fcst, da_obsv, bins=bins, indep_dims='init_date', ensemble_dim='ensemble')

In [None]:
with utils.timer():
    fig1 = plt.figure(figsize=(8,4))

    ax = fig1.add_axes([0.1, 0.1, 0.8, 0.8])
    ax.grid()
    ax.plot(rps['lead_time'],rps,linewidth=2)
    ax.set_xlabel('Lead time [months]')
    ax.set_ylabel('Ranked probability score');

# Skill metrics for continuous variables

## Additive bias error

In [None]:
with utils.timer():
    mean_additive_bias = skill.compute_mean_additive_bias(da_fcst, da_obsv, 
                                                          indep_dims='init_date', ensemble_dim='ensemble')

## Multiplicative bias error

In [None]:
with utils.timer():
    mean_multiplicative_bias = skill.compute_mean_multiplicative_bias(da_fcst, da_obsv, 
                                                                      indep_dims='init_date', ensemble_dim='ensemble')

## Mean absolute error

In [None]:
with utils.timer():
    mean_absolute_error = skill.compute_mean_absolute_error(da_fcst, da_obsv, 
                                                            indep_dims='init_date', ensemble_dim='ensemble')

## Mean squared error

In [None]:
with utils.timer():
    mean_squared_error = skill.compute_mean_squared_error(da_fcst, da_obsv, 
                                                          indep_dims='init_date', ensemble_dim='ensemble')

## Root mean squared error

In [None]:
with utils.timer():
    rms_error = skill.compute_rms_error(da_fcst, da_obsv, 
                                        indep_dims='init_date', ensemble_dim='ensemble')

#### Plot as a function of lead time

In [None]:
with utils.timer():
    fig1 = plt.figure(figsize=(8,4))

    ax = fig1.add_axes([0.1, 0.1, 0.8, 0.8])
    ax.grid()
    ax.plot(mean_additive_bias['lead_time'],mean_additive_bias,linewidth=2)
    ax.plot(mean_multiplicative_bias['lead_time'],mean_multiplicative_bias,linewidth=2)
    ax.plot(mean_absolute_error['lead_time'],mean_absolute_error,linewidth=2)
    # ax.plot(mean_squared_error['lead_time'],mean_squared_error,linewidth=2)
    ax.plot(rms_error['lead_time'],rms_error,linewidth=2)
    ax.set_xlabel('Lead time [months]')
    ax.set_ylabel('Error');
    ax.legend();

# Close dask client

In [None]:
# with utils.timer():
#     client.close()