# Tutorial demonstrating verification of v1 precip against jra55

#### Import pyLatte package

In [1]:
from pylatte import utils
from pylatte import skill

#### Currently, the following packages are required to load the data - this process will be replaced by the CAFE cookbook

In [2]:
import numpy as np
import pandas as pd
import xarray as xr

#### Import some plotting packages

In [3]:
import matplotlib.pyplot as plt
import warnings    
warnings.filterwarnings("ignore")

# Jupyter specific -----
%matplotlib inline

# A note about the methodology of pyLatte
The pyLatte package is constructed around the xarray Python package. This is particularly useful for verifications computations, which require large numbers of samples (different model runs) to converge. 

The approach here is to generate very large xarray objects that reference all data required for the verification, but do not store the data in memory. Operations are performed on these xarray objects out-of-memory. When it is necessary to perform a compute (e.g. to produce a plot), this is distributed over multiple processors using the dask Python package.

# Initialise dask (currently not working on vm31)

In [4]:
# import dask
# import distributed
# client = distributed.Client(local_dir='/tmp/squ027-dask-worker-space', n_workers=4)
# client

# Construct xarray objects for forecasts and observations
(The CAFE cookbook will replace these code blocks)

In [10]:
# Resampling details -----
resample_freq = 'M'
resample_method = 'sum'

# Location of forecast data -----
fcst_folder = '/OSM/CBR/OA_DCFP/data/model_output/CAFE/forecasts/v1/'
fcst_filename = 'atmos_daily*'

# Location of observation data -----
obsv_folder = '/OSM/CBR/OA_DCFP/data/observations/jra55/isobaric/061_tprat/'
obsv_filename = 'anl_surf125.002_prmsl.'

In [13]:
# Initial dates (takes approximately 1 min 30 sec per date) -----
init_dates = pd.date_range('2002-06','2005-06' , freq='1MS')

# Ensembles to include -----
ensembles = range(1,12)

# Forecast length -----
FCST_LENGTH = 2 # years
lead_times = utils.get_lead_times(FCST_LENGTH, resample_freq)

### Construct forecasts xarray object

In [14]:
# import pdb, traceback, sys

# ==================================================
# Initialize xarray object for first init_date -----
# ==================================================
with utils.timer():
    print(f'Gathering data for forecast started on {init_dates[0].month}-{init_dates[0].year}...')
    
    ds_fcst = xr.open_mfdataset(fcst_folder + 
                                '/yr' + str(init_dates[0].year) + 
                                '/mn' + str(init_dates[0].month) + 
                                '/OUTPUT.' + str(ensembles[0]) + 
                                '/' + fcst_filename, autoclose=True)
    ds_fcst.coords['ensemble'] = ensembles[0]

    for ensemble in ensembles[1:]:
        ds_temp = xr.open_mfdataset(fcst_folder + 
                                    '/yr' + str(init_dates[0].year) + 
                                    '/mn' + str(init_dates[0].month) + 
                                    '/OUTPUT.' + str(ensemble) + 
                                    '/' + fcst_filename, autoclose=True)
        # Concatenate along 'ensemble' dimension/coordinate -----
        ds_temp.coords['ensemble'] = ensemble
        ds_fcst = xr.concat([ds_fcst, ds_temp],'ensemble')

    # Resample to desired frequency and resave time as lead time -----
    ds_fcst = ds_fcst.resample(freq=resample_freq, dim='time', how=resample_method) \
                               .isel(time = range(len(lead_times)))
    ds_fcst['time'] = ds_fcst['time'].values.astype('<M8[' + resample_freq + ']')
    ds_fcst = utils.datetime_to_leadtime(ds_fcst).expand_dims('init_date')
    
# ==============================================
# Loop over remaining initialization dates -----
# ==============================================
for init_date in init_dates[1:]:
    with utils.timer():
        year = init_date.year
        month = init_date.month
        print(f'Gathering data for forecast started on {month}-{year}...')

        # There is a bug in xarray that causes an 'invalid type promotion' sometimes when concatenating 
        # The following while loop provides a work-around 
        more_ensembles = True
        first_chunk = True
        current_ensemble = 1

        while more_ensembles:
            try:
                # Initialize xarray object for first ensemble -----
                ds_temp1 = xr.open_mfdataset(fcst_folder + 
                                             '/yr' + str(year) + 
                                             '/mn' + str(month) + 
                                             '/OUTPUT.' + str(ensembles[current_ensemble-1]) + 
                                             '/' + fcst_filename, autoclose=True)
                ds_temp1.coords['ensemble'] = ensembles[current_ensemble-1]

                for ensemble in ensembles[current_ensemble:]:
                    ds_temp2 = xr.open_mfdataset(fcst_folder + 
                                                '/yr' + str(year) + 
                                                '/mn' + str(month) + 
                                                '/OUTPUT.' + str(ensemble) + 
                                                '/' + fcst_filename, autoclose=True)
                    # Concatenate along 'ensemble' dimension/coordinate -----
                    ds_temp2.coords['ensemble'] = ensemble
                    ds_temp1 = xr.concat([ds_temp1, ds_temp2],'ensemble')

                # try:
                if first_chunk:
                    ds_chunk = ds_temp1
                else:
                    ds_chunk = xr.concat([ds_chunk, ds_temp1],'ensemble')
                # except:
                #     type, value, tb = sys.exc_info()
                #     traceback.print_exc()
                #     pdb.post_mortem(tb)

                more_ensembles = False
            except TypeError:
                if first_chunk:
                    ds_chunk = ds_temp1
                    first_chunk = False
                else:
                    ds_chunk = xr.concat([ds_chunk, ds_temp1],'ensemble')
                current_ensemble = ensemble

        # Resample to desired frequency and resave time as lead time -----
        ds_chunk = ds_chunk.resample(freq=resample_freq, dim='time', how=resample_method) \
                           .isel(time = range(len(lead_times)))
        ds_chunk['time'] = ds_chunk['time'].values.astype('<M8[' + resample_freq + ']')
        ds_chunk = utils.datetime_to_leadtime(ds_chunk).expand_dims('init_date')
        
        # Concaneate along 'init_date' dimension/coordinate -----
        ds_fcst = xr.concat([ds_fcst, ds_chunk],'init_date')

# There seems to be a bug that re-adds the 'time' dimension after renaming - drop this -----
ds_fcst = utils.prune(ds_fcst) 

Gathering data for forecast started on 6-2002...
   Elapsed: 5.680202484130859 sec
Gathering data for forecast started on 7-2002...
   Elapsed: 12.016013145446777 sec


ValueError: too many different dimensions to concatenate: {'lat', 'latb'}

#### Rechunk

In [None]:
with utils.timer():
    # Rechunk for chunksizes of at least 1,000,000 elements -----
    ds_fcst = ds_fcst.chunk(chunks={'ensemble' : len(ensembles), 'lead_time' : len(lead_times)})

### Construct observations xarray object

In [None]:
fcst_years = [pd.to_datetime(init_dates).year]
fcst_year_min = min(fcst_years)[0]
fcst_year_max = max(fcst_years)[-1] + FCST_LENGTH

# ===================================================
# Only load years for which forecast data exist -----
# ===================================================
with utils.timer():
    print(f'Gathering data for observations...')
    
    ds_jra = xr.open_mfdataset(obsv_folder + obsv_filename + str(fcst_year_min) + '*', 
                                 autoclose=True)
    for year_to_load in range(fcst_year_min+1,fcst_year_max+1):
        ds_temp2 = xr.open_mfdataset(obsv_folder + obsv_filename + str(year_to_load) + '*', 
                                     autoclose=True)
        ds_jra = xr.concat([ds_jra, ds_temp2],'initial_time0_hours')

    # Standardize naming -----
    ds_jra = ds_jra.rename({'initial_time0_hours':'time',
                                      'g0_lon_3':'lon',
                                      'g0_lat_2':'lat',
                                      'TPRAT_GDS0_SFC_ave3h':'precip'})

    # Resample to desired frequency -----
    ds_jra = ds_jra.resample(freq=resample_freq, dim='time', how=resample_method)
    ds_jra['time'] = ds_jra['time'].values.astype('<M8[' + resample_freq + ']')
    
    # ===============================================
    # Stack to resemble ds_forecast coordinates -----
    # ===============================================
    ds_obsv = utils.stack_by_init_date(ds_jra,init_dates,24)
    
ds_obsv = utils.prune(ds_obsv) 

#### Rechunk

In [None]:
with utils.timer():
    # Rechunk for chunksizes of at least 1,000,000 elements -----
    ds_obsv = ds_obsv.chunk(chunks={'init_date' : len(init_dates)})

# Let's look at average monthly rainfall over Tasmania

##### Extract forecast and observation over region
Note we `compute()` the xarray objects here to save time later on. Once dask is working, it will probably be most sensible to leave the objects uncomputed

In [None]:
# Region of interest -----
region = (-44.0, -40.0, 144.0 , 148.0) # (lat_min,lat_max,lon_min,lon_max)

da_fcst = utils.calc_boxavg_latlon(ds_fcst['precip'] * 60 * 60 * 24 / 998.2 * 1000, region).compute()

# The jra55 precip data is saved with 3hr and 6hr forecasts as an additional dimension - deal with these -----
da_obsv = utils.calc_boxavg_latlon(1 / 8 * ds_obsv['precip'].sum(dim='forecast_time1'), region).compute()

##### Load climatology data
Various climatologies are/will be accessable using utils.load_climatology(). Here we use a climatology computed over the full 55 year jra reanalysis

In [None]:
jra_clim = utils.load_climatology('jra_1958-2016', 'precip', freq='M')

da_clim = utils.calc_boxavg_latlon(jra_clim, region).compute()

##### Compute anomaly data
Recall that the forecast and observation data are saved as functions of lead time and initial date. The function `utils.anomalize()` computes anomalies given data and a climatology which each have a datetime dimension `time`. Thus it is necessary to first convert from the lead time/initial date format to a datetime format, then compute the anomaly, the convert back to the lead time/initial date format.  

In [None]:
anomalize = lambda data, clim: utils.datetime_to_leadtime(
                                   utils.anomalize(
                                       utils.leadtime_to_datetime(data),clim))

In [None]:
da_fcst_anom = da_fcst.groupby('init_date').apply(anomalize, clim=da_clim)

da_obsv_anom = da_obsv.groupby('init_date').apply(anomalize, clim=da_clim)

##### Compute persistence data
This requires repeating the data at the first lead time over all lead times. `utils.repeat_data()` allows us to do this

In [None]:
da_pers = utils.repeat_data(da_obsv,'lead_time')

## Before computing any metrics, lets make some example plots

##### Plot the forecast ensembles and observations for the first initial date

In [None]:
fig1 = plt.figure(figsize=(10,5))

ax = fig1.add_axes([0.1, 0.1, 0.8, 0.8])
ax.grid()
ax.plot(da_fcst['lead_time'],da_fcst.isel(init_date=[0]).squeeze())
ax.plot(da_obsv['lead_time'],da_obsv.isel(init_date=[0]).squeeze(),'k-',linewidth=2)
ax.set_xlabel('lead time')
ax.set_ylabel('monthly rainfall [mm]');

##### Plot the forecast and observation anomalies for the first initial date

In [None]:
fig1 = plt.figure(figsize=(10,5))

ax = fig1.add_axes([0.1, 0.1, 0.8, 0.8])
ax.grid()
ax.plot(da_fcst['lead_time'],da_fcst_anom.isel(init_date=[0]).squeeze())
ax.plot(da_obsv['lead_time'],da_obsv_anom.isel(init_date=[0]).squeeze(),'k-',linewidth=2)
ax.set_xlabel('lead time')
ax.set_ylabel('monthly rainfall [mm]');

# Note:
The climatology and persistence data are not actually used here, since this notebook demonstrates skill metrics for probabilistic and event-based forecasts. Currently we only have access to the mean climatology - i.e. we cannot determine climatological probabilities of events occurring. I plan to instead load saved fields of the climatological PDFs which will enable climatological probabilities to be computed for any user-specified event

# Skill metrics for probabilistic forecasts

#### E.g. for the event of monthly rainfall over Tasmania being greater than 100 mm/month but less than 600 mm/month

In [None]:
event = '(> 100) and (< 600)'

## Reliability diagram

#### Compute reliability as a function of lead time 

In [None]:
with utils.timer():
    # Compute the event data for forecast likelihood and observations -----
    fcst_likelihood = skill.compute_likelihood(skill.did_event(da_fcst, event))
    obsv_logical = skill.did_event(da_obsv, event)

    # Compute the reliability -----
    fcst_probabilities = np.linspace(0,1,len(da_fcst['ensemble'])+1)
    reliability = skill.compute_reliability(fcst_likelihood,obsv_logical,
                                            fcst_probabilities,indep_dims='init_date')

In [None]:
with utils.timer():
    ncol = 4; nrow = int(np.ceil(len(lead_times)/ncol));
    fig, axs = plt.subplots(figsize=(15,15), nrows=nrow, ncols=ncol);

    for idx,ax in enumerate(axs.reshape(-1)): 
        ax.grid()
        ax.plot([0, 1],[0, 1],'k--')
        sample_clim = reliability['relative_freq'].isel(lead_time=idx, drop=True).mean()
        ax.plot([-1, 2],[sample_clim, sample_clim],'k--')
        ax.plot(reliability['forecast_probability'],
                reliability['relative_freq'].isel(lead_time=idx, drop=True),'r',linewidth=2)
        ax.set_xlim(0,1)
        ax.set_ylim(0,1)
        ax.text(0.82,0.7,'mn '+str(idx+1))

        if idx % ncol == 0:
            ax.set_ylabel('Relative frequency')

        if idx / ncol >= nrow - 1:
            ax.set_xlabel('Forecast probability')

        fig = plt.gcf()
        box = ax.get_position()
        width = box.width
        height = box.height
        subpos = [0.05,0.65,0.3,0.3]
        inax_position  = ax.transAxes.transform(subpos[0:2])
        transFigure = fig.transFigure.inverted()
        infig_position = transFigure.transform(inax_position)    
        x = infig_position[0]
        y = infig_position[1]
        width *= subpos[2]
        height *= subpos[3] 
        subax = fig.add_axes([x,y,width,height])
        subax.yaxis.tick_right()
        subax.bar(reliability['forecast_probability'],reliability['fcst_number'].isel(lead_time=idx, drop=True),
                  width=reliability['forecast_probability'][1])

#### Compute reliability across all lead times

In [None]:
with utils.timer():
    # Compute the event data for forecast likelihood and observations -----
    fcst_likelihood = skill.compute_likelihood(skill.did_event(da_fcst, event))
    obsv_logical = skill.did_event(da_obsv, event)

    # Compute the reliability -----
    fcst_probabilities = np.linspace(0,1,len(da_fcst['ensemble'])+1)
    reliability = skill.compute_reliability(fcst_likelihood,obsv_logical,fcst_probabilities,
                                            indep_dims=['init_date','lead_time'])

In [None]:
with utils.timer():
    fig1 = plt.figure(figsize=(8,6))

    ax = fig1.add_axes([0.1, 0.1, 0.8, 0.8])
    ax.grid()
    ax.plot([0, 1],[0, 1],'k--')
    sample_clim = reliability['relative_freq'].mean()
    ax.plot([-1, 2],[sample_clim, sample_clim],'k--')
    ax.plot(reliability['forecast_probability'],reliability['relative_freq'],'r',linewidth=2)
    ax.set_xlim(0,1)
    ax.set_ylim(0,1)
    ax.set_xlabel('Forecast probability')
    ax.set_ylabel('Relative frequency');

    fig = plt.gcf()
    box = ax.get_position()
    width = box.width
    height = box.height
    subpos = [0.05,0.65,0.3,0.3]
    inax_position  = ax.transAxes.transform(subpos[0:2])
    transFigure = fig.transFigure.inverted()
    infig_position = transFigure.transform(inax_position)    
    x = infig_position[0]
    y = infig_position[1]
    width *= subpos[2]
    height *= subpos[3] 
    subax = fig.add_axes([x,y,width,height])
    subax.yaxis.tick_right()
    subax.bar(reliability['forecast_probability'],reliability['fcst_number'],
              width=reliability['forecast_probability'][1]);

## Brier score

#### Compute Brier scores as a function of lead time

In [None]:
with utils.timer():
    # Compute the event data for forecast likelihood and observations -----
    fcst_likelihood = skill.compute_likelihood(skill.did_event(da_fcst, event))
    obsv_logical = skill.did_event(da_obsv, event)

    # Compute the Brier score -----
    fcst_probabilities = np.linspace(0,1,len(da_fcst['ensemble'])-6)
    Brier = skill.compute_Brier_score(fcst_likelihood,obsv_logical,fcst_prob=fcst_probabilities,
                                      indep_dims='init_date')

In [None]:
with utils.timer():
    fig1 = plt.figure(figsize=(8,4))

    ax = fig1.add_axes([0.1, 0.1, 0.8, 0.8])
    ax.grid()
    ax.plot(Brier['lead_time'],Brier['Brier_reliability'],linewidth=1)
    ax.plot(Brier['lead_time'],Brier['Brier_resolution'],linewidth=1)
    ax.plot(Brier['lead_time'],Brier['Brier_uncertainty'],linewidth=1)
    ax.plot(Brier['lead_time'],Brier['Brier_total'],linewidth=2)
    ax.set_xlabel('Lead time [months]')
    ax.set_ylabel('Brier score')
    ax.legend();

## Relative operating characteristic

#### Compute ROC diagrams as a function of lead time

In [None]:
with utils.timer():
    # Compute the event data for forecast likelihood and observations -----
    fcst_likelihood = skill.compute_likelihood(skill.did_event(da_fcst, event))
    obsv_logical = skill.did_event(da_obsv, event)

    # Compute the roc -----
    fcst_probabilities = np.linspace(0,1,len(da_fcst['ensemble'])+1)
    roc = skill.compute_roc(fcst_likelihood, obsv_logical, fcst_probabilities, 
                            indep_dims='init_date')

In [None]:
with utils.timer():
    ncol = 4; nrow = int(np.ceil(len(lead_times)/ncol));
    fig, axs = plt.subplots(figsize=(15,15), nrows=nrow, ncols=ncol);

    for idx,ax in enumerate(axs.reshape(-1)): 
        ax.grid()
        ax.plot([-1, 2],[-1, 2],'k--')
        ax.plot(roc['false_alarm_rate'].isel(lead_time=idx, drop=True),
                roc['hit_rate'].isel(lead_time=idx, drop=True),'ro-',linewidth=2)
        ax.set_xlim(-0.02,1.02)
        ax.set_ylim(-0.02,1.02)
        ax.text(0.82,0.7,'mn '+str(idx+1))

        if idx % ncol == 0:
            ax.set_ylabel('Hit rate')

        if idx / ncol >= nrow - 1:
            ax.set_xlabel('False alarm rate')

#### Compute ROC diagram for all lead times

In [None]:
with utils.timer():
    # Compute the event data for forecast likelihood and observations -----
    fcst_likelihood = skill.compute_likelihood(skill.did_event(da_fcst, event))
    obsv_logical = skill.did_event(da_obsv, event)

    # Compute the roc -----
    fcst_probabilities = np.linspace(0,1,len(da_fcst['ensemble'])+1)
    roc = skill.compute_roc(fcst_likelihood, obsv_logical, fcst_probabilities, 
                            indep_dims=('init_date','lead_time'))

In [None]:
with utils.timer():
    fig1 = plt.figure(figsize=(8,6))

    ax = fig1.add_axes([0.1, 0.1, 0.8, 0.8])
    ax.grid()
    ax.plot([-1, 2],[-1, 2],'k--')
    ax.plot(roc['false_alarm_rate'],roc['hit_rate'],'ro-',linewidth=2)
    ax.set_xlim(-0.02,1.02)
    ax.set_ylim(-0.02,1.02)
    ax.set_xlabel('Hit rate')
    ax.set_ylabel('False alarm rate');

## Discrimination diagram

#### Compute discrimination diagrams as a function of lead time

In [None]:
with utils.timer():
    # Compute the event data for forecast likelihood and observations -----
    fcst_likelihood = skill.compute_likelihood(skill.did_event(da_fcst, event))
    obsv_logical = skill.did_event(da_obsv, event)

    # Compute the discrimination -----
    fcst_probabilities = np.linspace(0,1,len(da_fcst['ensemble'])+1)
    discrimination = skill.compute_discrimination(fcst_likelihood, obsv_logical, 
                                                   fcst_probabilities, indep_dims='init_date')

In [None]:
with utils.timer():
    ncol = 4; nrow = int(np.ceil(len(lead_times)/ncol));
    fig, axs = plt.subplots(figsize=(15,15), nrows=nrow, ncols=ncol);

    for idx,ax in enumerate(axs.reshape(-1)): 
        ax.grid()
        scale_width = 2.5
        ax.bar(discrimination.bins-discrimination.bins[1]/scale_width/2,
            discrimination['hist_obsved'].isel(lead_time=idx, drop=True),
            width=discrimination.bins[1]/scale_width,
            color='b')
        ax.bar(discrimination.bins+discrimination.bins[1]/scale_width/2,
                discrimination['hist_not_obsved'].isel(lead_time=idx, drop=True),
                width=discrimination.bins[1]/scale_width,
                color='r')
        max_count = max([discrimination['hist_obsved'].isel(lead_time=idx, drop=True).max(), 
                         discrimination['hist_not_obsved'].isel(lead_time=idx, drop=True).max()])
        ax.text(0.9,0.85*max_count,'mn '+str(idx+1))

        if idx % ncol == 0:
            ax.set_ylabel('Likelihood')

        if idx / ncol >= nrow - 1:
            ax.set_xlabel('Forecast probability')

#### Compute discrimination diagram for all lead times

In [None]:
with utils.timer():
    discrimination = skill.compute_discrimination(fcst_likelihood, obsv_logical, 
                                                  fcst_probabilities, indep_dims=('init_date','lead_time'))

In [None]:
with utils.timer():
    fig1 = plt.figure(figsize=(8,3))

    ax1 = fig1.add_axes([0.1, 0.1, 0.8, 0.8])
    ax1.grid()
    scale_width = 2.5
    ax1.bar(discrimination.bins-discrimination.bins[1]/scale_width/2,
            discrimination['hist_obsved'],
            width=discrimination.bins[1]/scale_width,
            color='b')
    ax1.bar(discrimination.bins+discrimination.bins[1]/scale_width/2,
            discrimination['hist_not_obsved'],
            width=discrimination.bins[1]/scale_width,
            color='r')
    ax1.set_xlabel('Forecast probability')
    ax1.set_ylabel('Likelihood');

# Skill metrics for categorized forecasts

## Contingency table

#### E.g. for 4 categories between 25 and 150 mm of monthly rainfall

In [None]:
# Define category edges -----
category_edges = np.linspace(25,150,5)

#### Compute contingency as a function of lead time

In [None]:
with utils.timer():
    # Compute contingency table -----
    contingency = skill.compute_contingency_table(da_fcst,da_obsv, category_edges,
                                                  ensemble_dim='ensemble', indep_dims='init_date')

In [None]:
with utils.timer():
    ncol = 4; nrow = int(np.ceil(len(lead_times)/ncol));
    fig, axs = plt.subplots(figsize=(10,15), nrows=nrow, ncols=ncol);

    for idx,ax in enumerate(axs.reshape(-1)): 
        ax.grid()
        im = ax.imshow(contingency.isel(lead_time=idx, drop=True))
        # ax.text(0.82,0.7,'mn '+str(idx+1))

        if idx % ncol == 0:
            ax.set_ylabel('Forecast category')

        if idx / ncol >= nrow - 1:
            ax.set_xlabel('Observed category')

    fig.subplots_adjust(right=0.8)
    cbar_ax = fig.add_axes([0.85, 0.15, 0.03, 0.7])
    fig.colorbar(im, cax=cbar_ax);
    cbar_ax.set_ylabel('counts', rotation=270, labelpad=15);

## Accuracy score

In [None]:
with utils.timer():
    accuracy_score = skill.compute_accuracy_score(contingency)

## Heidke skill score

In [None]:
with utils.timer():
    Heidke_score = skill.compute_Heidke_score(contingency)

## Peirce skill score / Hanssen and Kuipers discriminant

In [None]:
with utils.timer():
    Peirce_score = skill.compute_Peirce_score(contingency)

## Gerrity score

In [None]:
with utils.timer():
    Gerrity_score = skill.compute_Gerrity_score(contingency)

#### Plot as a function of lead_time

In [None]:
with utils.timer():
    fig1 = plt.figure(figsize=(8,4))

    ax = fig1.add_axes([0.1, 0.1, 0.8, 0.8])
    ax.grid()
    ax.plot(accuracy_score['lead_time'],accuracy_score,linewidth=2)
    ax.plot(Heidke_score['lead_time'],Heidke_score,linewidth=2)
    ax.plot(Peirce_score['lead_time'],Peirce_score,linewidth=2)
    ax.plot(Gerrity_score['lead_time'],Gerrity_score,linewidth=2)
    ax.set_xlabel('Lead time [months]')
    ax.set_ylabel('Score');
    ax.legend();

# Skill metrics for dichotomously categorized forecasts

## Contingency table

#### E.g. for monthly rainfall being > or < 100 mm 

In [None]:
with utils.timer():
    # Define category edges -----
    category_edges = [-np.inf, 100, np.inf]

    # Compute contingency table -----
    contingency = skill.compute_contingency_table(da_fcst,da_obsv,category_edges,
                                                  ensemble_dim='ensemble',indep_dims='init_date')

## Bias score

In [None]:
with utils.timer():
    bias_score = skill.compute_bias_score(contingency)

## Probability of detection

In [None]:
with utils.timer():
    hit_rate = skill.compute_hit_rate(contingency)

## False alarm ratio

In [None]:
with utils.timer():
    false_alarm_ratio = skill.compute_false_alarm_ratio(contingency)

## False alarm rate

In [None]:
with utils.timer():
    false_alarm_rate = skill.compute_false_alarm_rate(contingency)

## Success ratio

In [None]:
with utils.timer():
    success_ratio = skill.compute_success_ratio(contingency)

## Threat score

In [None]:
with utils.timer():
    threat_score = skill.compute_threat_score(contingency)

## Equitable threat score

In [None]:
with utils.timer():
    equit_threat_score = skill.compute_equit_threat_score(contingency)

## Odds ratio

In [None]:
with utils.timer():
    odds_ratio = skill.compute_odds_ratio(contingency)

## Odds ratio skill score

In [None]:
with utils.timer():
    odds_ratio_skill = skill.compute_odds_ratio_skill(contingency)

#### Plot as a function of lead time

In [None]:
with utils.timer():
    fig1 = plt.figure(figsize=(8,4))

    ax = fig1.add_axes([0.1, 0.1, 0.8, 0.8])
    ax.grid()
    ax.plot(bias_score['lead_time'],bias_score,linewidth=2)
    ax.plot(hit_rate['lead_time'],hit_rate,linewidth=2)
    ax.plot(false_alarm_ratio['lead_time'],false_alarm_ratio,linewidth=2)
    ax.plot(false_alarm_rate['lead_time'],false_alarm_rate,linewidth=2)
    ax.plot(success_ratio['lead_time'],success_ratio,linewidth=2)
    ax.plot(threat_score['lead_time'],threat_score,linewidth=2)
    ax.plot(equit_threat_score['lead_time'],equit_threat_score,linewidth=2)
    # ax.plot(odds_ratio['lead_time'],odds_ratio,linewidth=2)
    ax.plot(odds_ratio_skill['lead_time'],odds_ratio_skill,linewidth=2)
    ax.set_xlabel('Lead time [months]')
    ax.set_ylabel('Score');
    ax.legend();

# Close dask client

In [None]:
# with utils.timer():
#     client.close()