# ABC Plots

In [None]:
""" 
Generates ABC result figures for 

Adaptive Bias Correction for Improved Subseasonal Forecasting

Soukayna Mouatadid, Paulo Orenstein, Genevieve Flaspohler, 
Judah Cohen, Miruna Oprescu, Ernest Fraenkel, and Lester Mackey. 
"""
# Ensure notebook is being run from base repository directory
import os, sys
try:
    os.chdir(os.path.join("..",".."))
except Exception as err:
    print(f"Warning: unable to change directory; {repr(err)}")
    
%load_ext autoreload
%autoreload 2
%matplotlib inline    
    
import itertools
import importlib
import subprocess
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from itertools import product
from functools import partial
from datetime import datetime
from IPython.display import Markdown, display

import os
import copy
import pdb
import calendar 
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib   
from matplotlib.gridspec import GridSpec
from matplotlib.colors import LinearSegmentedColormap

from subseasonal_toolkit.utils.experiments_util import pandas2hdf, get_measurement_variable, clim_merge, pandas2hdf
from subseasonal_toolkit.utils.general_util import printf, make_directories, tic, toc, set_file_permissions
from subseasonal_toolkit.utils.eval_util import (get_target_dates, score_to_mean_rmse, contest_quarter_start_dates, contest_quarter,
                                                 year_quarter, get_task_metrics_dir, save_metric, get_metric_filename)
from subseasonal_toolkit.utils.models_util import get_selected_submodel_name
from subseasonal_data.data_loaders import get_ground_truth, get_climatology
from subseasonal_data.utils import get_measurement_variable
from subseasonal_data import data_loaders
from subseasonal_toolkit.models.multillr.stepwise_util import default_stepwise_candidate_predictors

from cohortshapley import cohortshapley as cs
from cohortshapley import similarity
from cohortshapley import figure
from cohortshapley import varianceshapley as vs

from viz_util_abc import *

import statsmodels.stats.proportion as proportion

from mpl_toolkits.basemap import Basemap

%matplotlib inline
# Inline figure display: SVG for interactive and PDF for 
# eventual outputting as PDF through nbconvert
%config InlineBackend.figure_formats = ['pdf','svg']
#%matplotlib notebook # interactive plotting

# set figure and font sizes for seaborn plots
sns.set(rc={'figure.figsize':(8,6)}, font_scale=1)

#
# Directory for saving output
#
out_dir = os.path.join("viz", "abc", "figures")


In [None]:
# 
# Full set of regions, times, and tasks to evaluate
#
metrics = ["rmse", "skill"]
us_1_5_gt_ids = ["us_tmp2m_1.5x1.5", "us_precip_1.5x1.5"]

# All ground truth ids
gt_ids = us_1_5_gt_ids

horizons = ["12w", "34w", "56w"]
target_eval_dates = ["std_paper_forecast"]

# The full set of models we to evaluate in some experiment 
# De-biasing experiment model names
debias_experiment_models = [
    # Baselines
    "raw_ccsm4", 
    "raw_cfsv2",
    "raw_geos_v2p1",
    "raw_nesm",
    "raw_fimr1p1",
    "raw_gefs",
    "raw_gem",
    "raw_ecmwf",
    # Ensembles 
    "abc_ccsm4",
    "abc_cfsv2",
    "abc_geos_v2p1",
    "abc_nesm",
    "abc_fimr1p1",
    "abc_gefs",
    "abc_gem", 
    "abc_ecmwf",
]

#
# Dictionaries mapping all model names and tasks to their display names
#
us_tasks = {
    "us_tmp2m_1.5x1.5_34w": "Temp. weeks 3-4",
    "us_tmp2m_1.5x1.5_56w": "Temp. weeks 5-6",
    "us_precip_1.5x1.5_34w": "Precip. weeks 3-4",
    "us_precip_1.5x1.5_56w": "Precip. weeks 5-6",   
}

## Read in all metrics for all tasks and all models
Reads metrics, generates a summary of missing data, and produces the `all_metrics` dictionary to be used in further analysis. 

In [None]:
"""
Generate a dictionary with metric values for all models and every combination of gt_id, 
horizon, and target dates
"""

all_metrics = {}

# Get metrics for all experiments
for metric, gt_id, horizon, target_dates in \
        [x for x in product(metrics, gt_ids, horizons, target_eval_dates)]:  
    
    #Set model names   
    model_names = debias_experiment_models
    model_names_str = 'all_models'
        
    # Get task
    task = f"{gt_id}_{horizon}"

    #display(Markdown(f"### Loading metric {metric} for task {task} and dates {target_dates}"))
    display(Markdown(f"### {model_names_str}: {metric}, {task}, {target_dates}"))

    # Get all metrics
    df = get_metrics_df(gt_id, horizon, metric, target_dates, model_names=model_names)
    # No models exist for this task    
    if df is None: 
        continue

    # Add yearly and quarterly columns to the dataframe
    df = add_groupby_cols(df, horizon=horizon)

    all_metrics[(metric, task, target_dates)] = copy.copy(df)
    #print(all_metrics)

    if metric in metrics:
        key = (metric, task, target_dates)
        try:        
            missing_df = all_metrics[key].loc[(all_metrics[key][model_names].isnull().any(axis=1)),:]
        except:        
            missing_df = all_metrics[key].loc[(all_metrics[key].isnull().any(axis=1)),:]            
        if missing_df.shape[0] != 0:
            True
            display(Markdown(f"#### Missing metrics"))
            display(missing_df)    
        else:
            printf("All metrics present.")

# Main manuscript figures

## Figure 1:  Raw vs. ABC models (skill barplots)

In [None]:
# Set figure parameters
figure_model_names = [
    # Baselines
    "raw_ecmwf",
    "raw_ccsm4", 
    "raw_cfsv2",
    "raw_geos_v2p1",
    "raw_nesm",
    "raw_fimr1p1",
    "raw_gefs",
    "raw_gem",
    "raw_subx_mean",
    
    # Ensembles 
    "abc_ecmwf", 
    "abc_ccsm4",
    "abc_cfsv2",
    "abc_geos_v2p1",
    "abc_nesm",
    "abc_fimr1p1",
    "abc_gefs",
    "abc_gem", 
    "abc_subx_mean",
    
]

figure_gt_ids = us_1_5_gt_ids 
figure_horizons = horizons
figure_metric = 'skill'
figure_target_dates = 'std_paper_forecast'
figure_source_data = True
figure_source_data_filename = "fig_1-average_forecast_skill.xlsx"
figure_show = False

for gt_id, horizon in product(figure_gt_ids, figure_horizons):
    display(Markdown(f"#### {gt_id}_{horizon}"))
    df_barplot = barplot_rawabc(model_names=figure_model_names, 
                   gt_id=gt_id, 
                   horizon=horizon, 
                   metric=figure_metric, 
                   target_dates=figure_target_dates, 
                   source_data=figure_source_data,
                   source_data_filename=figure_source_data_filename,
                   show=figure_show)

## Figure 2:  Raw vs. Deb vs. ABC models (lat_lon_skill maps)

In [None]:
# Generate figure metrics
figure_gt_ids = us_1_5_gt_ids
figure_horizons = horizons
figure_target_dates = 'std_paper_forecast'
figure_metrics = ['lat_lon_skill','skill']
figure_models = [
    # Baselines
    "raw_cfsv2",
    "raw_ecmwf",
    # Standard de-biasing
    "deb_cfsv2",
    "deb_ecmwf",
    # Ensembles 
    "abc_cfsv2",
    "abc_ecmwf",
]

metric_dfs_rda = {}
for gt_id, horizon in product(figure_gt_ids, figure_horizons):
    task = f"{gt_id}_{horizon}"
    display(Markdown(f"#### Getting metrics for {gt_id} {horizon}"))
    metric_dfs_rda[task] = get_models_metric_lat_lon(gt_id=gt_id, horizon=horizon, target_dates=figure_target_dates, metrics = figure_metrics, model_names=figure_models)

In [None]:
# Set figure parameter
figure_gt_ids = us_1_5_gt_ids
figure_horizons = horizons
figure_metric = 'lat_lon_skill'
figure_mean_metric_df = metric_dfs_rda #None
figure_source_data = True
figure_source_data_filename = "fig_2-spatial_skill_distribution.xlsx"
figure_show = False



figure_model_names = ["raw_cfsv2", "deb_cfsv2", "abc_cfsv2"]
display(Markdown(f'#### Models: {", ".join(figure_model_names)}'))
for gt_id in figure_gt_ids:
    display(Markdown(f"#### {gt_id}"))
    plot_metric_maps(metric_dfs_rda, model_names=figure_model_names,
                         gt_ids=[gt_id],
                         horizons=figure_horizons,
                         metric=figure_metric,
                         target_dates=figure_target_dates,
                         mean_metric_df=figure_mean_metric_df,
                         show=figure_show, 
                         scale_type='linear',
                         CB_colors_customized=["white", "#dede00", "#ff7f00", "blueviolet", "indigo", "yellowgreen", "lightgreen", "darkgreen"],
                         CB_minmax = (0, 90),
                         source_data = figure_source_data,
                         source_data_filename = figure_source_data_filename)
    
figure_model_names = ["raw_ecmwf", "deb_ecmwf", "abc_ecmwf"]
display(Markdown(f'#### Models: {", ".join(figure_model_names)}'))
for gt_id in figure_gt_ids:
    display(Markdown(f"#### {gt_id}"))
    plot_metric_maps(metric_dfs_rda, model_names=figure_model_names,
                         gt_ids=[gt_id],
                         horizons=figure_horizons,
                         metric=figure_metric,
                         target_dates=figure_target_dates,
                         mean_metric_df=figure_mean_metric_df,
                         show=figure_show, 
                         scale_type='linear',
                         CB_colors_customized=["white", "#dede00", "#ff7f00", "blueviolet", "indigo", "yellowgreen", "lightgreen", "darkgreen"],
                         CB_minmax = (0, 90),
                         source_data = figure_source_data,
                         source_data_filename = figure_source_data_filename)

## Figure 3:  Fraction of grid points above skill threshold (lat_lon_skill barplots)

In [None]:
# Set figure parameters
figure_root_model_names = [
    "cfsv2",
    "ecmwf",    
]

figure_gt_ids = us_1_5_gt_ids
figure_horizons = horizons[-1:] 
figure_metric = 'lat_lon_skill'
figure_target_dates = 'std_paper_forecast'
figure_source_data = True
figure_source_data_filename = "fig_3-fraction_above_skill_threshold.xlsx"
figure_show = False

    
for model_name_root, gt_id, horizon  in product(figure_root_model_names, figure_gt_ids, figure_horizons):
    figure_model_names = [f"raw_{model_name_root}", f"deb_{model_name_root}", f"abc_{model_name_root}"]
    display(Markdown(f"#### {gt_id}_{horizon}"))
    df_barplot = barplot_skillthreshold(model_names=figure_model_names, 
                   gt_id=gt_id, 
                   horizon=horizon, 
                   metric=figure_metric, 
                   target_dates=figure_target_dates, 
                   source_data=figure_source_data,
                   source_data_filename=figure_source_data_filename,
                   show=figure_show)

## Figure 4:  Raw vs. ABC models (lat_lon_error maps)

In [None]:
# Generate figure metrics
figure_gt_ids = us_1_5_gt_ids
figure_horizons = horizons
figure_target_dates = 'std_paper_forecast'
figure_metrics = ['lat_lon_error']
figure_model_names = [
    # Baselines
    "raw_cfsv2",
    "raw_ecmwf",
    # Ensembles 
    "abc_cfsv2",
    "abc_ecmwf",
]

metric_dfs_rda = {}
for gt_id, horizon in product(figure_gt_ids, figure_horizons):
    task = f"{gt_id}_{horizon}"
    display(Markdown(f"#### Getting metrics for {gt_id} {horizon}"))
    metric_dfs_rda[task] = get_models_metric_lat_lon(gt_id=gt_id, horizon=horizon, target_dates=figure_target_dates, metrics = figure_metrics, model_names=figure_model_names)


In [None]:
# Set figure parameter
figure_gt_ids = us_1_5_gt_ids
figure_horizons = horizons
figure_metric = 'lat_lon_error'
figure_mean_metric_df = None
figure_source_data = True
figure_source_data_filename = "fig_4-spatial_bias_distribution.xlsx"
figure_show = False

figure_model_names = ["raw_cfsv2", "abc_cfsv2"]
display(Markdown(f'#### Models: {", ".join(figure_model_names)}'))
for gt_id in figure_gt_ids:
    display(Markdown(f"#### {gt_id}"))
    if 'tmp2m' in gt_id:
            CB_colors_customized=['blue','dodgerblue','lightskyblue',"white",'red','firebrick','darkred']
            CB_minmax = (-4, 4)
    else:
            CB_colors_customized=['orangered','darkorange',"white",'forestgreen','darkgreen']
            CB_minmax = (-15, 15)
    plot_metric_maps(metric_dfs_rda, model_names=figure_model_names,
                         gt_ids=[gt_id],
                         horizons=figure_horizons,
                         metric=figure_metric,
                         target_dates=figure_target_dates,
                         mean_metric_df=figure_mean_metric_df,
                         show=figure_show, 
                         scale_type='linear',
                         CB_colors_customized=CB_colors_customized,
                         CB_minmax = CB_minmax,
                         source_data = figure_source_data,
                         source_data_filename = figure_source_data_filename)

figure_model_names = ["raw_ecmwf", "abc_ecmwf"]
display(Markdown(f'#### Models: {", ".join(figure_model_names)}'))
for gt_id in figure_gt_ids:
    display(Markdown(f"#### {gt_id}"))
    if 'tmp2m' in gt_id:
            CB_colors_customized=['blue','dodgerblue','lightskyblue',"white",'red','firebrick','darkred']
            CB_minmax = (-4, 4)
    else:
            CB_colors_customized=['orangered','darkorange',"white",'forestgreen','darkgreen']
            CB_minmax = (-15, 15)
    plot_metric_maps(metric_dfs_rda, model_names=figure_model_names,
                         gt_ids=[gt_id],
                         horizons=figure_horizons,
                         metric=figure_metric,
                         target_dates=figure_target_dates,
                         mean_metric_df=figure_mean_metric_df,
                         show=figure_show, 
                         scale_type='linear',
                         CB_colors_customized=CB_colors_customized,
                         CB_minmax = CB_minmax,
                         source_data = figure_source_data,
                         source_data_filename = figure_source_data_filename)

# Cohort Shapley Figures: Figures 5-7
## Run Cohort Shapley analysis

In [None]:
# Step 1: Specify cs_task and models to explain
#
# Read input arguments from environment variable or specify interactively
#
cs_gt_id = os.environ.get("COMPARE_MODELS_gt_id", "us_precip_1.5x1.5")
cs_horizon = os.environ.get("COMPARE_MODELS_horizon", "34w")
cs_target_dates = os.environ.get("COMPARE_MODELS_target_dates", "std_paper_forecast")
cs_metric = os.environ.get("COMPARE_MODELS_metric", "skill")
cs_task = f"{cs_gt_id}_{cs_horizon}"
cs_task_long = cs_task.replace('us_precip_1.5x1.5_','precipitation').replace('us_tmp2m_1.5x1.5_','temperature').replace('34w', ', weeks 3-4').replace('56w','weeks 5-6')


# This code will explain the metrics of model if model2 is None
# and will explain the difference between model and model2 metrics
# if model2 is not None
cs_model = os.environ.get("COMPARE_MODELS_model", "abc_ecmwf")
cs_model2 = os.environ.get("COMPARE_MODELS_model2", "deb_ecmwf")#"None")
if cs_model2 == "None":
    cs_model2 = None 
if cs_model2 is None:
    cs_model_str = cs_model
else:
    cs_model_str = f"{cs_model}-vs-{cs_model2}"
# Prepare figure output directories
bin_fig_dir = os.path.join(out_dir, "bin_figs")
make_directories(bin_fig_dir)
date_fig_dir = os.path.join(out_dir, "date_maps")
make_directories(date_fig_dir)
# Identify measurement variable name
measurement_variable = get_measurement_variable(cs_gt_id) # 'tmp2m' or 'precip'
gt_col = measurement_variable
# column name for climatology
clim_col = measurement_variable+"_clim"
#Save figures source data here                                    
fig_dir = os.path.join(out_dir, "source_data")
if os.path.isdir(fig_dir) is False:
    make_directories(fig_dir)
    
    
# Step 2: Load outcome to be explained: model metric or model metric difference per target date
metrics = pd.read_hdf(get_metric_filename(model=cs_model, gt_id=cs_gt_id, horizon=cs_horizon, target_dates=cs_target_dates, metric=cs_metric))
metrics = metrics.set_index('start_date')
if cs_model2 is not None:
    metrics2 = pd.read_hdf(get_metric_filename(model=cs_model2, gt_id=cs_gt_id, horizon=cs_horizon, target_dates=cs_target_dates, metric=cs_metric))
    metrics2 = metrics2.set_index('start_date')
    outcome = metrics - metrics2
else:
    outcome = metrics    
# Discard NA values
outcome = outcome.dropna()


# Step 3: Load explanatory features
# Load continuous features
cols_to_load = ['start_date'] + continuous_feature_names(cs_gt_id, cs_horizon)
file_id = 'date_anom_data'
continuous = data_loaders.load_combined_data(
    file_id, cs_gt_id.replace("_1.5x1.5",""), cs_horizon, 
    columns=cols_to_load).set_index('start_date')
# Load discrete features
cols_to_load = ['start_date'] + discrete_feature_names(cs_gt_id, cs_horizon)
discrete = data_loaders.load_combined_data(
    'date_data', cs_gt_id.replace("_1.5x1.5",""), cs_horizon, 
    columns=cols_to_load).set_index('start_date')
# Add month feature
discrete["month"] = discrete.index.month

# Step 4: Merge outcome and feature data to ensure common indices
data = pd.merge(outcome, continuous, how='left', left_index=True, right_index=True)
data = pd.merge(data, discrete, how='left', left_index=True, right_index=True)
# Isolate outcome and feature components as y and X
y = data[cs_metric]
X = data.loc[:, continuous.columns.append(discrete.columns)]

# Step 5: Assess global variable importance with Shapley effects
#- Song et al., "Shapley effects for global sensitivity analysis: Theory and computation" https://epubs.siam.org/doi/abs/10.1137/15M1048070
    
# Step 5: Compute Shapley effects (a.k.a. Variance Shapley)
# Form quantile bins for continuous features
quantiles = 10
X_q = X.copy()
for col in continuous.columns:
    X_q[col] = pd.qcut(X[col],q=quantiles)

tic()
vs_values = vs.VarianceShapley(y.values, X_q.values)
toc()

# Store permutation that places features in decreasing order of their Shapley effect
order=np.argsort(vs_values)[::-1]


# Step 6: Compute or load Cohort Shapley impacts for local explanation
#vMase et al., "Explaining black box decisions by Shapley cohort refinement" https://arxiv.org/pdf/1911.00467.pdf

# Choose similarity type
# If quantiles is not None, replace continuous values with their quantile bins
# and declare points similar if they match exactly
# If quantiles is None, use similarity.similar_in_distance_cutoff with
# specified similarity.ratio to determine similarity
quantiles = 10
similarity.ratio = 0.1
if quantiles is None:
    sim_func = similarity.similar_in_distance_cutoff
    features = X.values
else:
    sim_func = similarity.similar_in_unity
    X_q = X.copy()
    for col in continuous.columns:
        X_q[col] = pd.qcut(data[col],q=quantiles)
    features = X_q.values

# Initialize Cohort Shapley object
cs_obj = cs.CohortShapley(
    None, sim_func,
    np.arange(len(y)), features, y=y.values, parallel=16) 

# Prepare results directory
if cs_model2 is None:
    cs_model_str = cs_model
else:
    cs_model_str = f"{cs_model}-vs-{cs_model2}"
results_dir = os.path.join("eval", "cohort_shapley", cs_model_str)
make_directories(results_dir)

# Construct results file name
result_file = os.path.join(results_dir,f'{cs_model_str}-{cs_metric}-{cs_gt_id}_{cs_horizon}-{cs_target_dates}')
if quantiles is None:
    result_file += f'-ratio{similarity.ratio}'
else:
    result_file += f'-q{quantiles}'
print(result_file)    

try:
    # Load previously saved results if available
    cs_obj.load(result_file)
    printf(f"Loaded Cohort Shapley results from {result_file}")
except:
    # Otherwise, compute Cohort Shapley from scratch and save to disk
    printf(f"Computing Cohort Shapley results")
    cs_obj.compute_cohort_shapley()
    printf(f"Saving Cohort Shapley results to {result_file}")
    cs_obj.save(result_file)
    # Ensure saved files have full read and write permissions
    set_file_permissions(result_file+".cs.npy", mode=0o666)
    set_file_permissions(result_file+".cs2.npy", mode=0o666)
    
# Store results as dataframe
df_cs = pd.DataFrame(cs_obj.shapley_values, index=X.index, columns=X.columns)


# Step 7: Visualize individual forecasts on which each feature had the greatest impact¶
# Identify dates with largest impacts
features = X.columns[order][:]
dates_largest_impact, dates_largest_metric = {}, {}

for feature in features:
    display(Markdown(f"#### {feature}:"))
    
    # Produce confidence interval for the probability of positive impact
    # within each feature quantile or bin
    cis_feature = (df_cs[feature] > 0).groupby(X_q[feature]).apply(
        lambda x: proportion.proportion_confint(x.sum(), x.size))
    num_bins = len(cis_feature)
    
    # Identify the highest impact bins (those within confidence interval of bin
    # with overall highest impact_level)
    high_impact_bins = get_high_impact_bins(feature, cis_feature, num_bins)
    
    # Identify the largest impact amongst all high impact bin forecasts
    largest_impact = -np.inf
    for feature_edges in high_impact_bins:        
        # Identify dates in the highest positive impact probability bin
        if feature in continuous:
            df = data[feature].to_frame()
            dmin, dmax = feature_edges.left, feature_edges.right
            dsel = df[(df[feature]>dmin)&(df[feature]<dmax)]
            label_str = f"between {dmin} and {dmax}"
        else:
            df = data[feature].to_frame()
            dm= feature_edges
            dsel = df[(df[feature]==dm)]
            label_str = f"{dm}"

        # Identify impacts of dates in the highest positive impact probability bin
        dsel_cs = df_cs[feature][df_cs.index.isin(dsel.index)]
        # Find the highest impact date in the highest positive impact probability bin
        dsel_cs = dsel_cs.sort_values(ascending=False).iloc[:10]
        if dsel_cs[0] > largest_impact:
            largest_impact = dsel_cs[0]
            dsel_date = datetime.strftime(dsel_cs.index[0],'%Y%m%d')
            dates_largest_impact[feature] = dsel_date
            printf(f"Date with largest impact ({dsel_cs[0]:.2}), where {feature} is {label_str}: {dsel_date}")
            
#Step 8: Generate metrics for dates with largest impact
if False:
    target_ranges = sorted(list(set([d for d in dates_largest_impact.values()] + [d for d in dates_largest_metric.values()])))
    figure_metrics = ['lat_lon_anom']#, 'lat_lon_skill'] 
    cs_model_names='gt '
    for m in ['ecmwf']:#,'cfsv2']:
        if cs_model2 is None:
            cs_model_names += f'abc_{m} '
        elif cs_model2.startswith("raw_"):
            cs_model_names += f'raw_{m} abc_{m} '
        elif cs_model2.startswith("deb_"):
            cs_model_names += f'deb_{m} abc_{m} '
    cmd = 'rm -f jobs_metrics.out'
    subprocess.call(cmd, shell=True)
    for figure_metric in figure_metrics:
        print(figure_metric)
        for target_range in target_ranges:
            print(target_range)
            cmd = f"python {os.path.join('src', 'eval', 'bulk_batch_metrics.py')} -mn {cs_model_names} -t {target_range} -m {figure_metric} >> jobs_metrics.out"
            print(cmd)
            subprocess.call(cmd, shell=True)
            
# Specify all explanatory features to visualize 
features = [feature for feature in X.columns[order[:]]]

# Specify continuous explanatory features to visualize
features_continuous = [feature for feature in X.columns[order[:]] if feature in continuous.columns]

## Figure 5: HGT_500 impact

#### Figure 5 a. Visualize the deciles of hgt_500_pc1

In [None]:
# Set figure parameters
figure_gt_id = cs_gt_id 
figure_horizon = cs_horizon
figure_model = cs_model 
figure_model2 = cs_model2 
figure_target_dates = cs_target_dates
figure_feature = "hgt_500_anom_2010_1_shift30"
figure_source_data = True
figure_source_data_filename = "fig_5-impact_hgt_500_pc1.xlsx"
figure_show = True


plot_lat_lon_mat_all(df_cs,
                     X_q, 
                     quantiles,
                     num_bins,
                     gt_id = figure_gt_id, 
                     horizon = figure_horizon, 
                     feature = figure_feature, 
                     model = figure_model, 
                     model2 = figure_model2, 
                     target_dates = figure_target_dates,
                     source_data = figure_source_data,
                     source_data_filename = figure_source_data_filename,
                     show = figure_show)

#### Figure 5 b. Visualize anomalies for date with largest impact

In [None]:
# Set figure parameter 
figure_feature = "hgt_500_anom_2010_1_shift30"
figure_gt_ids = [cs_gt_id]
figure_horizons = [cs_horizon]
figure_metrics = ['lat_lon_anom']
figure_target_dates = dates_largest_impact[figure_feature]
figure_model = cs_model 
figure_model2 = cs_model2 
figure_models = ["gt", figure_model] if figure_model2 is None else ['gt', figure_model2, figure_model]
figure_source_data = True
figure_source_data_filename = "fig_5-impact_hgt_500_pc1.xlsx"
figure_show = False

    

# mean_metric_df: Store metric for each model
mean_metric_df = pd.DataFrame()
mean_metric_df[figure_model] = metrics.loc[metrics.index == datetime.strptime(figure_target_dates, '%Y%m%d'), metric].values
if figure_model2 is not None:
    mean_metric_df[figure_model2] = metrics2.loc[metrics2.index == datetime.strptime(figure_target_dates, '%Y%m%d'), metric].values
if metric == 'skill':
    # Convert to a percentage
    mean_metric_df = mean_metric_df.apply(lambda x: x*100)
figure_mean_metric_df = mean_metric_df

#RDA: models for which Raw, Debiased and Abc versions are available
metric_dfs_rda = {}
for gt_id, horizon in product(figure_gt_ids, figure_horizons):
    task = f"{gt_id}_{horizon}"
    display(Markdown(f"#### Getting metrics for {gt_id} {horizon}"))
    
    metric_dfs_rda[task] = get_models_metric_lat_lon(gt_id=gt_id, 
                                                         horizon=horizon,
                                                         target_dates=figure_target_dates, 
                                                         metrics = figure_metrics, 
                                                         model_names=figure_models)

plot_metric_maps_date(metric_dfs_rda,
                          dates_largest_impact,
                          df_cs,
                          X_q,
                          model_names=figure_models,
                          gt_ids=figure_gt_ids,
                          horizons=figure_horizons,
                          metric='lat_lon_anom',
                          target_dates=figure_target_dates,
                          mean_metric_df=figure_mean_metric_df,
                          scale_type="linear", 
                          CB_colors_customized=['orangered','darkorange',"white",'forestgreen','darkgreen'],
                          CB_minmax=(-20, 20), 
                          feature=figure_feature, 
                          source_data = figure_source_data,
                          source_data_filename = figure_source_data_filename,
                          show = figure_show)


## Figure 6: MJO_phase impact

In [None]:
# Figure parameters
figure_gt_id = cs_gt_id
figure_horizon = cs_horizon
figure_target_dates = cs_target_dates
figure_model = cs_model
figure_model2 = cs_model2
figure_show = False
figure_feature = "phase_shift17"
figure_source_data = True
figure_source_data_filename = "fig_6-impact_mjo_phase.xlsx"


plot_mjo_impact(df_cs, X_q,  
                model = figure_model,
                model2 = figure_model2,
                gt_id = figure_gt_id,
                horizon = figure_horizon,
                target_dates = figure_target_dates,
                feature = figure_feature, 
                source_data = figure_source_data, 
                source_data_filename = figure_source_data_filename,
                show = figure_show)


#### Visualize anomalies for date with largest impact

In [None]:
# Figure parameter values
figure_gt_ids = [cs_gt_id]
figure_horizons = [cs_horizon]
figure_metrics = ['lat_lon_anom']
figure_feature = "phase_shift17"
figure_target_dates = dates_largest_impact[figure_feature]
figure_model = cs_model
figure_model2 = cs_model2
figure_show = False
figure_models = ["gt", cs_model] if cs_model2 is None else ['gt', cs_model2, cs_model]
figure_source_data = True
figure_source_data_filename = "fig_6-impact_mjo_phase.xlsx"


# mean_metric_df: Store metric for each model
mean_metric_df = pd.DataFrame()
mean_metric_df[figure_model] = metrics.loc[metrics.index == datetime.strptime(figure_target_dates, '%Y%m%d'), metric].values
if figure_model2 is not None:
    mean_metric_df[figure_model2] = metrics2.loc[metrics2.index == datetime.strptime(figure_target_dates, '%Y%m%d'), metric].values
if metric == 'skill':
    # Convert to a percentage
    mean_metric_df = mean_metric_df.apply(lambda x: x*100)
figure_mean_metric_df = mean_metric_df

#RDA: models for which Raw, Debiased and Abc versions are available
metric_dfs_rda = {}
for gt_id, horizon in product(figure_gt_ids, figure_horizons):
    task = f"{gt_id}_{horizon}"
    display(Markdown(f"#### Getting metrics for {gt_id} {horizon}"))
    metric_dfs_rda[task] = get_models_metric_lat_lon(gt_id=gt_id, 
                                                     horizon=horizon, 
                                                     target_dates=figure_target_dates, 
                                                     metrics = figure_metrics, 
                                                     model_names=figure_models)
plot_metric_maps_date(metric_dfs_rda,
                          dates_largest_impact,
                          df_cs,
                          X_q,
                          model_names=figure_models,
                          gt_ids=figure_gt_ids,
                          horizons=figure_horizons,
                          metric='lat_lon_anom',
                          target_dates=figure_target_dates,
                          mean_metric_df=figure_mean_metric_df,
                          scale_type="linear", 
                          CB_colors_customized=['orangered','darkorange',"white",'forestgreen','darkgreen'],
                          CB_minmax=(-20, 20), 
                          feature=figure_feature, 
                          source_data = figure_source_data,
                          source_data_filename = figure_source_data_filename,
                          show = figure_show)



## Figure 7: Opportunistic ABC

In [None]:
# Figure parameter values
figure_gt_id = cs_gt_id
figure_horizons = [cs_horizon]
figure_metric = cs_metric
figure_feature = "phase_shift17"
figure_target_dates = cs_target_dates #dates_largest_impact[figure_feature]
figure_model = cs_model
figure_model2 = cs_model2
figure_show = False
figure_models = ["gt", cs_model] if cs_model2 is None else ['gt', cs_model2, cs_model]
figure_source_data = True
figure_source_data_filename = "fig_7-windows_opportunity_abc.xlsx"   

plot_opportunistic_abc(X = X, X_q = X_q, df_cs = df_cs, y = y, 
                       metrics = metrics, 
                       metrics2 = metrics2, 
                       order = order,
                       model = figure_model,
                       model2 = figure_model2, 
                       gt_id = figure_gt_id,
                       horizon = figure_horizons[0],
                       target_dates = figure_target_dates,
                       metric = figure_metric,
                       show = figure_show,
                       source_data = figure_source_data,
                       source_data_filename = figure_source_data_filename)   


## Figure 8: ABC flowchart

In [None]:
figure_filename = os.path.join(out_dir, "abco", "abc_flowchart.pdf")
if os.path.isfile(figure_filename):
    printf(f"Figure is saved in:\n {figure_filename}")

# Supplementary Figures
## Figure S1: Average forecast skill bar plots - Raw vs. ABC models 
### Skill barplots per season 

In [None]:
# Set figure parameters
figure_models = [
    # Baselines
    "raw_ecmwf",
    "raw_ccsm4", 
    "raw_cfsv2",
    "raw_geos_v2p1",
    "raw_nesm",
    "raw_fimr1p1",
    "raw_gefs",
    "raw_gem",
    "raw_subx_mean",

    # Ensembles 
    "abc_ecmwf", 
    "abc_ccsm4",
    "abc_cfsv2",
    "abc_geos_v2p1",
    "abc_nesm",
    "abc_fimr1p1",
    "abc_gefs",
    "abc_gem", 
    "abc_subx_mean"]

figure_gt_ids = us_1_5_gt_ids
figure_horizons = ["34w", "56w"]
figure_metric = 'skill'
figure_target_dates = 'std_paper_forecast'
figure_quarters = ["DJF", "MAM", "JJA", "SON"]
figure_show = False
figure_source = True
figure_source_filename = "fig_s1-average_skill_season.xlsx"


for figure_gt_id, figure_horizon, figure_quarter in product(figure_gt_ids, 
                                              figure_horizons,
                                              figure_quarters):
    display(Markdown(f"#### {figure_gt_id} {figure_horizon} {figure_quarter}:"))
    df_barplot = barplot_rawabc_quarterly(model_names = figure_models, 
                                      gt_id = figure_gt_id, 
                                      horizon = figure_horizon, 
                                      metric = figure_metric, 
                                      target_dates = figure_target_dates, 
                                      quarter = figure_quarter, 
                                      show = figure_show,
                                      source_data_filename = figure_source_filename)

## Figure S2: Baselines vs. ABC models
#### Skill barplots 

In [None]:
# Set figure parameter 
figure_models = [
    # Ensembles 
    "abc_ecmwf", 
    "abc_cfsv2",
    # Baselines 
    "nn-a",
    "deb_loess_ecmwf", 
    "deb_loess_cfsv2", 
    "deb_quantile_ecmwf", 
    "deb_quantile_cfsv2",
]
figure_gt_ids = us_1_5_gt_ids
figure_horizons = ["34w", "56w"]
figure_metric = 'skill'
figure_target_dates = 'std_paper_forecast'
figure_show = False
figure_source_data = True
figure_source_data_filename = "fig_s2-average_skill_baselines.xlsx"

for figure_gt_id, figure_horizon in product(figure_gt_ids, figure_horizons):
    display(Markdown(f"#### {figure_gt_id} {figure_horizon}:"))
    df_barplot = barplot_baselineabc(model_names=figure_models, 
                                        gt_id=figure_gt_id, 
                                        horizon=figure_horizon, 
                                        metric=figure_metric, 
                                        target_dates=figure_target_dates,
                                        show=figure_show,
                                        source_data=figure_source_data,
                                        source_data_filename=figure_source_data_filename)


## Figure S3: Spatial skill distribution plots - Raw vs. Deb vs. ABC models 
#### (SubX mean lat_lon_skill maps)

In [None]:
figure_target_dates = 'std_paper_forecast'
figure_metrics = ['lat_lon_skill','skill']
figure_gt_ids = us_1_5_gt_ids
figure_horizons = horizons
figure_model_names = [
    # Baselines
    "raw_subx_mean",
    # Standard de-biasing
    "deb_subx_mean",
    # Ensembles 
    "abc_subx_mean",
]

metric_dfs_rda = {}
for figure_gt_id, figure_horizon in product(figure_gt_ids, figure_horizons):
    task = f"{figure_gt_id}_{figure_horizon}"
    display(Markdown(f"#### Getting metrics for {figure_gt_id} {figure_horizon}"))
    metric_dfs_rda[task] = get_models_metric_lat_lon(gt_id=figure_gt_id, 
                                                     horizon=figure_horizon, 
                                                     target_dates=figure_target_dates, 
                                                     metrics = figure_metrics, 
                                                     model_names=figure_model_names)
    


In [None]:
# Set figure parameter
figure_model_names = [
    # Baselines
    "raw_subx_mean",
    # Standard de-biasing
    "deb_subx_mean",
    # Ensembles 
    "abc_subx_mean",
]    
figure_gt_ids = us_1_5_gt_ids
figure_horizons = horizons
figure_metric = 'lat_lon_skill'
figure_mean_metric_df = metric_dfs_rda #None
figure_source_data = True
figure_source_data_filename = "fig_s3-spatial_skill_subx_mean.xlsx"
figure_show = False


for figure_gt_id in figure_gt_ids:
    display(Markdown(f"#### Processing {figure_gt_id}:"))
    plot_metric_maps(metric_dfs_rda, 
                         model_names=figure_model_names,
                         gt_ids=[figure_gt_id],
                         horizons=figure_horizons,
                         metric=figure_metric,
                         target_dates=figure_target_dates,
                         mean_metric_df=figure_mean_metric_df,
                         show=figure_show, 
                         scale_type='linear',
                         CB_colors_customized=["white", "#dede00", "#ff7f00", "blueviolet", "indigo", "yellowgreen", "lightgreen", "darkgreen"],
                         CB_minmax = (0, 90),
                         source_data = figure_source_data,
                         source_data_filename = figure_source_data_filename)

## Figure S4: Baselines - Raw vs. Deb vs. ABC models 
#### (Baseline lat_lon_skill maps)

In [None]:
figure_target_dates = 'std_paper_forecast'
figure_metrics = ['lat_lon_skill','skill']
figure_gt_ids = us_1_5_gt_ids
figure_horizons = horizons
figure_model_names = [
    # Baselines
    "nn-a",
    "deb_loess_ecmwf", 
    "deb_loess_cfsv2", 
    "deb_quantile_ecmwf", 
    "deb_quantile_cfsv2",
    # Ensembles 
    "abc_ecmwf",
    "abc_cfsv2",
]

metric_dfs_rda = {}
for figure_gt_id, figure_horizon in product(figure_gt_ids, figure_horizons):
    task = f"{figure_gt_id}_{figure_horizon}"
    display(Markdown(f"#### Getting metrics for {figure_gt_id} {figure_horizon}"))
    metric_dfs_rda[task] = get_models_metric_lat_lon(gt_id=figure_gt_id, 
                                                     horizon=figure_horizon, 
                                                     target_dates=figure_target_dates, 
                                                     metrics = figure_metrics, 
                                                     model_names=figure_model_names)
    

In [None]:
# Set figure parameter
figure_gt_ids = us_1_5_gt_ids
figure_horizons = horizons
figure_metric = 'lat_lon_skill'
figure_mean_metric_df = metric_dfs_rda #None
figure_source_data = True
figure_source_data_filename = "fig_s4-spatial_skill_baselines.xlsx"
figure_show = False



figure_model_names = ["nn-a", "deb_loess_cfsv2", "deb_quantile_cfsv2", "abc_cfsv2"]
display(Markdown(f'#### Models: {", ".join(figure_model_names)}'))
for figure_gt_id in figure_gt_ids:
    display(Markdown(f"#### {figure_gt_id}:"))
    plot_metric_maps(metric_dfs_rda, model_names=figure_model_names,
                         gt_ids=[figure_gt_id],
                         horizons=figure_horizons,
                         metric=figure_metric,
                         target_dates=figure_target_dates,
                         mean_metric_df=figure_mean_metric_df,
                         show=figure_show, 
                         scale_type='linear',
                         CB_colors_customized=["white", "#dede00", "#ff7f00", "blueviolet", "indigo", "yellowgreen", "lightgreen", "darkgreen"],
                         CB_minmax = (0, 90),
                         source_data = figure_source_data,
                         source_data_filename = figure_source_data_filename)
    
figure_model_names = ["deb_loess_ecmwf", "deb_quantile_ecmwf", "abc_ecmwf"]
display(Markdown(f'#### Models: {", ".join(figure_model_names)}'))
for figure_gt_id in figure_gt_ids:
    display(Markdown(f"#### {figure_gt_id}"))
    plot_metric_maps(metric_dfs_rda, model_names=figure_model_names,
                         gt_ids=[figure_gt_id],
                         horizons=figure_horizons,
                         metric=figure_metric,
                         target_dates=figure_target_dates,
                         mean_metric_df=figure_mean_metric_df,
                         show=figure_show, 
                         scale_type='linear',
                         CB_colors_customized=["white", "#dede00", "#ff7f00", "blueviolet", "indigo", "yellowgreen", "lightgreen", "darkgreen"],
                         CB_minmax = (0, 90),
                         source_data = figure_source_data,
                         source_data_filename = figure_source_data_filename)

## Figure S5:  Fraction of grid points above skill threshold 
#### (lat_lon_skill barplots)

In [None]:
# Set figure parameters
figure_root_model_names = [
    "cfsv2",
    "ecmwf",    
]

figure_gt_ids = us_1_5_gt_ids
figure_horizons = ['34w']
figure_metric = 'lat_lon_skill'
figure_target_dates = 'std_paper_forecast'
figure_source_data = True
figure_source_data_filename = "fig_s5-fraction_above_skill_threshold.xlsx"
figure_show = False

    
for model_name_root, gt_id, horizon  in product(figure_root_model_names, figure_gt_ids, figure_horizons):
    figure_model_names = [f"raw_{model_name_root}", f"deb_{model_name_root}", f"abc_{model_name_root}"]
    display(Markdown(f"#### {gt_id}_{horizon}"))
    df_barplot = barplot_skillthreshold(model_names=figure_model_names, 
                   gt_id=gt_id, 
                   horizon=horizon, 
                   metric=figure_metric, 
                   target_dates=figure_target_dates, 
                   source_data=figure_source_data,
                   source_data_filename=figure_source_data_filename,
                   show=figure_show)

## Figure S6:  Fraction of grid points above skill threshold 
#### (lat_lon_skill barplots)

In [None]:
# Set figure parameters
figure_root_model_names = [
    "subx_mean",    
]

figure_gt_ids = us_1_5_gt_ids
figure_horizons = horizons[1:]
figure_metric = 'lat_lon_skill'
figure_target_dates = 'std_paper_forecast'
figure_source_data = True
figure_source_data_filename = "fig_s6-fraction_above_skill_threshold.xlsx"
figure_show = False

    
for model_name_root, gt_id, horizon  in product(figure_root_model_names, figure_gt_ids, figure_horizons):
    figure_model_names = [f"raw_{model_name_root}", f"deb_{model_name_root}", f"abc_{model_name_root}"]
    display(Markdown(f"#### {gt_id}_{horizon}"))
    df_barplot = barplot_skillthreshold(model_names=figure_model_names, 
                   gt_id=gt_id, 
                   horizon=horizon, 
                   metric=figure_metric, 
                   target_dates=figure_target_dates, 
                   source_data=figure_source_data,
                   source_data_filename=figure_source_data_filename,
                   show=figure_show)

# Probabilistic Evaluation Plots: Figures 7-10
#### Save bss and lat_lon_bss metric dataframes

In [None]:
def get_model_task_rps(model = "d2p_raw_ecmwf", gt_id = "us_tmp2m_1.5x1.5", horizon = "56w",  target_dates = 'std_paper_forecast'):
    
    measurement_var = "tmp2m" if "tmp2m" in gt_id else "precip"
    gt_id_p1 = f"us_{measurement_var}_p1_1.5x1.5"
    gt_id_p3 = f"us_{measurement_var}_p3_1.5x1.5"
    task_p1 = f"{gt_id_p1}_{horizon}"
    task_p3 = f"{gt_id_p3}_{horizon}"
    # Load the mse metrics for model and task_p1, task_p3
    metrics_dir = get_task_metrics_dir(model=model, submodel=None, gt_id=gt_id_p1, horizon=horizon, target_dates=target_dates)
    model_metrics = pd.read_hdf(os.path.join(metrics_dir, f"mse-{task_p1}-{target_dates}.h5"))
    model_metrics.rename(columns={"mse":"mse_p1"},  inplace=True)
    model_metrics["mse_p3"] = pd.read_hdf(os.path.join(metrics_dir.replace("_p1","_p3"), f"mse-{task_p3}-{target_dates}.h5")).mse
    # Compute model_rps = mse_p1 + mse_p3
    model_metrics["rps"] = model_metrics["mse_p1"] + model_metrics["mse_p3"]
    return model_metrics

def get_model_task_lat_lon_rps(model = "d2p_raw_ecmwf", gt_id = "us_tmp2m_1.5x1.5", horizon = "56w",  target_dates = 'std_paper_forecast'):
    measurement_var = "tmp2m" if "tmp2m" in gt_id else "precip"
    gt_id_p1 = f"us_{measurement_var}_p1_1.5x1.5"
    gt_id_p3 = f"us_{measurement_var}_p3_1.5x1.5"
    task_p1 = f"{gt_id_p1}_{horizon}"
    task_p3 = f"{gt_id_p3}_{horizon}"
    # Load the lat_lon_mse metrics for model and task_p1, task_p3
    metrics_dir = get_task_metrics_dir(model=model, submodel=None, gt_id=gt_id_p1, horizon=horizon, target_dates=target_dates)
    model_metrics = pd.read_hdf(os.path.join(metrics_dir, f"lat_lon_mse-{task_p1}-{target_dates}.h5"))
    model_metrics.rename(columns={"lat_lon_mse":"mse_p1"},  inplace=True)
    model_metrics["mse_p3"] = pd.read_hdf(os.path.join(metrics_dir.replace("_p1","_p3"), f"lat_lon_mse-{task_p3}-{target_dates}.h5")).lat_lon_mse
    # Compute model_rps = mse_p1 + mse_p3
    model_metrics["rps"] = model_metrics["mse_p1"] + model_metrics["mse_p3"]
    return model_metrics

if False:
    gt = {}
    prob_gt_ids = ["us_tmp2m_p1_1.5x1.5", "us_tmp2m_p3_1.5x1.5", "us_precip_p1_1.5x1.5", "us_precip_p3_1.5x1.5"]
    for prob_gt_id in prob_gt_ids:
        gt[prob_gt_id] = get_ground_truth(prob_gt_id).set_index(["lat","lon","start_date"]).squeeze().unstack(["lat","lon"])

    # Save BSS metric dataframes
    target_dates = "std_paper_forecast"
    target_date_objs = get_target_dates(target_dates)
    gt_ids = ["us_tmp2m_1.5x1.5","us_precip_1.5x1.5"]
    horizons = ["12w","34w","56w"]

    # BSS by date
    for gt_id in gt_ids:
        measurement_var = get_measurement_variable(gt_id)
        p1_var = measurement_var+"_p1"
        p3_var = measurement_var+"_p3"

        clim_rps = pd.DataFrame(data = {
            'mse_p1': np.square(gt[gt_id.replace(measurement_var, p1_var)].loc[target_date_objs] - 1./3).mean(axis=1),
            'mse_p3': np.square(gt[gt_id.replace(measurement_var, p3_var)].loc[target_date_objs] - 1./3).mean(axis=1)})
        clim_rps['rps'] = clim_rps['mse_p1']+clim_rps['mse_p3']
        for horizon in horizons:
            mets = {}
            for model in ["raw_ecmwf", "deb_ecmwf", "abcds_ecmwf", "shift_deb_loess_ecmwf", "shift_deb_quantile_ecmwf"]: 
                mets[model] = get_model_task_rps(model = f"d2p_{model}", gt_id = gt_id, horizon = horizon,  target_dates = 'std_paper_forecast').dropna().set_index("start_date")
                mets[model]['bss'] = 1-mets[model]['mse_p3']/clim_rps['mse_p3']
                sn = get_selected_submodel_name(model=model, gt_id=gt_id, horizon=horizon)
                metric = "bss"
                df_metric = mets[model].reset_index()[["start_date", metric]]#.rename(columns={"rpss":metric})
                save_metric(df_metric, model=model, submodel=sn,
                            gt_id=gt_id, horizon=horizon, 
                            target_dates=target_dates, metric=metric)
    # BSS by grid point
    for gt_id in gt_ids:
        measurement_var = get_measurement_variable(gt_id)
        p1_var = measurement_var+"_p1"
        p3_var = measurement_var+"_p3"

        clim_rps = pd.DataFrame(data = {
            'mse_p1': np.square(gt[gt_id.replace(measurement_var, p1_var)].loc[target_date_objs] - 1./3).mean(axis=0),
            'mse_p3': np.square(gt[gt_id.replace(measurement_var, p3_var)].loc[target_date_objs] - 1./3).mean(axis=0)})
        clim_rps['rps'] = clim_rps['mse_p1']+clim_rps['mse_p3']
        for horizon in horizons:
            mets = {}
            for model in ["raw_ecmwf", "deb_ecmwf", "abcds_ecmwf"]: 
                mets[model] = get_model_task_lat_lon_rps(model = f"d2p_{model}", gt_id = gt_id, horizon = horizon,  target_dates = 'std_paper_forecast').dropna().set_index(['lat','lon'])
                mets[model]['bss'] = 1-mets[model]['mse_p3']/clim_rps['mse_p3']
                sn = get_selected_submodel_name(model=model, gt_id=gt_id, horizon=horizon)
                metric = "lat_lon_bss"
                df_metric = mets[model].reset_index()[["lat", "lon", "bss"]].rename(columns={"bss":metric})
                save_metric(df_metric, model=model, submodel=sn,
                            gt_id=gt_id, horizon=horizon, 
                            target_dates=target_dates, metric=metric)
    

## Figure S7: Raw vs. ABC models (BSS barplots by season)

In [None]:
# Set figure parameters
figure_model_names = [
    # Baselines
    "raw_ecmwf",   
    # Ensembles 
    "abcds_ecmwf",     
]
figure_gt_id = us_1_5_gt_ids
figure_horizon = horizons
figure_metric = 'bss'
figure_target_dates = 'std_paper_forecast'
figure_source_data = True
figure_source_data_filename = "fig_s7-average_bss.xlsx"
figure_show = False

                         
for figure_gt_id, figure_horizon in product(figure_gt_ids, figure_horizons):
    display(Markdown(f"#### {figure_gt_id} {figure_horizon}:"))
    df_barplot = barplot_rawabc_bss(model_names=figure_model_names, 
                                    gt_id=figure_gt_id, 
                                    horizon=figure_horizon, 
                                    metric=figure_metric, 
                                    target_dates=figure_target_dates,
                                    source_data=figure_source_data,
                                    source_data_filename=figure_source_data_filename,
                                    show=figure_show)

## Figure S8: Raw vs. ABC models (CRPS barplots)

In [None]:
# Set figure parameters
figure_model_names = [
    # Baselines
    "raw_ecmwf",   
    # Ensembles 
    "abcds_ecmwf",     
]
figure_gt_id = us_1_5_gt_ids
figure_horizon = horizons
figure_metric = 'crps'
figure_target_dates = 'std_paper_forecast'
figure_source_data = True
figure_source_data_filename = "fig_s8-average_crps.xlsx"
figure_show = False

                         
for figure_gt_id, figure_horizon in product(figure_gt_ids, figure_horizons):
    display(Markdown(f"#### {figure_gt_id} {figure_horizon}:"))
    df_barplot = barplot_rawabc_crps(model_names=figure_model_names, 
                                    gt_id=figure_gt_id, 
                                    horizon=figure_horizon, 
                                    metric=figure_metric, 
                                    target_dates=figure_target_dates,
                                    source_data=figure_source_data,
                                    source_data_filename=figure_source_data_filename,
                                    show=figure_show)


## Figure S9: Baselines vs. ABC models (BSS barplots by season)

In [None]:
# Set figure parameters
figure_model_names = [
    # Baselines
    "shift_deb_quantile_ecmwf",
    "shift_deb_loess_ecmwf",
    # Ensembles 
    "abcds_ecmwf",       
]
figure_gt_id = us_1_5_gt_ids
figure_horizon = horizons
figure_metric = 'bss'
figure_target_dates = 'std_paper_forecast'
figure_source_data = True
figure_source_data_filename = "fig_s9-average_bss_baselines.xlsx"
figure_show = False

                         
for figure_gt_id, figure_horizon in product(figure_gt_ids[:1], figure_horizons[:1]):
    display(Markdown(f"#### {figure_gt_id} {figure_horizon}:"))
    df_barplot = barplot_baselinesabc_bss(model_names=figure_model_names, 
                                    gt_id=figure_gt_id, 
                                    horizon=figure_horizon, 
                                    metric=figure_metric, 
                                    target_dates=figure_target_dates,
                                    source_data=figure_source_data,
                                    source_data_filename=figure_source_data_filename,
                                    show=figure_show)

## Figure S10: Baselines vs. ABC models (CRPS barplots by season)

In [None]:
# Set figure parameters
figure_model_names = [
    # Baselines
    "shift_deb_quantile_ecmwf",
    "shift_deb_loess_ecmwf",
    # Ensembles 
    "abcds_ecmwf",       
]
figure_gt_id = us_1_5_gt_ids
figure_horizon = horizons
figure_metric = 'crps'
figure_target_dates = 'std_paper_forecast'
figure_source_data = True
figure_source_data_filename = "fig_s10-average_crps_baselines.xlsx"
figure_show = False

                         
for figure_gt_id, figure_horizon in product(figure_gt_ids, figure_horizons):
    display(Markdown(f"#### {figure_gt_id} {figure_horizon}:"))
    df_barplot = barplot_baselinesabc_crps(model_names=figure_model_names, 
                                    gt_id=figure_gt_id, 
                                    horizon=figure_horizon, 
                                    metric=figure_metric, 
                                    target_dates=figure_target_dates,
                                    source_data=figure_source_data,
                                    source_data_filename=figure_source_data_filename,
                                    show=figure_show)

# Spatial bias plots: Figures 11-13
## Figure S11: Raw vs. ABC models (SubX mean lat_lon_error maps)

In [None]:
figure_target_dates = 'std_paper_forecast'
figure_metrics = ['lat_lon_error']
figure_gt_ids = us_1_5_gt_ids
figure_horizons = horizons
figure_model_names = ["raw_subx_mean", "abc_subx_mean"]

metric_dfs_rda = {}
for figure_gt_id, figure_horizon in product(figure_gt_ids, figure_horizons):
    task = f"{figure_gt_id}_{figure_horizon}"
    display(Markdown(f"#### Getting metrics for {figure_gt_id} {figure_horizon}"))
    metric_dfs_rda[task] = get_models_metric_lat_lon(gt_id=figure_gt_id, 
                                                     horizon=figure_horizon, 
                                                     target_dates=figure_target_dates, 
                                                     metrics = figure_metrics, 
                                                     model_names=figure_model_names)

In [None]:
# Figure parameter values
figure_model_names = ["raw_subx_mean", "abc_subx_mean"]
figure_gt_ids = us_1_5_gt_ids
figure_horizons = horizons
figure_metric = 'lat_lon_error'
figure_mean_metric_df = None
figure_source_data = True
figure_source_data_filename = "fig_s11-spatial_bias_subx_mean.xlsx"
figure_show = False

for figure_gt_id in figure_gt_ids:
    display(Markdown(f"#### Processing {figure_gt_id}:"))
    if 'tmp2m' in figure_gt_id:
            CB_colors_customized=['blue','dodgerblue','lightskyblue',"white",'red','firebrick','darkred']
            CB_minmax = (-4, 4)
    else:
            CB_colors_customized=['orangered','darkorange',"white",'forestgreen','darkgreen']
            CB_minmax = (-20, 20)
    plot_metric_maps(metric_dfs_rda, model_names=figure_model_names,
                         gt_ids=[figure_gt_id],
                         horizons=figure_horizons,
                         metric=figure_metric,
                         target_dates=figure_target_dates,
                         mean_metric_df=figure_mean_metric_df,
                         show=figure_show, 
                         scale_type='linear',
                         CB_colors_customized=CB_colors_customized,
                         CB_minmax = CB_minmax,
                         source_data = figure_source_data,
                         source_data_filename = figure_source_data_filename)

## Figure S12: Baselines vs. ABC models (lat_lon_error maps)

In [None]:
figure_target_dates = 'std_paper_forecast'
figure_metrics = ['lat_lon_error']
figure_gt_ids = us_1_5_gt_ids
figure_horizons = horizons
figure_model_names = ["nn-a", 
                      "deb_loess_cfsv2", 
                      "deb_quantile_cfsv2", 
                      "abc_cfsv2",
                      "deb_loess_ecmwf", 
                      "deb_quantile_ecmwf",
                      "abc_ecmwf"]

metric_dfs_rda = {}
for figure_gt_id, figure_horizon in product(figure_gt_ids, figure_horizons):
    task = f"{figure_gt_id}_{figure_horizon}"
    display(Markdown(f"#### Getting metrics for {figure_gt_id} {figure_horizon}"))
    metric_dfs_rda[task] = get_models_metric_lat_lon(gt_id=figure_gt_id, 
                                                     horizon=figure_horizon, 
                                                     target_dates=figure_target_dates, 
                                                     metrics = figure_metrics, 
                                                     model_names=figure_model_names)

In [None]:
# Set figure parameter
figure_gt_ids = us_1_5_gt_ids
figure_horizons = horizons
figure_metric = 'lat_lon_error'
figure_mean_metric_df = None
figure_source_data = True
figure_source_data_filename = "fig_s12-spatial_bias_baselines.xlsx"
figure_show = False



figure_model_names = ["nn-a", "deb_loess_cfsv2", "deb_quantile_cfsv2", "abc_cfsv2"]
display(Markdown(f'#### Models: {", ".join(figure_model_names)}'))
for figure_gt_id in figure_gt_ids:
    display(Markdown(f"#### {figure_gt_id}:"))
    if 'tmp2m' in figure_gt_id:
        CB_colors_customized=['blue','dodgerblue','lightskyblue',"white",'red','firebrick','darkred']
        CB_minmax = (-4, 4)
    else:
        CB_colors_customized=['orangered','darkorange',"white",'forestgreen','darkgreen']
        CB_minmax = (-20, 20)
    plot_metric_maps(metric_dfs_rda, 
                         model_names=figure_model_names,
                         gt_ids=[figure_gt_id],
                         horizons=figure_horizons,
                         metric=figure_metric,
                         target_dates=figure_target_dates,
                         mean_metric_df=figure_mean_metric_df,
                         show=figure_show, 
                         scale_type='linear',
                         CB_colors_customized=CB_colors_customized,
                         CB_minmax = CB_minmax,
                         source_data = figure_source_data,
                         source_data_filename = figure_source_data_filename)
    
figure_model_names = ["deb_loess_ecmwf", "deb_quantile_ecmwf","abc_ecmwf"]
display(Markdown(f'#### Models: {", ".join(figure_model_names)}'))
for figure_gt_id in figure_gt_ids:
    display(Markdown(f"#### {figure_gt_id}:"))
    if 'tmp2m' in figure_gt_id:
        CB_colors_customized=['blue','dodgerblue','lightskyblue',"white",'red','firebrick','darkred']
        CB_minmax = (-4, 4)
    else:
        CB_colors_customized=['orangered','darkorange',"white",'forestgreen','darkgreen']
        CB_minmax = (-20, 20)
    plot_metric_maps(metric_dfs_rda, 
                         model_names=figure_model_names,
                         gt_ids=[figure_gt_id],
                         horizons=figure_horizons,
                         metric=figure_metric,
                         target_dates=figure_target_dates,
                         mean_metric_df=figure_mean_metric_df,
                         show=figure_show, 
                         scale_type='linear',
                         CB_colors_customized=CB_colors_customized,
                         CB_minmax = CB_minmax,
                         source_data = figure_source_data,
                         source_data_filename = figure_source_data_filename)

## Figure S13: Deb vs. ABC models (lat_lon_error maps)

In [None]:
figure_target_dates = 'std_paper_forecast'
figure_metrics = ['lat_lon_error']
figure_gt_ids = us_1_5_gt_ids
figure_horizons = horizons
figure_model_names = ["deb_cfsv2", 
                      "abc_cfsv2", 
                      "deb_ecmwf",
                      "abc_ecmwf"]

metric_dfs_rda = {}
for figure_gt_id, figure_horizon in product(figure_gt_ids, figure_horizons):
    task = f"{figure_gt_id}_{figure_horizon}"
    display(Markdown(f"#### Getting metrics for {figure_gt_id} {figure_horizon}"))
    metric_dfs_rda[task] = get_models_metric_lat_lon(gt_id=figure_gt_id, 
                                                     horizon=figure_horizon, 
                                                     target_dates=figure_target_dates, 
                                                     metrics = figure_metrics, 
                                                     model_names=figure_model_names)

In [None]:
# Set figure parameter
figure_gt_ids = us_1_5_gt_ids
figure_horizons = horizons
figure_metric = 'lat_lon_error'
figure_mean_metric_df = None
figure_source_data = True
figure_source_data_filename = "fig_s13-spatial_bias_deb.xlsx"
figure_show = False



figure_model_names = ["deb_cfsv2", "abc_cfsv2"]
display(Markdown(f'#### Models: {", ".join(figure_model_names)}'))
for figure_gt_id in figure_gt_ids:
    display(Markdown(f"#### {figure_gt_id}:"))
    if 'tmp2m' in figure_gt_id:
        CB_colors_customized=['blue','dodgerblue','lightskyblue',"white",'red','firebrick','darkred']
        CB_minmax = (-1, 1)
    else:
        CB_colors_customized=['orangered','darkorange',"white",'forestgreen','darkgreen']
        CB_minmax = (-5, 5)
    plot_metric_maps(metric_dfs_rda, 
                         model_names=figure_model_names,
                         gt_ids=[figure_gt_id],
                         horizons=figure_horizons,
                         metric=figure_metric,
                         target_dates=figure_target_dates,
                         mean_metric_df=figure_mean_metric_df,
                         show=figure_show, 
                         scale_type='linear',
                         CB_colors_customized=CB_colors_customized,
                         CB_minmax = CB_minmax,
                         source_data = figure_source_data,
                         source_data_filename = figure_source_data_filename)
    
figure_model_names = ["deb_ecmwf","abc_ecmwf"]
display(Markdown(f'#### Models: {", ".join(figure_model_names)}'))
for figure_gt_id in figure_gt_ids:
    display(Markdown(f"#### {figure_gt_id}:"))
    if 'tmp2m' in figure_gt_id:
        CB_colors_customized=['blue','dodgerblue','lightskyblue',"white",'red','firebrick','darkred']
        CB_minmax = (-1, 1)
    else:
        CB_colors_customized=['orangered','darkorange',"white",'forestgreen','darkgreen']
        CB_minmax = (-5, 5)
    plot_metric_maps(metric_dfs_rda, 
                         model_names=figure_model_names,
                         gt_ids=[figure_gt_id],
                         horizons=figure_horizons,
                         metric=figure_metric,
                         target_dates=figure_target_dates,
                         mean_metric_df=figure_mean_metric_df,
                         show=figure_show, 
                         scale_type='linear',
                         CB_colors_customized=CB_colors_customized,
                         CB_minmax = CB_minmax,
                         source_data = figure_source_data,
                         source_data_filename = figure_source_data_filename)

# Overall variable importance

## Figure S14: Variable importance

In [None]:
# Run Cohort Shapley analysis cells prior to running the following code
# Set figure parameters
figure_gt_id =  "us_precip_1.5x1.5" 
figure_horizon = "34w" 
figure_model = cs_model
figure_model2 = cs_model2
figure_target_dates = 'std_paper_forecast'
figure_source_data = True
figure_source_data_filename = "fig_s14-variable_importance.xlsx"
figure_show = False

plot_variable_importance(X=X, vs_values=vs_values, order=order,
                             model = figure_model, 
                             model2 = figure_model2,
                             gt_id = figure_gt_id, 
                             horizon = figure_horizon, 
                             source_data=figure_source_data,
                             source_data_filename = figure_source_data_filename,
                             show=figure_show)

# ABC schematic 

## Figure S15: Schematic of ABC input-output (lat_lon_anom maps)

In [None]:
figure_target_dates = '20201218'
figure_metrics = ['lat_lon_anom']
figure_gt_ids = ["us_precip_1.5x1.5"]
figure_horizons = horizons
figure_model_names = ["raw_cfsv2",
                 "deb_cfsv2",
                 "abc_cfsv2", 
                 "tuned_cfsv2pp", 
                 "perpp_cfsv2", 
                 "tuned_climpp", 
                 "gt",
                ]

metric_dfs_rda = {}
for figure_gt_id, figure_horizon in product(figure_gt_ids, figure_horizons):
    task = f"{figure_gt_id}_{figure_horizon}"
    display(Markdown(f"#### Getting metrics for {figure_gt_id} {figure_horizon}"))
    metric_dfs_rda[task] = get_models_metric_lat_lon(gt_id=figure_gt_id, horizon=figure_horizon, 
                                                     target_dates=figure_target_dates, 
                                                     metrics = figure_metrics, 
                                                     model_names=figure_model_names)


In [None]:
# Set figure parameter
figure_gt_ids = ["us_precip_1.5x1.5"]
figure_horizons = horizons
figure_metric = 'lat_lon_anom'
figure_mean_metric_df = None
figure_source_data = True
figure_source_data_filename = "fig_s15-schematic_abc_anoms.xlsx"
figure_show = False
figure_model_names = ["raw_cfsv2",
                         "deb_cfsv2",
                         "abc_cfsv2", 
                         "tuned_cfsv2pp", 
                         "perpp_cfsv2", 
                         "tuned_climpp", 
                         "gt",
                     ]

display(Markdown(f'#### Models: {", ".join(figure_model_names)}'))
for figure_gt_id in figure_gt_ids:
    display(Markdown(f"#### {figure_gt_id}:"))
    if 'tmp2m' in figure_gt_id:
        CB_colors_customized=['blue','dodgerblue','lightskyblue',"white",'red','firebrick','darkred']
        CB_minmax = (-4, 4)
    else:
        CB_colors_customized=['orangered','darkorange',"white",'forestgreen','darkgreen']
        CB_minmax = (-10, 10)
    plot_metric_maps(metric_dfs_rda, 
                         model_names=figure_model_names,
                         gt_ids=[figure_gt_id],
                         horizons=figure_horizons,
                         metric=figure_metric,
                         target_dates=figure_target_dates,
                         mean_metric_df=figure_mean_metric_df,
                         show=figure_show, 
                         scale_type='linear',
                         CB_colors_customized=CB_colors_customized,
                         CB_minmax = CB_minmax,
                         source_data = figure_source_data,
                         source_data_filename = figure_source_data_filename)
                 
figure_filename = os.path.join(out_dir, "abco", "abc_schematic.pdf")
if os.path.isfile(figure_filename):
    printf(f"\nFinal figure saved in:\n {figure_filename}")