# Cohort Shapley Explanations

In [None]:
# Ensure notebook is being run from base repository directory
try:
    %cd "~/forecast_rodeo_ii"
except Exception as err:
    print(f"Warning: unable to change directory; {repr(err)}")

%load_ext autoreload
%autoreload 2

import os
import pandas as pd
import seaborn as sns
from datetime import datetime 

import numpy as np
import matplotlib.pyplot as plt
import matplotlib

import itertools
import importlib
import subprocess
from itertools import product
from functools import partial

from IPython.display import Markdown, display

import copy
import pdb
import calendar 
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from matplotlib.gridspec import GridSpec

from subseasonal_toolkit.utils.eval_util import get_metric_filename, get_target_dates, score_to_mean_rmse, contest_quarter_start_dates, contest_quarter
from subseasonal_toolkit.models.multillr.stepwise_util import default_stepwise_candidate_predictors
from subseasonal_toolkit.utils.models_util import get_selected_submodel_name
from subseasonal_data.utils import get_measurement_variable
from subseasonal_data import data_loaders
from subseasonal_toolkit.utils.experiments_util import clim_merge, pandas2hdf
from subseasonal_toolkit.utils.general_util import printf, tic, toc, make_directories, set_file_permissions

from cohortshapley import cohortshapley as cs
from cohortshapley import similarity
from cohortshapley import figure
from cohortshapley import varianceshapley as vs

# from viz_util_abc import *
from viz_util_optimistic_abc_plots import *
# (get_plot_params_vertical, color_dic, LinearSegmentedColormap, cmap_name, get_feature_name)

import statsmodels.stats.proportion as proportion

from mpl_toolkits.basemap import Basemap

%matplotlib inline
# Inline figure display: SVG for interactive and PDF for 
# eventual outputting as PDF through nbconvert
%config InlineBackend.figure_formats = ['pdf','svg']
#%matplotlib notebook # interactive plotting

## Specify task and models to explain

In [None]:
#
# Read input arguments from environment variable or specify interactively
#
gt_id = os.environ.get("COMPARE_MODELS_gt_id", "us_precip_1.5x1.5")
horizon = os.environ.get("COMPARE_MODELS_horizon", "34w")
target_dates = os.environ.get("COMPARE_MODELS_target_dates", "std_paper_forecast")
metric = os.environ.get("COMPARE_MODELS_metric", "skill")
task = f"{gt_id}_{horizon}"
task_long = task.replace('us_precip_1.5x1.5_','precipitation').replace('us_tmp2m_1.5x1.5_','temperature').replace('34w', ', weeks 3-4').replace('56w','weeks 5-6')


# This notebook will explain the metrics of model if model2 is None
# and will explain the difference between model and model2 metrics
# if model2 is not None
model = os.environ.get("COMPARE_MODELS_model", "abc_ecmwf")
model2 = os.environ.get("COMPARE_MODELS_model2", "deb_ecmwf")#"None")
if model2 == "None":
    model2 = None 
if model2 is None:
    model_str = model
else:
    model_str = f"{model}-vs-{model2}"
# Prepare figure output directories
bin_fig_dir = os.path.join("subseasonal_toolkit", "viz", "bin_figs")
make_directories(bin_fig_dir)
date_fig_dir = os.path.join("subseasonal_toolkit", "viz", "date_figs")
make_directories(date_fig_dir)

# Identify measurement variable name
measurement_variable = get_measurement_variable(gt_id) # 'tmp2m' or 'precip'
gt_col = measurement_variable
# column name for climatology
clim_col = measurement_variable+"_clim"

#Save figures source data here                                    
fig_dir = os.path.join("subseasonal_toolkit","viz","abc_figures_source_data")
if os.path.isdir(fig_dir) is False:
    make_directories(fig_dir)

## Load outcome to be explained: model metric or model metric difference per target date

In [None]:
metrics = pd.read_hdf(get_metric_filename(model=model, gt_id=gt_id, horizon=horizon, target_dates=target_dates, metric=metric))
metrics = metrics.set_index('start_date')

if model2 is not None:
    metrics2 = pd.read_hdf(get_metric_filename(model=model2, gt_id=gt_id, horizon=horizon, target_dates=target_dates, metric=metric))
    metrics2 = metrics2.set_index('start_date')
    outcome = metrics - metrics2
else:
    outcome = metrics
    
# Discard NA values
outcome = outcome.dropna()
    

## Load explanatory features

In [None]:
# Load continuous features
def continuous_feature_names(gt_id, horizon):
    """Returns a list of continuous feature names for a given gt_id and horizon"""
    #---------------
    # temperature, 3-4 weeks
    if "tmp2m" in gt_id and horizon == "34w":
        feature_names = ['mei_shift45',
                        'sst_anom_2010_1_shift30', 'sst_anom_2010_2_shift30', 'sst_anom_2010_3_shift30',
                        'icec_anom_2010_1_shift30', 'icec_anom_2010_2_shift30', 'icec_anom_2010_3_shift30',
                        'hgt_10_anom_2010_1_shift30', 'hgt_10_anom_2010_2_shift30',
                        'hgt_500_anom_2010_1_shift30', 'hgt_500_anom_2010_2_shift30']
    #---------------
    # temperature, 5-6 weeks
    if "tmp2m" in gt_id and horizon == "56w":
        feature_names = ['mei_shift59', 
                        'sst_anom_2010_1_shift44', 'sst_anom_2010_2_shift44', 'sst_anom_2010_3_shift44',
                        'icec_anom_2010_1_shift44', 'icec_anom_2010_2_shift44', 'icec_anom_2010_3_shift44',
                        'hgt_10_anom_2010_1_shift44', 'hgt_10_anom_2010_2_shift44',
                        'hgt_500_anom_2010_1_shift44', 'hgt_500_anom_2010_2_shift44']
    #---------------
    # precipitation, 3-4 weeks
    if "precip" in gt_id and horizon == "34w":
        feature_names = ['mei_shift45',
                        'sst_anom_2010_1_shift30', 'sst_anom_2010_2_shift30', 'sst_anom_2010_3_shift30',
                        'icec_anom_2010_1_shift30', 'icec_anom_2010_2_shift30', 'icec_anom_2010_3_shift30',
                        'hgt_10_anom_2010_1_shift30', 'hgt_10_anom_2010_2_shift30',
                        'hgt_500_anom_2010_1_shift30', 'hgt_500_anom_2010_2_shift30']
    #---------------
    # precipitation, 5-6 weeks
    if "precip" in gt_id and horizon == "56w":
        feature_names = ['mei_shift59',  
                        'sst_anom_2010_1_shift44', 'sst_anom_2010_2_shift44', 'sst_anom_2010_3_shift44',
                        'icec_anom_2010_1_shift44', 'icec_anom_2010_2_shift44', 'icec_anom_2010_3_shift44',
                        'hgt_10_anom_2010_1_shift44', 'hgt_10_anom_2010_2_shift44',
                        'hgt_500_anom_2010_1_shift44', 'hgt_500_anom_2010_2_shift44']
    
    return feature_names

cols_to_load = ['start_date'] + continuous_feature_names(gt_id, horizon)
file_id = 'date_anom_data'
continuous = data_loaders.load_combined_data(
    file_id, gt_id.replace("_1.5x1.5",""), horizon, 
    columns=cols_to_load).set_index('start_date')

In [None]:
# Load discrete features
def discrete_feature_names(gt_id, horizon):
    """Returns a list of discrete feature names for a given gt_id and horizon"""
    if "tmp2m" in gt_id and horizon == "34w":
        feature_names = ['phase_shift17']
    #---------------
    # temperature, 5-6 weeks
    if "tmp2m" in gt_id and horizon == "56w":
        feature_names = ['phase_shift31']
    #---------------
    # precipitation, 3-4 weeks
    if "precip" in gt_id and horizon == "34w":
        feature_names = ['phase_shift17']
    #---------------
    # precipitation, 5-6 weeks
    if "precip" in gt_id and horizon == "56w":
        feature_names = ['phase_shift31']
    return feature_names

cols_to_load = ['start_date'] + discrete_feature_names(gt_id, horizon)
discrete = data_loaders.load_combined_data(
    'date_data', gt_id.replace("_1.5x1.5",""), horizon, 
    columns=cols_to_load).set_index('start_date')
# Add month feature
discrete["month"] = discrete.index.month

In [None]:
# Merge outcome and feature data to ensure common indices
data = pd.merge(outcome, continuous, how='left', left_index=True, right_index=True)
data = pd.merge(data, discrete, how='left', left_index=True, right_index=True)

# Isolate outcome and feature components as y and X
y = data[metric]
X = data.loc[:, continuous.columns.append(discrete.columns)]

## Assess global variable importance with Shapley effects

- Song et al., "Shapley effects for global sensitivity analysis: Theory and computation" https://epubs.siam.org/doi/abs/10.1137/15M1048070

### Compute Shapley effects (a.k.a. Variance Shapley)

In [None]:
# Form quantile bins for continuous features
quantiles = 10
X_q = X.copy()
for col in continuous.columns:
    X_q[col] = pd.qcut(X[col],q=quantiles)

tic()
vs_values = vs.VarianceShapley(y.values, X_q.values)
# similarity.bins = 10
# vs_values = vs.VarianceShapley(y.values, 
#                                np.concatenate([similarity.binning(X[continuous.columns].values)[0], 
#                                                X[discrete.columns].values], axis=1))
toc()

# Store permutation that places features in decreasing order of their Shapley effect
order=np.argsort(vs_values)[::-1]

## Figure A12.  Variable importance
### Visualize Shapley effects


In [None]:
plt.rcParams.update({'font.size': 10,
                     'font.weight': 'bold',
                     'figure.titlesize' : 12,
                     'figure.titleweight': 'bold',
                     'lines.markersize'  : 10,
                     'xtick.labelsize'  : 10,
                     'ytick.labelsize'  : 10})

### TODO: improve title
title = f"{gt_id} {horizon}"
if model2 is None:
    title += f", {all_model_names[model]}"
else:
    title += f" ({all_model_names[model]} vs. {all_model_names[model2]})"
ylabel = 'Variable importance'
# printf(title)
title = title.replace('_','').replace('1.5x1.5','').replace('us','U.S.').replace('precip',' Precipitation').replace('tmp2m',' Temperature').replace('56w', ', weeks 5-6').replace('34w', ', weeks 3-4').replace('12w', ', weeks 1-2').replace(' ,', ',')
# printf(title)
fig=plt.figure(dpi=300)
ax = plt.bar(X.columns[order],vs_values[order])
plt.title(title, fontdict={'weight': 'bold'})
plt.ylabel(ylabel, fontdict={'weight': 'bold'})
l = [get_feature_name(l) for l in  X.columns[order].values]
l = [l if l.startswith('month') else l[:-6] for l in l]
plt.xticks(ticks=range(len(l)),labels=l)
fig.autofmt_xdate(rotation=45)
plt.tight_layout()
plt.show()

out_file = f"subseasonal_toolkit/viz/shapley_effects_{title.replace(',','').replace(' ','_')}.pdf"
fig.savefig(out_file, bbox_inches='tight')#; fig.savefig(out_file.replace('.pdf','.png'))
printf(f'Saving {out_file}')

In [None]:
#Save Figure source data                                      
fig_filename = os.path.join(fig_dir, "fig_a12-variable_importance.xlsx")    
if os.path.isfile(fig_filename):
    os.remove(fig_filename)
    
df_barplot = pd.DataFrame(columns=["variable", "importance"])
df_barplot["variable"] = X.columns[order]
df_barplot["importance"] = vs_values[order]
task = f"{gt_id} {horizon}"
with pd.ExcelWriter(fig_filename) as writer:  
    df_barplot.to_excel(writer, sheet_name=task, na_rep="NaN") 


## Compute or load Cohort Shapley impacts for local explanation

- Mase et al., "Explaining black box decisions by Shapley cohort refinement" https://arxiv.org/pdf/1911.00467.pdf

In [None]:
# Choose similarity type
# If quantiles is not None, replace continuous values with their quantile bins
# and declare points similar if they match exactly
# If quantiles is None, use similarity.similar_in_distance_cutoff with
# specified similarity.ratio to determine similarity
quantiles = 10
similarity.ratio = 0.1
if quantiles is None:
    sim_func = similarity.similar_in_distance_cutoff
    features = X.values
else:
    sim_func = similarity.similar_in_unity
    X_q = X.copy()
    for col in continuous.columns:
        X_q[col] = pd.qcut(data[col],q=quantiles)
    features = X_q.values

# Initialize Cohort Shapley object
cs_obj = cs.CohortShapley(
    None, sim_func,
    np.arange(len(y)), features, y=y.values, parallel=16) ##mc_num=10000) ##, parallel=4)

# Prepare results directory
if model2 is None:
    model_str = model
else:
    model_str = f"{model}-vs-{model2}"
results_dir = os.path.join("eval", "cohort_shapley", model_str)
make_directories(results_dir)

# Construct results file name
result_file = os.path.join(results_dir,f'{model_str}-{metric}-{gt_id}_{horizon}-{target_dates}')
if quantiles is None:
    result_file += f'-ratio{similarity.ratio}'
else:
    result_file += f'-q{quantiles}'
print(result_file)    

try:
    # Load previously saved results if available
    cs_obj.load(result_file)
    printf(f"Loaded Cohort Shapley results from {result_file}")
except:
    # Otherwise, compute Cohort Shapley from scratch and save to disk
    printf(f"Computing Cohort Shapley results")
    cs_obj.compute_cohort_shapley()
    printf(f"Saving Cohort Shapley results to {result_file}")
    cs_obj.save(result_file)
    # Ensure saved files have full read and write permissions
    set_file_permissions(result_file+".cs.npy", mode=0o666)
    set_file_permissions(result_file+".cs2.npy", mode=0o666)
    
# Store results as dataframe
df_cs = pd.DataFrame(cs_obj.shapley_values, index=X.index, columns=X.columns)

## Visualize feature impact by quantile or categorical bin

In [None]:
# Define plotting parameters
plt.rcParams.update({'font.size': 16,
                     'font.weight': 'bold',
                     'figure.titlesize' : 16,
                     'figure.titleweight': 'bold',
                     'lines.markersize'  : 14,
                     'xtick.labelsize'  : 14,
                     'ytick.labelsize'  : 14})

# Define helper functions and dictionary for plotting
def get_viz_var(feature):
    """Returns the identifier of the visualization variable associated with a given feature"""
    if feature.startswith('mei') or feature.startswith('sst'):
        viz_var = 'us_sst_anom'
    else:
        viz_var = str.split(feature,'_2010')[0]
        if viz_var.startswith('icec') or viz_var.startswith('sst'):
            viz_var = 'us_'+viz_var
        elif viz_var.startswith('hgt'):
            viz_var = 'north_'+viz_var
            
    return 'wide_'+viz_var

# Provide a description of the vizualization variables
mean_viz_var_long = {'wide_us_sst_anom': 'Mean sea surface temperature anomalies',
                    'wide_us_icec_anom': 'Mean sea ice concentration anomalies',
                    'wide_north_hgt_10_anom': 'Mean 10 hPa geopotential height anomalies',
                    'wide_north_hgt_500_anom': 'Mean 500 hPa geopotential height anomalies'
                   }

def lat_lon_mat(data):
    """Converts a series or dataframe with indices of the form '(gt_var, lat, lon)_shift###' 
    into a matrix with rows indexed by lat and columns by lon. 
    Add in rows corresponding to any missing lat values with NaN values.
    """
    # Parse index to extract lat and lon values
    lats = [float(str.split(tup,',')[1]) for tup in data.index]
    # Ensure lons are in [-180,180]
    lons = [(float(str.split(str.split(tup,',')[2],')')[0]) + 180) % 360 - 180
            for tup in data.index]
    # Construct lat lon matrix
    data = pd.DataFrame({'var' : data.values, 'lat' : lats, 'lon' : lons}).set_index(['lat','lon']).squeeze().unstack('lon')
    return data 

def get_impact_levels_errors(feature, cis, num_bins):
    """Returns the center and halflengths of the confidence intervals associated
    with each bin of a given feature"""
    cis = cis.to_frame() 
    ci_centers, ci_halflens = [],[]
    for bin_num in range(num_bins):
        ci = cis.iloc[bin_num][feature]
        ci_center, ci_halflen = (ci[0]+ci[1])/2, (ci[1]-ci[0])/2
        ci_centers += [round(ci_center,2)]
        ci_halflens += [round(ci_halflen,2)]
#         printf(f"Decile {bin_num+1}: Probability of positive impact {ci_center:.2g}" 
#                          " +/- " f"{ci_halflen:.2g}")
    return ci_centers, ci_halflens

def get_high_impact_bins(feature, cis, num_bins):
    """Returns the bins with impact probability estimates inside the confidence interval
    of the highest impact probability estimate"""
    # From each confidence interval, extract point estimate of probability of 
    # positive impact per feature quantile or bin
    impact_levels, errors = get_impact_levels_errors(feature, cis, num_bins)
    
    # Identify the highest impact bins (those within confidence interval of bin
    # with overall highest impact_level)
    impact_max = max(impact_levels)
    errors_max = errors[impact_levels.index(impact_max)]
    high_impact_bins = cis.index[(impact_max-errors_max <= impact_levels) & 
                                 (impact_levels <= impact_max + errors_max)]
    return high_impact_bins

def get_low_impact_bins(feature, cis, num_bins):
    """Returns the bins with impact probability estimates inside the confidence interval
    of the lowest impact probability estimate"""
    # From each confidence interval, extract point estimate of probability of 
    # positive impact per feature quantile or bin
    impact_levels, errors = get_impact_levels_errors(feature, cis, num_bins)
    
    # Identify the highest impact bins (those within confidence interval of bin
    # with overall highest impact_level)
    impact_min = min(impact_levels)
    errors_min = errors[impact_levels.index(impact_min)]
    low_impact_bins = cis.index[(impact_min-errors_min <= impact_levels) & 
                                 (impact_levels <= impact_min + errors_min)]
    return low_impact_bins

# def plot_metric_maps_base(metric_dfs, model_names, gt_ids, horizons, metric, target_dates, mean_metric_df=None, show=True, scale_type="linear", CB_colors_customized=None, CB_minmax=[], zoom=False):
### TODO: add function comment block
def plot_lat_lon_mat_all(viz_df, feature, cis, num_bins, viz_var):
    subplots_num = viz_df.shape[0]
    params =  get_plot_params_vertical(subplots_num=subplots_num)
    nrows, ncols = params['nrows'], params['ncols']

    fig = plt.figure(figsize=(nrows*params['figsize_x'], ncols*params['figsize_y']))
    gs = GridSpec(nrows=nrows-1, ncols=ncols, width_ratios=params['width_ratios']) #, wspace=0.15, hspace=0.15)#, bottom=0.5)
    
    impact_levels, errors = get_impact_levels_errors(feature, cis, num_bins)
    impact_min, impact_max = min(impact_levels), max(impact_levels)
    errors_min, errors_max = errors[impact_levels.index(impact_min)], errors[impact_levels.index(impact_max)]
    cis = cis.to_frame()    
    
    
    for bin_num, xy in enumerate(product(range(nrows), range(ncols))):
        if bin_num >= subplots_num:
            break
        
        
        i = bin_num
        x, y = xy[0], xy[1]
        task = f'{gt_id}_{horizon}'
        
        data_matrix = lat_lon_mat(viz_df.iloc[bin_num])
        if feature.startswith('icec'):
            # Add in rows corresponding to any missing lat values with NaN values
            data_matrix = data_matrix.reindex(
                np.arange(data_matrix.index.min(), data_matrix.index.max()+1), fill_value = np.nan)
            # For icec, NaN and 0 values should be treated identically
            data_matrix[data_matrix.isna()] = 0
        
        # Subsample lats and lons to reduce figure size
        if 'hgt' in viz_var:
            subsample_factor = 1
        elif 'sst' in viz_var:
            subsample_factor = 4
        else:
            subsample_factor = 2
        data_matrix = data_matrix.iloc[::subsample_factor, ::subsample_factor]
        
        ci = cis.iloc[bin_num][feature]
        num_bins = quantiles
        viz_var = viz_var
        
        # Set lats and lons
        lats = data_matrix.index.values
        lons = data_matrix.columns.values
        if 'hgt' in viz_var:
            edge_len = 2.5 * subsample_factor
        elif 'global' in viz_var:
            edge_len = 1.5 * subsample_factor
        else:
            edge_len = 1 * subsample_factor
        lats_edges = np.asarray(list(np.arange(lats[0], lats[-1]+edge_len*2, edge_len))) - edge_len/2
        lons_edges = np.asarray(list(np.arange(lons[0], lons[-1]+edge_len*2, edge_len))) - edge_len/2
        lat_grid, lon_grid = np.meshgrid(lats_edges,lons_edges)

#         if feature.startswith("sst"):
        if 'sst' in viz_var:
            ax = fig.add_subplot(gs[x,y], projection=ccrs.PlateCarree(), aspect="auto")
        else:
            ax = fig.add_subplot(gs[x,y], aspect="auto")
        
        ax.set_facecolor('w')
        ax.axis('off')

#         ax.coastlines(linewidth=0.9, color='gray') 
        if 'sst' in viz_var:
            ax.coastlines(linewidth=0.9, color='gray') 
            land_110m = cfeature.NaturalEarthFeature('physical', 'land', '110m',
                                            edgecolor='face',
                                            facecolor='white')

            ax.add_feature(land_110m, edgecolor='gray')

        gt_var = "tmp2m" if "tmp2m" in gt_id else "precip"

        metric = 'skill'
        scale_type='linear'
        CB_colors_customized=(
            ['white','peachpuff','green','lightskyblue','dodgerblue','blue'] if 'icec' in viz_var
            else ['purple','blue','lightblue','white','pink','yellow','red'])
        if 'icec' in viz_var:
            #CB_minmax = (0, 1) # raw icec
            #cb_skip = 1 #color_dic[(metric, gt_var, horizon)]['cb_skip']
            max_val = 1/4
            CB_minmax = (-max_val, max_val)
            cb_skip = max_val
            CB_colors_customized=['blue','dodgerblue','lightskyblue','white','pink', 'red', 'darkred']
        elif 'sst' in viz_var:
            CB_minmax = (-2, 2)
            cb_skip = 1 #color_dic[(metric, gt_var, horizon)]['cb_skip']   
            CB_colors_customized=['blue','dodgerblue','lightskyblue','white','pink', 'red', 'darkred']
        elif 'global' in viz_var:
            CB_minmax = (-5, 5)#30)
            cb_skip = 1#color_dic[(metric, gt_var, horizon)]['cb_skip']   
            CB_colors_customized=['tan','violet','yellow','green','lightskyblue','dodgerblue','blue']
        elif 'hgt' in viz_var:
            #CB_minmax = (29500, 31250) # raw hgt
            #cb_skip = 500
            max_val = max(np.abs(viz_df.min().min()), viz_df.max().max())/1.25
            CB_minmax = (-max_val, max_val)
            cb_skip = max_val #color_dic[(metric, gt_var, horizon)]['cb_skip']   
            CB_colors_customized=['blue','dodgerblue','lightskyblue','white','pink', 'red', 'darkred']
        else:
            CB_minmax = []

        if CB_minmax == []:
            colorbar_min_value = color_dic[(metric, gt_var, horizon)]['colorbar_min_value'] 
            colorbar_max_value = color_dic[(metric, gt_var, horizon)]['colorbar_max_value'] 
        else:
            colorbar_min_value = CB_minmax[0]
            colorbar_max_value = CB_minmax[1]
# 
        color_map_str = color_dic[(metric, gt_var, horizon)]['color_map_str'] 


        if CB_colors_customized is not None:
            if CB_colors_customized == []:
                cmap = LinearSegmentedColormap.from_list(cmap_name, color_dic[(metric, gt_var, horizon)]['CB_colors'] , N=100)
            else:
                #customized cmap
                cmap = LinearSegmentedColormap.from_list(cmap_name, CB_colors_customized, N=100)
            color_map = matplotlib.cm.get_cmap(cmap)
            if "sst" in viz_var:
                plot = ax.pcolormesh(lon_grid,lat_grid, np.transpose(data_matrix),
                         vmin=colorbar_min_value, vmax=colorbar_max_value,
                         cmap=color_map, rasterized=True)
            elif "icec" in viz_var:
                m = Basemap(projection='npstere',boundinglat=45,lon_0=270,resolution='c', round=True)
                m.drawcoastlines()
                plot = m.pcolor(lon_grid,lat_grid, np.transpose(data_matrix), vmin=colorbar_min_value, 
                                vmax=colorbar_max_value, cmap=color_map, latlon=True, rasterized=True)
            else:
                m = Basemap(projection='npstere',boundinglat=15,lon_0=270,resolution='c', round=True)
                m.drawcoastlines()
                plot = m.pcolor(lon_grid,lat_grid, np.transpose(data_matrix), vmin=colorbar_min_value, 
                                vmax=colorbar_max_value, cmap=color_map, latlon=True, rasterized=True)
        else:
            color_map = matplotlib.cm.get_cmap(color_map_str)      
            if "linear" in scale_type:
                plot = ax.pcolormesh(lon_grid,lat_grid, np.transpose(data_matrix),
                                 vmin=colorbar_min_value, vmax=colorbar_max_value,
                                 cmap=color_map)
            elif "symlognorm" in scale_type:
                plot = ax.pcolormesh(lon_grid,lat_grid, np.transpose(data_matrix),
                                 cmap=color_map, 
                                 norm=colors.SymLogNorm(vmin=colorbar_min_value, vmax=colorbar_max_value, linthresh=0.03, base=10))

        ax.tick_params(axis='both', labelsize=params['fontsize_ticks'])


        ci_center = round((ci[0]+ci[1])/2,2)
        ci_halflen = round((ci[1]-ci[0])/2,2)
        if  impact_min-errors_min <= impact_levels[i] <= impact_min+errors_min:
            ax.set_title(r"$\bf{Decile\ " + str(bin_num+1)  + ":}$" + 
                        f"{str(ci_center).replace('0.','.')}" 
                        " ${\pm}$ " f"{str(ci_halflen).replace('0.','.')}", fontsize = params['fontsize_title'], color='red')
        elif impact_max-errors_max <= impact_levels[i] <= impact_max + errors_max:
            ax.set_title(r"$\bf{Decile\ " + str(bin_num+1)  + ":}$" + 
                        f"{str(ci_center).replace('0.','.')}" 
                        " ${\pm}$ " f"{str(ci_halflen).replace('0.','.')}", fontsize = params['fontsize_title'], color='blue')
        else:
            ax.set_title(r"$\bf{Decile\ " + str(bin_num+1)  + ":}$" + 
                        f"{str(ci_center).replace('0.','.')}" 
                        " ${\pm}$ " f"{str(ci_halflen).replace('0.','.')}", fontsize = params['fontsize_title'])
    
    
        
        if x == 0:
            fig_title = f"Impact of {get_feature_name(feature)[:-6]} on ABC-ECMWF skill for {task_long}"
            fig_title = fig_title.replace('skill for','skill improvement for') if model2 is not None else fig_title
            fig.suptitle(fig_title,
                         fontsize = params['fontsize_suptitle'],fontweight="bold",
                         y=1.04, x=.55)
            fig.subplots_adjust(wspace=0.025, hspace=0.25)
        
        #Add colorbar
        if CB_minmax != []:
            if  (i == ncols):#subplots_num-1):
                #Add colorbar for weeks 3-4 and 5-6
                cb_ax_loc = [0.92, 0.1, 0.01, 0.8] if subplots_num == 10 else [0.2, 0.08, 0.6, 0.02]
                cb_ax = fig.add_axes(cb_ax_loc) 
                if CB_colors_customized is not None:
                    cb = fig.colorbar(plot, cax=cb_ax, cmap=cmap, orientation='vertical')
                else:
                    cb = fig.colorbar(plot, cax=cb_ax, orientation='vertical')
                cb.outline.set_edgecolor('black')
                cb.ax.tick_params(labelsize=params['fontsize_ticks']) 
                cbar_title = mean_viz_var_long[viz_var] #+ f' per decile' #'Skill (%)' if 'skill' in metric else metric
                if metric == 'lat_lon_error':
                    cbar_title = 'model bias (mm)' if 'precip' in gt_id else 'model bias ($^\degree$C)'

                cb.ax.set_ylabel(cbar_title, fontsize=params['fontsize_title'], weight='bold', rotation=270, labelpad=25)
                     
                if "linear" in scale_type:  
                    cb_ticklabels = [f'{tick}' if 'icec' in viz_var else f'{tick:.0f}' 
                                     for tick in np.arange(colorbar_min_value, colorbar_max_value+cb_skip, cb_skip)]
                    cb.set_ticks(np.arange(colorbar_min_value, colorbar_max_value+cb_skip, cb_skip))
                    cb.ax.set_yticklabels(cb_ticklabels, fontsize=params['fontsize_title'], 
                                          weight='bold')
    
    
    #Save figure
    out_file = os.path.join(bin_fig_dir, 
        f'{model_str}-{metric}-{gt_id}_{horizon}-{target_dates}-{feature}-perdecile.pdf')
    plt.savefig(out_file, orientation = 'landscape', bbox_inches='tight')
    plt.close(fig)
    # Ensure saved files have full read and write permissions
    set_file_permissions(out_file, mode=0o666)
    print(f"\nFigure saved: {out_file}\n")  


## Figure 6. Opportunistic ABC
### Generate forecast of opportunity table

In [None]:
# Load data for third model: either raw or debiased
if model2.startswith("raw_"):
    raw_model = model2
    model3 = model2.replace("raw_", "deb_")
    deb_model = model3
else:
    deb_model = model2
    model3 = model2.replace("deb_", "raw_")
    raw_model = model3
metrics3 = pd.read_hdf(get_metric_filename(model=model3, gt_id=gt_id, horizon=horizon, target_dates=target_dates, metric=metric))
metrics3 = metrics3.set_index('start_date')
task = f"{gt_id}_{horizon}"

# Merge outcome data with individual model performances
all_metrics = pd.merge(y, metrics, how="left", left_index=True, right_index=True, 
                       suffixes=('','_'+model))
if model2 is not None:
    all_metrics = pd.merge(all_metrics, metrics2, how="left", left_index=True, right_index=True, 
                           suffixes=('','_'+model2))
if model3 is not None:
    all_metrics = pd.merge(all_metrics, metrics3, how="left", left_index=True, right_index=True, 
                           suffixes=('','_'+model3))    
    
# For each forecast, identify how many explanatory features are in high-impact bins
features = X.columns[order[:]]
num_high_impact = pd.Series(int(0), index=all_metrics.index, dtype=int)

for feature in features:
    # Estimate 95% Wilson confidence intervals for probability of positive impact 
    # for each feature quantile / bin
    cis = (df_cs[feature] > 0).groupby(X_q[feature]).apply(
        lambda x: proportion.proportion_confint(x.sum(), x.size))
    # Identify the highest impact bins
    num_bins = len(cis)
    high_impact_bins = get_high_impact_bins(feature, cis, num_bins)
    # Identify which forecasts have this feature in a high-impact bin
    num_high_impact += X_q[feature].isin(high_impact_bins)

# Construct a table summarizing the forecast of opportunity benefits of
# selectively using ABC in when num_high_impact >= k
opportunity_high = pd.DataFrame(index=np.sort(num_high_impact.unique()))

num_name = "# High-impact features"
perc_name = "% Forecasts using ABC"
op_name = f"Opportunistic ABC overall {metric}"
abc_high_name = f"ABC high-impact {metric}"
deb_high_name = f"Deb. ECMWF high-impact {metric}"
#abc_low_name = f"ABC low-impact {metric}"
#deb_low_name = f"Deb. ECMWF low-impact {metric}"
#deb_overall_name = f"Deb. ECMWF overall {metric}"
for k in opportunity_high.index:
    opportunity_high.loc[k,num_name] = k
    # Store percentage of dates flagged as high impact
    which_rows = (num_high_impact) >= k
    opportunity_high.loc[k,perc_name] = sum(which_rows)/len(num_high_impact)    
    # Store mean model performances 
    opportunity_high.loc[k,abc_high_name] = all_metrics.loc[which_rows,metric+'_'+model].mean()
    opportunity_high.loc[k,deb_high_name] = all_metrics.loc[which_rows,metric+'_'+deb_model].mean()
    opportunity_high.loc[k,op_name] = (all_metrics.loc[which_rows,metric+'_'+model].sum()+
                                                    all_metrics.loc[~which_rows,metric+'_'+deb_model].sum())/all_metrics.shape[0]
    #opportunity_high.loc[k,deb_low_name] = all_metrics.loc[~which_rows,metric+'_'+deb_model].mean()
    #opportunity_high.loc[k,abc_low_name] = all_metrics.loc[~which_rows,metric+'_'+model].mean()
    #opportunity_high.loc[k,deb_overall_name] = all_metrics.loc[:,metric+'_'+deb_model].mean()

print('\033[1m'+f"\nForecasts of opportunity:"+'\033[0m'+
      (f" Mean {metric} of ABC and debiased ECMWF on high-impact target dates"))
#      (f" Mean {metric} of opportunistic ABC versus standard ECMWF bias correction"))
# display(opportunity_high)
opportunity_high_table = opportunity_high.drop(columns=op_name).style.hide_index().set_properties(**{'text-align': 'center'}).format(
    '{:,.2%}'.format, subset=[abc_high_name,deb_high_name]).format(
    '{:,.0f} or more'.format, subset=num_name).format(
    '{:,.0%}'.format, subset=[perc_name])
display(opportunity_high_table)

#Save Figure source data                                      
fig_filename = os.path.join(fig_dir, "fig_6-windows_opportunity_abc.xlsx")    
if os.path.isfile(fig_filename):
    os.remove(fig_filename)
with pd.ExcelWriter(fig_filename) as writer:  
    opportunity_high_table.to_excel(writer, sheet_name=f"table_{task}", na_rep="NaN") 
# # save dataframe in latex table format
# table_to_tex(opportunity_high.mul(100).round(2).astype(str).add(' %'), out_dir, f"table_opportunity", precision=2)


# Form accompanying plot of high-impact skill and overall skill of opportunistic ABC
plt.plot(
    opportunity_high[num_name],
    opportunity_high[abc_high_name],
    label="ABC-ECMWF on high-impact dates",
    color='tab:blue', linestyle='dashed')
plt.plot(
    opportunity_high[num_name],
    opportunity_high[deb_high_name],
    label="Deb. ECMWF on high-impact dates",
    color='tab:red', linestyle='dashdot')
plt.plot(
    opportunity_high[num_name],
    opportunity_high[op_name],
    label="Opportunistic ABC on all dates",
    color='tab:green', linewidth=2)
plt.ylabel("Skill", fontsize=12, weight='bold')
plt.xlabel("Minimum number of high-impact features", fontsize=12, weight='bold')
plt.legend(prop={"size":12})
plt.tight_layout()
# plt.show()


out_file = f"subseasonal_toolkit/viz/opportunity.pdf"
plt.savefig(out_file)#; fig.savefig(out_file.replace('.pdf','.png'))
printf(f'Saving {out_file}')

#save figure source data
with pd.ExcelWriter(fig_filename, mode='a') as writer:  
    opportunity_high.to_excel(writer, sheet_name=f"lineplot_{task}", na_rep="NaN") 


### Visualize continuous features

In [None]:
# Specify continuous explanatory features to visualize
features = [feature for feature in X.columns[order[:]] if feature in continuous.columns]
# print(features)

for feature in ["hgt_500_anom_2010_1_shift30"]: #features:
    printf(f"Visualizing {feature}")
    # Estimate 95% Wilson confidence intervals for probability of positive impact 
    # for each feature quantile / bin
    cis = (df_cs[feature] > 0).groupby(X_q[feature]).apply(
        lambda x: proportion.proportion_confint(x.sum(), x.size))

    # Identify associated visualization variable
    viz_var = get_viz_var(feature)
    shift = int(str.split(str.split(feature,'_')[-1],'shift')[1])

    # Load visualization variable
    tic()
    viz_df = data_loaders.get_ground_truth(
        viz_var, shift=shift).set_index('start_date')
    toc()

    # Restrict to relevant dates
    viz_df = viz_df.loc[X_q.index]

    # Average visualization variable by bin / quantile
    viz_df = viz_df.groupby(X_q[feature]).mean()
#     display(viz_df)
    plot_lat_lon_mat_all(viz_df, feature, cis, quantiles, viz_var)
    
    if feature == "hgt_500_anom_2010_1_shift30":
        fig_filename = os.path.join(fig_dir, "fig_4-impact_hgt_500_pc1.xlsx")    
        if os.path.isfile(fig_filename):
            os.remove(fig_filename)
        with pd.ExcelWriter(fig_filename) as writer:  
            viz_df.T.to_excel(writer, sheet_name=f"binfig_{task}", na_rep="NaN") 

### Visualize MJO impact by phase

In [None]:
plt.rcParams.update({'font.size': 16,
                     'font.weight': 'bold',
                     'figure.titlesize' : 16,
                     'figure.titleweight': 'bold',
                     'lines.markersize'  : 14,
                     'xtick.labelsize'  : 14,
                     'ytick.labelsize'  : 14})

# Specify explanatory feature to visualize
feature = 'phase_shift17'

# Estimate 95% Wilson confidence intervals for probability of positive impact 
# for each feature quantile / bin
cis_mjo = (df_cs[feature] > 0).groupby(X_q[feature]).apply(
    lambda x: proportion.proportion_confint(x.sum(), x.size))
num_bins = len(cis_mjo)

impact_levels, errors = get_impact_levels_errors(feature, cis_mjo, num_bins)
impact_min, impact_max = min(impact_levels), max(impact_levels)
errors_min, errors_max = errors[impact_levels.index(impact_min)], errors[impact_levels.index(impact_max)]

colors = impact_levels
probabilities = [str(impact_level) + u"\u00B1" + str(error) for impact_level, error in zip(impact_levels, errors)]

triangles = {
    "P1": {"x": (0, -2, -2, 0), "y": (0, -2, 0, 0), "prob_position": (-1.75, -0.7), "text_position": (-1.95, -1.65)},
    "P2": {"x": (0, 0, -2, 0), "y": (0, -2, -2, 0), "prob_position": (-1.1, -1.5), "text_position": (-1.65, -1.9)},
    "P3": {"x": (0, 0, 2, 0), "y": (0, -2, -2, 0), "prob_position": (0.25, -1.5), "text_position": (1.5, -1.9)},
    "P4": {"x": (0, 2, 2, 0), "y": (0, -2, 0, 0), "prob_position": (1, -0.7), "text_position": (1.75, -1.65)},
    "P5": {"x": (0, 2, 2, 0), "y": (0, 2, 0, 0), "prob_position": (1, 0.5), "text_position": (1.75, 1.5)},
    "P6": {"x": (0, 0, 2, 0), "y": (0, 2, 2, 0), "prob_position": (0.25, 1.35), "text_position": (1.5, 1.75)},
    "P7": {"x": (0, -2, 0, 0), "y": (0, 2, 2, 0), "prob_position": (-1.1, 1.35), "text_position": (-1.6, 1.75)},
    "P8": {"x": (0, -2, -2, 0), "y": (0, 2, 0, 0), "prob_position": (-1.75, 0.5), "text_position": (-1.95, 1.5)},
}

fig = plt.figure('Triangles')
fig.set_size_inches(6, 6)
fig.patch.set_facecolor('white')

ax = fig.add_subplot()
plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0)

plt.plot([-2, 2], [2, -2], 'black', lw=2, alpha=0.4)
plt.plot([-2, 2], [-2, 2], 'black', lw=2, alpha=0.4)
plt.plot([-2, 2], [0, 0], 'black', lw=2, alpha=0.4)
plt.plot([0, 0], [-2, 2], 'black', lw=2, alpha=0.4)
# plt.plot([0, -2], [0, 2], 'black', lw=2)

for i, id in enumerate(triangles):
    if  impact_min-errors_min <= impact_levels[i] <= impact_min+errors_min:
        plt.fill(triangles[id]["x"], triangles[id]["y"], 'darkorange', alpha=impact_levels[i])
        ax.text(*triangles[id]["prob_position"], probabilities[i], fontsize=12, color="red")
        ax.text(*triangles[id]["text_position"], id, fontsize=12)
    elif impact_max-errors_max <= impact_levels[i] <= impact_max + errors_max:
        plt.fill(triangles[id]["x"], triangles[id]["y"], 'darkorange', alpha=impact_levels[i])
        ax.text(*triangles[id]["prob_position"], probabilities[i], fontsize=12, color="blue")
        ax.text(*triangles[id]["text_position"], id, fontsize=12)
    else:
        plt.fill(triangles[id]["x"], triangles[id]["y"], 'darkorange', alpha=impact_levels[i])#'#1102b3', alpha=impact_levels[i])
        ax.text(*triangles[id]["prob_position"], probabilities[i], fontsize=12)
        ax.text(*triangles[id]["text_position"], id, fontsize=12)
        
ax.text(-0.72, 2.05, "Western Pacific")
ax.text(2.0, -1, " Maritime Continent", rotation=-90)
ax.text(-0.62, -2.16, "Indian Ocean")
ax.text(-2.17, -1.1, " West. Hem. & Africa", rotation=90)

fig_title = f"Impact of {get_feature_name(feature)[:-6]} on ABC-ECMWF skill for {task_long}"
fig_title = fig_title.replace('skill for','skill improvement for') if model2 is not None else fig_title       
plt.title(f'{fig_title}\n', weight='bold', fontsize=14)

plt.xticks([-2, -1, 0, 1, 2])
plt.yticks([-2, -1, 0, 1, 2])

ax.set_xlabel('RMM1', weight='bold')
ax.set_ylabel('RMM2', weight='bold')

#Save figure
out_file = os.path.join(bin_fig_dir, 
    f'{model_str}-mjo-{gt_id}_{horizon}-{target_dates}-{feature}.pdf')
plt.savefig(out_file, orientation = 'landscape', bbox_inches='tight')
# plt.close(fig)
# Ensure saved files have full read and write permissions
set_file_permissions(out_file, mode=0o666)
print(f"\nFigure saved: {out_file}\n")  

#Save Figure source data                                      
fig_filename = os.path.join(fig_dir, "fig_5-impact_mjo_phase.xlsx")    
if os.path.isfile(fig_filename):
    os.remove(fig_filename)
with pd.ExcelWriter(fig_filename) as writer:  
    df_cs[feature].to_excel(writer, sheet_name=f"cs_{task}", na_rep="NaN") 
with pd.ExcelWriter(fig_filename, mode='a') as writer:  
    cis_mjo.to_excel(writer, sheet_name=f"ci_{task}", na_rep="NaN") 

## Visualize individual forecasts on which each feature had the greatest impact

### Identify dates with largest impacts

In [None]:
features = X.columns[order][:]
dates_largest_impact, dates_largest_metric = {}, {}

for feature in features:
    display(Markdown(f"#### {feature}:"))
    
    # Produce confidence interval for the probability of positive impact
    # within each feature quantile or bin
    cis_feature = (df_cs[feature] > 0).groupby(X_q[feature]).apply(
        lambda x: proportion.proportion_confint(x.sum(), x.size))
    num_bins = len(cis_feature)
    
    # Identify the highest impact bins (those within confidence interval of bin
    # with overall highest impact_level)
    high_impact_bins = get_high_impact_bins(feature, cis_feature, num_bins)
    
    # Identify the largest impact amongst all high impact bin forecasts
    largest_impact = -np.inf
    for feature_edges in high_impact_bins:        
        # Identify dates in the highest positive impact probability bin
        if feature in continuous:
            df = data[feature].to_frame()
            dmin, dmax = feature_edges.left, feature_edges.right
            dsel = df[(df[feature]>dmin)&(df[feature]<dmax)]
            label_str = f"between {dmin} and {dmax}"
        else:
            df = data[feature].to_frame()
            dm= feature_edges
            dsel = df[(df[feature]==dm)]
            label_str = f"{dm}"

        # Identify impacts of dates in the highest positive impact probability bin
        dsel_cs = df_cs[feature][df_cs.index.isin(dsel.index)]
        # Find the highest impact date in the highest positive impact probability bin
        dsel_cs = dsel_cs.sort_values(ascending=False).iloc[:10]
        if dsel_cs[0] > largest_impact:
            largest_impact = dsel_cs[0]
            dsel_date = datetime.strftime(dsel_cs.index[0],'%Y%m%d')
            dates_largest_impact[feature] = dsel_date
            printf(f"Date with largest impact ({dsel_cs[0]:.2}), where {feature} is {label_str}: {dsel_date}")

### Generate metrics for dates with largest impact

In [None]:
if False:
    target_ranges = sorted(list(set([d for d in dates_largest_impact.values()] + [d for d in dates_largest_metric.values()])))
    figure_metrics = ['lat_lon_anom']#, 'lat_lon_skill'] 
    model_names='gt '
    for m in ['ecmwf']:#,'cfsv2']:
        if model2 is None:
            model_names += f'abc_{m} '
        elif model2.startswith("raw_"):
            model_names += f'raw_{m} abc_{m} '
        elif model2.startswith("deb_"):
            model_names += f'deb_{m} abc_{m} '
    cmd = 'rm -f jobs_metrics.out'
    subprocess.call(cmd, shell=True)
    for figure_metric in figure_metrics:
        print(figure_metric)
        for target_range in target_ranges:
            print(target_range)
            cmd = f"python src/eval/bulk_batch_metrics.py -mn {model_names} -t {target_range} -m {figure_metric} >> jobs_metrics.out"
            print(cmd)
            subprocess.call(cmd, shell=True)

### Visualize anomalies for dates with largest impact

In [None]:
# Define helper functions and dictionary for plotting

# Provide a description of the vizualization variable
viz_var_long = {'wide_us_sst_anom': 'Lagged SST anomalies',
                'wide_us_icec_anom': 'Lagged sea ice concentration anomalies',
                'wide_north_hgt_10_anom': 'Lagged 10 hPa HGT anomalies',# 'geopotential height anomalies',
                'wide_north_hgt_500_anom': 'Lagged 500 hPa HGT anomalies' #\n' 'geopotential height anomalies'
               }


###TODO: add comment block
def plot_metric_maps_trio(metric_dfs, model_names, gt_ids, horizons,
                          data_matrix, viz_df, viz_var, 
                          metric, target_dates, mean_metric_df=None, show=True, 
                          scale_type="linear", CB_colors_customized=None, CB_minmax=[], zoom=False,
                          feature='mei_shift45', bin_str="decile 1"):
    
    # Save original settings
    CB_colors_customized_or = CB_colors_customized
    CB_minmax_or = CB_minmax
    metric_or = metric
    
    # Format target date
    target_dates_objs = get_target_dates(target_dates)
    target_dates_str = datetime.strftime(target_dates_objs[0], '%Y-%m-%d')
    
    #Make figure with compared models plots
    tasks = [f"{t[0]}_{t[1]}" for t in product(gt_ids, horizons)]
    subplots_num = 1 + (len(model_names) * len(tasks))
#     params =  get_plot_params_vertical(subplots_num=subplots_num)
    params =  get_plot_params_horizontal(subplots_num=subplots_num)
    params['fontsize_title'] += 4
    params['fontsize_ticks'] += 4
    params['y_sup_fontsize'] += 4
    nrows, ncols = params['nrows'], params['ncols']
    

    #Set properties common to all subplots
    fig = plt.figure(figsize=(nrows*params['figsize_x'], ncols*params['figsize_y']))
    gs = GridSpec(nrows=nrows-1, ncols=ncols, width_ratios=params['width_ratios']) #, wspace=0.15, hspace=0.15)#, bottom=0.5)

    
    
# SUBPLOT 1 *******************************************************************************************
        
    # Subsample lats and lons to reduce figure size
    if 'hgt' in viz_var:
        subsample_factor = 1
    elif 'sst' in viz_var:
        subsample_factor = 4
    else:
        subsample_factor = 2
    data_matrix = data_matrix.iloc[::subsample_factor, ::subsample_factor]

    # Set lats and lons
    lats = data_matrix.index.values
    lons = data_matrix.columns.values
    if 'hgt' in viz_var:
        edge_len = 2.5 * subsample_factor
    elif 'global' in viz_var:
        edge_len = 1.5 * subsample_factor
    else:
        edge_len = 1 * subsample_factor
    lats_edges = np.asarray(list(np.arange(lats[0], lats[-1]+edge_len*2, edge_len))) - edge_len/2
    lons_edges = np.asarray(list(np.arange(lons[0], lons[-1]+edge_len*2, edge_len))) - edge_len/2
    lat_grid, lon_grid = np.meshgrid(lats_edges,lons_edges)
    i=0
    gt_id, horizon = gt_ids[i], horizons[i]
    task = f'{gt_id}_{horizon}'
    x, y = 0, 0

    if 'sst' in viz_var:
        ax = fig.add_subplot(gs[x,y], projection=ccrs.PlateCarree(), aspect="auto")
    else:
        ax = fig.add_subplot(gs[x,y], aspect="auto")
#     ax = fig.add_subplot(gs[x,y], aspect="auto")
    ax.set_facecolor('w')
    ax.axis('off')

    if 'sst' in viz_var:
        ax.coastlines(linewidth=0.9, color='gray') 
        ax.add_feature(cfeature.STATES.with_scale('110m'), edgecolor='gray', linewidth=0.9, linestyle=':')
        land_110m = cfeature.NaturalEarthFeature('physical', 'land', '110m',
                                        edgecolor='face',
                                        facecolor='white')#cfeature.COLORS['land'])

        ax.add_feature(land_110m, edgecolor='gray')

    gt_var = "tmp2m" if "tmp2m" in gt_id else "precip"

    metric = 'skill'
    scale_type='linear'
    CB_colors_customized=(
        ['white','peachpuff','green','lightskyblue','dodgerblue','blue'] if 'icec' in viz_var
        else ['purple','blue','lightblue','white','pink','yellow','red'])
    if 'icec' in viz_var:
        #CB_minmax = (0, 1) # raw icec
        #cb_skip = 1 #color_dic[(metric, gt_var, horizon)]['cb_skip']
        max_val = 1/4
        CB_minmax = (-max_val, max_val)
        cb_skip = max_val
        CB_colors_customized=['blue','dodgerblue','lightskyblue','white','pink', 'red', 'darkred']
    elif 'sst' in viz_var:
        CB_minmax = (-2, 2)
        cb_skip = 1 #color_dic[(metric, gt_var, horizon)]['cb_skip']   
        CB_colors_customized=['blue','dodgerblue','lightskyblue','white','pink', 'red', 'darkred']
    elif 'global' in viz_var:
        CB_minmax = (-5, 5)#30)
        cb_skip = 1#color_dic[(metric, gt_var, horizon)]['cb_skip']   
        CB_colors_customized=['tan','violet','yellow','green','lightskyblue','dodgerblue','blue']
    elif 'hgt' in viz_var:
        #CB_minmax = (29500, 31250) # raw hgt
        #cb_skip = 500
        max_val = max(np.abs(viz_df.min().min()), viz_df.max().max())/1.25
        CB_minmax = (-max_val, max_val)
        cb_skip = max_val #color_dic[(metric, gt_var, horizon)]['cb_skip']   
        CB_colors_customized=['blue','dodgerblue','lightskyblue','white','pink', 'red', 'darkred']
    else:
        CB_minmax = []

    if CB_minmax == []:
        colorbar_min_value = color_dic[(metric, gt_var, horizon)]['colorbar_min_value'] 
        colorbar_max_value = color_dic[(metric, gt_var, horizon)]['colorbar_max_value'] 
    else:
        colorbar_min_value = CB_minmax[0]
        colorbar_max_value = CB_minmax[1]

    color_map_str = color_dic[(metric, gt_var, horizon)]['color_map_str'] 


    if CB_colors_customized is not None:
        
        if CB_colors_customized == []:
            cmap = LinearSegmentedColormap.from_list(cmap_name, color_dic[(metric, gt_var, horizon)]['CB_colors'] , N=100)
        else:
            #customized cmap
            cmap = LinearSegmentedColormap.from_list(cmap_name, CB_colors_customized, N=100)
        color_map = matplotlib.cm.get_cmap(cmap) 
#         plot = ax.pcolormesh(lon_grid,lat_grid, np.transpose(data_matrix),
#                      vmin=colorbar_min_value, vmax=colorbar_max_value,
#                      cmap=color_map)
        if 'sst' in viz_var: 
            plot = ax.pcolormesh(lon_grid,lat_grid, np.transpose(data_matrix), 
                                 vmin=colorbar_min_value, vmax=colorbar_max_value,
                                 cmap=color_map, rasterized=True)
        elif 'icec' in viz_var:
            m = Basemap(projection='npstere',boundinglat=45,lon_0=270,resolution='c', round=True)
            m.drawcoastlines()
            plot = m.pcolor(lon_grid,lat_grid, np.transpose(data_matrix),
                             vmin=colorbar_min_value, vmax=colorbar_max_value,
                             cmap=color_map, latlon=True, rasterized=True)
        else:
            m = Basemap(projection='npstere',boundinglat=15,lon_0=270,resolution='c', round=True)
            m.drawcoastlines()
            plot = m.pcolor(lon_grid,lat_grid, np.transpose(data_matrix),
                             vmin=colorbar_min_value, vmax=colorbar_max_value,
                             cmap=color_map, latlon=True, rasterized=True)
    else:
        color_map = matplotlib.cm.get_cmap(color_map_str)      
        if "linear" in scale_type:
            plot = ax.pcolormesh(lon_grid,lat_grid, np.transpose(data_matrix),
                             vmin=colorbar_min_value, vmax=colorbar_max_value,
                             cmap=color_map)
        elif "symlognorm" in scale_type:
            plot = ax.pcolormesh(lon_grid,lat_grid, np.transpose(data_matrix),
                             cmap=color_map, 
                             norm=colors.SymLogNorm(vmin=colorbar_min_value, vmax=colorbar_max_value, linthresh=0.03, base=10))

    ax.tick_params(axis='both', labelsize=params['fontsize_ticks'])
#     feature_name = get_feature_name(feature)
#     if 'mei' in feature:
#         shift = int(str.split(str.split(feature,'_')[-1],'shift')[1])
#         feature_df = data_loaders.get_ground_truth(f'mei', shift=shift).set_index('start_date')
#         feature_mean = feature_df.loc[target_dates_str][feature]
#     else:
#         ###TODO: this shouldn't be the mean of the data matrix but rather the scalar value of the feature 
#         ###(e.g., the single eof value on this date)
#         feature_mean = str(round(data_matrix.mean().mean(),2))[:5]
#     ylabel = f"{feature_name} = {feature_mean}"
#     ax.text(-0.05, 0.50, ylabel, va='bottom', ha='center',
#                     rotation='vertical', rotation_mode='anchor',
#                     transform=ax.transAxes, fontsize = params['fontsize_title'], fontweight="bold")
    # Plot title below figure
    ax.set_title(viz_var_long[viz_var], fontsize = params['fontsize_title'],fontweight="bold",
                 y=-0.2, 
                 x=.45 if 'icec' in viz_var else .5)

    #Add colorbar
    cb_shift = 0.02 if ncols == 4 else 0
    if CB_minmax != []:
        if  (i == 0):#subplots_num-1):
            #Add colorbar for weeks 3-4 and 5-6
            # first coordinate moves colorbar right as it increases
            # second coordinate moves colorbar up as it increases
            # third coordinate determines colorbar width
            cb_ax_loc = [0.105, 0.16, 0.007, 0.7]
            cb_ax = fig.add_axes(cb_ax_loc) 
            if CB_colors_customized is not None:
                cb = fig.colorbar(plot, cax=cb_ax, cmap=cmap, orientation='vertical')
            else:
                cb = fig.colorbar(plot, cax=cb_ax, orientation='vertical')
#             cb_ax = fig.add_axes([0.14-cb_shift, 0.06, 0.18, 0.04]) #fig.add_axes([0.2, 0.08, 0.6, 0.02])
#             if CB_colors_customized is not None:
#                 cb = fig.colorbar(plot, cax=cb_ax, cmap=cmap, orientation='horizontal')
#             else:
#                 cb = fig.colorbar(plot, cax=cb_ax, orientation='horizontal')
            cb.outline.set_edgecolor('black')
            cb.ax.tick_params(labelsize=params['fontsize_ticks']) 
            cbar_title = viz_var_long[viz_var] #+ f' per decile' #'Skill (%)' if 'skill' in metric else metric
            if metric == 'lat_lon_error':
                cbar_title = 'model bias (mm)' if 'precip' in gt_id else 'model bias ($^\degree$C)'

            ###cb.ax.set_xlabel(cbar_title, fontsize=params['fontsize_title'], weight='bold')
            if "linear" in scale_type:
                #cb_skip = 1#color_dic[(metric, gt_var, horizon)]['cb_skip']
                cb_ticklabels = [f'{tick}' if 'icec' in viz_var else f'{tick:.0f}' 
                                 for tick in np.arange(colorbar_min_value, colorbar_max_value+cb_skip, cb_skip)]
                cb.set_ticks(np.arange(colorbar_min_value, colorbar_max_value+cb_skip, cb_skip))
                cb.ax.set_yticklabels(cb_ticklabels, fontsize=params['fontsize_title'], weight='bold')
                cb.ax.yaxis.set_ticks_position('left')

# SUBPLOT 2 and 3 *******************************************************************************************
    # Get original settings
    CB_colors_customized = CB_colors_customized_or
    CB_minmax = CB_minmax_or
    metric = metric_or

    # Create latitude, longitude list, model data is not yet used
    df_models = metric_dfs[tasks[0]][metric]
    df_models, model_names = format_df_models(df_models, model_names)
    data_matrix = pivot_model_output(df_models, model_name=model_names[0])


    # Get grid edges for each latitude, longitude coordinate
    if '1.5' in tasks[0]:
        lats = np.linspace(25.5, 48, data_matrix.shape[0])
        lons = np.linspace(-123, -67.5, data_matrix.shape[1])
    elif 'us' in tasks[0]:
        lats = np.linspace(27, 49, data_matrix.shape[0])
        lons = np.linspace(-124, -68, data_matrix.shape[1])
    elif 'contest' in tasks[0]:
        lats = np.linspace(27, 49, data_matrix.shape[0])
        lons = np.linspace(-124, -94, data_matrix.shape[1])

    if '1.5' in tasks[0]:
        lats_edges = np.asarray(list(np.arange(lats[0], lats[-1]+1.5, 1.5))) - 0.75
        lons_edges = np.asarray(list(np.arange(lons[0], lons[-1]+1.5, 1.5))) - 0.75
        lat_grid, lon_grid = np.meshgrid(lats_edges,lons_edges)
    else:
        lats_edges = np.asarray(list(range(int(lats[0]), (int(lats[-1])+1)+1))) - 0.5
        lons_edges = np.asarray(list(range(int(lons[0]), (int(lons[-1])+1)+1))) - 0.5
        lat_grid, lon_grid = np.meshgrid(lats_edges,lons_edges)
    

    for i, xy in enumerate(product(model_names, tasks)):
        if i >= subplots_num:
            break

        model_name, task = xy[0], xy[1]
        x, y = tasks.index(task), model_names.index(model_name)+1
        
        ax = fig.add_subplot(gs[x,y], projection=ccrs.PlateCarree(), aspect="auto")
        ax.set_facecolor('w')
        ax.axis('off')
        
        df_models = metric_dfs[task][metric]
        if 'skill' in metric:
            df_models =df_models.apply(lambda x: x*100)      
        df_models, model_names = format_df_models(df_models, model_names)  


        data_matrix = pivot_model_output(df_models, model_name=model_name)
        ax.coastlines(linewidth=0.9, color='gray') 
        ax.add_feature(cfeature.STATES.with_scale('110m'), edgecolor='gray', linewidth=0.9, linestyle=':')      

        # Set color parameters
        gt_id, horizon = task[:-4], task[-3:]
        gt_var = "tmp2m" if "tmp2m" in gt_id else "precip" #gt_id.split("_")[-1]
        if CB_minmax == []:
            colorbar_min_value = color_dic[(metric, gt_var, horizon)]['colorbar_min_value'] 
            colorbar_max_value = color_dic[(metric, gt_var, horizon)]['colorbar_max_value'] 
        else:
            colorbar_min_value = CB_minmax[0]
            colorbar_max_value = CB_minmax[1]
        
        color_map_str = color_dic[(metric, gt_var, horizon)]['color_map_str'] 
        
        
        if CB_colors_customized is not None:
            if CB_colors_customized == []:
                cmap = LinearSegmentedColormap.from_list(cmap_name, color_dic[(metric, gt_var, horizon)]['CB_colors'] , N=100)
            else:
                #customized cmap
                cmap = LinearSegmentedColormap.from_list(cmap_name, CB_colors_customized, N=100)
            color_map = matplotlib.cm.get_cmap(cmap) 
            plot = ax.pcolormesh(lon_grid,lat_grid, np.transpose(data_matrix),
                         vmin=colorbar_min_value, vmax=colorbar_max_value,
                         cmap=color_map, rasterized=True)
        else:
            color_map = matplotlib.cm.get_cmap(color_map_str)      
            if "linear" in scale_type:
                plot = ax.pcolormesh(lon_grid,lat_grid, np.transpose(data_matrix),
                                 vmin=colorbar_min_value, vmax=colorbar_max_value,
                                 cmap=color_map, rasterized=True)
            elif "symlognorm" in scale_type:
                plot = ax.pcolormesh(lon_grid,lat_grid, np.transpose(data_matrix),
                                 cmap=color_map, 
                                 norm=colors.SymLogNorm(vmin=colorbar_min_value, 
                                                        vmax=colorbar_max_value, linthresh=0.03, base=10),
                                 rasterized=True)
    
        
        ax.tick_params(axis='both', labelsize=params['fontsize_ticks'])
   
        if mean_metric_df is not None:
            df_mean_metric = mean_metric_df
            mean_metric = '' if model_name =='gt' else int(df_mean_metric[model_name].mean())
        elif metric == 'lat_lon_anom' and 'lat_lon_skill' in metric_dfs[task].keys():
            df_mean_metric = metric_dfs[task]['lat_lon_skill'].apply(lambda x: x*100)
            df_mean_metric, model_names = format_df_models(df_mean_metric, model_names)
            mean_metric = int(df_mean_metric[model_name].mean())
        else:
            df_mean_metric = df_models
            mean_metric = int(df_mean_metric[model_name].mean())
            
            
        mean_metric_title = f"{mean_metric}%" if 'skill' in metric else str(mean_metric)
        if x == 0 and y==0:
#             ax.set_title(f"{mean_metric_title}", fontsize = params['fontsize_title'],fontweight="bold")
            #ax.set_ylabel(all_model_names[model_name], fontsize = params['fontsize_title'],fontweight="bold")
            ax.text(0.005, 0.55, all_model_names[model_name], va='bottom', ha='center',
                    rotation='vertical', rotation_mode='anchor',
                    transform=ax.transAxes, fontsize = params['fontsize_title'], fontweight="bold")
        elif x == 0 and y==1:
#             ax.set_title(f"Skill: {mean_metric_title}%", fontsize = params['fontsize_title'],fontweight="bold")
            #ax.set_ylabel(all_model_names[model_name], fontsize = params['fontsize_title'],fontweight="bold")
            ax.text(0.005, 0.55, all_model_names[model_name], va='bottom', ha='center',
                    rotation='vertical', rotation_mode='anchor',
                    transform=ax.transAxes, fontsize = params['fontsize_title'], fontweight="bold")
        elif x == 0 and y>1:
            ax.set_title(f"Skill: {mean_metric_title}%", fontsize = params['fontsize_title'],fontweight="bold")
            #ax.set_ylabel(all_model_names[model_name], fontsize = params['fontsize_title'],fontweight="bold")
            ax.text(0.005, 0.55, all_model_names[model_name], va='bottom', ha='center',
                    rotation='vertical', rotation_mode='anchor',
                    transform=ax.transAxes, fontsize = params['fontsize_title'], fontweight="bold")
        elif y>=1:
            ax.set_title(f"{mean_metric_title}", fontsize = params['fontsize_title'],fontweight="bold")
       

        #'''
        #Add colorbar
        
        if CB_minmax != []:
            if  i == 0:#subplots_num-1:                
                #Add colorbar for weeks 3-4 and 5-6
                cb_ax = fig.add_axes([0.45-cb_shift, 0.06, 0.4, 0.04]) #fig.add_axes([0.2, 0.08, 0.6, 0.02])
                if CB_colors_customized is not None:
                    cb = fig.colorbar(plot, cax=cb_ax, cmap=cmap, orientation='horizontal')
                else:
                    cb = fig.colorbar(plot, cax=cb_ax, orientation='horizontal')
                cb.outline.set_edgecolor('black')
                cb.ax.tick_params(labelsize=params['fontsize_ticks']) 
                if metric == 'lat_lon_error':
                    cbar_title = 'model bias (mm)' if 'precip' in gt_id else 'model bias ($^\degree$C)'
                elif metric == 'lat_lon_anom':
                    cbar_title = f"{gt_var.replace('precip','Precipitation').replace('tmp2m','Temperature')} anomalies"
                elif 'skill' in metric:
                    cbar_title = 'Skill (%)'
                else:
                    cbar_title = metric
                cb.ax.set_xlabel(cbar_title, fontsize=params['fontsize_title'], weight='bold')
                if "linear" in scale_type:
                    cb_skip = color_dic[(metric, gt_var, horizon)]['cb_skip']   
                    cb_ticklabels = [f'{tick}' for tick in range(colorbar_min_value, colorbar_max_value+cb_skip, cb_skip)]
                    cb.set_ticks(range(colorbar_min_value, colorbar_max_value+cb_skip, cb_skip))
                    cb.ax.set_xticklabels(cb_ticklabels, fontsize=params['fontsize_title'], weight='bold')  
                    
    fig_title = f"Forecast with largest {get_feature_name(feature)[:-6]} impact in {bin_str}: {target_dates_str}"
    #set figure superior title
    fig.suptitle(fig_title, fontsize=params['y_sup_fontsize'], y=params['y_sup_title'])
        
    #Save figure
    model_names_str = '-'.join(model_names)
    out_file = os.path.join(date_fig_dir, f"{metric}_{target_dates}_{gt_id}_n{subplots_num}_{model_names_str}_zoom{zoom}_{feature}.pdf") 
    plt.savefig(out_file, orientation = 'landscape', bbox_inches='tight')
#     plt.savefig(out_file.replace('.pdf','.png'), orientation = 'landscape', bbox_inches='tight')
    subprocess.call("chmod a+w "+out_file, shell=True)
    subprocess.call("chown $USER:sched_mit_hill "+out_file, shell=True)
    print(f"\nFigure saved: {out_file}\n")
    if not show:
        fig.clear()
        plt.close(fig)  
        
    return fig


# Figure parameter values
figure_gt_ids = [gt_id]#us_1_5_gt_ids
figure_horizons = [horizon]#horizons
figure_metrics = ['lat_lon_anom']#, 'lat_lon_skill']
figure_mean_metric_df = pd.DataFrame()#None
figure_show = False
figure_models = ["gt", model] if model2 is None else ['gt', model2, model]
figure_zoom = False


features = [feature for feature in X.columns[order[:]] if feature in continuous.columns]

for feature in ["hgt_500_anom_2010_1_shift30"]:#features[1:6]:
    # Identify the largest impact forecast date
    figure_target_dates = dates_largest_impact[feature]
    # Skip over discrete features
    if feature.startswith('phase_shift') or feature.startswith('month'):
        display(Markdown(f"### {feature}, {figure_target_dates}: SKIPPING."))
        continue
    display(Markdown(f"### {feature}, {figure_target_dates}:"))
    
    # Compute impact level (i.e., the probability of positive impact in 
    # the feature bin associated with this forecast date) and the associated
    # decile
    cis_feature = (df_cs[feature] > 0).groupby(X_q[feature]).apply(
        lambda x: proportion.proportion_confint(x.sum(), x.size))
    num_bins = len(cis_feature)
    
    # Prepare string summary of high impact bins
    high_impact_bins = get_high_impact_bins(feature, cis_feature, num_bins)
    if len(high_impact_bins) == 1:
        bin_str = f"decile {high_impact_bins.categories.get_loc(high_impact_bins[0])+1}"
    else:
        bin_str = f"deciles " + ", ".join(
            [str(high_impact_bins.categories.get_loc(b)+1) for b in high_impact_bins])

    
    # Store metric for each model
    figure_mean_metric_df = pd.DataFrame()
    figure_mean_metric_df[model] = metrics.loc[metrics.index == datetime.strptime(figure_target_dates, '%Y%m%d'), metric].values
    if model2 is not None:
        figure_mean_metric_df[model2] = metrics2.loc[metrics2.index == datetime.strptime(figure_target_dates, '%Y%m%d'), metric].values
    if metric == 'skill':
        # Convert to a percentage
        figure_mean_metric_df = figure_mean_metric_df.apply(lambda x: x*100)
    #RDA: models for which Raw, Debiased and Abc versions are available
    metric_dfs_rda = {}
    for fig_gt_id, fig_horizon in product(figure_gt_ids, figure_horizons):
        fig_task = f"{fig_gt_id}_{fig_horizon}"
        display(Markdown(f"#### Getting metrics for {fig_gt_id} {fig_horizon}"))
        metric_dfs_rda[fig_task] = get_models_metric_lat_lon(gt_id=fig_gt_id, horizon=fig_horizon, 
                                                         target_dates=figure_target_dates, 
                                                         metrics = figure_metrics, model_names=figure_models)
    
    
    figure_target_date = dates_largest_impact[feature]


    # Identify associated visualization variable
    viz_var = get_viz_var(feature)
    shift = int(str.split(str.split(feature,'_')[-1],'shift')[1])

    # Load visualization variable
    tic()
    viz_df = data_loaders.get_ground_truth(
        viz_var, shift=shift).set_index('start_date')
    toc()


    # Restrict to relevant dates
    target_date_obj = get_target_dates(figure_target_date,'%Y%m%d')[0]
    target_date_ind = datetime.strftime(target_date_obj,'%Y-%m-%d')#'2019-04-09'
    data_matrix = viz_df.loc[target_date_ind].to_frame().T
    data_matrix.index.names = [feature]

    #plot single lat lon mat
    bin_num = 0
    data_matrix = lat_lon_mat(data_matrix.iloc[bin_num])
    if feature.startswith('icec'):
        # Add in rows corresponding to any missing lat values with NaN values
        data_matrix = data_matrix.reindex(
            np.arange(data_matrix.index.min(), data_matrix.index.max()+1), fill_value = np.nan)
        # For icec, NaN and 0 values should be treated identically
        data_matrix[data_matrix.isna()] = 0

    # Also provide access to mean anomalies per bin / quantile to set colorbar range
    viz_df = viz_df.loc[X_q.index]
    viz_df = viz_df.groupby(X_q[feature]).mean()    
    
    plot_metric_maps_trio(metric_dfs_rda, 
                          model_names=figure_models,
                          gt_ids=figure_gt_ids,
                          horizons=figure_horizons,
                          data_matrix=data_matrix,
                          viz_df=viz_df,
                          viz_var=viz_var,
                          metric='lat_lon_anom',
                          target_dates=figure_target_dates,
                          mean_metric_df=figure_mean_metric_df,
                          show=figure_show,
                          scale_type="linear", 
                          CB_colors_customized=['orangered','darkorange',"white",'forestgreen','darkgreen'], #['saddlebrown','peru',"white",'yellowgreen','green'], 
                          CB_minmax=(-20, 20), 
                          zoom=figure_zoom,
                          feature=feature, 
                          bin_str=bin_str)
    if feature == "hgt_500_anom_2010_1_shift30":
        # Save figure source data
        fig_filename = os.path.join(fig_dir, "fig_4-impact_hgt_500_pc1.xlsx")    

        if os.path.isfile(fig_filename):
            with pd.ExcelWriter(fig_filename, mode='a') as writer:  
                data_matrix.to_excel(writer, sheet_name=f"high_impact_bin_{task}", na_rep="NaN") 
                metric_dfs_rda[task]['lat_lon_anom'].to_excel(writer, sheet_name=f"anom_{task}", na_rep="NaN") 
        else:
            with pd.ExcelWriter(fig_filename) as writer:  
                data_matrix.to_excel(writer, sheet_name=f"high_impact_bin_{task}", na_rep="NaN") 
                metric_dfs_rda[task]['lat_lon_anom'].to_excel(writer, sheet_name=f"anom_{task}", na_rep="NaN") 


In [None]:
# MJO - Visualize anomalies for dates with largest impact 
# Figure parameter values
figure_gt_ids = [gt_id]#us_1_5_gt_ids
figure_horizons = [horizon]#horizons
figure_metrics = ['lat_lon_anom']#, 'lat_lon_skill']
figure_mean_metric_df = pd.DataFrame()#None
figure_show = False
figure_models = ["gt", model] if model2 is None else ['gt', model2, model]
figure_zoom = False


features = [feature for feature in X.columns[order[:]]]

for feature in [f for f in features if f.startswith('phase')][:1]:
    figure_target_dates = dates_largest_impact[feature]
    printf(f'{feature} - {figure_target_dates}:')
    
    cis_feature = (df_cs[feature] > 0).groupby(X_q[feature]).apply(
    lambda x: proportion.proportion_confint(x.sum(), x.size))
    num_bins = len(cis_feature)
    
    # Prepare string summary of high impact bins
    high_impact_bins = get_high_impact_bins(feature, cis_feature, num_bins)
    if len(high_impact_bins) == 1:
        bin_str = f"phase {int(high_impact_bins[0])}"
    else:
        bin_str = f"phases " + ", ".join([str(int(b)) for b in high_impact_bins])
    
    figure_mean_metric_df = pd.DataFrame()
    if model2 is not None:
        figure_mean_metric_df[model2] = metrics2[metrics2.index == datetime.strptime(figure_target_dates, '%Y%m%d')].skill.values
        figure_mean_metric_df[model] = metrics[metrics.index == datetime.strptime(figure_target_dates, '%Y%m%d')].skill.values
    else:
        figure_mean_metric_df[model] = outcome[outcome.index == datetime.strptime(figure_target_dates, '%Y%m%d')].skill.values
    figure_mean_metric_df = figure_mean_metric_df.apply(lambda x: x*100)
    #RDA: models for which Raw, Debiased and Abc versions are available
    metric_dfs_rda = {}
    for gt_id, horizon in product(figure_gt_ids, figure_horizons):
        task = f"{gt_id}_{horizon}"
        display(Markdown(f"#### Getting metrics for {gt_id} {horizon}"))
        metric_dfs_rda[task] = get_models_metric_lat_lon(gt_id=gt_id, horizon=horizon, 
                                                         target_dates=figure_target_dates, 
                                                         metrics = figure_metrics, model_names=figure_models)
#     display(metric_dfs_rda)
    plot_metric_maps(metric_dfs_rda, 
                         model_names=figure_models,
                         gt_ids=figure_gt_ids,
                         horizons=figure_horizons,
                         metric='lat_lon_anom',
                         target_dates=figure_target_dates,
                         mean_metric_df=figure_mean_metric_df,
                         show=figure_show, 
                         scale_type='linear',
                         CB_colors_customized=['orangered','darkorange',"white",'forestgreen','darkgreen'], #['saddlebrown','peru',"white",'yellowgreen','green'],# "#dede00", "#ff7f00", "blueviolet", "indigo", "yellowgreen", "lightgreen", "darkgreen"],
                         CB_minmax = (-20, 20),
                         zoom = figure_zoom,
                        feature = feature,
                        bin_str = bin_str)
    
    # Save figure source data
    fig_filename = os.path.join(fig_dir, "fig_5-impact_mjo_phase.xlsx")    
    if os.path.isfile(fig_filename):
        with pd.ExcelWriter(fig_filename, mode='a') as writer:  
            metric_dfs_rda[task]['lat_lon_anom'].to_excel(writer, sheet_name=f"anom_{task}", na_rep="NaN") 
    else:
        with pd.ExcelWriter(fig_filename) as writer:  
            metric_dfs_rda[task]['lat_lon_anom'].to_excel(writer, sheet_name=f"anom_{task}", na_rep="NaN") 

# Old / obsolete code

### Visualize relationship between variables and their impact

In [None]:
if False:
    # Choose title and y-axis label
    title = f"{gt_id} {horizon}"
    if model2 is None:
        title += f", {model}"
        ylabel = f'Variable impact on {metric}'
    else:
        title += f" ({model} vs. {model2})"
        ylabel = f'Variable impact on {metric} difference'

    from scipy import stats

    for col in X.columns[order][:5]:
        if False:
            # Scatterplots with kernel density shading
            fig = plt.figure(dpi=300)
            if col in continuous:
                ###sns.violinplot(x=pd.qcut(data[col],q=10), y=df_cs[col], scale='width')
                #sns.scatterplot(x=data[col], y=df_cs[col])
                #sns.kdeplot(x=data[col], y=df_cs[col], fill=True, alpha=0.6, cut=2, cmap="viridis")
                values = np.vstack([data[col], df_cs[col]])
                kernel = stats.gaussian_kde(values)(values)
                sns.scatterplot(
                    x=data[col], y=df_cs[col],
                    c=kernel,
                    cmap="viridis",
                )
            else:
                ###sns.violinplot(x=data[col], y=df_cs[col], scale='width')
                #sns.kdeplot(x=data[col], y=df_cs[col], fill=True, alpha=0.6, cut=2, cmap="viridis")
                values = np.vstack([data[col], df_cs[col]])
                kernel = stats.gaussian_kde(values)(values)
                sns.scatterplot(
                    x=data[col], y=df_cs[col],
                    c=kernel,
                    cmap="viridis",
                )
            fig.autofmt_xdate(rotation=45)
            plt.title(title)
            plt.ylabel(ylabel)
            plt.tight_layout()
            plt.show()

        if False:
            # For continuous features, plot Cohort Shapley distribution conditional
            # on quantile bins
            # For discrete features, plot Cohort Shapley distribution conditional
            # on feature value
            fig = plt.figure(dpi=300)
            if col in continuous:
                sns.violinplot(x=pd.qcut(data[col],q=10), y=df_cs[col], scale='width')
            else:
                sns.violinplot(x=data[col], y=df_cs[col], scale='width')
            fig.autofmt_xdate(rotation=45)


        fig = plt.figure(dpi=300)
        if col in continuous:

            # Display estimated positive impact probability for each variable's quantiles
            # with 95% bootstrap confidence intervals 
            impact_name = "Positive impact probability"
            ax = sns.catplot(kind="bar",data = pd.DataFrame({col: pd.qcut(data[col],q=10), impact_name:df_cs[col]>0}), 
                        x = col, y = impact_name, estimator = np.mean, alpha=.7, 
                        edgecolor=".2",
                        #edgecolor=sns.color_palette()[1], 
                        #errcolor=sns.color_palette()[1],
                        color=sns.color_palette()[1])
            xlabels = [f'{int(1+i)}' for i in range(10)]
            ax.set(xticks=range(10),xticklabels=xlabels)
            plt.ylim(0,1)
            plt.xlabel(f"{col} decile", fontdict={'weight': 'bold'})
            plt.ylabel(impact_name, fontdict={'weight': 'bold'})
            plt.xticks(ticks=range(10), labels=xlabels, weight= 'bold')
    #         plt.xticks(rotation=90)
            out_file = f"subseasonal_toolkit/viz/variable_impact_catplot_{title.replace(',','').replace(' ','_')}_{col}.pdf"
            ax.fig.savefig(out_file, dpi=300, bbox_inches = "tight")#; ax.fig.savefig(out_file.replace('.pdf','.png'), dpi=300, bbox_inches = "tight")
            printf(f'Saving figure {out_file}')

        else:

            # Display estimated positive impact probability for each variable
            # with 95% bootstrap confidence intervals 
            impact_name = "Positive impact probability"
            ax = sns.catplot(kind="bar",data = pd.DataFrame({col: data[col], impact_name:df_cs[col]>0}), 
                        x = col, y = impact_name, estimator = np.mean, alpha=.7, 
                        edgecolor=".2",
                        #edgecolor=sns.color_palette()[1], 
                        #errcolor=sns.color_palette()[1],
                        color=sns.color_palette()[1])
    #         labels = [f'Decile {1+i}' for i in range(12)]
    #         ax.set(xticks=range(12),xticklabels=labels)
            ax.set(xticks=range(8),xticklabels=[f'{int(1+i)}' for i in range(8)])
            plt.ylim(0,1)
            plt.xlabel(f"{col} decile", fontdict={'weight': 'bold'})
            plt.ylabel(impact_name, fontdict={'weight': 'bold'})
            out_file = f"subseasonal_toolkit/viz/variable_impact_catplot_{title.replace(',','').replace(' ','_')}_{col}.pdf"
            ax.fig.savefig(out_file, dpi=300, bbox_inches = "tight")#; ax.fig.savefig(out_file.replace('.pdf','.png'), dpi=300, bbox_inches = "tight")
            printf(f'Saving figure {out_file}')
    #         plt.xticks(rotation=90)
        plt.tight_layout()
        plt.show()

