In [None]:
from __future__ import absolute_import, division, print_function

In [None]:
# License: MIT

In [None]:
import datetime
import itertools
import os
import time

import cartopy.crs as ccrs
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr

import reanalysis_dbns.indices as rdi
import reanalysis_dbns.utils as rdu

from cartopy.util import add_cyclic_point

In [None]:
%matplotlib inline

# Teleconnection indices calculations

This notebook generates the teleconnection indices that
are used as the data for fitting DBNs.

In [None]:
PROJECT_DIR = os.path.join(os.getenv('HOME'), 'projects', 'reanalysis-dbns')

DATA_DIR = os.path.join(PROJECT_DIR, 'data')
RESULTS_DIR = os.path.join(PROJECT_DIR, 'results')

REF_INDICES_RESULTS_DIR = os.path.join(RESULTS_DIR, 'reference-indices')

In [None]:
BASE_PERIOD = [np.datetime64('1979-01-01'), np.datetime64('2001-12-30')]
START_YEAR = 1940 # earliest year allowed

In [None]:
REANALYSES = ['jra55', 'nnr1']

In [None]:
def get_reanalysis_full_name(reanalysis):
    """Get full name for reanalysis."""
    
    if reanalysis == 'hadisst':
        return 'HadISST'

    if reanalysis == 'nnr1':
        return 'NNR1'
    
    if reanalysis == 'jra55':
        return 'JRA-55'
    
    raise ValueError("Unrecognized reanalysis '%r'" % reanalysis)

    
def get_reanalysis_results_dir(reanalysis, results_dir=RESULTS_DIR):
    """Get path to results for given reanalysis."""
    return os.path.join(results_dir, reanalysis)


def get_reanalysis_fields_dir(reanalysis, results_dir=RESULTS_DIR):
    """Get path to directory containing reanalysis fields."""
    reanalysis_dir = get_reanalysis_results_dir(reanalysis, results_dir=results_dir)
    return os.path.join(reanalysis_dir, 'fields')


def get_reanalysis_eofs_dir(reanalysis, results_dir=RESULTS_DIR):
    """Get path to directory containing reanalysis EOFs."""
    reanalysis_dir = get_reanalysis_results_dir(reanalysis, results_dir=results_dir)
    return os.path.join(reanalysis_dir, 'eofs')


def get_reanalysis_eofs_nc_dir(reanalysis, results_dir=RESULTS_DIR):
    """Get path to directory containing reanalysis EOFs in netCDF format."""
    reanalysis_dir = get_reanalysis_eofs_dir(reanalysis, results_dir=results_dir)
    return os.path.join(reanalysis_dir, 'nc')


def get_reanalysis_eofs_plt_dir(reanalysis, results_dir=RESULTS_DIR):
    """Get path to directory containing plots of reanalysis EOFs."""
    reanalysis_dir = get_reanalysis_eofs_dir(reanalysis, results_dir=results_dir)
    return os.path.join(reanalysis_dir, 'plt')


def get_reanalysis_indices_dir(reanalysis, results_dir=RESULTS_DIR):
    """Get path to directory containing reanalysis indices."""
    reanalysis_dir = get_reanalysis_results_dir(reanalysis, results_dir=results_dir)
    return os.path.join(reanalysis_dir, 'indices')


def get_reanalysis_indices_csv_dir(reanalysis, results_dir=RESULTS_DIR):
    """Get path to directory containing reanalysis indices in CSV format."""
    indices_dir = get_reanalysis_indices_dir(reanalysis, results_dir=results_dir)
    return os.path.join(indices_dir, 'csv')


def get_reanalysis_indices_nc_dir(reanalysis, results_dir=RESULTS_DIR):
    """Get path to directory containing reanalysis indices in netCDF format."""
    indices_dir = get_reanalysis_indices_dir(reanalysis, results_dir=results_dir)
    return os.path.join(indices_dir, 'nc')


def get_reanalysis_indices_plt_dir(reanalysis, results_dir=RESULTS_DIR):
    """Get path to directory containing plots of reanalysis indices."""
    indices_dir = get_reanalysis_indices_dir(reanalysis, results_dir=results_dir)
    return os.path.join(indices_dir, 'plt')

In [None]:
def get_timespan_string(time_span):
    """Get string for time period used in file names."""
    return '{}_{}'.format(
        pd.to_datetime(time_span[0]).strftime('%Y%m%d'),
        pd.to_datetime(time_span[1]).strftime('%Y%m%d'))

In [None]:
def get_index_datafile_name(reanalysis, index, base_period=BASE_PERIOD,
                            frequency='daily', ext='csv'):
    """Get filename for datafile containing index data."""
    
    base_period_str = '{}_{}'.format(
        pd.to_datetime(base_period[0]).strftime('%Y%m%d'),
        pd.to_datetime(base_period[1]).strftime('%Y%m%d'))
    
    return '.'.join([reanalysis, base_period_str, index, frequency, ext])

In [None]:
def read_reference_index_csv(filename, time_name='time'):
    """Read index CSV file provided in format for downloaded indices."""
    
    data = np.genfromtxt(filename, names=True, delimiter=',')
    
    dates = np.array(
        [datetime.datetime(int(data['year'][i]), int(data['month'][i]), int(data['day'][i]))
         for i in range(data['value'].shape[0])])
    
    da = xr.DataArray(data['value'], coords={time_name: dates}, dims=[time_name])
    
    ds = da.to_dataset(name='index')
    ds.attrs['source_file'] = filename
    
    return ds

In [None]:
def write_index_to_csv(output_file, index_da, header_attrs=None, time_name=None):
    """Write index data to CSV file."""
    
    time_name = time_name if time_name is not None else rdu.get_time_name(index_da)
    
    df = pd.DataFrame({'value': index_da.data}, index=index_da[time_name].data)
    
    body = df.to_csv(index_label='date')
    
    if header_attrs is not None:
        header = '\n'.join('# {}: {}'.format(h, header_attrs[h]) for h in header_attrs)
        body = '\n'.join([header, body])
        
    with open(output_file, 'w') as ofs:
        ofs.write(body)

In [None]:
def write_index_nc_file(reanalysis, index, index_ds, base_period=BASE_PERIOD,
                        frequency='daily', output_dir=None):
    """Write index data to netCDF file."""
    
    if output_dir is None:
        output_dir = get_reanalysis_indices_nc_dir(reanalysis)

    output_file = get_index_datafile_name(
        reanalysis, index, base_period=base_period, frequency=frequency,
        ext='nc')

    output_file = os.path.join(output_dir, output_file)
    
    index_ds.to_netcdf(output_file)
    
    return output_file
    
    
def write_index_csv_file(reanalysis, index, index_ds, index_var,
                         base_period=BASE_PERIOD, frequency='daily',
                         output_dir=None, time_name=None):
    """Write index data to CSV file."""
    
    if output_dir is None:
        output_dir = get_reanalysis_indices_csv_dir(reanalysis)

    output_file = get_index_datafile_name(
        reanalysis, index, base_period=base_period, frequency=frequency,
        ext='csv')

    output_file = os.path.join(output_dir, output_file)
    
    header_attrs = {attr: index_ds.attrs[attr] for attr in index_ds.attrs}
    
    write_index_to_csv(output_file, index_ds[index_var], header_attrs=header_attrs,
                       time_name=time_name)
    
    return output_file

In [None]:
def propagate_missing_values_through_time(da, time_name=None):
    """Propagate any missing values at one time to all times."""
    
    time_name = time_name if time_name is not None else rdu.get_time_name(da)
    
    feature_dims = [d for d in da.dims if d != time_name]
    original_shape = [da.sizes[d] for d in feature_dims]
    
    n_samples = da.sizes[time_name]
    n_features = np.prod(original_shape)
    
    if da.get_axis_num(time_name) != 0:
        da = da.transpose(*([time_name] + feature_dims))
        
    filled_data = np.reshape(da.data, (n_samples, n_features)).copy()
    
    points_to_fill = np.any(np.logical_not(np.isfinite(filled_data)), axis=0)
    
    if not np.any(points_to_fill):
        return da
    
    filled_data[:, points_to_fill] = np.NaN
    
    filled_da = xr.DataArray(
        filled_data.reshape([n_samples] + original_shape),
        coords=da.coords, dims=da.dims)
    
    return filled_da

In [None]:
def plot_indices_timeseries(indices_to_plot, years_per_row=10, points=False, time_name='time'):
    """Plot time-series of indices."""
    
    start_date = None
    end_date = None
    index_min = np.inf
    index_max = -np.inf
    for index in indices_to_plot:
        index_start_date = indices_to_plot[index][time_name].min().values
        index_end_date = indices_to_plot[index][time_name].max().values
        
        if start_date is None or index_start_date < start_date:
            start_date = index_start_date
            
        if end_date is None or index_end_date > end_date:
            end_date = index_end_date
            
        if indices_to_plot[index].min().item() < index_min:
            index_min = indices_to_plot[index].min().item()
            
        if indices_to_plot[index].max().item() > index_max:
            index_max = indices_to_plot[index].max().item()
            
    start_year = pd.to_datetime(start_date).year
    end_year = pd.to_datetime(end_date).year
    n_years = end_year - start_year + 1
    n_rows = int(np.ceil(n_years / years_per_row))
    
    fig = plt.figure(figsize=(15, 3 * n_rows))
    gs = gridspec.GridSpec(nrows=n_rows, ncols=1, hspace=0.15)
    
    for i in range(n_rows):
        
        ax = fig.add_subplot(gs[i, 0])

        row_start_year = start_year + i * years_per_row
        row_end_year = start_year + (i + 1) * years_per_row - 1
        
        row_start_date = np.datetime64('{:d}-01-01'.format(row_start_year))
        row_end_date = np.datetime64('{:d}-12-31'.format(row_end_year))
        
        markers = itertools.cycle(('.', 'x', 's', '+', 'd'))
        colors = itertools.cycle(('r', 'b', 'g', 'y', 'k'))
        styles = itertools.cycle(('-', '--', ':', '-.'))
        
        for index in indices_to_plot:
            
            index_data = indices_to_plot[index]
            index_data = index_data.where(
                (index_data[time_name] >= row_start_date) &
                (index_data[time_name] <= row_end_date), drop=True)
            
            marker = next(markers)
            color = next(colors)
            style = next(styles)

            if points:
                ax.plot(index_data[time_name], index_data.data, color=color, marker=marker,
                        ls='none', label=index)
            else:
                ax.plot(index_data[time_name], index_data.data, color=color, ls=style,
                        label=index)
            
        ax.grid(ls='--', color='gray', alpha=0.5)
        
        ax.set_xlim(row_start_date, row_end_date)
        ax.set_ylim(index_min, index_max)

        ax.tick_params(which='both', labelsize=13)
        ax.set_xlabel('Date', fontsize=14)
        ax.set_ylabel('Index value', fontsize=14)
            
        if i == 0:
            ax.legend(fontsize=13)

    return fig

## Index calculations

In [None]:
output_files = {}

### 500 hPa geopotential height indices

In [None]:
def get_reanalysis_h500_variable(reanalysis):
    """Get name of reanalysis 500 hPa geopotential height field."""
    
    if reanalysis == 'nnr1':
        return 'hgt'
    
    if reanalysis == 'jra55':
        return 'HGT_GDS0_ISBL'
    
    raise ValueError("Unrecognized reanalysis '%r'" % reanalysis)

In [None]:
def calculate_ao_indices(hgt_da, reanalysis, input_file=None,
                         eofs_output_dir=None, indices_csv_output_dir=None,
                         indices_nc_output_dir=None,
                         ao_mode=0, base_period=BASE_PERIOD,
                         lat_name=None, lon_name=None, time_name=None):
    """Calculate monthly and daily AO indices."""
    
    # Calculate daily and monthly indices.
    ao_loadings_ds, daily_index_ds = rdi.pc_ao(
        hgt_da, frequency='daily', base_period=base_period, ao_mode=ao_mode)
    _, monthly_index_ds = rdi.pc_ao(
        hgt_da, frequency='monthly', base_period=base_period, ao_mode=ao_mode)
    
    if input_file is not None:
        ao_loadings_ds.attrs['input_file'] = input_file
        daily_index_ds.attrs['input_file'] = input_file
        monthly_index_ds.attrs['input_file'] = input_file

    # Write loading pattern to file.
    base_period_str = get_timespan_string(base_period)
    
    if eofs_output_dir is None:
        eofs_output_dir = get_reanalysis_eofs_nc_dir(reanalysis)
        
    eofs_output_file = '.'.join(
        [reanalysis, base_period_str, 'ao', 'monthly', 'eofs', 'nc'])
    eofs_output_file = os.path.join(eofs_output_dir, eofs_output_file)
        
    ao_loadings_ds.to_netcdf(eofs_output_file)

    # Write daily index to file.        
    daily_index_nc_output_file = write_index_nc_file(
        reanalysis, 'ao', daily_index_ds, base_period=base_period,
        frequency='daily', output_dir=indices_nc_output_dir)
    daily_index_csv_output_file = write_index_csv_file(
        reanalysis, 'ao', daily_index_ds, 'ao_index',
        base_period=base_period,
        frequency='daily', output_dir=indices_csv_output_dir)
    
    # Write monthly index to file.        
    monthly_index_nc_output_file = write_index_nc_file(
        reanalysis, 'ao', monthly_index_ds, base_period=base_period,
        frequency='monthly', output_dir=indices_nc_output_dir)
    monthly_index_csv_output_file = write_index_csv_file(
        reanalysis, 'ao', monthly_index_ds, 'ao_index',
        base_period=base_period,
        frequency='monthly', output_dir=indices_csv_output_dir)
    
    return {'eofs_nc': eofs_output_file,
            'daily_index_nc': daily_index_nc_output_file,
            'daily_index_csv': daily_index_csv_output_file,
            'monthly_index_nc': monthly_index_nc_output_file,
            'monthly_index_csv': monthly_index_csv_output_file}

In [None]:
def calculate_nhtele_indices(hgt_da, reanalysis, input_file=None,
                             composites_output_dir=None, indices_csv_output_dir=None,
                             indices_nc_output_dir=None,
                             base_period=BASE_PERIOD,
                             lat_name=None, lon_name=None, time_name=None):
    """Calculate monthly and daily NHTELE indices."""
    
    window_length = 10
    n_modes = 24
    n_clusters = 4
    season = 'DJF'
    
    daily_indices_ds = rdi.kmeans_pcs(
        hgt_da, frequency='daily', base_period=base_period,
        window_length=window_length, n_modes=n_modes, n_clusters=n_clusters,
        season=season)
    monthly_indices_ds = rdi.kmeans_pcs(
        hgt_da, frequency='monthly', base_period=base_period,
        window_length=window_length, n_modes=n_modes, n_clusters=n_clusters,
        season=season)
    
    if input_file is not None:
        daily_indices_ds.attrs['input_file'] = input_file
        monthly_indices_ds.attrs['input_file'] = input_file
        
    # Write composites to file.
    base_period_str = get_timespan_string(base_period)
    
    if composites_output_dir is None:
        composites_output_dir = get_reanalysis_eofs_nc_dir(reanalysis)
        
    composites_output_file = '.'.join(
        [reanalysis, base_period_str, 'nhtele', 'daily', 'composites', 'nc'])
    composites_output_file = os.path.join(
        composites_output_dir, composites_output_file)

    composites_ds = daily_indices_ds['composites'].to_dataset(name='composites')
    for attr in daily_indices_ds.attrs:
        composites_ds.attrs[attr] = daily_indices_ds.attrs[attr]
        
    composites_ds.to_netcdf(composites_output_file)
    
    output_files = {'composites_nc': composites_output_file}

    # Write indices to file.
    for i in range(n_clusters):
        
        index_name = 'nhtele{:d}'.format(i + 1)
        
        daily_index_ds = daily_indices_ds['indices'].sel(cluster=i).squeeze().to_dataset(
            name=index_name)
        for attr in daily_indices_ds.attrs:
            daily_index_ds.attrs[attr] = daily_indices_ds.attrs[attr]
            
        monthly_index_ds = monthly_indices_ds['indices'].sel(cluster=i).squeeze().to_dataset(
            name=index_name)
        for attr in monthly_indices_ds.attrs:
            monthly_index_ds.attrs[attr] = monthly_indices_ds.attrs[attr]

        # Write daily index to file.
        daily_index_nc_output_file = write_index_nc_file(
            reanalysis, index_name, daily_index_ds, base_period=base_period,
            frequency='daily', output_dir=indices_nc_output_dir)
        daily_index_csv_output_file = write_index_csv_file(
            reanalysis, index_name, daily_index_ds, index_name,
            base_period=base_period,
            frequency='daily', output_dir=indices_csv_output_dir)
        
        # Write monthly index to file.
        monthly_index_nc_output_file = write_index_nc_file(
            reanalysis, index_name, monthly_index_ds, base_period=base_period,
            frequency='monthly', output_dir=indices_nc_output_dir)
        monthly_index_csv_output_file = write_index_csv_file(
            reanalysis, index_name, monthly_index_ds, index_name,
            base_period=base_period,
            frequency='monthly', output_dir=indices_csv_output_dir)
        
        output_files['daily_{}_nc'.format(index_name)] = daily_index_nc_output_file
        output_files['daily_{}_csv'.format(index_name)] = daily_index_csv_output_file
        output_files['monthly_{}_nc'.format(index_name)] = monthly_index_nc_output_file
        output_files['monthly_{}_csv'.format(index_name)] = monthly_index_csv_output_file

    return output_files

In [None]:
def calculate_pna_indices(hgt_da, reanalysis, input_file=None,
                          eofs_output_dir=None, indices_csv_output_dir=None,
                          indices_nc_output_dir=None, pna_mode=0,
                          base_period=BASE_PERIOD,
                          lat_name=None, lon_name=None, time_name=None):
    """Calculate monthly and daily PNA indices."""
    
    n_modes = 10

    # Calculate daily and monthly indices.
    pna_loadings_ds, daily_index_ds = rdi.pc_pna(
        hgt_da, frequency='daily', base_period=base_period,
        pna_mode=pna_mode, n_modes=n_modes)
    _, monthly_index_ds = rdi.pc_pna(
        hgt_da, frequency='monthly', base_period=base_period,
        pna_mode=pna_mode, n_modes=n_modes)
    
    if input_file is not None:
        pna_loadings_ds.attrs['input_file'] = input_file
        daily_index_ds.attrs['input_file'] = input_file
        monthly_index_ds.attrs['input_file'] = input_file
        
    # Write loading patterns to file.
    base_period_str = get_timespan_string(base_period)
    
    if eofs_output_dir is None:
        eofs_output_dir = get_reanalysis_eofs_nc_dir(reanalysis)
        
    eofs_output_file = '.'.join(
        [reanalysis, base_period_str, 'pna', 'monthly', 'eofs', 'nc'])
    eofs_output_file = os.path.join(eofs_output_dir, eofs_output_file)

    pna_loadings_ds.to_netcdf(eofs_output_file)
    
    # Write daily index to file.        
    daily_index_nc_output_file = write_index_nc_file(
        reanalysis, 'pna', daily_index_ds, base_period=base_period,
        frequency='daily', output_dir=indices_nc_output_dir)
    daily_index_csv_output_file = write_index_csv_file(
        reanalysis, 'pna', daily_index_ds, 'pna_index',
        base_period=base_period,
        frequency='daily', output_dir=indices_csv_output_dir)
    
    # Write monthly index to file.        
    monthly_index_nc_output_file = write_index_nc_file(
        reanalysis, 'pna', monthly_index_ds, base_period=base_period,
        frequency='monthly', output_dir=indices_nc_output_dir)
    monthly_index_csv_output_file = write_index_csv_file(
        reanalysis, 'pna', monthly_index_ds, 'pna_index',
        base_period=base_period,
        frequency='monthly', output_dir=indices_csv_output_dir)
    
    return {'eofs_nc': eofs_output_file,
            'daily_index_nc': daily_index_nc_output_file,
            'daily_index_csv': daily_index_csv_output_file,
            'monthly_index_nc': monthly_index_nc_output_file,
            'monthly_index_csv': monthly_index_csv_output_file}

In [None]:
def calculate_psa_indices(hgt_da, reanalysis, input_file=None,
                          eofs_output_dir=None, indices_csv_output_dir=None,
                          indices_nc_output_dir=None, psa1_mode=1, psa2_mode=2,
                          base_period=BASE_PERIOD,
                          lat_name=None, lon_name=None, time_name=None):
    """Calculate monthly and daily PSA indices."""
    
    # Calculate daily and monthly indices.
    psa_loadings_ds, daily_indices_ds = rdi.real_pc_psa(
        hgt_da, frequency='daily', base_period=base_period,
        psa1_mode=psa1_mode, psa2_mode=psa2_mode, rotate=False,
        eofs_season='ALL', eofs_frequency='daily')
    _, monthly_indices_ds = rdi.real_pc_psa(
        hgt_da, frequency='monthly', base_period=base_period,
        psa1_mode=psa1_mode, psa2_mode=psa2_mode, rotate=False,
        eofs_season='ALL', eofs_frequency='daily')
    
    if input_file is not None:
        psa_loadings_ds.attrs['input_file'] = input_file
        daily_indices_ds.attrs['input_file'] = input_file
        monthly_indices_ds.attrs['input_file'] = input_file
        
    # Write loading patterns to file.
    base_period_str = get_timespan_string(base_period)
    
    if eofs_output_dir is None:
        eofs_output_dir = get_reanalysis_eofs_nc_dir(reanalysis)
        
    eofs_output_file = '.'.join(
        [reanalysis, base_period_str, 'psa', 'daily', 'eofs', 'nc'])
    eofs_output_file = os.path.join(eofs_output_dir, eofs_output_file)

    psa_loadings_ds.to_netcdf(eofs_output_file)
    
    output_files = {'eofs_nc': eofs_output_file}

    # Write indices to file.
    for i in (1, 2):
        
        index_name = 'psa{:d}_index'.format(i)
        
        daily_index_ds = daily_indices_ds[index_name].to_dataset(name=index_name)
        monthly_index_ds = monthly_indices_ds[index_name].to_dataset(name=index_name)
        
        for attr in daily_indices_ds.attrs:
            daily_index_ds.attrs[attr] = daily_indices_ds.attrs[attr]
            
        for attr in monthly_indices_ds.attrs:
            monthly_index_ds.attrs[attr] = monthly_indices_ds.attrs[attr]
        
        # Write daily index to file.
        daily_index_nc_output_file = write_index_nc_file(
            reanalysis, index_name, daily_index_ds, base_period=base_period,
            frequency='daily', output_dir=indices_nc_output_dir)
        daily_index_csv_output_file = write_index_csv_file(
            reanalysis, index_name, daily_index_ds, index_name,
            base_period=base_period,
            frequency='daily', output_dir=indices_csv_output_dir)
        
        # Write monthly index to file.
        monthly_index_nc_output_file = write_index_nc_file(
            reanalysis, index_name, monthly_index_ds, base_period=base_period,
            frequency='monthly', output_dir=indices_nc_output_dir)
        monthly_index_csv_output_file = write_index_csv_file(
            reanalysis, index_name, monthly_index_ds, index_name,
            base_period=base_period,
            frequency='monthly', output_dir=indices_csv_output_dir)
        
        output_files['daily_{}_nc'.format(index_name)] = daily_index_nc_output_file
        output_files['daily_{}_csv'.format(index_name)] = daily_index_csv_output_file
        output_files['monthly_{}_nc'.format(index_name)] = monthly_index_nc_output_file
        output_files['monthly_{}_csv'.format(index_name)] = monthly_index_csv_output_file
        
    return output_files

In [None]:
def calculate_sam_indices(hgt_da, reanalysis, input_file=None,
                          eofs_output_dir=None, indices_csv_output_dir=None,
                          indices_nc_output_dir=None,
                          sam_mode=0, base_period=BASE_PERIOD,
                          lat_name=None, lon_name=None, time_name=None):
    """Calculate monthly and daily SAM indices."""
    
    # Calculate daily and monthly indices.
    sam_loadings_ds, daily_index_ds = rdi.pc_sam(
        hgt_da, frequency='daily', base_period=base_period, sam_mode=sam_mode)
    _, monthly_index_ds = rdi.pc_sam(
        hgt_da, frequency='monthly', base_period=base_period, sam_mode=sam_mode)
    
    if input_file is not None:
        sam_loadings_ds.attrs['input_file'] = input_file
        daily_index_ds.attrs['input_file'] = input_file
        monthly_index_ds.attrs['input_file'] = input_file

    # Write loading pattern to file.
    base_period_str = get_timespan_string(base_period)
    
    if eofs_output_dir is None:
        eofs_output_dir = get_reanalysis_eofs_nc_dir(reanalysis)
        
    eofs_output_file = '.'.join(
        [reanalysis, base_period_str, 'sam', 'monthly', 'eofs', 'nc'])
    eofs_output_file = os.path.join(eofs_output_dir, eofs_output_file)
        
    sam_loadings_ds.to_netcdf(eofs_output_file)

    # Write daily index to file.        
    daily_index_nc_output_file = write_index_nc_file(
        reanalysis, 'sam', daily_index_ds, base_period=base_period,
        frequency='daily', output_dir=indices_nc_output_dir)
    daily_index_csv_output_file = write_index_csv_file(
        reanalysis, 'sam', daily_index_ds, 'sam_index',
        base_period=base_period,
        frequency='daily', output_dir=indices_csv_output_dir)
    
    # Write monthly index to file.        
    monthly_index_nc_output_file = write_index_nc_file(
        reanalysis, 'sam', monthly_index_ds, base_period=base_period,
        frequency='monthly', output_dir=indices_nc_output_dir)
    monthly_index_csv_output_file = write_index_csv_file(
        reanalysis, 'sam', monthly_index_ds, 'sam_index',
        base_period=base_period,
        frequency='monthly', output_dir=indices_csv_output_dir)
    
    return {'eofs_nc': eofs_output_file,
            'daily_index_nc': daily_index_nc_output_file,
            'daily_index_csv': daily_index_csv_output_file,
            'monthly_index_nc': monthly_index_nc_output_file,
            'monthly_index_csv': monthly_index_csv_output_file}

In [None]:
h500_input_files = {
    'jra55': os.path.join(get_reanalysis_fields_dir('jra55'), 'jra.55.hgt.500.1958010100_2018123118.nc'),
    'nnr1': os.path.join(get_reanalysis_fields_dir('nnr1'), 'nnr1.hgt.500.19480101_20200530.nc')
}

for reanalysis in h500_input_files:
    
    print('* Reanalysis: ', reanalysis)

    eofs_output_dir = get_reanalysis_eofs_nc_dir(reanalysis)
    if not os.path.exists(eofs_output_dir):
        os.makedirs(eofs_output_dir)
        
    indices_nc_output_dir = get_reanalysis_indices_nc_dir(reanalysis)
    if not os.path.exists(indices_nc_output_dir):
        os.makedirs(indices_nc_output_dir)
        
    indices_csv_output_dir = get_reanalysis_indices_csv_dir(reanalysis)
    if not os.path.exists(indices_csv_output_dir):
        os.makedirs(indices_csv_output_dir)

    output_files[reanalysis] = {}

    hgt_var = get_reanalysis_h500_variable(reanalysis)

    input_file = h500_input_files[reanalysis]
    with xr.open_dataset(input_file) as ds:
        
        print('\t- Preparing data ...', end='')

        start_time = time.perf_counter()

        # Extract height data.
        hgt_da = ds[hgt_var].squeeze()

        # Restrict to common time-period and ensure fixed missing
        # values, if present.
        time_name = rdu.get_time_name(hgt_da)
        hgt_da = hgt_da.where(hgt_da[time_name].dt.year >= START_YEAR,
                              drop=True)
        
        hgt_da = propagate_missing_values_through_time(
            hgt_da, time_name=time_name)
        
        # Normalize start times.
        input_frequency = rdu.detect_frequency(hgt_da, time_name=time_name)
        
        if input_frequency == 'daily':
            hgt_da = hgt_da.resample({time_name: '1D'}).mean(time_name)
        elif input_frequency == 'monthly':
            hgt_da = hgt_da.resample({time_name: '1MS'}).mean(time_name)
        else:
            raise RuntimeError('Could not determine input frequency')
        
        end_time = time.perf_counter()
        print(' (time: {:.2f}s)'.format(end_time - start_time))

        print('\t- AO ...', end='')
        start_time = time.perf_counter()
        output_files[reanalysis]['AO'] = calculate_ao_indices(
            hgt_da, reanalysis, input_file=input_file)
        end_time = time.perf_counter()
        print(' (time: {:.2f}s)'.format(end_time - start_time))
        
        print('\t- NHTELE ...', end='')
        start_time = time.perf_counter()
        output_files[reanalysis]['NHTELE'] = calculate_nhtele_indices(
            hgt_da, reanalysis, input_file=input_file)
        end_time = time.perf_counter()
        print(' (time: {:.2f}s)'.format(end_time - start_time))

        print('\t- PNA ...', end='')
        start_time = time.perf_counter()
        output_files[reanalysis]['PNA'] = calculate_pna_indices(
            hgt_da, reanalysis, input_file=input_file)
        end_time = time.perf_counter()
        print(' (time: {:.2f}s)'.format(end_time - start_time))
        
        print('\t- PSA ...', end='')
        start_time = time.perf_counter()
        output_files[reanalysis]['PSA'] = calculate_psa_indices(
            hgt_da, reanalysis, input_file=input_file)
        end_time = time.perf_counter()
        print(' (time: {:.2f}s)'.format(end_time - start_time))
        
        print('\t- SAM ...', end='')
        start_time = time.perf_counter()
        output_files[reanalysis]['SAM'] = calculate_sam_indices(
            hgt_da, reanalysis, input_file=input_file)
        end_time = time.perf_counter()
        print(' (time: {:.2f}s)'.format(end_time - start_time))

### SST indices

In [None]:
def get_reanalysis_sst_variable(reanalysis):
    """Get name of reanalysis SST field."""
    
    if reanalysis in ('hadisst', 'nnr1'):
        return 'sst'
    
    if reanalysis == 'jra55':
        return 'BRTMP_GDS0_SFC'
    
    raise ValueError("Unrecognized reanalysis '%r'" % reanalysis)

In [None]:
def calculate_indopacific_indices(sst_da, reanalysis, input_file=None,
                                  eofs_output_dir=None, indices_csv_output_dir=None,
                                  indices_nc_output_dir=None,
                                  base_period=BASE_PERIOD,
                                  lat_name=None, lon_name=None, time_name=None):
    """Calculate monthly Indopacific SST indices."""
    
    # Calculate monthly index.
    loadings_ds, monthly_indices_ds = rdi.dc_sst(
        sst_da, frequency='monthly', base_period=base_period,
        lat_name=lat_name, lon_name=lon_name, time_name=time_name)

    if input_file is not None:
        loadings_ds.attrs['input_file'] = input_file
        monthly_indices_ds.attrs['input_file'] = input_file

    # Write EOFs to file.
    base_period_str = get_timespan_string(base_period)
    
    if eofs_output_dir is None:
        eofs_output_dir = get_reanalysis_eofs_nc_dir(reanalysis)
        
    eofs_output_file = '.'.join(
        [reanalysis, base_period_str, 'dc_sst', 'monthly', 'eofs', 'nc'])
    eofs_output_file = os.path.join(eofs_output_dir, eofs_output_file)
        
    loadings_ds.to_netcdf(eofs_output_file)

    output_files = {'eofs_nc': eofs_output_file}

    # Write monthly indices to file.
    for index in ('sst1_index', 'sst2_index'):
        
        monthly_index_ds = monthly_indices_ds[index].to_dataset(name=index)
        
        for attr in monthly_indices_ds.attrs:
            monthly_index_ds.attrs[attr] = monthly_indices_ds.attrs[attr]
        
        # Write monthly index to file.
        monthly_index_nc_output_file = write_index_nc_file(
            reanalysis, index, monthly_index_ds, base_period=base_period,
            frequency='monthly', output_dir=indices_nc_output_dir)
        monthly_index_csv_output_file = write_index_csv_file(
            reanalysis, index, monthly_index_ds, index,
            base_period=base_period,
            frequency='monthly', output_dir=indices_csv_output_dir)
        
        output_files['monthly_{}_nc'.format(index)] = monthly_index_nc_output_file
        output_files['monthly_{}_csv'.format(index)] = monthly_index_csv_output_file
        
    return output_files

In [None]:
def calculate_iod_indices(sst_da, reanalysis, input_file=None,
                          eofs_output_dir=None, indices_csv_output_dir=None,
                          indices_nc_output_dir=None,
                          base_period=BASE_PERIOD,
                          lat_name=None, lon_name=None, time_name=None):
    """Calculate monthly and daily IOD indices."""
    
    # Calculate daily and monthly indices.
    daily_index_da = rdi.dmi(
        sst_da, frequency='daily', base_period=base_period,
        lat_name=lat_name, lon_name=lon_name, time_name=time_name)
    monthly_index_da = rdi.dmi(
        sst_da, frequency='monthly', base_period=base_period,
        lat_name=lat_name, lon_name=lon_name, time_name=time_name)
    weekly_index_da = rdi.dmi(
        sst_da, frequency='weekly', base_period=base_period,
        lat_name=lat_name, lon_name=lon_name, time_name=time_name)
    
    daily_index_ds = daily_index_da.to_dataset(name='dmi')
    for attr in daily_index_da.attrs:
        daily_index_ds.attrs[attr] = daily_index_da.attrs[attr]

    monthly_index_ds = monthly_index_da.to_dataset(name='dmi')
    for attr in monthly_index_da.attrs:
        monthly_index_ds.attrs[attr] = monthly_index_da.attrs[attr]

    weekly_index_ds = weekly_index_da.to_dataset(name='dmi')
    for attr in weekly_index_da.attrs:
        weekly_index_ds.attrs[attr] = weekly_index_da.attrs[attr]

    if input_file is not None:
        daily_index_ds.attrs['input_file'] = input_file
        monthly_index_ds.attrs['input_file'] = input_file
        weekly_index_ds.attrs['input_file'] = input_file

    # Write daily index to file.        
    daily_index_nc_output_file = write_index_nc_file(
        reanalysis, 'dmi', daily_index_ds, base_period=base_period,
        frequency='daily', output_dir=indices_nc_output_dir)
    daily_index_csv_output_file = write_index_csv_file(
        reanalysis, 'dmi', daily_index_ds, 'dmi',
        base_period=base_period,
        frequency='daily', output_dir=indices_csv_output_dir)
    
    # Write monthly index to file.        
    monthly_index_nc_output_file = write_index_nc_file(
        reanalysis, 'dmi', monthly_index_ds, base_period=base_period,
        frequency='monthly', output_dir=indices_nc_output_dir)
    monthly_index_csv_output_file = write_index_csv_file(
        reanalysis, 'dmi', monthly_index_ds, 'dmi',
        base_period=base_period,
        frequency='monthly', output_dir=indices_csv_output_dir)
    
    # Write weekly index to file.        
    weekly_index_nc_output_file = write_index_nc_file(
        reanalysis, 'dmi', weekly_index_ds, base_period=base_period,
        frequency='weekly', output_dir=indices_nc_output_dir)
    weekly_index_csv_output_file = write_index_csv_file(
        reanalysis, 'dmi', weekly_index_ds, 'dmi',
        base_period=base_period,
        frequency='weekly', output_dir=indices_csv_output_dir)

    return {'daily_index_nc': daily_index_nc_output_file,
            'daily_index_csv': daily_index_csv_output_file,
            'monthly_index_nc': monthly_index_nc_output_file,
            'monthly_index_csv': monthly_index_csv_output_file,
            'weekly_index_nc': weekly_index_nc_output_file,
            'weekly_index_csv': weekly_index_csv_output_file}

In [None]:
sst_input_files = {
    'hadisst': os.path.join(get_reanalysis_fields_dir('hadisst'), 'HadISST_sst.nc'),
    'jra55': os.path.join(get_reanalysis_fields_dir('jra55'), 'fcst_surf125.118_brtmp.1958010100_2018123100.daily.nc')
}

for reanalysis in sst_input_files:
    
    print('* Reanalysis: ', reanalysis)

    eofs_output_dir = get_reanalysis_eofs_nc_dir(reanalysis)
    if not os.path.exists(eofs_output_dir):
        os.makedirs(eofs_output_dir)
        
    indices_nc_output_dir = get_reanalysis_indices_nc_dir(reanalysis)
    if not os.path.exists(indices_nc_output_dir):
        os.makedirs(indices_nc_output_dir)
        
    indices_csv_output_dir = get_reanalysis_indices_csv_dir(reanalysis)
    if not os.path.exists(indices_csv_output_dir):
        os.makedirs(indices_csv_output_dir)

    sst_var = get_reanalysis_sst_variable(reanalysis)

    input_file = sst_input_files[reanalysis]
    with xr.open_dataset(input_file) as ds:
        
        print('\t- Preparing data ...', end='')

        if reanalysis not in output_files:
            output_files[reanalysis] = {}

        start_time = time.perf_counter()

        # Extract SST data.
        sst_da = ds[sst_var].squeeze()

        # Restrict to common time-period and ensure fixed missing
        # values, if present.
        time_name = rdu.get_time_name(sst_da)
        sst_da = sst_da.where(sst_da[time_name].dt.year >= START_YEAR,
                              drop=True)
        
        sst_da = propagate_missing_values_through_time(
            sst_da, time_name=time_name)
        
        # Normalize start times.
        input_frequency = rdu.detect_frequency(sst_da, time_name=time_name)
        
        if input_frequency == 'daily':
            sst_da = sst_da.resample({time_name: '1D'}).mean(time_name)
        elif input_frequency == 'monthly':
            sst_da = sst_da.resample({time_name: '1MS'}).mean(time_name)
        else:
            raise RuntimeError('Could not determine input frequency')
        
        end_time = time.perf_counter()
        print(' (time: {:.2f}s)'.format(end_time - start_time))

        print('\t- Indopacific SST ...', end='')
        start_time = time.perf_counter()
        output_files[reanalysis]['DCSST'] = calculate_indopacific_indices(
            sst_da, reanalysis, input_file=input_file)
        end_time = time.perf_counter()
        print(' (time: {:.2f}s)'.format(end_time - start_time))
        
        print('\t- DMI ...', end='')
        start_time = time.perf_counter()
        output_files[reanalysis]['DMI'] = calculate_iod_indices(
            sst_da, reanalysis, input_file=input_file)
        end_time = time.perf_counter()
        print(' (time: {:.2f}s)'.format(end_time - start_time))


### MEI

In [None]:
def get_reanalysis_olr_variable(reanalysis):
    """Get name of reanalysis OLR field."""
    
    if reanalysis == 'jra55':
        return 'ULWRF_GDS0_NTAT_ave3h'

    if reanalysis == 'nnr1':
        return 'ulwrf'
    
    raise ValueError("Unrecognized reanalysis '%r'" % reanalysis)
    

def get_reanalysis_slp_variable(reanalysis):
    """Get name of reanalysis SLP field."""

    if reanalysis == 'jra55':
        return 'PRMSL_GDS0_MSL'

    if reanalysis == 'nnr1':
        return 'slp'
    
    raise ValueError("Unrecognized reanalysis '%r'" % reanalysis)
    
    
def get_reanalysis_usfc_variable(reanalysis):
    """Get name of reanalysis surface u-wind field."""
    
    if reanalysis == 'jra55':
        return 'UGRD_GDS0_HTGL'

    if reanalysis == 'nnr1':
        return 'uwnd'
    
    raise ValueError("Unrecognized reanalysis '%r'" % reanalysis)
    
    
def get_reanalysis_vsfc_variable(reanalysis):
    """Get name of reanalysis surface v-wind field."""
    
    if reanalysis == 'jra55':
        return 'VGRD_GDS0_HTGL'

    if reanalysis == 'nnr1':
        return 'vwnd'
    
    raise ValueError("Unrecognized reanalysis '%r'" % reanalysis)

In [None]:
def normalize_coordinate_names(ds, input_lat_name=None, output_lat_name='lat',
                               input_lon_name=None, output_lon_name='lon',
                               input_time_name=None, output_time_name='time'):
    """Rename coordinates."""
    
    input_lat_name = input_lat_name if input_lat_name is not None else rdu.get_lat_name(ds)
    input_lon_name = input_lon_name if input_lon_name is not None else rdu.get_lon_name(ds)
    input_time_name = input_time_name if input_time_name is not None else rdu.get_time_name(ds)
    
    output_lat_name = output_lat_name if output_lat_name is not None else rdu.get_lat_name(ds)
    output_lon_name = output_lon_name if output_lon_name is not None else rdu.get_lon_name(ds)
    output_time_name = output_time_name if output_time_name is not None else rdu.get_time_name(ds)
    
    coords_to_rename = {}

    if input_lat_name != output_lat_name:
        coords_to_rename[input_lat_name] = output_lat_name
        
    if input_lon_name != output_lon_name:
        coords_to_rename[input_lon_name] = output_lon_name

    if input_time_name != output_time_name:
        coords_to_rename[input_time_name] = output_time_name
        
    if coords_to_rename:
        return ds.rename(coords_to_rename)
    
    return ds

In [None]:
mei_input_files = {
    'jra55': {
        'sst': os.path.join(get_reanalysis_fields_dir('jra55'), 'fcst_surf125.118_brtmp.1958010100_2018123100.daily.2.5x2.5.nc'),
        'olr': os.path.join(get_reanalysis_fields_dir('jra55'), 'jra.55.ulwrf.ntat.1958010100_2018123121.daily.2.5x2.5.nc'),
        'uwnd': os.path.join(get_reanalysis_fields_dir('jra55'), 'anl_surf125.033_ugrd.1958010100_2018123100.daily.2.5x2.5.nc'),
        'vwnd': os.path.join(get_reanalysis_fields_dir('jra55'), 'anl_surf125.034_vgrd.1958010100_2018123100.daily.2.5x2.5.nc'),
        'slp': os.path.join(get_reanalysis_fields_dir('jra55'), 'anl_surf125.002_prmsl.1958010100_2018123100.daily.2.5x2.5.nc'),
    },
    'nnr1': {
        'sst': os.path.join(get_reanalysis_fields_dir('hadisst'), 'HadISST_sst.2.5x2.5.nc'),
        'olr': os.path.join(get_reanalysis_fields_dir('nnr1'), 'nnr1.ulwrf.ntat.gauss.19480101_20200530.2.5x2.5.nc'),
        'uwnd': os.path.join(get_reanalysis_fields_dir('nnr1'), 'nnr1.uwnd.sig995.19480101_20200530.2.5x2.5.nc'),
        'vwnd': os.path.join(get_reanalysis_fields_dir('nnr1'), 'nnr1.vwnd.sig995.19480101_20200530.2.5x2.5.nc'),
        'slp': os.path.join(get_reanalysis_fields_dir('nnr1'), 'nnr1.slp.19480101_20200530.2.5x2.5.nc'),
    }
}

for reanalysis in mei_input_files:
    
    print('* Reanalysis: ', reanalysis)

            
    print('\t- Preparing data ...', end='')

    if reanalysis not in output_files:
        output_files[reanalysis] = {}

    eofs_output_dir = get_reanalysis_eofs_nc_dir(reanalysis)
    if not os.path.exists(eofs_output_dir):
        os.makedirs(eofs_output_dir)
        
    indices_nc_output_dir = get_reanalysis_indices_nc_dir(reanalysis)
    if not os.path.exists(indices_nc_output_dir):
        os.makedirs(indices_nc_output_dir)
        
    indices_csv_output_dir = get_reanalysis_indices_csv_dir(reanalysis)
    if not os.path.exists(indices_csv_output_dir):
        os.makedirs(indices_csv_output_dir)

    sst_var = get_reanalysis_sst_variable(reanalysis)
    slp_var = get_reanalysis_slp_variable(reanalysis)
    uwnd_var = get_reanalysis_usfc_variable(reanalysis)
    vwnd_var = get_reanalysis_vsfc_variable(reanalysis)
    olr_var = get_reanalysis_olr_variable(reanalysis)

    input_datasets = {
        'sst': xr.open_dataset(mei_input_files[reanalysis]['sst'])[sst_var],
        'olr': xr.open_dataset(mei_input_files[reanalysis]['olr'])[olr_var],
        'uwnd': xr.open_dataset(mei_input_files[reanalysis]['uwnd'])[uwnd_var],
        'vwnd': xr.open_dataset(mei_input_files[reanalysis]['vwnd'])[vwnd_var],
        'slp': xr.open_dataset(mei_input_files[reanalysis]['slp'])[slp_var]
    }
    
    input_coords = {
        'sst': {'input_lat_name': rdu.get_lat_name(input_datasets['sst']),
                'input_lon_name': rdu.get_lon_name(input_datasets['sst']),
                'input_time_name': rdu.get_time_name(input_datasets['sst'])},
        'olr': {'input_lat_name': rdu.get_lat_name(input_datasets['olr']),
                'input_lon_name': rdu.get_lon_name(input_datasets['olr']),
                'input_time_name': rdu.get_time_name(input_datasets['olr'])},
        'uwnd': {'input_lat_name': rdu.get_lat_name(input_datasets['uwnd']),
                 'input_lon_name': rdu.get_lon_name(input_datasets['uwnd']),
                 'input_time_name': rdu.get_time_name(input_datasets['uwnd'])},
        'vwnd': {'input_lat_name': rdu.get_lat_name(input_datasets['vwnd']),
                 'input_lon_name': rdu.get_lon_name(input_datasets['vwnd']),
                 'input_time_name': rdu.get_time_name(input_datasets['vwnd'])},
        'slp': {'input_lat_name': rdu.get_lat_name(input_datasets['slp']),
                'input_lon_name': rdu.get_lon_name(input_datasets['slp']),
                'input_time_name': rdu.get_time_name(input_datasets['slp'])}
    }
    
    for field in input_datasets:
        
        input_datasets[field] = normalize_coordinate_names(
            input_datasets[field], **input_coords[field])
        
        # Ensure common time-period.
        input_datasets[field] = input_datasets[field].where(
            input_datasets[field]['time'].dt.year >= START_YEAR,
            drop=True)
        
        # Ensure monthly input.
        input_datasets[field] = input_datasets[field].resample(time='1MS').mean('time')
        
        # Where intermittent missing values are present for a given grid point,
        # fill all times at that point with missing values.
        input_datasets[field] = propagate_missing_values_through_time(
            input_datasets[field], time_name='time')

    end_time = time.perf_counter()
    print(' (time: {:.2f}s)'.format(end_time - start_time))

    print('\t- MEI ...', end='')
    start_time = time.perf_counter()

    daily_mei_ds = rdi.mei(
        input_datasets['slp'],
        input_datasets['uwnd'],
        input_datasets['vwnd'],
        input_datasets['sst'],
        input_datasets['olr'],
        frequency='daily', base_period=BASE_PERIOD)
        
    monthly_mei_ds = rdi.mei(
        input_datasets['slp'],
        input_datasets['uwnd'],
        input_datasets['vwnd'],
        input_datasets['sst'],
        input_datasets['olr'],
        frequency='monthly', base_period=BASE_PERIOD)
        
    # Write EOFs to file.
    base_period_str = get_timespan_string(BASE_PERIOD)

    eofs_output_dir = get_reanalysis_eofs_nc_dir(reanalysis)

    eofs_output_file = '.'.join(
        [reanalysis, base_period_str, 'mei', 'monthly', 'eofs', 'nc'])
    eofs_output_file = os.path.join(eofs_output_dir, eofs_output_file)
    
    loadings_ds = xr.Dataset(
        {'mslp_pattern': monthly_mei_ds['mslp_pattern'],
         'uwnd_pattern': monthly_mei_ds['uwnd_pattern'],
         'vwnd_pattern': monthly_mei_ds['vwnd_pattern'],
         'sst_pattern': monthly_mei_ds['sst_pattern'],
         'olr_pattern': monthly_mei_ds['olr_pattern']})

    for attr in monthly_mei_ds.attrs:
        loadings_ds.attrs[attr] = monthly_mei_ds.attrs[attr]
    for field in input_datasets:
        loadings_ds.attrs['{}_input_file'.format(field)] = mei_input_files[reanalysis][field]

    loadings_ds.to_netcdf(eofs_output_file)

    # Write daily index to file.
    daily_index_ds = daily_mei_ds['index'].to_dataset(name='mei')
    for attr in daily_mei_ds.attrs:
        daily_index_ds.attrs[attr] = daily_mei_ds.attrs[attr]
    for field in input_datasets:
        daily_index_ds.attrs['{}_input_file'.format(field)] = mei_input_files[reanalysis][field]

    daily_index_nc_output_file = write_index_nc_file(
        reanalysis, 'mei', daily_index_ds, base_period=BASE_PERIOD,
        frequency='daily', output_dir=indices_nc_output_dir)
    daily_index_csv_output_file = write_index_csv_file(
        reanalysis, 'mei', daily_index_ds, 'mei',
        base_period=BASE_PERIOD,
        frequency='daily', output_dir=indices_csv_output_dir)
    
    # Write monthly index to file.
    monthly_index_ds = monthly_mei_ds['index'].to_dataset(name='mei')
    for attr in monthly_mei_ds.attrs:
        monthly_index_ds.attrs[attr] = monthly_mei_ds.attrs[attr]
    for field in input_datasets:
        monthly_index_ds.attrs['{}_input_file'.format(field)] = mei_input_files[reanalysis][field]

    monthly_index_nc_output_file = write_index_nc_file(
        reanalysis, 'mei', monthly_index_ds, base_period=BASE_PERIOD,
        frequency='monthly', output_dir=indices_nc_output_dir)
    monthly_index_csv_output_file = write_index_csv_file(
        reanalysis, 'mei', monthly_index_ds, 'mei',
        base_period=BASE_PERIOD,
        frequency='monthly', output_dir=indices_csv_output_dir)
    
    for field in input_datasets:
        input_datasets[field].close()
        
    output_files[reanalysis]['MEI'] = {
        'eofs_nc': eofs_output_file,
        'daily_index_nc': daily_index_nc_output_file,
        'daily_index_csv': daily_index_csv_output_file,
        'monthly_index_nc': monthly_index_nc_output_file,
        'monthly_index_csv': monthly_index_csv_output_file
    }

    end_time = time.perf_counter()
    print(' (time: {:.2f}s)'.format(end_time - start_time))


### MJO

In [None]:
def get_reanalysis_uwnd_variable(reanalysis):
    """Get name of reanalysis u-wind field."""
    
    if reanalysis == 'jra55':
        return 'UGRD_GDS0_ISBL'
    
    if reanalysis == 'nnr1':
        return 'uwnd'
    
    raise ValueError("Unrecognized reanalysis '%r'")

In [None]:
mjo_input_files = {
    'jra55': {
        'olr': os.path.join(get_reanalysis_fields_dir('jra55'), 'jra.55.ulwrf.ntat.1958010100_2018123121.daily.2.5x2.5.nc'),
        'u850': os.path.join(get_reanalysis_fields_dir('jra55'), 'jra.55.ugrd.850.1958010100_2016123118.2.5x2.5.nc'),
        'u200': os.path.join(get_reanalysis_fields_dir('jra55'), 'jra.55.ugrd.250.1958010100_2016123118.2.5x2.5.nc'),
        'enso_index': os.path.join(get_reanalysis_indices_nc_dir('jra55'), 'jra55.19790101_20011230.sst1_index.monthly.nc'),
        'enso_index_name': 'sst1_index'
    },
    'nnr1': {
        'olr': os.path.join(get_reanalysis_fields_dir('nnr1'), 'nnr1.ulwrf.ntat.gauss.19480101_20200530.2.5x2.5.nc'),
        'u850': os.path.join(get_reanalysis_fields_dir('nnr1'), 'nnr1.uwnd.850.19480101_20200530.2.5x2.5.nc'),
        'u200': os.path.join(get_reanalysis_fields_dir('nnr1'), 'nnr1.uwnd.200.19480101_20200530.2.5x2.5.nc'),
        'enso_index': os.path.join(get_reanalysis_indices_nc_dir('hadisst'), 'hadisst.19790101_20011230.sst1_index.monthly.nc'),
        'enso_index_name': 'sst1_index'
    }
}

for reanalysis in mjo_input_files:
    
    print('* Reanalysis: ', reanalysis)

            
    print('\t- Preparing data ...', end='')
    start_time = time.perf_counter()

    if reanalysis not in output_files:
        output_files[reanalysis] = {}

    eofs_output_dir = get_reanalysis_eofs_nc_dir(reanalysis)
    if not os.path.exists(eofs_output_dir):
        os.makedirs(eofs_output_dir)
        
    indices_nc_output_dir = get_reanalysis_indices_nc_dir(reanalysis)
    if not os.path.exists(indices_nc_output_dir):
        os.makedirs(indices_nc_output_dir)
        
    indices_csv_output_dir = get_reanalysis_indices_csv_dir(reanalysis)
    if not os.path.exists(indices_csv_output_dir):
        os.makedirs(indices_csv_output_dir)

    uwnd_var = get_reanalysis_uwnd_variable(reanalysis)
    olr_var = get_reanalysis_olr_variable(reanalysis)

    input_datasets = {
        'olr': xr.open_dataset(mjo_input_files[reanalysis]['olr'])[olr_var],
        'u850': xr.open_dataset(mjo_input_files[reanalysis]['u850'])[uwnd_var],
        'u200': xr.open_dataset(mjo_input_files[reanalysis]['u200'])[uwnd_var]
    }
    
    input_coords = {
        'olr': {'input_lat_name': rdu.get_lat_name(input_datasets['olr']),
                'input_lon_name': rdu.get_lon_name(input_datasets['olr']),
                'input_time_name': rdu.get_time_name(input_datasets['olr'])},
        'u850': {'input_lat_name': rdu.get_lat_name(input_datasets['u850']),
                 'input_lon_name': rdu.get_lon_name(input_datasets['u850']),
                 'input_time_name': rdu.get_time_name(input_datasets['u850'])},
        'u200': {'input_lat_name': rdu.get_lat_name(input_datasets['u200']),
                 'input_lon_name': rdu.get_lon_name(input_datasets['u200']),
                 'input_time_name': rdu.get_time_name(input_datasets['u200'])}
    }
    
    for field in input_datasets:
        
        input_datasets[field] = normalize_coordinate_names(
            input_datasets[field], **input_coords[field])
        
        # Ensure common time-period.
        input_datasets[field] = input_datasets[field].where(
            input_datasets[field]['time'].dt.year >= START_YEAR,
            drop=True)
        
        # Ensure daily input.
        input_datasets[field] = input_datasets[field].resample(time='1D').mean('time')
        
        # Where intermittent missing values are present for a given grid point,
        # fill all times at that point with missing values.
        input_datasets[field] = propagate_missing_values_through_time(
            input_datasets[field], time_name='time')

    enso_index = xr.open_dataset(
            mjo_input_files[reanalysis]['enso_index'])[mjo_input_files[reanalysis]['enso_index_name']]
    
    enso_time_name = rdu.get_time_name(enso_index)
    if enso_time_name != 'time':
        enso_index = enso_index.rename({enso_time_name: 'time'})

    end_time = time.perf_counter()
    print(' (time: {:.2f}s)'.format(end_time - start_time))

    print('\t- RMM ...', end='')
    start_time = time.perf_counter()

    daily_rmm_ds = rdi.wh_rmm(
        input_datasets['olr'],
        input_datasets['u850'],
        input_datasets['u200'],
        enso_index=enso_index,
        base_period=BASE_PERIOD)
    
    # Write EOFs to file.
    base_period_str = get_timespan_string(BASE_PERIOD)

    eofs_output_dir = get_reanalysis_eofs_nc_dir(reanalysis)

    eofs_output_file = '.'.join(
        [reanalysis, base_period_str, 'rmm', 'daily', 'eofs', 'nc'])
    eofs_output_file = os.path.join(eofs_output_dir, eofs_output_file)
    
    loadings_ds = xr.Dataset(
        {'olr_eofs': daily_rmm_ds['olr_eofs'],
         'u850_eofs': daily_rmm_ds['u850_eofs'],
         'u200_eofs': daily_rmm_ds['u200_eofs']})

    for attr in daily_rmm_ds.attrs:
        loadings_ds.attrs[attr] = daily_rmm_ds.attrs[attr]
    for field in input_datasets:
        loadings_ds.attrs['{}_input_file'.format(field)] = mjo_input_files[reanalysis][field]
    loadings_ds.attrs['enso_input_file'] = mjo_input_files[reanalysis]['enso_index']

    loadings_ds.to_netcdf(eofs_output_file)

    output_files[reanalysis]['MJO'] = {'eofs_nc': eofs_output_file}

    for index in ('rmm1', 'rmm2'):

        # Write daily index to file.
        daily_index_ds = daily_rmm_ds[index].to_dataset(name=index)
        for attr in daily_rmm_ds.attrs:
            daily_index_ds.attrs[attr] = daily_rmm_ds.attrs[attr]
        for field in input_datasets:
            daily_index_ds.attrs['{}_input_file'.format(field)] = mjo_input_files[reanalysis][field]
        daily_index_ds.attrs['enso_input_file'] = mjo_input_files[reanalysis]['enso_index']

        daily_index_nc_output_file = write_index_nc_file(
            reanalysis, index, daily_index_ds, base_period=BASE_PERIOD,
            frequency='daily', output_dir=indices_nc_output_dir)
        daily_index_csv_output_file = write_index_csv_file(
            reanalysis, index, daily_index_ds, index,
            base_period=BASE_PERIOD,
            frequency='daily', output_dir=indices_csv_output_dir)
        
        output_files[reanalysis]['MJO']['daily_{}_index_nc'.format(index)] = daily_index_nc_output_file
        output_files[reanalysis]['MJO']['daily_{}_index_csv'.format(index)] = daily_index_csv_output_file

    for field in input_datasets:
        input_datasets[field].close()
        
    enso_index.close()

    end_time = time.perf_counter()
    print(' (time: {:.2f}s)'.format(end_time - start_time))


## Plots

### AO

In [None]:
for reanalysis in output_files:
    
    if 'AO' not in output_files[reanalysis]:
        continue

    ds = xr.open_dataset(output_files[reanalysis]['AO']['eofs_nc'])
    
    lat_name = rdu.get_lat_name(ds)
    lon_name = rdu.get_lon_name(ds)
    
    mode = int(ds.attrs['ao_mode'])
    
    vmin = np.min(ds['EOFs'].sel(mode=mode))
    vmax = np.max(ds['EOFs'].sel(mode=mode))


    if np.abs(vmax) > np.abs(vmin):
        vmin = -np.abs(vmax)
        vmax = np.abs(vmax)
    else:
        vmin = -np.abs(vmin)
        vmax = np.abs(vmin)


    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(nrows=2, ncols=1, height_ratios=[15, 1], hspace=0.05)

    projection = ccrs.Orthographic(central_latitude=90)
    cmap = plt.cm.RdBu_r

    lat = ds[lat_name]
    lon = ds[lon_name]

    pattern = ds['EOFs'].sel(mode=mode).squeeze().data
    explained_var = 100.0 * ds['explained_var'].sel(mode=mode).squeeze().values

    pattern, lon = add_cyclic_point(pattern, coord=lon)
    lon_grid, lat_grid = np.meshgrid(lon, lat)

    ax = fig.add_subplot(gs[0, 0], projection=projection)

    ax.coastlines()
    ax.set_global()

    cs = ax.pcolor(lon_grid, lat_grid, pattern, shading='auto',
                   vmin=vmin, vmax=vmax, cmap=cmap, transform=ccrs.PlateCarree())

    ax.set_title('{} mode {:d} ({:.2f}%)'.format(
        get_reanalysis_full_name(reanalysis), mode + 1, explained_var), y=1.01, fontsize=14)
    ax.set_aspect('equal')

    cb_ax = fig.add_subplot(gs[-1, :])
    cb = fig.colorbar(cs, cax=cb_ax, pad=0.05, orientation='horizontal')

    ds.close()

In [None]:
datasets = {}
for reanalysis in output_files:
    
    if 'AO' not in output_files[reanalysis]:
        continue

    datasets[reanalysis] = xr.open_dataset(output_files[reanalysis]['AO']['daily_index_nc'])
 
    time_name = rdu.get_time_name(datasets[reanalysis])

    if time_name != 'time':
        datasets[reanalysis] = datasets[reanalysis].rename(
            {time_name: 'time'})

indices_to_plot = {get_reanalysis_full_name(reanalysis): datasets[reanalysis]['ao_index'] for reanalysis in datasets}

ref_index_datafile = os.path.join(REF_INDICES_RESULTS_DIR, 'cpc.ao.daily.csv')

if os.path.exists(ref_index_datafile):
    ref_index_ds = read_reference_index_csv(ref_index_datafile)
    indices_to_plot['CPC'] = ref_index_ds['index']

fig = plot_indices_timeseries(indices_to_plot, years_per_row=10)

plt.suptitle('Daily AO index (base period {} - {})'.format(
        pd.to_datetime(BASE_PERIOD[0]).strftime('%Y%m%d'),
        pd.to_datetime(BASE_PERIOD[1]).strftime('%Y%m%d')), fontsize=14, y=0.9)

plt.show()

for reanalysis in datasets:
    datasets[reanalysis].close()

if os.path.exists(ref_index_datafile):
    ref_index_ds.close()

In [None]:
datasets = {}
for reanalysis in output_files:
    
    if 'AO' not in output_files[reanalysis]:
        continue

    datasets[reanalysis] = xr.open_dataset(output_files[reanalysis]['AO']['monthly_index_nc'])
 
    time_name = rdu.get_time_name(datasets[reanalysis])

    if time_name != 'time':
        datasets[reanalysis] = datasets[reanalysis].rename(
            {time_name: 'time'})

indices_to_plot = {get_reanalysis_full_name(reanalysis): datasets[reanalysis]['ao_index'] for reanalysis in datasets}

ref_index_datafile = os.path.join(REF_INDICES_RESULTS_DIR, 'cpc.ao.monthly.csv')

if os.path.exists(ref_index_datafile):
    ref_index_ds = read_reference_index_csv(ref_index_datafile)
    indices_to_plot['CPC'] = ref_index_ds['index']

fig = plot_indices_timeseries(indices_to_plot, years_per_row=10)

plt.suptitle('Monthly AO index (base period {} - {})'.format(
        pd.to_datetime(BASE_PERIOD[0]).strftime('%Y%m%d'),
        pd.to_datetime(BASE_PERIOD[1]).strftime('%Y%m%d')), fontsize=14, y=0.9)

plt.show()

for reanalysis in datasets:
    datasets[reanalysis].close()

if os.path.exists(ref_index_datafile):
    ref_index_ds.close()

### DMI

In [None]:
datasets = {}
for reanalysis in output_files:
    
    if 'DMI' not in output_files[reanalysis]:
        continue

    datasets[reanalysis] = xr.open_dataset(output_files[reanalysis]['DMI']['weekly_index_nc'])
 
    time_name = rdu.get_time_name(datasets[reanalysis])

    if time_name != 'time':
        datasets[reanalysis] = datasets[reanalysis].rename(
            {time_name: 'time'})

indices_to_plot = {get_reanalysis_full_name(reanalysis): datasets[reanalysis]['dmi'] for reanalysis in datasets}

ref_index_datafile = os.path.join(REF_INDICES_RESULTS_DIR, 'bom.dmi.weekly.csv')

if os.path.exists(ref_index_datafile):
    ref_index_ds = read_reference_index_csv(ref_index_datafile)
    indices_to_plot['BOM'] = ref_index_ds['index']

fig = plot_indices_timeseries(indices_to_plot, years_per_row=10)

plt.suptitle('Weekly DMI (base period {} - {})'.format(
        pd.to_datetime(BASE_PERIOD[0]).strftime('%Y%m%d'),
        pd.to_datetime(BASE_PERIOD[1]).strftime('%Y%m%d')), fontsize=14, y=0.9)

plt.show()

for reanalysis in datasets:
    datasets[reanalysis].close()

if os.path.exists(ref_index_datafile):
    ref_index_ds.close()

### MEI

In [None]:
for reanalysis in output_files:
    
    if 'MEI' not in output_files[reanalysis]:
        continue

    ds = xr.open_dataset(output_files[reanalysis]['MEI']['eofs_nc'])
    
    lat_name = rdu.get_lat_name(ds)
    lon_name = rdu.get_lon_name(ds)

    patterns = {'MSLP': 'mslp_pattern', 'SST': 'sst_pattern', 'OLR': 'olr_pattern',
                'u-wind': 'uwnd_pattern', 'v-wind': 'vwnd_pattern'}

    for pattern in patterns:
    
        pattern_da = ds[patterns[pattern]]
        seasons = pattern_da['season'].data
        
        vmin = np.min(pattern_da)
        vmax = np.max(pattern_da)
        
        nrows = 4
        ncols = 3
        fig = plt.figure(figsize=(ncols * 5, nrows * 3))
        height_ratios = nrows * [5] + [1,]

        gs = gridspec.GridSpec(nrows=(nrows + 1), ncols=ncols, height_ratios=height_ratios,
                               wspace=0.15, hspace=0.25)

        projection = ccrs.PlateCarree(central_longitude=180)
        cmap = plt.cm.RdBu_r

        row_index = 0
        col_index = 0
        for season in seasons:

            lat = pattern_da[lat_name]
            lon = pattern_da[lon_name]
            pattern_data = pattern_da.sel(season=season)

            lon_grid, lat_grid = np.meshgrid(lon, lat)

            ax = fig.add_subplot(gs[row_index, col_index], projection=projection)

            ax.coastlines()
            ax.set_extent([100, 290, -30, 30], ccrs.PlateCarree())
            ax.set_aspect('auto')

            ax.set_title('{}'.format(season), fontsize=12)

            cs = ax.contourf(lon_grid, lat_grid, pattern_data, vmin=vmin, vmax=vmax,
                             cmap=cmap, transform=ccrs.PlateCarree())
            
            col_index += 1
            if col_index == ncols:
                col_index = 0
                row_index += 1

        cb_ax = fig.add_subplot(gs[-1, :])
        cb = fig.colorbar(cs, cax=cb_ax, pad=0.05, orientation='horizontal')

        plt.suptitle('{} {} patterns (base period {} - {})'.format(
                get_reanalysis_full_name(reanalysis), pattern,
                pd.to_datetime(BASE_PERIOD[0]).strftime('%Y%m%d'),
                pd.to_datetime(BASE_PERIOD[1]).strftime('%Y%m%d')),
            fontsize=14, y=0.94)

        plt.show()

    ds.close()

In [None]:
datasets = {}
for reanalysis in output_files:
    
    if 'MEI' not in output_files[reanalysis]:
        continue

    datasets[reanalysis] = xr.open_dataset(output_files[reanalysis]['MEI']['daily_index_nc'])
 
    time_name = rdu.get_time_name(datasets[reanalysis])

    if time_name != 'time':
        datasets[reanalysis] = datasets[reanalysis].rename(
            {time_name: 'time'})

indices_to_plot = {get_reanalysis_full_name(reanalysis): datasets[reanalysis]['mei'] for reanalysis in datasets}

fig = plot_indices_timeseries(indices_to_plot, years_per_row=10)

plt.suptitle('Daily MEI (base period {} - {})'.format(
        pd.to_datetime(BASE_PERIOD[0]).strftime('%Y%m%d'),
        pd.to_datetime(BASE_PERIOD[1]).strftime('%Y%m%d')), fontsize=14, y=0.9)

plt.show()

for reanalysis in datasets:
    datasets[reanalysis].close()

In [None]:
datasets = {}
for reanalysis in output_files:
    
    if 'MEI' not in output_files[reanalysis]:
        continue

    datasets[reanalysis] = xr.open_dataset(output_files[reanalysis]['MEI']['monthly_index_nc'])
 
    time_name = rdu.get_time_name(datasets[reanalysis])

    if time_name != 'time':
        datasets[reanalysis] = datasets[reanalysis].rename(
            {time_name: 'time'})

indices_to_plot = {get_reanalysis_full_name(reanalysis): datasets[reanalysis]['mei'] for reanalysis in datasets}

ref_index_datafile = os.path.join(REF_INDICES_RESULTS_DIR, 'esrl.mei.bimonthly.csv')

if os.path.exists(ref_index_datafile):
    ref_index_ds = read_reference_index_csv(ref_index_datafile)
    indices_to_plot['ESRL'] = ref_index_ds['index']

fig = plot_indices_timeseries(indices_to_plot, years_per_row=10)

plt.suptitle('Monthly MEI index (base period {} - {})'.format(
        pd.to_datetime(BASE_PERIOD[0]).strftime('%Y%m%d'),
        pd.to_datetime(BASE_PERIOD[1]).strftime('%Y%m%d')), fontsize=14, y=0.9)

plt.show()

for reanalysis in datasets:
    datasets[reanalysis].close()

if os.path.exists(ref_index_datafile):
    ref_index_ds.close()

### MJO

In [None]:
for reanalysis in output_files:

    if 'MJO' not in output_files[reanalysis]:
        continue

    ds = xr.open_dataset(output_files[reanalysis]['MJO']['eofs_nc'])
    
    lon_name = rdu.get_lon_name(ds)

    fig = plt.figure(figsize=(6, 9))

    gs = gridspec.GridSpec(nrows=2, ncols=1, wspace=0.25, hspace=0.3)

    for mode in range(2):

        ax = fig.add_subplot(gs[mode, 0])
    
        lons = ds[lon_name]
    
        ax.plot(lons, ds['olr_eofs'].sel(mode=mode), 'b-', label='OLR')
        ax.plot(lons, ds['u850_eofs'].sel(mode=mode), 'r--', label='u850')
        ax.plot(lons, ds['u200_eofs'].sel(mode=mode), 'g:', label='u200')

        ax.grid(ls='--', color='gray', alpha=0.5)
        ax.legend()

        ax.tick_params(which='both', labelsize=12)
        ax.set_xlabel('Longitude', fontsize=13)
        ax.set_ylabel('Normalized magnitude', fontsize=13)
        ax.set_title('EOF {:d}'.format(mode + 1), fontsize=13)

    plt.suptitle('{} RMM EOFs (base period {} - {})'.format(
            get_reanalysis_full_name(reanalysis),
            pd.to_datetime(BASE_PERIOD[0]).strftime('%Y%m%d'),
            pd.to_datetime(BASE_PERIOD[1]).strftime('%Y%m%d')),
        fontsize=14, y=0.95)

    plt.show()

    ds.close()

In [None]:
datasets = {}
for reanalysis in output_files:
    
    if 'MJO' not in output_files[reanalysis]:
        continue

    datasets[reanalysis] = xr.open_dataset(output_files[reanalysis]['MJO']['daily_rmm1_index_nc'])
 
    time_name = rdu.get_time_name(datasets[reanalysis])

    if time_name != 'time':
        datasets[reanalysis] = datasets[reanalysis].rename(
            {time_name: 'time'})

indices_to_plot = {get_reanalysis_full_name(reanalysis): datasets[reanalysis]['rmm1'] for reanalysis in datasets}

ref_index_datafile = os.path.join(REF_INDICES_RESULTS_DIR, 'bom.rmm1.daily.csv')

if os.path.exists(ref_index_datafile):
    ref_index_ds = read_reference_index_csv(ref_index_datafile)
    indices_to_plot['BOM'] = ref_index_ds['index']

fig = plot_indices_timeseries(indices_to_plot, years_per_row=10)

plt.suptitle('Daily RMM1 index (base period {} - {})'.format(
        pd.to_datetime(BASE_PERIOD[0]).strftime('%Y%m%d'),
        pd.to_datetime(BASE_PERIOD[1]).strftime('%Y%m%d')), fontsize=14, y=0.9)

plt.show()

for reanalysis in datasets:
    datasets[reanalysis].close()

if os.path.exists(ref_index_datafile):
    ref_index_ds.close()

In [None]:
datasets = {}
for reanalysis in output_files:
    
    if 'MJO' not in output_files[reanalysis]:
        continue

    datasets[reanalysis] = xr.open_dataset(output_files[reanalysis]['MJO']['daily_rmm2_index_nc'])
 
    time_name = rdu.get_time_name(datasets[reanalysis])

    if time_name != 'time':
        datasets[reanalysis] = datasets[reanalysis].rename(
            {time_name: 'time'})

indices_to_plot = {get_reanalysis_full_name(reanalysis): datasets[reanalysis]['rmm2'] for reanalysis in datasets}

ref_index_datafile = os.path.join(REF_INDICES_RESULTS_DIR, 'bom.rmm2.daily.csv')

if os.path.exists(ref_index_datafile):
    ref_index_ds = read_reference_index_csv(ref_index_datafile)
    indices_to_plot['BOM'] = ref_index_ds['index']

fig = plot_indices_timeseries(indices_to_plot, years_per_row=10)

plt.suptitle('Daily RMM2 index (base period {} - {})'.format(
        pd.to_datetime(BASE_PERIOD[0]).strftime('%Y%m%d'),
        pd.to_datetime(BASE_PERIOD[1]).strftime('%Y%m%d')), fontsize=14, y=0.9)

plt.show()

for reanalysis in datasets:
    datasets[reanalysis].close()

if os.path.exists(ref_index_datafile):
    ref_index_ds.close()

### NHTELE

In [None]:
for reanalysis in output_files:
    
    if 'NHTELE' not in output_files[reanalysis]:
        continue

    ds = xr.open_dataset(output_files[reanalysis]['NHTELE']['composites_nc'])
    
    lat_name = rdu.get_lat_name(ds)
    lon_name = rdu.get_lon_name(ds)

    vmin = ds['composites'].min().item()
    vmax = ds['composites'].max().item()
            
    if np.abs(vmax) > np.abs(vmin):
        vmin = -np.abs(vmax)
        vmax = np.abs(vmax)
    else:
        vmin = -np.abs(vmin)
        vmax = np.abs(vmin)

    fig = plt.figure(figsize=(8, 8))
    gs = gridspec.GridSpec(nrows=3, ncols=2, height_ratios=[6, 6, 1], hspace=0.2)

    projection = ccrs.Orthographic(central_latitude=90, central_longitude=0)
    cmap = plt.cm.RdBu_r

    row_index = 0
    col_index = 0
    for cluster in range(4):
        
        lat = ds[lat_name]
        lon = ds[lon_name]

        pattern = ds['composites'].sel(cluster=cluster).squeeze().data
        
        pattern, lon = add_cyclic_point(pattern, coord=lon)
        
        lon_grid, lat_grid = np.meshgrid(lon, lat)
        
        ax = fig.add_subplot(gs[row_index, col_index], projection=projection)
        
        ax.coastlines()
        ax.set_global()

        cs = ax.pcolor(lon_grid, lat_grid, pattern, shading='auto',
                       vmin=vmin, vmax=vmax,
                       cmap=cmap, transform=ccrs.PlateCarree())

        ax.set_title('NHTELE{:d}'.format(cluster + 1))

        ax.set_aspect('equal')
        
        col_index += 1
        if col_index == 2:
            col_index = 0
            row_index += 1

    cb_ax = fig.add_subplot(gs[-1, :])
    cb = fig.colorbar(cs, cax=cb_ax, pad=0.05, orientation='horizontal')

    plt.suptitle(
        'Daily {} NH teleconnection patterns (base period {} - {})'.format(
            get_reanalysis_full_name(reanalysis),
            pd.to_datetime(BASE_PERIOD[0]).strftime('%Y%m%d'),
            pd.to_datetime(BASE_PERIOD[1]).strftime('%Y%m%d')), y=0.95, fontsize=14)

    plt.show()

    ds.close()

In [None]:
datasets = {}
for reanalysis in output_files:
    
    if 'NHTELE' not in output_files[reanalysis]:
        continue

    datasets[reanalysis] = xr.open_dataset(output_files[reanalysis]['NHTELE']['daily_nhtele1_nc'])
 
    time_name = rdu.get_time_name(datasets[reanalysis])

    if time_name != 'time':
        datasets[reanalysis] = datasets[reanalysis].rename(
            {time_name: 'time'})

indices_to_plot = {get_reanalysis_full_name(reanalysis): datasets[reanalysis]['nhtele1'] for reanalysis in datasets}

fig = plot_indices_timeseries(indices_to_plot, years_per_row=10)

plt.suptitle('Daily NHTELE1 index (base period {} - {})'.format(
        pd.to_datetime(BASE_PERIOD[0]).strftime('%Y%m%d'),
        pd.to_datetime(BASE_PERIOD[1]).strftime('%Y%m%d')), fontsize=14, y=0.9)

plt.show()

for reanalysis in datasets:
    datasets[reanalysis].close()

In [None]:
datasets = {}
for reanalysis in output_files:
    
    if 'NHTELE' not in output_files[reanalysis]:
        continue

    datasets[reanalysis] = xr.open_dataset(output_files[reanalysis]['NHTELE']['daily_nhtele2_nc'])
 
    time_name = rdu.get_time_name(datasets[reanalysis])

    if time_name != 'time':
        datasets[reanalysis] = datasets[reanalysis].rename(
            {time_name: 'time'})

indices_to_plot = {get_reanalysis_full_name(reanalysis): datasets[reanalysis]['nhtele2'] for reanalysis in datasets}

fig = plot_indices_timeseries(indices_to_plot, years_per_row=10)

plt.suptitle('Daily NHTELE2 index (base period {} - {})'.format(
        pd.to_datetime(BASE_PERIOD[0]).strftime('%Y%m%d'),
        pd.to_datetime(BASE_PERIOD[1]).strftime('%Y%m%d')), fontsize=14, y=0.9)

plt.show()

for reanalysis in datasets:
    datasets[reanalysis].close()

In [None]:
datasets = {}
for reanalysis in output_files:
    
    if 'NHTELE' not in output_files[reanalysis]:
        continue

    datasets[reanalysis] = xr.open_dataset(output_files[reanalysis]['NHTELE']['daily_nhtele3_nc'])
 
    time_name = rdu.get_time_name(datasets[reanalysis])

    if time_name != 'time':
        datasets[reanalysis] = datasets[reanalysis].rename(
            {time_name: 'time'})

indices_to_plot = {get_reanalysis_full_name(reanalysis): datasets[reanalysis]['nhtele3'] for reanalysis in datasets}

fig = plot_indices_timeseries(indices_to_plot, years_per_row=10)

plt.suptitle('Daily NHTELE3 index (base period {} - {})'.format(
        pd.to_datetime(BASE_PERIOD[0]).strftime('%Y%m%d'),
        pd.to_datetime(BASE_PERIOD[1]).strftime('%Y%m%d')), fontsize=14, y=0.9)

plt.show()

for reanalysis in datasets:
    datasets[reanalysis].close()

In [None]:
datasets = {}
for reanalysis in output_files:

    if 'NHTELE' not in output_files[reanalysis]:
        continue

    datasets[reanalysis] = xr.open_dataset(output_files[reanalysis]['NHTELE']['daily_nhtele4_nc'])
 
    time_name = rdu.get_time_name(datasets[reanalysis])

    if time_name != 'time':
        datasets[reanalysis] = datasets[reanalysis].rename(
            {time_name: 'time'})

indices_to_plot = {get_reanalysis_full_name(reanalysis): datasets[reanalysis]['nhtele4'] for reanalysis in datasets}

fig = plot_indices_timeseries(indices_to_plot, years_per_row=10)

plt.suptitle('Daily NHTELE4 index (base period {} - {})'.format(
        pd.to_datetime(BASE_PERIOD[0]).strftime('%Y%m%d'),
        pd.to_datetime(BASE_PERIOD[1]).strftime('%Y%m%d')), fontsize=14, y=0.9)

plt.show()

for reanalysis in datasets:
    datasets[reanalysis].close()

In [None]:
datasets = {}
for reanalysis in output_files:
    
    if 'NHTELE' not in output_files[reanalysis]:
        continue

    datasets[reanalysis] = xr.open_dataset(output_files[reanalysis]['NHTELE']['monthly_nhtele1_nc'])
 
    time_name = rdu.get_time_name(datasets[reanalysis])

    if time_name != 'time':
        datasets[reanalysis] = datasets[reanalysis].rename(
            {time_name: 'time'})

indices_to_plot = {get_reanalysis_full_name(reanalysis): datasets[reanalysis]['nhtele1'] for reanalysis in datasets}

fig = plot_indices_timeseries(indices_to_plot, years_per_row=10)

plt.suptitle('Monthly NHTELE1 index (base period {} - {})'.format(
        pd.to_datetime(BASE_PERIOD[0]).strftime('%Y%m%d'),
        pd.to_datetime(BASE_PERIOD[1]).strftime('%Y%m%d')), fontsize=14, y=0.9)

plt.show()

for reanalysis in datasets:
    datasets[reanalysis].close()

In [None]:
datasets = {}
for reanalysis in output_files:
    
    if 'NHTELE' not in output_files[reanalysis]:
        continue

    datasets[reanalysis] = xr.open_dataset(output_files[reanalysis]['NHTELE']['monthly_nhtele2_nc'])
 
    time_name = rdu.get_time_name(datasets[reanalysis])

    if time_name != 'time':
        datasets[reanalysis] = datasets[reanalysis].rename(
            {time_name: 'time'})

indices_to_plot = {get_reanalysis_full_name(reanalysis): datasets[reanalysis]['nhtele2'] for reanalysis in datasets}

fig = plot_indices_timeseries(indices_to_plot, years_per_row=10)

plt.suptitle('Monthly NHTELE2 index (base period {} - {})'.format(
        pd.to_datetime(BASE_PERIOD[0]).strftime('%Y%m%d'),
        pd.to_datetime(BASE_PERIOD[1]).strftime('%Y%m%d')), fontsize=14, y=0.9)

plt.show()

for reanalysis in datasets:
    datasets[reanalysis].close()

In [None]:
datasets = {}
for reanalysis in output_files:
    
    if 'NHTELE' not in output_files[reanalysis]:
        continue

    datasets[reanalysis] = xr.open_dataset(output_files[reanalysis]['NHTELE']['monthly_nhtele3_nc'])
 
    time_name = rdu.get_time_name(datasets[reanalysis])

    if time_name != 'time':
        datasets[reanalysis] = datasets[reanalysis].rename(
            {time_name: 'time'})

indices_to_plot = {get_reanalysis_full_name(reanalysis): datasets[reanalysis]['nhtele3'] for reanalysis in datasets}

fig = plot_indices_timeseries(indices_to_plot, years_per_row=10)

plt.suptitle('Monthly NHTELE3 index (base period {} - {})'.format(
        pd.to_datetime(BASE_PERIOD[0]).strftime('%Y%m%d'),
        pd.to_datetime(BASE_PERIOD[1]).strftime('%Y%m%d')), fontsize=14, y=0.9)

plt.show()

for reanalysis in datasets:
    datasets[reanalysis].close()

In [None]:
datasets = {}
for reanalysis in output_files:
    
    if 'NHTELE' not in output_files[reanalysis]:
        continue

    datasets[reanalysis] = xr.open_dataset(output_files[reanalysis]['NHTELE']['monthly_nhtele4_nc'])
 
    time_name = rdu.get_time_name(datasets[reanalysis])

    if time_name != 'time':
        datasets[reanalysis] = datasets[reanalysis].rename(
            {time_name: 'time'})

indices_to_plot = {get_reanalysis_full_name(reanalysis): datasets[reanalysis]['nhtele4'] for reanalysis in datasets}

fig = plot_indices_timeseries(indices_to_plot, years_per_row=10)

plt.suptitle('Monthly NHTELE4 index (base period {} - {})'.format(
        pd.to_datetime(BASE_PERIOD[0]).strftime('%Y%m%d'),
        pd.to_datetime(BASE_PERIOD[1]).strftime('%Y%m%d')), fontsize=14, y=0.9)

plt.show()

for reanalysis in datasets:
    datasets[reanalysis].close()

### PNA

In [None]:
for reanalysis in output_files:
    
    if 'PNA' not in output_files[reanalysis]:
        continue

    ds = xr.open_dataset(output_files[reanalysis]['PNA']['eofs_nc'])
    
    lat_name = rdu.get_lat_name(ds)
    lon_name = rdu.get_lon_name(ds)
    
    mode = int(ds.attrs['pna_mode'])
    
    vmin = np.min(ds['EOFs'].sel(mode=mode))
    vmax = np.max(ds['EOFs'].sel(mode=mode))


    if np.abs(vmax) > np.abs(vmin):
        vmin = -np.abs(vmax)
        vmax = np.abs(vmax)
    else:
        vmin = -np.abs(vmin)
        vmax = np.abs(vmin)


    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(nrows=2, ncols=1, height_ratios=[15, 1], hspace=0.05)

    projection = ccrs.Orthographic(central_latitude=90, central_longitude=280)
    cmap = plt.cm.RdBu_r

    lat = ds[lat_name]
    lon = ds[lon_name]

    pattern = ds['EOFs'].sel(mode=mode).squeeze().data
    explained_var = 100.0 * ds['explained_var'].sel(mode=mode).squeeze().values

    pattern, lon = add_cyclic_point(pattern, coord=lon)
    lon_grid, lat_grid = np.meshgrid(lon, lat)

    ax = fig.add_subplot(gs[0, 0], projection=projection)

    ax.coastlines()
    ax.set_global()

    cs = ax.pcolor(lon_grid, lat_grid, pattern, shading='auto',
                   vmin=vmin, vmax=vmax, cmap=cmap, transform=ccrs.PlateCarree())

    ax.set_title('{} mode {:d} ({:.2f}%)'.format(
        get_reanalysis_full_name(reanalysis), mode + 1, explained_var), y=1.01, fontsize=14)
    ax.set_aspect('equal')

    cb_ax = fig.add_subplot(gs[-1, :])
    cb = fig.colorbar(cs, cax=cb_ax, pad=0.05, orientation='horizontal')

    ds.close()

In [None]:
datasets = {}
for reanalysis in output_files:
    
    if 'PNA' not in output_files[reanalysis]:
        continue

    datasets[reanalysis] = xr.open_dataset(output_files[reanalysis]['PNA']['daily_index_nc'])
 
    time_name = rdu.get_time_name(datasets[reanalysis])

    if time_name != 'time':
        datasets[reanalysis] = datasets[reanalysis].rename(
            {time_name: 'time'})

indices_to_plot = {get_reanalysis_full_name(reanalysis): datasets[reanalysis]['pna_index'] for reanalysis in datasets}

ref_index_datafile = os.path.join(REF_INDICES_RESULTS_DIR, 'cpc.pna.daily.csv')

if os.path.exists(ref_index_datafile):
    ref_index_ds = read_reference_index_csv(ref_index_datafile)
    indices_to_plot['CPC'] = ref_index_ds['index']

fig = plot_indices_timeseries(indices_to_plot, years_per_row=10)

plt.suptitle('Daily PNA index (base period {} - {})'.format(
        pd.to_datetime(BASE_PERIOD[0]).strftime('%Y%m%d'),
        pd.to_datetime(BASE_PERIOD[1]).strftime('%Y%m%d')), fontsize=14, y=0.9)

plt.show()

for reanalysis in datasets:
    datasets[reanalysis].close()

if os.path.exists(ref_index_datafile):
    ref_index_ds.close()

In [None]:
datasets = {}
for reanalysis in output_files:
    
    if 'PNA' not in output_files[reanalysis]:
        continue

    datasets[reanalysis] = xr.open_dataset(output_files[reanalysis]['PNA']['monthly_index_nc'])
 
    time_name = rdu.get_time_name(datasets[reanalysis])

    if time_name != 'time':
        datasets[reanalysis] = datasets[reanalysis].rename(
            {time_name: 'time'})

indices_to_plot = {get_reanalysis_full_name(reanalysis): datasets[reanalysis]['pna_index'] for reanalysis in datasets}

ref_index_datafile = os.path.join(REF_INDICES_RESULTS_DIR, 'cpc.pna.monthly.csv')

if os.path.exists(ref_index_datafile):
    ref_index_ds = read_reference_index_csv(ref_index_datafile)
    indices_to_plot['CPC'] = ref_index_ds['index']

fig = plot_indices_timeseries(indices_to_plot, years_per_row=10)

plt.suptitle('Monthly PNA index (base period {} - {})'.format(
        pd.to_datetime(BASE_PERIOD[0]).strftime('%Y%m%d'),
        pd.to_datetime(BASE_PERIOD[1]).strftime('%Y%m%d')), fontsize=14, y=0.9)

plt.show()

for reanalysis in datasets:
    datasets[reanalysis].close()

if os.path.exists(ref_index_datafile):
    ref_index_ds.close()

### PSA

In [None]:
for reanalysis in output_files:
    
    if 'PSA' not in output_files[reanalysis]:
        continue

    ds = xr.open_dataset(output_files[reanalysis]['PSA']['eofs_nc'])
    
    psa1_mode = int(ds.attrs['psa1_mode'])
    psa2_mode = int(ds.attrs['psa2_mode'])
    
    lat_name = rdu.get_lat_name(ds)
    lon_name = rdu.get_lon_name(ds)

    vmin = min(np.min(ds['EOFs'].sel(mode=psa1_mode)),
               np.min(ds['EOFs'].sel(mode=psa2_mode)))
    vmax = max(np.max(ds['EOFs'].sel(mode=psa1_mode)),
               np.max(ds['EOFs'].sel(mode=psa2_mode)))
    
    if np.abs(vmax) > np.abs(vmin):
        vmin = -np.abs(vmax)
        vmax = np.abs(vmax)
    else:
        vmin = -np.abs(vmin)
        vmax = np.abs(vmin)

    fig = plt.figure(figsize=(6, 8))
    gs = gridspec.GridSpec(nrows=3, ncols=1, height_ratios=[6, 6, 1], hspace=0.2)

    projection = ccrs.PlateCarree(central_longitude=180)
    cmap = plt.cm.RdBu_r
    
    lat = ds[lat_name]
    lon = ds[lon_name]
    
    psa1_pattern = ds['EOFs'].sel(mode=psa1_mode).squeeze().data
    psa2_pattern = ds['EOFs'].sel(mode=psa2_mode).squeeze().data
    
    psa1_explained_var = 100.0 * ds['explained_var'].sel(mode=psa1_mode).squeeze().values
    psa2_explained_var = 100.0 * ds['explained_var'].sel(mode=psa2_mode).squeeze().values

    psa1_pattern, _ = add_cyclic_point(psa1_pattern, coord=lon)
    psa2_pattern, lon = add_cyclic_point(psa2_pattern, coord=lon)

    lon_grid, lat_grid = np.meshgrid(lon, lat)

    ax = fig.add_subplot(gs[0, 0], projection=projection)

    ax.coastlines()

    ax.set_extent([0, 357.5, -90, -20], ccrs.PlateCarree(central_longitude=180))
    ax.set_aspect('auto')

    cs = ax.pcolor(lon_grid, lat_grid, psa1_pattern, shading='auto',
                   vmin=vmin, vmax=vmax,
                   cmap=cmap, transform=ccrs.PlateCarree())

    ax.set_title(
        'Mode {:d} ({:.2f}%)'.format(
            psa1_mode + 1, psa1_explained_var
            ), y=1.01, fontsize=14)

    ax = fig.add_subplot(gs[1, 0], projection=projection)

    ax.coastlines()

    cs = ax.pcolor(lon_grid, lat_grid, psa2_pattern, shading='auto',
                   vmin=vmin, vmax=vmax,
                   cmap=cmap, transform=ccrs.PlateCarree())

    ax.set_extent([0, 357.5, -90, -20], ccrs.PlateCarree(central_longitude=180))
    ax.set_aspect('auto')

    ax.set_title(
        'Mode {:d} ({:.2f}%)'.format(
            psa2_mode + 1, psa2_explained_var
            ), y=1.01, fontsize=14)

    cb_ax = fig.add_subplot(gs[-1, :])
    cb = fig.colorbar(cs, cax=cb_ax, pad=0.05, orientation='horizontal')

    plt.suptitle('{} daily PSA patterns (base period {} - {})'.format(
        get_reanalysis_full_name(reanalysis),
        pd.to_datetime(BASE_PERIOD[0]).strftime('%Y%m%d'),
        pd.to_datetime(BASE_PERIOD[1]).strftime('%Y%m%d')),
                fontsize=14, y=0.98)

    plt.show()

    ds.close()

In [None]:
datasets = {}
for reanalysis in output_files:
    
    if 'PSA' not in output_files[reanalysis]:
        continue

    datasets[reanalysis] = xr.open_dataset(output_files[reanalysis]['PSA']['daily_psa1_index_nc'])
 
    time_name = rdu.get_time_name(datasets[reanalysis])

    if time_name != 'time':
        datasets[reanalysis] = datasets[reanalysis].rename(
            {time_name: 'time'})

indices_to_plot = {get_reanalysis_full_name(reanalysis): datasets[reanalysis]['psa1_index'] for reanalysis in datasets}

fig = plot_indices_timeseries(indices_to_plot, years_per_row=10)

plt.suptitle('Daily PSA1 index (base period {} - {})'.format(
        pd.to_datetime(BASE_PERIOD[0]).strftime('%Y%m%d'),
        pd.to_datetime(BASE_PERIOD[1]).strftime('%Y%m%d')), fontsize=14, y=0.9)

plt.show()

for reanalysis in datasets:
    datasets[reanalysis].close()

In [None]:
datasets = {}
for reanalysis in output_files:
    
    if 'PSA' not in output_files[reanalysis]:
        continue

    datasets[reanalysis] = xr.open_dataset(output_files[reanalysis]['PSA']['daily_psa2_index_nc'])
 
    time_name = rdu.get_time_name(datasets[reanalysis])

    if time_name != 'time':
        datasets[reanalysis] = datasets[reanalysis].rename(
            {time_name: 'time'})

indices_to_plot = {get_reanalysis_full_name(reanalysis): datasets[reanalysis]['psa2_index'] for reanalysis in datasets}

fig = plot_indices_timeseries(indices_to_plot, years_per_row=10)

plt.suptitle('Daily PSA2 index (base period {} - {})'.format(
        pd.to_datetime(BASE_PERIOD[0]).strftime('%Y%m%d'),
        pd.to_datetime(BASE_PERIOD[1]).strftime('%Y%m%d')), fontsize=14, y=0.9)

plt.show()

for reanalysis in datasets:
    datasets[reanalysis].close()

In [None]:
datasets = {}
for reanalysis in output_files:
    
    if 'PSA' not in output_files[reanalysis]:
        continue

    datasets[reanalysis] = xr.open_dataset(output_files[reanalysis]['PSA']['monthly_psa1_index_nc'])
 
    time_name = rdu.get_time_name(datasets[reanalysis])

    if time_name != 'time':
        datasets[reanalysis] = datasets[reanalysis].rename(
            {time_name: 'time'})

indices_to_plot = {get_reanalysis_full_name(reanalysis): datasets[reanalysis]['psa1_index'] for reanalysis in datasets}

fig = plot_indices_timeseries(indices_to_plot, years_per_row=10)

plt.suptitle('Monthly PSA1 index (base period {} - {})'.format(
        pd.to_datetime(BASE_PERIOD[0]).strftime('%Y%m%d'),
        pd.to_datetime(BASE_PERIOD[1]).strftime('%Y%m%d')), fontsize=14, y=0.9)

plt.show()

for reanalysis in datasets:
    datasets[reanalysis].close()

In [None]:
datasets = {}
for reanalysis in output_files:
    
    if 'PSA' not in output_files[reanalysis]:
        continue

    datasets[reanalysis] = xr.open_dataset(output_files[reanalysis]['PSA']['monthly_psa2_index_nc'])
 
    time_name = rdu.get_time_name(datasets[reanalysis])

    if time_name != 'time':
        datasets[reanalysis] = datasets[reanalysis].rename(
            {time_name: 'time'})

indices_to_plot = {get_reanalysis_full_name(reanalysis): datasets[reanalysis]['psa2_index'] for reanalysis in datasets}

fig = plot_indices_timeseries(indices_to_plot, years_per_row=10)

plt.suptitle('Monthly PSA2 index (base period {} - {})'.format(
        pd.to_datetime(BASE_PERIOD[0]).strftime('%Y%m%d'),
        pd.to_datetime(BASE_PERIOD[1]).strftime('%Y%m%d')), fontsize=14, y=0.9)

plt.show()

for reanalysis in datasets:
    datasets[reanalysis].close()

### SAM

In [None]:
for reanalysis in output_files:
    
    if 'SAM' not in output_files[reanalysis]:
        continue

    ds = xr.open_dataset(output_files[reanalysis]['SAM']['eofs_nc'])
    
    lat_name = rdu.get_lat_name(ds)
    lon_name = rdu.get_lon_name(ds)
    
    mode = int(ds.attrs['sam_mode'])
    
    vmin = np.min(ds['EOFs'].sel(mode=mode))
    vmax = np.max(ds['EOFs'].sel(mode=mode))


    if np.abs(vmax) > np.abs(vmin):
        vmin = -np.abs(vmax)
        vmax = np.abs(vmax)
    else:
        vmin = -np.abs(vmin)
        vmax = np.abs(vmin)


    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(nrows=2, ncols=1, height_ratios=[15, 1], hspace=0.05)

    projection = ccrs.Orthographic(central_latitude=-90)
    cmap = plt.cm.RdBu_r

    lat = ds[lat_name]
    lon = ds[lon_name]

    pattern = ds['EOFs'].sel(mode=mode).squeeze().data
    explained_var = 100.0 * ds['explained_var'].sel(mode=mode).squeeze().values

    pattern, lon = add_cyclic_point(pattern, coord=lon)
    lon_grid, lat_grid = np.meshgrid(lon, lat)

    ax = fig.add_subplot(gs[0, 0], projection=projection)

    ax.coastlines()
    ax.set_global()

    cs = ax.pcolor(lon_grid, lat_grid, pattern, shading='auto',
                   vmin=vmin, vmax=vmax, cmap=cmap, transform=ccrs.PlateCarree())

    ax.set_title('{} mode {:d} ({:.2f}%)'.format(
        get_reanalysis_full_name(reanalysis), mode + 1, explained_var), y=1.01, fontsize=14)
    ax.set_aspect('equal')

    cb_ax = fig.add_subplot(gs[-1, :])
    cb = fig.colorbar(cs, cax=cb_ax, pad=0.05, orientation='horizontal')

    ds.close()

In [None]:
datasets = {}
for reanalysis in output_files:
    
    if 'SAM' not in output_files[reanalysis]:
        continue

    datasets[reanalysis] = xr.open_dataset(output_files[reanalysis]['SAM']['daily_index_nc'])
 
    time_name = rdu.get_time_name(datasets[reanalysis])

    if time_name != 'time':
        datasets[reanalysis] = datasets[reanalysis].rename(
            {time_name: 'time'})

indices_to_plot = {get_reanalysis_full_name(reanalysis): datasets[reanalysis]['sam_index'] for reanalysis in datasets}

ref_index_datafile = os.path.join(REF_INDICES_RESULTS_DIR, 'cpc.sam.daily.csv')

if os.path.exists(ref_index_datafile):
    ref_index_ds = read_reference_index_csv(ref_index_datafile)
    indices_to_plot['CPC'] = ref_index_ds['index']

fig = plot_indices_timeseries(indices_to_plot, years_per_row=10)

plt.suptitle('Daily SAM index (base period {} - {})'.format(
        pd.to_datetime(BASE_PERIOD[0]).strftime('%Y%m%d'),
        pd.to_datetime(BASE_PERIOD[1]).strftime('%Y%m%d')), fontsize=14, y=0.9)

plt.show()

for reanalysis in datasets:
    datasets[reanalysis].close()

if os.path.exists(ref_index_datafile):
    ref_index_ds.close()

In [None]:
datasets = {}
for reanalysis in output_files:
    
    if 'SAM' not in output_files[reanalysis]:
        continue

    datasets[reanalysis] = xr.open_dataset(output_files[reanalysis]['SAM']['monthly_index_nc'])
 
    time_name = rdu.get_time_name(datasets[reanalysis])

    if time_name != 'time':
        datasets[reanalysis] = datasets[reanalysis].rename(
            {time_name: 'time'})

indices_to_plot = {get_reanalysis_full_name(reanalysis): datasets[reanalysis]['sam_index'] for reanalysis in datasets}

ref_index_datafile = os.path.join(REF_INDICES_RESULTS_DIR, 'cpc.sam.monthly.csv')

if os.path.exists(ref_index_datafile):
    ref_index_ds = read_reference_index_csv(ref_index_datafile)
    indices_to_plot['CPC'] = ref_index_ds['index']

fig = plot_indices_timeseries(indices_to_plot, years_per_row=10)

plt.suptitle('Monthly SAM index (base period {} - {})'.format(
        pd.to_datetime(BASE_PERIOD[0]).strftime('%Y%m%d'),
        pd.to_datetime(BASE_PERIOD[1]).strftime('%Y%m%d')), fontsize=14, y=0.9)

plt.show()

for reanalysis in datasets:
    datasets[reanalysis].close()

if os.path.exists(ref_index_datafile):
    ref_index_ds.close()