# ABC Plots

In [None]:
""" 
Generates ABC result figures for 

Adaptive Bias Correction for Improved Subseasonal Forecasting

Soukayna Mouatadid, Paulo Orenstein, Genevieve Flaspohler, 
Judah Cohen, Miruna Oprescu, Ernest Fraenkel, and Lester Mackey. 
"""
# Ensure notebook is being run from base repository directory
import os, sys
try:
    os.chdir("/home/{}/forecast_rodeo_ii/".format(os.environ["USER"]))
except Exception as err:
    print(f"Warning: unable to change directory; {repr(err)}")
    
%load_ext autoreload
%autoreload 2
%matplotlib inline    
    
import itertools
import importlib
import subprocess
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from itertools import product
from functools import partial
from datetime import datetime
from IPython.display import Markdown, display

import os
import copy
import pdb
import calendar 
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib   
from matplotlib.gridspec import GridSpec
from matplotlib.colors import LinearSegmentedColormap

from subseasonal_toolkit.utils.experiments_util import pandas2hdf, get_measurement_variable
from subseasonal_toolkit.utils.general_util import printf, make_directories
from subseasonal_toolkit.utils.eval_util import (get_target_dates, score_to_mean_rmse, 
                                                contest_quarter_start_dates, contest_quarter, year_quarter,
                                                get_task_metrics_dir, save_metric)


from subseasonal_toolkit.utils.models_util import get_selected_submodel_name
from subseasonal_data.data_loaders import get_ground_truth, get_climatology

from viz_util_abc import get_metrics_df, plot_metric_maps, get_models_metric_lat_lon, plot_metric_maps, add_groupby_cols, all_model_names



# set figure and font sizes for seaborn plots
sns.set(rc={'figure.figsize':(8,6)}, font_scale=1)

#
# Directory for saving output
#
out_dir = "/home/{}/forecast_rodeo_ii/subseasonal_toolkit/viz".format(os.environ["USER"])


In [None]:
# 
# Full set of regions, times, and tasks to evaluate
#
metrics = ["rmse", "skill", "score"]

contest_gt_ids = ["contest_tmp2m", "contest_precip"]
us_gt_ids = ["us_tmp2m", "us_precip"]
east_gt_ids = ["east_tmp2m", "east_precip"]
us_1_5_gt_ids = ["us_tmp2m_1.5x1.5", "us_precip_1.5x1.5"]

# All ground truth ids
gt_ids = contest_gt_ids + us_gt_ids 

horizons = ["12w", "34w", "56w"]
target_eval_dates = ["std_paper", "std_contest"]

# The full set of models we to evaluate in some
# experiment 
all_models = [
    # Raw Baselines
    'raw_cfsv2', 
    # Baselines
    "climatology",   
    'deb_cfsv2',
    'persistence',
    # ECMWF
    'ecmwf'
    # Toolkit 
    'tuned_climpp',
    'tuned_cfsv2pp',
    'perpp',
    #Learning
    'autoknn',
    'informer',
    'tuned_localboosting',
    'multillr',
    'nbeats',
    'prophet',
    'salient',
    'tuned_salient2',
    #Ensembles
    'linear_ensemble',  
    'online_learning'
]

# Main experiment model names
main_experiment_models = [
    # Baselines
    "climatology",   
    'deb_cfsv2',
    'persistence',
    # Toolkit 
    'tuned_climpp',
    'tuned_cfsv2pp',
    'perpp',
    #Learning
    'autoknn',
    'tuned_localboosting',
    'multillr',
    'nbeats',
    'informer',
    'prophet',
    'tuned_salient2',
    #Ensembles
    'linear_ensemble',  
    'online_learning'
]

# Rodeo experiment model names
rodeo_experiment_models = [
    # Baselines
    "climatology",   
    'deb_cfsv2',
    'persistence',
    # Toolkit 
    'tuned_climpp',
    'tuned_cfsv2pp',
    'perpp',
    #Learning
    'autoknn',
    'tuned_localboosting',
    'multillr',
    'prophet',
    'tuned_salient2',
    #Ensembles
    'linear_ensemble_localFalse_dynamicFalse_stepFalse_LtCtD',
    'linear_ensemble_localFalse_dynamicFalse_stepFalse_AMLPtCtDtKtS',  
    'online_learning-ah_rpNone_R1_recent_g_SC_LtCtD',
    'online_learning-ah_rpNone_R1_recent_g_SC_AMLPtCtDtKtS'
]

# Salient experiment model names
salient_experiment_models = [
    # Baselines   
    'deb_cfsv2',
    # Toolkit 
    'tuned_cfsv2pp',
    #Learning
    'tuned_salient2',
]


# ECMWF experiment model names
ecmwf_experiment_models = [
    # Baselines
    "climatology", 
    'deb_cfsv2',
    'persistence',
    # Toolkit 
    'tuned_climpp',
    'tuned_cfsv2pp',
    'perpp',
    # ECMWF
    'ecmwf-years20_leads15-15_lossmse_forecastc_debiasp+c',
    'ecmwf-years20_leads15-15_lossmse_forecastp_debiasp+c',
    # Ensembles
    "online_learning", 
    "linear_ensemble" 
]

# De-biasing experiment model names
debias_experiment_models = [
    # Baselines
    "raw_ccsm4", 
    "raw_cfsv2",
    "raw_geos_v2p1",
    "raw_nesm",
    "raw_fimr1p1",
    "raw_gefs",
    "raw_gem",
    "raw_ecmwf",
    # Ensembles 
    "abc_ccsm4",
    "abc_cfsv2",
    "abc_geos_v2p1",
    "abc_nesm",
    "abc_fimr1p1",
    "abc_gefs",
    "abc_gem", 
    "abc_ecmwf",
]

In [None]:
#
# Dictionaries mapping all model names and tasks to their display names
#

east_tasks = {
    "east_tmp2m_34w": "Temp. weeks 3-4",
    "east_tmp2m_56w": "Temp. weeks 5-6",
    "east_precip_34w": "Precip. weeks 3-4",
    "east_precip_56w": "Precip. weeks 5-6"
}

contest_tasks = {
    "contest_tmp2m_34w": "Temp. weeks 3-4",
    "contest_tmp2m_56w": "Temp. weeks 5-6",
    "contest_precip_34w": "Precip. weeks 3-4",
    "contest_precip_56w": "Precip. weeks 5-6"
}
us_tasks = {
    "us_tmp2m_34w": "Temp. weeks 3-4",
    "us_tmp2m_56w": "Temp. weeks 5-6",
    "us_precip_34w": "Precip. weeks 3-4",
    "us_precip_56w": "Precip. weeks 5-6",   
}


## Read in all metrics for all tasks and all models
Reads metrics, generates a summary of missing data, and produces the `all_metrics` dictionary to be used in further analysis. 

In [None]:
"""
Generate a dictionary with metric values for all models and every combination of gt_id, 
horizon, and target dates
"""

all_metrics = {}

# Get metrics for main experiment, rodeo experiment, salient experiment and ecmwf experiment
for metric, gt_id, horizon, target_dates in \
        [x for x in product(['rmse', 'skill'], us_1_5_gt_ids, horizons, ['std_paper_forecast'])]:
        #[x for x in product(['rmse', 'skill'], us_gt_ids, horizons, ['std_paper'])] \
        #+[x for x in product(['rmse'], contest_gt_ids, horizons, ['std_contest'])] \
        #+[x for x in product(['rmse'], contest_gt_ids, horizons, ['std_paper'])] \
        #+[x for x in product(['rmse', 'skill'], us_1_5_gt_ids, horizons, ['std_ecmwf'])]: 
   
    
    #Set model names   
    if 'us' in gt_id:
        model_names = ecmwf_experiment_models if '1.5x1.5' in gt_id else debias_experiment_models
        model_names_str = 'ecmwf_experiment_models' if '1.5x1.5' in gt_id else 'debias_experiment_models'
    elif 'contest' in gt_id:
        model_names = rodeo_experiment_models if 'contest' in target_dates else salient_experiment_models
        model_names_str = 'rodeo_experiment_models' if 'contest' in target_dates else 'salient_experiment_models'
    else:
        model_names = all_models
        model_names_str = 'all_models'
        

    
    model_names = debias_experiment_models
    # Get task
    task = f"{gt_id}_{horizon}"

    #display(Markdown(f"### Loading metric {metric} for task {task} and dates {target_dates}"))
    display(Markdown(f"### {model_names_str}: {metric}, {task}, {target_dates}"))

    # Get all metrics
    df = get_metrics_df(gt_id, horizon, metric, target_dates, model_names=model_names)
    # No models exist for this task    
    if df is None: 
        continue

    # Add yearly and quarterly columns to the dataframe
    df = add_groupby_cols(df, horizon=horizon)

    all_metrics[(metric, task, target_dates)] = copy.copy(df)
    #print(all_metrics)

    if metric in ['rmse', 'skill']:
        key = (metric, task, target_dates)
        try:        
            missing_df = all_metrics[key].loc[(all_metrics[key][model_names].isnull().any(axis=1)),:]
        except:        
            missing_df = all_metrics[key].loc[(all_metrics[key].isnull().any(axis=1)),:]            
        if missing_df.shape[0] != 0:
            True
            display(Markdown(f"#### Missing metrics"))
            display(missing_df)    
        else:
            printf("All metrics present.")

# Main manuscript figures

## Figure 1: Raw vs. ABC models
### Skill barplots 
This code produces maps to compare models' skill over the evaluation period 2018-2021. 

In [None]:
sns.set_context("notebook", font_scale=2.5, rc={"lines.linewidth": 0.5})
sns.set_theme(style="whitegrid")
sns.set_palette("Paired")
sns.set(font_scale = 1.5, rc={'font.weight': 'bold', 'figure.facecolor':'white', "lines.linewidth": 0.75})
sns.set_style("whitegrid")#, {'legend.frameon':True})

def barplot_rawabc(model_names, gt_id, horizon, metric, target_dates, show=True): 
    target_dates_objs = get_target_dates(target_dates)
    target_dates_start = datetime.strftime(target_dates_objs[0], '%Y-%m-%d')
    target_dates_end = datetime.strftime(target_dates_objs[-1], '%Y-%m-%d')
    target_dates_str = target_dates.replace('cold_','Cold wave, ').replace('texas','Texas').replace('gl','Great Lakes').replace('ne','New England')
    figure_models_missing56w = [
    "raw_fimr1p1",
    "raw_gefs",
    "raw_gem",
    "abc_fimr1p1",
    "abc_gefs",
    "abc_gem",    
    ]
    task = f'{gt_id}_{horizon}'
    if horizon == '56w':
        model_names = [m for m in model_names if m not in figure_models_missing56w]
    df_barplot = pd.DataFrame(columns=['start_date', metric, 'debias_method', 'model'])
    for i, m in enumerate(model_names):
    #     print(m)
        sn = get_selected_submodel_name(m, gt_id, horizon)
        f = os.path.join('eval', 'metrics', m, 'submodel_forecasts', sn, task, f'{metric}-{task}-{target_dates}.h5')
        if os.path.isfile(f):
            df = pd.read_hdf(f)
            df['debias_method'] = 'Dynamical' if m.split('_')[0]=='raw' else 'ABC'
            df['model'] = all_model_names[f"raw_{'_'.join(m.split('_')[1:])}"]
            df_barplot = df_barplot.append(df)
        else:
            printf(f"Metrics file missing for {metric} {m} {task}")
#     display(df_barplot)
    #Save Figure source data                                      
    fig_dir = os.path.join("subseasonal_toolkit","viz","abc_figures_source_data")
    if os.path.isdir(fig_dir) is False:
        make_directories(fig_dir)
    fig_filename = os.path.join(fig_dir, "fig_1-average_forecast_skill.xlsx")
    if os.path.isfile(fig_filename):
        with pd.ExcelWriter(fig_filename, mode='a') as writer:  
            df_barplot.to_excel(writer, sheet_name=task, na_rep="NaN") 
    else:
        with pd.ExcelWriter(fig_filename) as writer:  
            df_barplot.to_excel(writer, sheet_name=task, na_rep="NaN") 
                                    
    ax = sns.barplot(x="model", y=metric, hue="debias_method", data=df_barplot, ci=95, capsize=0.1, palette={
    'Dynamical': 'red',
    'ABC': 'skyblue'
})
    fig_title = f"{task.replace('_','').replace('precip',' Precipitation').replace('tmp2m',' Temperature').replace('us','U.S.')}"
    fig_title = fig_title.replace('56w', ', weeks 5-6').replace('34w', ', weeks 3-4').replace('12w', ', weeks 1-2').replace('1.5x1.5', '')
    fig_title = f"{fig_title}\n"#{target_dates_str}: {target_dates_start} to {target_dates_end}"
    ax.set_title(fig_title, weight='bold')
    if '56w' in horizon:
        ax.set_xticklabels(ax.get_xticklabels(), fontdict={'size': 16}, rotation = 90)
    else:
        ax.set_xticklabels(ax.get_xticklabels(), fontdict={'size': 11}, rotation = 90)
    ax.set(xlabel=None)
    ax.set_ylabel('Skill', fontdict={'weight': 'bold'})
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles=handles[0:], labels=labels[0:], frameon=True, edgecolor='white', framealpha=1)
    if target_dates.startswith('std_'):
        if 'precip' in gt_id and '12' not in horizon:
            ax. set(ylim=(-0.025, 0.3))
        elif 'precip' in gt_id and '12' in horizon:
            ax. set(ylim=(-0.025, 0.65))
        elif 'tmp2m' in gt_id and '12' not in horizon:
            ax. set(ylim=(-0.03, 0.5))
        elif 'tmp2m' in gt_id and '12' in horizon:
            ax. set(ylim=(-0.03, 0.9))
    fig = ax.get_figure()
    out_file = f"subseasonal_toolkit/viz/barplot_{metric}_{task}_{target_dates}.pdf"
    fig.savefig(out_file, bbox_inches='tight') 
    subprocess.call("chmod a+w "+out_file, shell=True)
    subprocess.call("chown $USER:sched_mit_hill "+out_file, shell=True)
    print(f"\nFigure saved: {out_file}\n")
    if show is False:
        fig.clear()
        plt.close(fig)  

figure_models = [
    # Baselines
    "raw_ecmwf",
    "raw_ccsm4", 
    "raw_cfsv2",
    "raw_geos_v2p1",
    "raw_nesm",
    "raw_fimr1p1",
    "raw_gefs",
    "raw_gem",
    "raw_subx_mean",
    
    # Ensembles 
    "abc_ecmwf", 
    "abc_ccsm4",
    "abc_cfsv2",
    "abc_geos_v2p1",
    "abc_nesm",
    "abc_fimr1p1",
    "abc_gefs",
    "abc_gem", 
    "abc_subx_mean",
    
]

# gt_id = 'us_tmp2m_1.5x1.5'
# horizon = '12w'
# metric = 'skill'
# target_dates = 'std_paper_forecast'
# barplot_rawabc(model_names=figure_models, gt_id=gt_id, horizon=horizon, metric=metric, target_dates=target_dates)

#Run this to generate figure source data
fig_filename = os.path.join("subseasonal_toolkit","viz","abc_figures_source_data", "fig_1-average_forecast_skill.xlsx")
if os.path.isfile(fig_filename):
    os.remove(fig_filename)
for gt_id, horizon in product(us_1_5_gt_ids, horizons):
    barplot_rawabc(model_names=figure_models, gt_id=gt_id, horizon=horizon, metric=metric, target_dates=target_dates, show=False)

## Figure 2: Raw vs. Debiased and ABC models
### lat_lon_skill map
This code produces maps to analyze skill for different models. 

In [None]:
figure_target_dates = 'std_paper_forecast'
figure_metrics = ['lat_lon_skill','skill']
figure_models = [
    # Baselines
    "raw_cfsv2",
    "raw_ecmwf",
#     "raw_subx_mean",
    # Standard de-biasing
    "deb_cfsv2",
    "deb_ecmwf",
#     "deb_subx_mean",
    # Ensembles 
    "abc_cfsv2",
    "abc_ecmwf",
#     "abc_subx_mean",
]

#RDA: models for which Raw, Debiased and Abc versions are available
metric_dfs_rda = {}
for gt_id, horizon in product(us_1_5_gt_ids, horizons):
    task = f"{gt_id}_{horizon}"
#     if 'us_tmp2m_1.5x1.5_34w' not in task:
#         continue
    display(Markdown(f"#### Getting metrics for {gt_id} {horizon}"))
    metric_dfs_rda[task] = get_models_metric_lat_lon(gt_id=gt_id, horizon=horizon, target_dates=figure_target_dates, metrics = figure_metrics, model_names=figure_models)

In [None]:
# Figure parameter values
figure_gt_ids = us_1_5_gt_ids
figure_horizons = horizons
figure_metric = 'lat_lon_skill'
figure_mean_metric_df = metric_dfs_rda #None
figure_show = False


figure_model_names = ["raw_cfsv2", "deb_cfsv2", "abc_cfsv2"]
#for gt_id in figure_gt_ids:
for gt_id in figure_gt_ids:
    plot_metric_maps(metric_dfs_rda, model_names=figure_model_names,
                         gt_ids=[gt_id],
                         horizons=figure_horizons,
                         metric=figure_metric,
                         target_dates=figure_target_dates,
                         mean_metric_df=figure_mean_metric_df,
                         show=figure_show, 
                         scale_type='linear',
                         CB_colors_customized=["white", "#dede00", "#ff7f00", "blueviolet", "indigo", "yellowgreen", "lightgreen", "darkgreen"],
                         CB_minmax = (0, 90))
    
figure_model_names = ["raw_ecmwf", "deb_ecmwf", "abc_ecmwf"]
#for gt_id in figure_gt_ids:
for gt_id in figure_gt_ids:
    plot_metric_maps(metric_dfs_rda, model_names=figure_model_names,
                         gt_ids=[gt_id],
                         horizons=figure_horizons,
                         metric=figure_metric,
                         target_dates=figure_target_dates,
                         mean_metric_df=figure_mean_metric_df,
                         show=figure_show, 
                         scale_type='linear',
                         CB_colors_customized=["white", "#dede00", "#ff7f00", "blueviolet", "indigo", "yellowgreen", "lightgreen", "darkgreen"],
                         CB_minmax = (0, 90))

In [None]:
#Save Figure source data                                      
fig_dir = os.path.join("subseasonal_toolkit","viz","abc_figures_source_data")
if os.path.isdir(fig_dir) is False:
    make_directories(fig_dir)
fig_filename = os.path.join(fig_dir, "fig_2-spatial_skill_distribution.xlsx")    
if os.path.isfile(fig_filename):
    os.remove(fig_filename)

for gt_id, horizon, metric in product(figure_gt_ids, figure_horizons, [figure_metric]):
    task = f"{gt_id}_{horizon}"
    printf(f"Processing {task} {metric}")
#     display(metric_dfs_rda[task][metric])
    if os.path.isfile(fig_filename):
        with pd.ExcelWriter(fig_filename, mode='a') as writer:  
            metric_dfs_rda[task][metric].to_excel(writer, sheet_name=task, na_rep="NaN") 
    else:
        with pd.ExcelWriter(fig_filename) as writer:  
            metric_dfs_rda[task][metric].to_excel(writer, sheet_name=task, na_rep="NaN") 

## Figure 3: Raw vs. ABC models
### lat_lon_error map
This code produces maps to analyze bias for different models. 

In [None]:
figure_target_dates = 'std_paper_forecast'
figure_metrics = ['lat_lon_error']
figure_model_names = [
    # Baselines
    "raw_cfsv2",
    "raw_ecmwf",
#     "raw_subx_mean",
#     # Standard de-biasing
#     "deb_cfsv2",
#     "deb_ecmwf",
#     "deb_subx_mean",
    # Ensembles 
    "abc_cfsv2",
    "abc_ecmwf",
#     "abc_subx_mean",
]

#RDA: models for which Raw, Debiased and Abc versions are available
metric_dfs_rda = {}
for gt_id, horizon in product(us_1_5_gt_ids, horizons):
    task = f"{gt_id}_{horizon}"
    display(Markdown(f"#### Getting metrics for {gt_id} {horizon}"))
    metric_dfs_rda[task] = get_models_metric_lat_lon(gt_id=gt_id, horizon=horizon, target_dates=figure_target_dates, metrics = figure_metrics, model_names=figure_model_names)


In [None]:
# Figure parameter values
figure_gt_ids = us_1_5_gt_ids
figure_horizons = horizons
figure_metric = 'lat_lon_error'
figure_mean_metric_df = None
figure_show = True

figure_model_names = ["raw_cfsv2", "abc_cfsv2"]
for gt_id in figure_gt_ids:
    if 'tmp2m' in gt_id:
            CB_colors_customized=['blue','dodgerblue','lightskyblue',"white",'red','firebrick','darkred']
            CB_minmax = (-4, 4)
    else:
            CB_colors_customized=['orangered','darkorange',"white",'forestgreen','darkgreen']
            CB_minmax = (-15, 15)
    plot_metric_maps(metric_dfs_rda, model_names=figure_model_names,
                         gt_ids=[gt_id],
                         horizons=figure_horizons,
                         metric=figure_metric,
                         target_dates=figure_target_dates,
                         mean_metric_df=figure_mean_metric_df,
                         show=figure_show, 
                         scale_type='linear',
                         CB_colors_customized=CB_colors_customized,
                         CB_minmax = CB_minmax)

figure_model_names = ["raw_ecmwf", "abc_ecmwf"]
for gt_id in figure_gt_ids:
    if 'tmp2m' in gt_id:
            CB_colors_customized=['blue','dodgerblue','lightskyblue',"white",'red','firebrick','darkred']
            CB_minmax = (-4, 4)
    else:
            CB_colors_customized=['orangered','darkorange',"white",'forestgreen','darkgreen']
            CB_minmax = (-15, 15)
    plot_metric_maps(metric_dfs_rda, model_names=figure_model_names,
                         gt_ids=[gt_id],
                         horizons=figure_horizons,
                         metric=figure_metric,
                         target_dates=figure_target_dates,
                         mean_metric_df=figure_mean_metric_df,
                         show=figure_show, 
                         scale_type='linear',
                         CB_colors_customized=CB_colors_customized,
                         CB_minmax = CB_minmax)

In [None]:
#Save Figure source data                                      
fig_dir = os.path.join("subseasonal_toolkit","viz","abc_figures_source_data")
if os.path.isdir(fig_dir) is False:
    make_directories(fig_dir)
fig_filename = os.path.join(fig_dir, "fig_3-spatial_bias_distribution.xlsx")    
if os.path.isfile(fig_filename):
    os.remove(fig_filename)

for gt_id, horizon, metric in product(figure_gt_ids, figure_horizons, [figure_metric]):
    task = f"{gt_id}_{horizon}"
    printf(f"Processing {task} {metric}")
#     display(metric_dfs_rda[task][metric])
    if os.path.isfile(fig_filename):
        with pd.ExcelWriter(fig_filename, mode='a') as writer:  
            metric_dfs_rda[task][metric].to_excel(writer, sheet_name=task, na_rep="NaN") 
    else:
        with pd.ExcelWriter(fig_filename) as writer:  
            metric_dfs_rda[task][metric].to_excel(writer, sheet_name=task, na_rep="NaN") 


## Figure 4: HGT_500 impact
See subseasonal_toolkit/viz/cohort_shapley_all.ipynb

## Figure 5: MJO_phase impact
See subseasonal_toolkit/viz/cohort_shapley_all.ipynb

## Figure 6: Opportunistic ABC
See subseasonal_toolkit/viz/cohort_shapley_all.ipynb

## Figure 7: ABC flowchart
See subseasonal_toolkit/viz/cohort_shapley_all.ipynb

# A1. Average forecast skill bar plots

## Figure A1: Raw vs. ABC models
### Skill barplots per season 

In [None]:
# from viz_util import (highlight_min, highlight_max, 
#                       bold_min, bold_max, 
#                       get_per_period_metrics_df, all_model_types,
#                      styles)
from subseasonal_toolkit.utils.eval_util import get_target_dates, score_to_mean_rmse, contest_quarter_start_dates, contest_quarter, year_quarter, mean_rmse_to_score

sns.set_context("notebook", font_scale=2.5, rc={"lines.linewidth": 0.5})
sns.set_theme(style="whitegrid")
sns.set_palette("Paired")
sns.set(font_scale = 1.5, rc={'font.weight': 'bold', 'figure.facecolor':'white', "lines.linewidth": 0.75})
sns.set_style("whitegrid")#, {'legend.frameon':True})

def barplot_rawabc_quarterly(model_names, gt_id, horizon, metric, target_dates, quarter, show=True): 
    target_dates_objs = get_target_dates(target_dates)
    target_dates_start = datetime.strftime(target_dates_objs[0], '%Y-%m-%d')
    target_dates_end = datetime.strftime(target_dates_objs[-1], '%Y-%m-%d')
    target_dates_str = target_dates.replace('cold_','Cold wave, ').replace('texas','Texas').replace('gl','Great Lakes').replace('ne','New England')
    figure_models_missing56w = [
    "raw_fimr1p1",
    "raw_gefs",
    "raw_gem",
    "abc_fimr1p1",
    "abc_gefs",
    "abc_gem",    
    ]
    task = f'{gt_id}_{horizon}'
    if horizon == '56w':
        model_names = [m for m in model_names if m not in figure_models_missing56w]
    df_barplot = pd.DataFrame(columns=['start_date', metric, 'debias_method', 'model'])
    for i, m in enumerate(model_names):
    #     print(m)
        sn = get_selected_submodel_name(m, gt_id, horizon)
        f = os.path.join('eval', 'metrics', m, 'submodel_forecasts', sn, task, f'{metric}-{task}-{target_dates}.h5')
        if os.path.isfile(f):
            df = pd.read_hdf(f)
            df['debias_method'] = 'Dynamical' if m.split('_')[0]=='raw' else 'ABC'
            df['model'] = all_model_names[f"raw_{'_'.join(m.split('_')[1:])}"]
            df_barplot = df_barplot.append(df)
        else:
            printf(f"Metrics file missing for {metric} {m} {task}")
    
    quarter_index = pd.Series([f"Q{year_quarter(date)}" for date in df_barplot.start_date], index=df_barplot.index)
    df_barplot['quarter'] = quarter_index
    quarter_names = {"Q0":"DJF", "Q1":"MAM", "Q2":"JJA", "Q3":"SON"}
    df_barplot["quarter"].replace(quarter_names, inplace=True)
    df_barplot = df_barplot[df_barplot.quarter==quarter]
                                          
    ax = sns.barplot(x="model", y=metric, hue="debias_method", data=df_barplot, ci=95, capsize=0.1, palette={
    'Dynamical': 'red',
    'ABC': 'skyblue'
})
    fig_title = f"{task.replace('_','').replace('precip',' Precipitation').replace('tmp2m',' Temperature').replace('us','U.S.')}"
    fig_title = fig_title.replace('56w', ', weeks 5-6').replace('34w', ', weeks 3-4').replace('12w', ', weeks 1-2').replace('1.5x1.5', '')
    fig_title = f"{fig_title}\n{quarter}"#{target_dates_str}: {target_dates_start} to {target_dates_end}"
    ax.set_title(fig_title, weight='bold', fontdict={'size': 25})
    ax.set_xticklabels(ax.get_xticklabels(), fontdict={'size': 25}, rotation = 90)
    ax.set(xlabel=None)
    ax.set_ylabel('Skill', fontdict={'weight': 'bold', 'size': 25})
    if quarter == "DJF":
        handles, labels = ax.get_legend_handles_labels()
        ax.legend(handles=handles[0:], labels=labels[0:], frameon=True, edgecolor='white', framealpha=1)
    else:
        ax.legend_.remove()
        ax.set(ylabel=None)
    dic_ylim = {"us_tmp2m_1.5x1.5_12w": (-0.15, 0.7),
                "us_tmp2m_1.5x1.5_34w": (-0.1, 0.6),
                "us_tmp2m_1.5x1.5_56w": (-0.2, 0.6),
                "us_precip_1.5x1.5_12w": (-0.5, 14),
                "us_precip_1.5x1.5_34w": (-0.1, 0.4),
                "us_precip_1.5x1.5_56w": (-0.1, 0.4),
               }
    ax. set(ylim=dic_ylim[task])
    fig = ax.get_figure()
    out_file = f"subseasonal_toolkit/viz/barplot_{metric}_{task}_{target_dates}_{quarter.lower()}.pdf"
    fig.savefig(out_file, bbox_inches='tight') 
    subprocess.call("chmod a+w "+out_file, shell=True)
    subprocess.call("chown $USER:sched_mit_hill "+out_file, shell=True)
    print(f"\nFigure saved: {out_file}\n")
    if show is False:
        fig.clear()
        plt.close(fig) 
    return df_barplot

figure_models = [
    # Baselines
    "raw_ecmwf",
    "raw_ccsm4", 
    "raw_cfsv2",
    "raw_geos_v2p1",
    "raw_nesm",
    "raw_fimr1p1",
    "raw_gefs",
    "raw_gem",
    "raw_subx_mean",
    
    # Ensembles 
    "abc_ecmwf", 
    "abc_ccsm4",
    "abc_cfsv2",
    "abc_geos_v2p1",
    "abc_nesm",
    "abc_fimr1p1",
    "abc_gefs",
    "abc_gem", 
    "abc_subx_mean",
    
]

gt_id = 'us_precip_1.5x1.5'
horizon = '34w'
metric = 'skill'
target_dates = 'std_paper_forecast'
figure_quarters = ["DJF", "MAM", "JJA", "SON"]
quarter = "DJF"
show = True

df_barplot = barplot_rawabc_quarterly(model_names=figure_models, gt_id=gt_id, horizon=horizon,  metric=metric, target_dates=target_dates, quarter=quarter, show=show)
# display(df_barplot)
# for quarter in figure_quarters:
#     df_barplot = barplot_rawabc_quarterly(model_names=figure_models, 
#                                               gt_id=gt_id, 
#                                               horizon=horizon, 
#                                               metric=metric, 
#                                               target_dates=target_dates,
#                                               quarter=quarter,
#                                               show=False)
                                          


In [None]:
#Save Figure source data 
# Figure parameter values
figure_gt_ids = us_1_5_gt_ids
figure_horizons = horizons
figure_metrics = ['skill']
figure_target_dates = 'std_paper_forecast'
figure_quarters = ["DJF", "MAM", "JJA", "SON"]

fig_dir = os.path.join("subseasonal_toolkit","viz","abc_figures_source_data")
if os.path.isdir(fig_dir) is False:
    make_directories(fig_dir)
fig_filename = os.path.join(fig_dir, "fig_a1-average_skill_season.xlsx")    
if os.path.isfile(fig_filename):
    os.remove(fig_filename)

for gt_id, horizon, metric, quarter in product(figure_gt_ids, figure_horizons, figure_metrics, figure_quarters):
    task = f"{gt_id}_{horizon}_{quarter.lower()}"
    printf(f"Processing {task} {metric}")
    df_barplot = barplot_rawabc_quarterly(model_names=figure_models, gt_id=gt_id, horizon=horizon,  metric=metric, target_dates=figure_target_dates, quarter=quarter, show=False)
#     display(metric_dfs_rda[task][metric])
    if os.path.isfile(fig_filename):
        with pd.ExcelWriter(fig_filename, mode='a') as writer:  
            df_barplot.to_excel(writer, sheet_name=task, na_rep="NaN") 
    else:
        with pd.ExcelWriter(fig_filename) as writer:  
            df_barplot.to_excel(writer, sheet_name=task, na_rep="NaN") 

## Figure A2: Baselines vs. ABC models
### Skill barplots 
This code produces maps to compare models' skill over the evaluation period 2018-2021. 

In [None]:
sns.set_context("notebook", font_scale=2.5, rc={"lines.linewidth": 0.5})
sns.set_theme(style="whitegrid")
sns.set_palette("Paired")
sns.set(font_scale = 1.5, rc={'font.weight': 'bold', 'figure.facecolor':'white', "lines.linewidth": 0.75})
sns.set_style("whitegrid")#, {'legend.frameon':True})

def barplot_baselineabc(model_names, gt_id, horizon, metric, target_dates, show=True): 
    target_dates_objs = get_target_dates(target_dates)
    task = f"{gt_id}_{horizon}"
    df_barplot = pd.DataFrame(columns=['start_date', metric, 'model'])
    for i, m in enumerate(model_names):
    #     print(m)
        sn = get_selected_submodel_name(m, gt_id, horizon)
        f = os.path.join('eval', 'metrics', m, 'submodel_forecasts', sn, task, f'{metric}-{task}-{target_dates}.h5')
        if os.path.isfile(f):
            df = pd.read_hdf(f)
#             df['debias_method'] = 'Dynamical' if m.split('_')[0]=='raw' else 'ABC'
            df['model'] = all_model_names[m]#f"raw_{'_'.join(m.split('_')[1:])}"]
            df_barplot = df_barplot.append(df)
        else:
            printf(f"Metrics file missing for {metric} {m} {task}\n{f}")
    
    ax = sns.barplot(x="model", y=metric, data=df_barplot, ci=95, capsize=0.1, color="skyblue")#, 
#                      hue="debias_method", palette={'Dynamical': 'red','ABC': 'skyblue' })
    fig_title = f"{task.replace('_','').replace('precip',' Precipitation').replace('tmp2m',' Temperature').replace('us','U.S.')}"
    fig_title = fig_title.replace('56w', ', weeks 5-6').replace('34w', ', weeks 3-4').replace('12w', ', weeks 1-2').replace('1.5x1.5', '')
    fig_title = f"{fig_title}\n"#{target_dates_str}: {target_dates_start} to {target_dates_end}"
    ax.set_title(fig_title, weight='bold')
    ax.set_xticklabels(ax.get_xticklabels(), fontdict={'size': 16}, rotation = 90)
    ax.set(xlabel=None)
    ax.set_ylabel('Skill', fontdict={'weight': 'bold'})
    handles, labels = ax.get_legend_handles_labels()
    dic_ylim = {"us_tmp2m_1.5x1.5_12w": (-0.2, 0.9),
                "us_tmp2m_1.5x1.5_34w": (-0.2, 0.5),
                "us_tmp2m_1.5x1.5_56w": (-0.2, 0.5),
                "us_precip_1.5x1.5_12w": (-0.025, 0.65),
                "us_precip_1.5x1.5_34w": (-0.025, 0.3),
                "us_precip_1.5x1.5_56w": (-0.025, 0.3),
               }
    ax. set(ylim=dic_ylim[task])
    fig = ax.get_figure()
    out_file = f"subseasonal_toolkit/viz/barplot_baselines_{metric}_{task}_{target_dates}.pdf"
    fig.savefig(out_file, bbox_inches='tight') 
#     fig.savefig(out_file.replace(".pdf", ".png"), bbox_inches='tight') 
    subprocess.call("chmod a+w "+out_file, shell=True)
    subprocess.call("chown $USER:sched_mit_hill "+out_file, shell=True)
    print(f"\nFigure saved: {out_file}\n")
    if show is False:
        fig.clear()
        plt.close(fig)  
    return df_barplot

figure_models = [
    # Ensembles 
    "abc_ecmwf", 
    "abc_cfsv2",
    # Baselines 
    "nn-a",
    "deb_loess_ecmwf", 
    "deb_loess_cfsv2", 
    "deb_quantile_ecmwf", 
    "deb_quantile_cfsv2",
    
]

gt_id = 'us_precip_1.5x1.5'
horizon = '34w'
metric = 'skill'
target_dates = 'std_paper_forecast'

df_barplot = barplot_baselineabc(model_names=figure_models, gt_id=gt_id, horizon=horizon, metric=metric, target_dates=target_dates)
# display(df_barplot)

In [None]:
#Save Figure source data 
# Figure parameter values
figure_gt_ids = us_1_5_gt_ids
figure_horizons = horizons
figure_metrics = ['skill']
figure_target_dates = 'std_paper_forecast'
figure_models = [
    # Ensembles 
    "abc_ecmwf", 
    "abc_cfsv2",
    # Baselines 
    "nn-a",
    "deb_loess_ecmwf", 
    "deb_loess_cfsv2", 
    "deb_quantile_ecmwf", 
    "deb_quantile_cfsv2",
]

fig_dir = os.path.join("subseasonal_toolkit","viz","abc_figures_source_data")
if os.path.isdir(fig_dir) is False:
    make_directories(fig_dir)
fig_filename = os.path.join(fig_dir, "fig_a2-average_skill_baselines.xlsx")    
if os.path.isfile(fig_filename):
    os.remove(fig_filename)

for gt_id, horizon, metric in product(figure_gt_ids, figure_horizons, figure_metrics):
    task = f"{gt_id}_{horizon}"
    printf(f"Processing {task} {metric}")
    df_barplot = barplot_baselineabc(model_names=figure_models, 
                                     gt_id=gt_id, 
                                     horizon=horizon, 
                                     metric=metric, 
                                     target_dates=figure_target_dates,
                                     show=False)
    if os.path.isfile(fig_filename):
        with pd.ExcelWriter(fig_filename, mode='a') as writer:  
            df_barplot.to_excel(writer, sheet_name=task, na_rep="NaN") 
    else:
        with pd.ExcelWriter(fig_filename) as writer:  
            df_barplot.to_excel(writer, sheet_name=task, na_rep="NaN") 

# A2. Spatial skill distribution plots

## Figure A3: SubX mean - Raw vs. Debiased and ABC models
### lat_lon_skill map
This code produces maps to analyze skill for different models. 

In [None]:
figure_target_dates = 'std_paper_forecast'
figure_metrics = ['lat_lon_skill','skill']
figure_model_names = [
    # Baselines
    "raw_subx_mean",
    # Standard de-biasing
    "deb_subx_mean",
    # Ensembles 
    "abc_subx_mean",
]

#RDA: models for which Raw, Debiased and Abc versions are available
metric_dfs_rda = {}
for gt_id, horizon in product(us_1_5_gt_ids, horizons):
    task = f"{gt_id}_{horizon}"
#     if 'us_tmp2m_1.5x1.5_34w' not in task:
#         continue
    display(Markdown(f"#### Getting metrics for {gt_id} {horizon}"))
    metric_dfs_rda[task] = get_models_metric_lat_lon(gt_id=gt_id, horizon=horizon, target_dates=figure_target_dates, metrics = figure_metrics, model_names=figure_model_names)
    


In [None]:
# Figure parameter values
figure_gt_ids = us_1_5_gt_ids
figure_horizons = horizons
figure_metric = 'lat_lon_skill'
figure_mean_metric_df = metric_dfs_rda #None
figure_show = False


figure_model_names = ["raw_subx_mean", "deb_subx_mean", "abc_subx_mean"]
#for gt_id in figure_gt_ids:
for gt_id in figure_gt_ids:
    plot_metric_maps(metric_dfs_rda, model_names=figure_model_names,
                         gt_ids=[gt_id],
                         horizons=figure_horizons,
                         metric=figure_metric,
                         target_dates=figure_target_dates,
                         mean_metric_df=figure_mean_metric_df,
                         show=figure_show, 
                         scale_type='linear',
                         CB_colors_customized=["white", "#dede00", "#ff7f00", "blueviolet", "indigo", "yellowgreen", "lightgreen", "darkgreen"],
                         CB_minmax = (0, 90))


In [None]:
#Save Figure source data                                      
fig_dir = os.path.join("subseasonal_toolkit","viz","abc_figures_source_data")
if os.path.isdir(fig_dir) is False:
    make_directories(fig_dir)
fig_filename = os.path.join(fig_dir, "fig_a3-spatial_skill_subx_mean.xlsx")    
if os.path.isfile(fig_filename):
    os.remove(fig_filename)

for gt_id, horizon, metric in product(figure_gt_ids, figure_horizons, [figure_metric]):
    task = f"{gt_id}_{horizon}"
    printf(f"Processing {task} {metric}")
#     display(metric_dfs_rda[task][metric])
    if os.path.isfile(fig_filename):
        with pd.ExcelWriter(fig_filename, mode='a') as writer:  
            metric_dfs_rda[task][metric].to_excel(writer, sheet_name=task, na_rep="NaN") 
    else:
        with pd.ExcelWriter(fig_filename) as writer:  
            metric_dfs_rda[task][metric].to_excel(writer, sheet_name=task, na_rep="NaN") 

## Figure A4: Baselines - Raw vs. Debiased and ABC models
### lat_lon_skill map
This code produces maps to analyze skill for different models. 

In [None]:
figure_target_dates = 'std_paper_forecast'
figure_metrics = ['lat_lon_skill','skill']
figure_model_names = [
    # Baselines
    "nn-a",
    "deb_loess_ecmwf", 
    "deb_loess_cfsv2", 
    "deb_quantile_ecmwf", 
    "deb_quantile_cfsv2",
    # Ensembles 
    "abc_ecmwf",
    "abc_cfsv2",
]

#RDA: models for which Raw, Debiased and Abc versions are available
metric_dfs_rda = {}
for gt_id, horizon in product(us_1_5_gt_ids, horizons):
    task = f"{gt_id}_{horizon}"
#     if 'us_tmp2m_1.5x1.5_34w' not in task:
#         continue
    display(Markdown(f"#### Getting metrics for {gt_id} {horizon}"))
    metric_dfs_rda[task] = get_models_metric_lat_lon(gt_id=gt_id, horizon=horizon, target_dates=figure_target_dates, metrics = figure_metrics, model_names=figure_model_names)
    


In [None]:
# Figure parameter values
figure_gt_ids = us_1_5_gt_ids
figure_horizons = horizons
figure_metric = 'lat_lon_skill'
figure_mean_metric_df = metric_dfs_rda #None
figure_show = False


figure_model_names = ["nn-a", "deb_loess_cfsv2", "deb_quantile_cfsv2", "abc_cfsv2"]
#for gt_id in figure_gt_ids:
for gt_id in figure_gt_ids:
    plot_metric_maps(metric_dfs_rda, model_names=figure_model_names,
                         gt_ids=[gt_id],
                         horizons=figure_horizons,
                         metric=figure_metric,
                         target_dates=figure_target_dates,
                         mean_metric_df=figure_mean_metric_df,
                         show=figure_show, 
                         scale_type='linear',
                         CB_colors_customized=["white", "#dede00", "#ff7f00", "blueviolet", "indigo", "yellowgreen", "lightgreen", "darkgreen"],
                         CB_minmax = (0, 90))
    
figure_model_names = ["deb_loess_ecmwf", "deb_quantile_ecmwf", "abc_ecmwf"]
#for gt_id in figure_gt_ids:
for gt_id in figure_gt_ids:
    plot_metric_maps(metric_dfs_rda, model_names=figure_model_names,
                         gt_ids=[gt_id],
                         horizons=figure_horizons,
                         metric=figure_metric,
                         target_dates=figure_target_dates,
                         mean_metric_df=figure_mean_metric_df,
                         show=figure_show, 
                         scale_type='linear',
                         CB_colors_customized=["white", "#dede00", "#ff7f00", "blueviolet", "indigo", "yellowgreen", "lightgreen", "darkgreen"],
                         CB_minmax = (0, 90))

In [None]:
#Save Figure source data                                      
fig_dir = os.path.join("subseasonal_toolkit","viz","abc_figures_source_data")
if os.path.isdir(fig_dir) is False:
    make_directories(fig_dir)
fig_filename = os.path.join(fig_dir, "fig_a4-spatial_skill_baselines.xlsx")    
if os.path.isfile(fig_filename):
    os.remove(fig_filename)

for gt_id, horizon, metric in product(figure_gt_ids, figure_horizons, [figure_metric]):
    task = f"{gt_id}_{horizon}"
    printf(f"Processing {task} {metric}")
#     display(metric_dfs_rda[task][metric])
    if os.path.isfile(fig_filename):
        with pd.ExcelWriter(fig_filename, mode='a') as writer:  
            metric_dfs_rda[task][metric].to_excel(writer, sheet_name=task, na_rep="NaN") 
    else:
        with pd.ExcelWriter(fig_filename) as writer:  
            metric_dfs_rda[task][metric].to_excel(writer, sheet_name=task, na_rep="NaN") 

# A3. Probabilistic evaluation plots
#### Save bss and lat_lon_bss metric dataframes

In [None]:
def get_model_task_rps(model = "d2p_raw_ecmwf", gt_id = "us_tmp2m_1.5x1.5", horizon = "56w",  target_dates = 'std_paper_forecast'):
    
    measurement_var = "tmp2m" if "tmp2m" in gt_id else "precip"
    gt_id_p1 = f"us_{measurement_var}_p1_1.5x1.5"
    gt_id_p3 = f"us_{measurement_var}_p3_1.5x1.5"
    task_p1 = f"{gt_id_p1}_{horizon}"
    task_p3 = f"{gt_id_p3}_{horizon}"
    # Load the mse metrics for model and task_p1, task_p3
    metrics_dir = get_task_metrics_dir(model=model, submodel=None, gt_id=gt_id_p1, horizon=horizon, target_dates=target_dates)
    ###print(os.path.join(metrics_dir, f"mse-{task_p1}-{target_dates}.h5"))
    model_metrics = pd.read_hdf(os.path.join(metrics_dir, f"mse-{task_p1}-{target_dates}.h5"))
    model_metrics.rename(columns={"mse":"mse_p1"},  inplace=True)
    model_metrics["mse_p3"] = pd.read_hdf(os.path.join(metrics_dir.replace("_p1","_p3"), f"mse-{task_p3}-{target_dates}.h5")).mse
    # Compute model_rps = mse_p1 + mse_p3
    model_metrics["rps"] = model_metrics["mse_p1"] + model_metrics["mse_p3"]
    return model_metrics

def get_model_task_lat_lon_rps(model = "d2p_raw_ecmwf", gt_id = "us_tmp2m_1.5x1.5", horizon = "56w",  target_dates = 'std_paper_forecast'):
    measurement_var = "tmp2m" if "tmp2m" in gt_id else "precip"
    gt_id_p1 = f"us_{measurement_var}_p1_1.5x1.5"
    gt_id_p3 = f"us_{measurement_var}_p3_1.5x1.5"
    task_p1 = f"{gt_id_p1}_{horizon}"
    task_p3 = f"{gt_id_p3}_{horizon}"
    # Load the lat_lon_mse metrics for model and task_p1, task_p3
    metrics_dir = get_task_metrics_dir(model=model, submodel=None, gt_id=gt_id_p1, horizon=horizon, target_dates=target_dates)
    ###print(os.path.join(metrics_dir, f"mse-{task_p1}-{target_dates}.h5"))
    model_metrics = pd.read_hdf(os.path.join(metrics_dir, f"lat_lon_mse-{task_p1}-{target_dates}.h5"))
    model_metrics.rename(columns={"lat_lon_mse":"mse_p1"},  inplace=True)
    model_metrics["mse_p3"] = pd.read_hdf(os.path.join(metrics_dir.replace("_p1","_p3"), f"lat_lon_mse-{task_p3}-{target_dates}.h5")).lat_lon_mse
    # Compute model_rps = mse_p1 + mse_p3
    model_metrics["rps"] = model_metrics["mse_p1"] + model_metrics["mse_p3"]
    return model_metrics

if False:
    gt = {}
    prob_gt_ids = ["us_tmp2m_p1_1.5x1.5", "us_tmp2m_p3_1.5x1.5", "us_precip_p1_1.5x1.5", "us_precip_p3_1.5x1.5"]
    for prob_gt_id in prob_gt_ids:
        gt[prob_gt_id] = get_ground_truth(prob_gt_id).set_index(["lat","lon","start_date"]).squeeze().unstack(["lat","lon"])

    # Save BSS metric dataframes
    target_dates = "std_paper_forecast"
    target_date_objs = get_target_dates(target_dates)
    gt_ids = ["us_tmp2m_1.5x1.5","us_precip_1.5x1.5"]
    horizons = ["12w","34w","56w"]

    # BSS by date
    for gt_id in gt_ids:
        measurement_var = get_measurement_variable(gt_id)
        p1_var = measurement_var+"_p1"
        p3_var = measurement_var+"_p3"

        clim_rps = pd.DataFrame(data = {
            'mse_p1': np.square(gt[gt_id.replace(measurement_var, p1_var)].loc[target_date_objs] - 1./3).mean(axis=1),
            'mse_p3': np.square(gt[gt_id.replace(measurement_var, p3_var)].loc[target_date_objs] - 1./3).mean(axis=1)})
        clim_rps['rps'] = clim_rps['mse_p1']+clim_rps['mse_p3']
        for horizon in horizons:
            mets = {}
            for model in ["raw_ecmwf", "deb_ecmwf", "abcds_ecmwf", "shift_deb_loess_ecmwf", "shift_deb_quantile_ecmwf"]: 
                mets[model] = get_model_task_rps(model = f"d2p_{model}", gt_id = gt_id, horizon = horizon,  target_dates = 'std_paper_forecast').dropna().set_index("start_date")
    #             mets[model]['rpss'] = 1-mets[model]['rps']/clim_rps['rps']
                mets[model]['bss'] = 1-mets[model]['mse_p3']/clim_rps['mse_p3']
                sn = get_selected_submodel_name(model=model, gt_id=gt_id, horizon=horizon)
                metric = "bss"
                df_metric = mets[model].reset_index()[["start_date", metric]]#.rename(columns={"rpss":metric})
                save_metric(df_metric, model=model, submodel=sn,
                            gt_id=gt_id, horizon=horizon, 
                            target_dates=target_dates, metric=metric)
    # BSS by grid point
    for gt_id in gt_ids:
        measurement_var = get_measurement_variable(gt_id)
        p1_var = measurement_var+"_p1"
        p3_var = measurement_var+"_p3"

        clim_rps = pd.DataFrame(data = {
            'mse_p1': np.square(gt[gt_id.replace(measurement_var, p1_var)].loc[target_date_objs] - 1./3).mean(axis=0),
            'mse_p3': np.square(gt[gt_id.replace(measurement_var, p3_var)].loc[target_date_objs] - 1./3).mean(axis=0)})
        clim_rps['rps'] = clim_rps['mse_p1']+clim_rps['mse_p3']
        for horizon in horizons:
            mets = {}
            for model in ["raw_ecmwf", "deb_ecmwf", "abcds_ecmwf"]: 
                mets[model] = get_model_task_lat_lon_rps(model = f"d2p_{model}", gt_id = gt_id, horizon = horizon,  target_dates = 'std_paper_forecast').dropna().set_index(['lat','lon'])
    #             mets[model]['rpss'] = 1-mets[model]['rps']/clim_rps['rps']
                mets[model]['bss'] = 1-mets[model]['mse_p3']/clim_rps['mse_p3']
                sn = get_selected_submodel_name(model=model, gt_id=gt_id, horizon=horizon)
                metric = "lat_lon_bss"
                df_metric = mets[model].reset_index()[["lat", "lon", "bss"]].rename(columns={"bss":metric})
                save_metric(df_metric, model=model, submodel=sn,
                            gt_id=gt_id, horizon=horizon, 
                            target_dates=target_dates, metric=metric)
    

## Figure A5: Raw vs. ABC models
### BSS barplots by season 

In [None]:
sns.set_context("notebook", font_scale=2.5, rc={"lines.linewidth": 0.5})
sns.set_theme(style="whitegrid")
sns.set_palette("Paired")
sns.set(font_scale = 1.5, rc={'font.weight': 'bold', 'figure.facecolor':'white', "lines.linewidth": 0.75})
sns.set_style("whitegrid")#, {'legend.frameon':True})

def barplot_rawabc_bss(model_names, gt_id, horizon, metric, target_dates, show=True): 
    target_dates_objs = get_target_dates(target_dates)
    task = f'{gt_id}_{horizon}'
    df_barplot = pd.DataFrame(columns=['start_date', metric, 'model'])
    for i, m in enumerate(model_names):
    #     print(m)
        sn = get_selected_submodel_name(m, gt_id, horizon)
        f = os.path.join('eval', 'metrics', m, 'submodel_forecasts', sn, task, f'{metric}-{task}-{target_dates}.h5')
        if os.path.isfile(f):
            df = pd.read_hdf(f)
#             df['debias_method'] = 'Dynamical' if m.split('_')[0]=='raw' else 'ABC'
            df['model'] = all_model_names[m]#f"raw_{'_'.join(m.split('_')[1:])}"]
            df_barplot = df_barplot.append(df)
        else:
            printf(f"Metrics file missing for {metric} {m} {task}")
    df_barplot["quarter"] = pd.Series([f"Q{year_quarter(date)}" for date in df_barplot.start_date], index=df_barplot.index)
    df_barplot = df_barplot.replace({"quarter": {"Q0":"DJF", "Q1":"MAM", "Q2":"JJA", "Q3":"SON"}})
#     display(df_barplot)
    ax = sns.barplot(x="quarter", y=metric, hue="model", data=df_barplot, ci=95, capsize=0.1, palette={
        'ECMWF': 'red',
        'ABC-ECMWF': 'skyblue'
    })
    
    fig_title = f"{task.replace('_','').replace('precip',' Precipitation').replace('tmp2m',' Temperature').replace('us','U.S.')}"
    fig_title = fig_title.replace('56w', ', weeks 5-6').replace('34w', ', weeks 3-4').replace('12w', ', weeks 1-2').replace('1.5x1.5', '')
    fig_title = f"{fig_title}\n"
    ax.set_title(fig_title, weight='bold')
    ax.set(xlabel=None)
    ax.set_ylabel(metric.upper(), fontdict={'weight': 'bold'})
    if horizon == "12w":
        handles, labels = ax.get_legend_handles_labels()
        ax.legend(handles=handles[0:], labels=labels[0:], frameon=True, edgecolor='white', framealpha=1)
    else:
        ax.legend_.remove()
        ax.set(ylabel=None)
    dic_ylim = {"us_tmp2m_1.5x1.5_12w": (-0.1, 0.8),
                "us_tmp2m_1.5x1.5_34w": (-0.25, 0.4),
                "us_tmp2m_1.5x1.5_56w": (-0.25, 0.4),
                "us_precip_1.5x1.5_12w": (-0.1, 0.6),
                "us_precip_1.5x1.5_34w": (-0.15, 0.15),
                "us_precip_1.5x1.5_56w": (-0.15, 0.15),
               }
    ax. set(ylim=dic_ylim[task])
    dic_labelpad = {"us_tmp2m_1.5x1.5_12w": 30,
                "us_tmp2m_1.5x1.5_34w": 30,
                "us_tmp2m_1.5x1.5_56w": 30,
                "us_precip_1.5x1.5_12w": 30,
                "us_precip_1.5x1.5_34w": 30,
                "us_precip_1.5x1.5_56w": 30,
               }
    for container in ax.containers:
        ax.bar_label(container, fmt="%.2f", fontsize=15, fontweight='bold', padding=dic_labelpad[task])
    
    
    
    fig = ax.get_figure()
    out_file = f"subseasonal_toolkit/viz/barplot_{metric}_quarterly_{'_'.join(model_names)}_{task}_{target_dates}.pdf"
    fig.savefig(out_file, bbox_inches='tight') 
    subprocess.call("chmod a+w "+out_file, shell=True)
    subprocess.call("chown $USER:sched_mit_hill "+out_file, shell=True)
    print(f"\nFigure saved: {out_file}\n")
    if show is False:
        fig.clear()
        plt.close(fig)  
    return df_barplot

figure_model_names = [
    # Baselines
    "raw_ecmwf",   
    # Ensembles 
    "abcds_ecmwf",     
]

figure_gt_id = 'us_tmp2m_1.5x1.5'
figure_horizon = '12w'
figure_metric = 'bss'
figure_target_dates = 'std_paper_forecast'

df_barplot = barplot_rawabc_bss(model_names=figure_model_names, gt_id=figure_gt_id, horizon=figure_horizon, metric=figure_metric, target_dates=figure_target_dates)

In [None]:
#Save Figure source data 
# Figure parameter values
figure_gt_ids = us_1_5_gt_ids
figure_horizons = horizons
figure_metrics = ['bss']
figure_target_dates = 'std_paper_forecast'

fig_dir = os.path.join("subseasonal_toolkit","viz","abc_figures_source_data")
if os.path.isdir(fig_dir) is False:
    make_directories(fig_dir)
fig_filename = os.path.join(fig_dir, "fig_a5-average_bss.xlsx")    
if os.path.isfile(fig_filename):
    os.remove(fig_filename)

for gt_id, horizon, metric in product(figure_gt_ids, figure_horizons, figure_metrics):
    task = f"{gt_id}_{horizon}"
    printf(f"Processing {task} {metric}")
    df_barplot = barplot_rawabc_bss(model_names=figure_model_names, 
                                    gt_id=gt_id, 
                                    horizon=horizon, 
                                    metric=metric, 
                                    target_dates=target_dates,
                                    show=False)
    if os.path.isfile(fig_filename):
        with pd.ExcelWriter(fig_filename, mode='a') as writer:  
            df_barplot.to_excel(writer, sheet_name=task, na_rep="NaN") 
    else:
        with pd.ExcelWriter(fig_filename) as writer:  
            df_barplot.to_excel(writer, sheet_name=task, na_rep="NaN") 

## Figure A6: Raw vs. ABC models
### CRPS barplots by season 

In [None]:
sns.set_context("notebook", font_scale=2.5, rc={"lines.linewidth": 0.5})
sns.set_theme(style="whitegrid")
sns.set_palette("Paired")
sns.set(font_scale = 1.5, rc={'font.weight': 'bold', 'figure.facecolor':'white', "lines.linewidth": 0.75})
sns.set_style("whitegrid")#, {'legend.frameon':True})

def barplot_rawabc_crps(model_names, gt_id, horizon, metric, target_dates, show=True): 
    target_dates_objs = get_target_dates(target_dates)
    task = f'{gt_id}_{horizon}'
    df_barplot = pd.DataFrame(columns=['start_date', metric, 'model'])
    for i, m in enumerate(model_names):
    #     print(m)
        sn = get_selected_submodel_name(m, gt_id, horizon)
        f = os.path.join('eval', 'metrics', m, 'submodel_forecasts', sn, task, f'{metric}-{task}-{target_dates}.h5')
        if os.path.isfile(f):
            df = pd.read_hdf(f)
            df['model'] = all_model_names[m]#f"raw_{'_'.join(m.split('_')[1:])}"]
            df_barplot = df_barplot.append(df)
        else:
            printf(f"Metrics file missing for {metric} {m} {task}")
    df_barplot["quarter"] = pd.Series([f"Q{year_quarter(date)}" for date in df_barplot.start_date], index=df_barplot.index)
    df_barplot = df_barplot.replace({"quarter": {"Q0":"DJF", "Q1":"MAM", "Q2":"JJA", "Q3":"SON"}})
    ax = sns.barplot(x="quarter", y=metric, hue="model", data=df_barplot, ci=95, capsize=0.1, palette={
        'ECMWF': 'red',
        'ABC-ECMWF': 'skyblue'
    })
    
    fig_title = f"{task.replace('_','').replace('precip',' Precipitation').replace('tmp2m',' Temperature').replace('us','U.S.')}"
    fig_title = fig_title.replace('56w', ', weeks 5-6').replace('34w', ', weeks 3-4').replace('12w', ', weeks 1-2').replace('1.5x1.5', '')
    fig_title = f"{fig_title}\n"#{target_dates_str}: {target_dates_start} to {target_dates_end}"
    ax.set_title(fig_title, weight='bold')
    ax.set(xlabel=None)
    ax.set_ylabel(metric.upper(), fontdict={'weight': 'bold'})
    if horizon == "12w":
        handles, labels = ax.get_legend_handles_labels()
        ax.legend(handles=handles[0:], labels=labels[0:], frameon=True, edgecolor='white', framealpha=1)
    else:
        ax.legend_.remove()
        ax.set(ylabel=None)
    dic_ylim = {"us_tmp2m_1.5x1.5_12w": (-0.25, 3),
                "us_tmp2m_1.5x1.5_34w": (-0.25, 3),
                "us_tmp2m_1.5x1.5_56w": (-0.25, 3),
                "us_precip_1.5x1.5_12w": (-0.5, 20),
                "us_precip_1.5x1.5_34w": (-0.5, 20),
                "us_precip_1.5x1.5_56w": (-0.5, 20),
               }
    ax. set(ylim=dic_ylim[task])
    dic_labelpad = {"us_tmp2m_1.5x1.5_12w": 18,
                "us_tmp2m_1.5x1.5_34w": 25,
                "us_tmp2m_1.5x1.5_56w": 27,
                "us_precip_1.5x1.5_12w": 15,
                "us_precip_1.5x1.5_34w": 20,
                "us_precip_1.5x1.5_56w": 20,
               }
    for container in ax.containers:
        ax.bar_label(container, fmt="%.2f", fontsize=18, fontweight='bold', padding=dic_labelpad[task], rotation = 90)

    fig = ax.get_figure()
    out_file = f"subseasonal_toolkit/viz/barplot_{metric}_quarterly_{'_'.join(model_names)}_{task}_{target_dates}.pdf"
    fig.savefig(out_file, bbox_inches='tight') 
    subprocess.call("chmod a+w "+out_file, shell=True)
    subprocess.call("chown $USER:sched_mit_hill "+out_file, shell=True)
    print(f"\nFigure saved: {out_file}\n")
    if show is False:
        fig.clear()
        plt.close(fig)  
    return df_barplot


# Set figure parameters
figure_model_names = [
    # Baselines
    "raw_ecmwf",   
    # Ensembles 
    "abcds_ecmwf",     
]

figure_gt_id = 'us_tmp2m_1.5x1.5'
figure_horizon = '12w'
figure_metric = 'crps'
figure_target_dates = 'std_paper_forecast'

df_barplot = barplot_rawabc_crps(model_names=figure_model_names, gt_id=figure_gt_id, horizon=figure_horizon, metric=figure_metric, target_dates=figure_target_dates)


In [None]:
#Save Figure source data 
# Figure parameter values
figure_gt_ids = us_1_5_gt_ids
figure_horizons = horizons
figure_metrics = ['crps']
figure_target_dates = 'std_paper_forecast'

fig_dir = os.path.join("subseasonal_toolkit","viz","abc_figures_source_data")
if os.path.isdir(fig_dir) is False:
    make_directories(fig_dir)
fig_filename = os.path.join(fig_dir, "fig_a6-average_crps.xlsx")    
if os.path.isfile(fig_filename):
    os.remove(fig_filename)

for gt_id, horizon, metric in product(figure_gt_ids, figure_horizons, figure_metrics):
    task = f"{gt_id}_{horizon}"
    printf(f"Processing {task} {metric}")
    df_barplot = barplot_rawabc_crps(model_names=figure_model_names, 
                                    gt_id=gt_id, 
                                    horizon=horizon, 
                                    metric=metric, 
                                    target_dates=target_dates,
                                    show=False)
    if os.path.isfile(fig_filename):
        with pd.ExcelWriter(fig_filename, mode='a') as writer:  
            df_barplot.to_excel(writer, sheet_name=task, na_rep="NaN") 
    else:
        with pd.ExcelWriter(fig_filename) as writer:  
            df_barplot.to_excel(writer, sheet_name=task, na_rep="NaN") 

## Figure A7: Baselines vs. ABC models
### BSS barplots by season 

In [None]:
sns.set_context("notebook", font_scale=2.5, rc={"lines.linewidth": 0.5})
sns.set_theme(style="whitegrid")
sns.set_palette("Paired")
sns.set(font_scale = 1.5, rc={'font.weight': 'bold', 'figure.facecolor':'white', "lines.linewidth": 0.75})
sns.set_style("whitegrid")#, {'legend.frameon':True})

def barplot_baselinesabc_bss(model_names, gt_id, horizon, metric, target_dates, show=True): 
    target_dates_objs = get_target_dates(target_dates)
    task = f'{gt_id}_{horizon}'
    df_barplot = pd.DataFrame(columns=['start_date', metric, 'model'])
    for i, m in enumerate(model_names):
#         print(m)
        sn = get_selected_submodel_name(m, gt_id, horizon)
#         print(sn)
        f = os.path.join('eval', 'metrics', m, 'submodel_forecasts', sn, task, f'{metric}-{task}-{target_dates}.h5')
        if os.path.isfile(f):
            df = pd.read_hdf(f)
            df['model'] = all_model_names[m]#f"raw_{'_'.join(m.split('_')[1:])}"]
            df_barplot = df_barplot.append(df)
        else:
            printf(f"Metrics file missing for {metric} {m} {task}")
    df_barplot["quarter"] = pd.Series([f"Q{year_quarter(date)}" for date in df_barplot.start_date], index=df_barplot.index)
    df_barplot = df_barplot.replace({"quarter": {"Q0":"DJF", "Q1":"MAM", "Q2":"JJA", "Q3":"SON"}})
    ax = sns.barplot(x="quarter", y=metric, hue="model", data=df_barplot, ci=95, capsize=0.1, palette={
        'QM-ECMWF': 'red',
        'LOESS-ECMWF': 'gold',
        'ABC-ECMWF': 'skyblue'
    })
    
    fig_title = f"{task.replace('_','').replace('precip',' Precipitation').replace('tmp2m',' Temperature').replace('us','U.S.')}"
    fig_title = fig_title.replace('56w', ', weeks 5-6').replace('34w', ', weeks 3-4').replace('12w', ', weeks 1-2').replace('1.5x1.5', '')
    fig_title = f"{fig_title}\n"#{target_dates_str}: {target_dates_start} to {target_dates_end}"
    ax.set_title(fig_title, weight='bold')
    ax.set(xlabel=None)
    ax.set_ylabel(metric.upper(), fontdict={'weight': 'bold'})
    if horizon == "12w":
        handles, labels = ax.get_legend_handles_labels()
        ax.legend(handles=handles[0:], labels=labels[0:], frameon=True, edgecolor='white', framealpha=1)
    else:
        ax.legend_.remove()
        ax.set(ylabel=None)
    dic_ylim = {"us_tmp2m_1.5x1.5_12w": (-0.1, 1),
                "us_tmp2m_1.5x1.5_34w": (-0.6, 0.5),
                "us_tmp2m_1.5x1.5_56w": (-0.6, 0.5),
                "us_precip_1.5x1.5_12w": (-0.1, 0.8),
                "us_precip_1.5x1.5_34w": (-0.45, 0.2),
                "us_precip_1.5x1.5_56w": (-0.45, 0.2),
               }
    ax. set(ylim=dic_ylim[task])
    dic_labelpad = {"us_tmp2m_1.5x1.5_12w": 30,
                "us_tmp2m_1.5x1.5_34w": 40,
                "us_tmp2m_1.5x1.5_56w": 45,
                "us_precip_1.5x1.5_12w": 30,
                "us_precip_1.5x1.5_34w": 30,
                "us_precip_1.5x1.5_56w": 30,
               }
    for container in ax.containers:
        ax.bar_label(container, fmt="%.2f", fontsize=15, fontweight='bold', padding=dic_labelpad[task], rotation=90)

    fig = ax.get_figure()
    out_file = f"subseasonal_toolkit/viz/barplot_{metric}_quarterly_{'_'.join(model_names)}_{task}_{target_dates}.pdf"
    fig.savefig(out_file, bbox_inches='tight') 
    subprocess.call("chmod a+w "+out_file, shell=True)
    subprocess.call("chown $USER:sched_mit_hill "+out_file, shell=True)
    print(f"\nFigure saved: {out_file}\n")
    if show is False:
        fig.clear()
        plt.close(fig)  
    return df_barplot

# Set figure parameters
figure_model_names = [
    # Baselines
    "shift_deb_quantile_ecmwf",
    "shift_deb_loess_ecmwf",
    # Ensembles 
    "abcds_ecmwf",     
]

figure_gt_id = 'us_tmp2m_1.5x1.5'
figure_horizon = '12w'
figure_metric = 'bss'
figure_target_dates = 'std_paper_forecast'

df_barplot = barplot_baselinesabc_bss(model_names=figure_model_names, gt_id=figure_gt_id, horizon=figure_horizon, metric=figure_metric, target_dates=figure_target_dates)

In [None]:
#Save Figure source data 
# Figure parameter values
figure_gt_ids = us_1_5_gt_ids
figure_horizons = horizons
figure_metrics = ['bss']
figure_target_dates = 'std_paper_forecast'

fig_dir = os.path.join("subseasonal_toolkit","viz","abc_figures_source_data")
if os.path.isdir(fig_dir) is False:
    make_directories(fig_dir)
fig_filename = os.path.join(fig_dir, "fig_a7-average_bss_baselines.xlsx")    
if os.path.isfile(fig_filename):
    os.remove(fig_filename)

for gt_id, horizon, metric in product(figure_gt_ids, figure_horizons, figure_metrics):
    task = f"{gt_id}_{horizon}"
    printf(f"Processing {task} {metric}")
    df_barplot = barplot_baselinesabc_bss(model_names=figure_model_names, 
                                            gt_id=gt_id, 
                                            horizon=horizon, 
                                            metric=metric, 
                                            target_dates=target_dates,
                                            show=False)
    if os.path.isfile(fig_filename):
        with pd.ExcelWriter(fig_filename, mode='a') as writer:  
            df_barplot.to_excel(writer, sheet_name=task, na_rep="NaN") 
    else:
        with pd.ExcelWriter(fig_filename) as writer:  
            df_barplot.to_excel(writer, sheet_name=task, na_rep="NaN") 

## Figure A8: Baselines vs. ABC models
### CRPS barplots by season 

In [None]:
sns.set_context("notebook", font_scale=2.5, rc={"lines.linewidth": 0.5})
sns.set_theme(style="whitegrid")
sns.set_palette("Paired")
sns.set(font_scale = 1.5, rc={'font.weight': 'bold', 'figure.facecolor':'white', "lines.linewidth": 0.75})
sns.set_style("whitegrid")#, {'legend.frameon':True})

def barplot_baselinesabc_crps(model_names, gt_id, horizon, metric, target_dates, show=True): 
    target_dates_objs = get_target_dates(target_dates)
    task = f'{gt_id}_{horizon}'
    df_barplot = pd.DataFrame(columns=['start_date', metric, 'model'])
    for i, m in enumerate(model_names):
    #     print(m)
        sn = get_selected_submodel_name(m, gt_id, horizon)
        f = os.path.join('eval', 'metrics', m, 'submodel_forecasts', sn, task, f'{metric}-{task}-{target_dates}.h5')
        if os.path.isfile(f):
            df = pd.read_hdf(f)
            df['model'] = all_model_names[m]
            df_barplot = df_barplot.append(df)
        else:
            printf(f"Metrics file missing for {metric} {m} {task}")
    df_barplot["quarter"] = pd.Series([f"Q{year_quarter(date)}" for date in df_barplot.start_date], index=df_barplot.index)
    df_barplot = df_barplot.replace({"quarter": {"Q0":"DJF", "Q1":"MAM", "Q2":"JJA", "Q3":"SON"}})
    ax = sns.barplot(x="quarter", y=metric, hue="model", data=df_barplot, ci=95, capsize=0.1, palette={
        'QM-ECMWF': 'red',
        'LOESS-ECMWF': 'gold',
        'ABC-ECMWF': 'skyblue'
    })
    
    fig_title = f"{task.replace('_','').replace('precip',' Precipitation').replace('tmp2m',' Temperature').replace('us','U.S.')}"
    fig_title = fig_title.replace('56w', ', weeks 5-6').replace('34w', ', weeks 3-4').replace('12w', ', weeks 1-2').replace('1.5x1.5', '')
    fig_title = f"{fig_title}\n"#{target_dates_str}: {target_dates_start} to {target_dates_end}"
    ax.set_title(fig_title, weight='bold')
    ax.set(xlabel=None)
    ax.set_ylabel(metric.upper(), fontdict={'weight': 'bold'})
    if horizon == "12w":
        handles, labels = ax.get_legend_handles_labels()
        ax.legend(handles=handles[0:], labels=labels[0:], frameon=True, edgecolor='white', framealpha=1)
    else:
        ax.legend_.remove()
        ax.set(ylabel=None)
    dic_ylim = {"us_tmp2m_1.5x1.5_12w": (-0.25, 3),
                "us_tmp2m_1.5x1.5_34w": (-0.25, 3),
                "us_tmp2m_1.5x1.5_56w": (-0.25, 3),
                "us_precip_1.5x1.5_12w": (-0.5, 25),
                "us_precip_1.5x1.5_34w": (-0.5, 25),
                "us_precip_1.5x1.5_56w": (-0.5, 25),
               }
    ax. set(ylim=dic_ylim[task])
    dic_labelpad = {"us_tmp2m_1.5x1.5_12w": 18,
                "us_tmp2m_1.5x1.5_34w": 25,
                "us_tmp2m_1.5x1.5_56w": 27,
                "us_precip_1.5x1.5_12w": 15,
                "us_precip_1.5x1.5_34w": 20,
                "us_precip_1.5x1.5_56w": 20,
               }
    for container in ax.containers:
        ax.bar_label(container, fmt="%.2f", fontsize=18, fontweight='bold', padding=dic_labelpad[task], rotation = 90)
    fig = ax.get_figure()
    out_file = f"subseasonal_toolkit/viz/barplot_{metric}_quarterly_{'_'.join(model_names)}_{task}_{target_dates}.pdf"
    fig.savefig(out_file, bbox_inches='tight') 
    subprocess.call("chmod a+w "+out_file, shell=True)
    subprocess.call("chown $USER:sched_mit_hill "+out_file, shell=True)
    print(f"\nFigure saved: {out_file}\n")
    if show is False:
        fig.clear()
        plt.close(fig)  
    return df_barplot

# Set figure parameters
figure_model_names = [
    # Baselines
    "shift_deb_quantile_ecmwf",
    "shift_deb_loess_ecmwf",
    # Ensembles 
    "abcds_ecmwf",     
]

figure_gt_id = 'us_tmp2m_1.5x1.5'
figure_horizon = '12w'
figure_metric = 'crps'
figure_target_dates = 'std_paper_forecast'

df_barplot = barplot_baselinesabc_crps(model_names=figure_model_names, gt_id=figure_gt_id, horizon=figure_horizon, metric=figure_metric, target_dates=figure_target_dates)


In [None]:
#Save Figure source data 
# Figure parameters
figure_gt_ids = us_1_5_gt_ids
figure_horizons = horizons
figure_metrics = ['crps']
figure_target_dates = 'std_paper_forecast'

fig_dir = os.path.join("subseasonal_toolkit","viz","abc_figures_source_data")
if os.path.isdir(fig_dir) is False:
    make_directories(fig_dir)
fig_filename = os.path.join(fig_dir, "fig_a8-average_crps_baselines.xlsx")    
if os.path.isfile(fig_filename):
    os.remove(fig_filename)

for gt_id, horizon, metric in product(figure_gt_ids, figure_horizons, figure_metrics):
    task = f"{gt_id}_{horizon}"
    printf(f"Processing {task} {metric}")
    df_barplot = barplot_baselinesabc_crps(model_names=figure_model_names, 
                                            gt_id=gt_id, 
                                            horizon=horizon, 
                                            metric=metric, 
                                            target_dates=target_dates,
                                            show=False)
    if os.path.isfile(fig_filename):
        with pd.ExcelWriter(fig_filename, mode='a') as writer:  
            df_barplot.to_excel(writer, sheet_name=task, na_rep="NaN") 
    else:
        with pd.ExcelWriter(fig_filename) as writer:  
            df_barplot.to_excel(writer, sheet_name=task, na_rep="NaN") 

# A4. Spatial bias plots
### Bias map (lat_lon_error)
This code produces maps to analyze bias for different models. 

In [None]:
figure_target_dates = 'std_paper_forecast'
figure_metrics = ['lat_lon_error']
figure_models = [
    # Baselines
    "raw_cfsv2",
    "raw_ecmwf",
    "raw_subx_mean",
    # Standard de-biasing
    "deb_cfsv2",
    "deb_ecmwf",
    "deb_subx_mean",
    "deb_loess_ecmwf", 
    "deb_loess_cfsv2", 
    "deb_quantile_ecmwf", 
    "deb_quantile_cfsv2",
    "nn-a",
    # Ensembles 
    "abc_cfsv2",
    "abc_ecmwf",
    "abc_subx_mean"
]

#RDA: models for which Raw, Debiased and Abc versions are available
metric_dfs_rda = {}
for gt_id, horizon in product(us_1_5_gt_ids, horizons):
    task = f"{gt_id}_{horizon}"
    display(Markdown(f"#### Getting metrics for {gt_id} {horizon}"))
    metric_dfs_rda[task] = get_models_metric_lat_lon(gt_id=gt_id, horizon=horizon, target_dates=figure_target_dates, metrics = figure_metrics, model_names=figure_models)


## Figure A9: SubX mean - Raw vs. ABC models
### Bias map (lat_lon_error)
This code produces maps to analyze bias for different models. 

In [None]:
# Figure parameter values
figure_gt_ids = us_1_5_gt_ids
figure_horizons = horizons
figure_metric = 'lat_lon_error'
figure_mean_metric_df = None
figure_show = True

figure_model_names = ["raw_subx_mean", "abc_subx_mean"]
for gt_id in figure_gt_ids:
    if 'tmp2m' in gt_id:
            CB_colors_customized=['blue','dodgerblue','lightskyblue',"white",'red','firebrick','darkred']
            CB_minmax = (-4, 4)
    else:
            CB_colors_customized=['orangered','darkorange',"white",'forestgreen','darkgreen']
            CB_minmax = (-20, 20)
    plot_metric_maps(metric_dfs_rda, model_names=figure_model_names,
                         gt_ids=[gt_id],
                         horizons=figure_horizons,
                         metric=figure_metric,
                         target_dates=figure_target_dates,
                         mean_metric_df=figure_mean_metric_df,
                         show=figure_show, 
                         scale_type='linear',
                         CB_colors_customized=CB_colors_customized,
                         CB_minmax = CB_minmax)

In [None]:
#Save Figure source data                                      
fig_dir = os.path.join("subseasonal_toolkit","viz","abc_figures_source_data")
if os.path.isdir(fig_dir) is False:
    make_directories(fig_dir)
fig_filename = os.path.join(fig_dir, "fig_a9-spatial_bias_subx_mean.xlsx")    
if os.path.isfile(fig_filename):
    os.remove(fig_filename)

for gt_id, horizon, metric in product(figure_gt_ids, figure_horizons, [figure_metric]):
    task = f"{gt_id}_{horizon}"
    printf(f"Processing {task} {metric}")
#     display(metric_dfs_rda[task][metric][figure_model_names])
    if os.path.isfile(fig_filename):
        with pd.ExcelWriter(fig_filename, mode='a') as writer:  
            metric_dfs_rda[task][metric][figure_model_names].to_excel(writer, sheet_name=task, na_rep="NaN") 
    else:
        with pd.ExcelWriter(fig_filename) as writer:  
            metric_dfs_rda[task][metric][figure_model_names].to_excel(writer, sheet_name=task, na_rep="NaN") 


## Figure A10: Baselines vs. ABC models
### Bias map (lat_lon_error)
This code produces maps to analyze bias for different models. 

In [None]:
# Figure parameter values
figure_gt_ids = us_1_5_gt_ids
figure_horizons = horizons
figure_metric = 'lat_lon_error'
figure_mean_metric_df = None
figure_show = True

figure_model_names_p1 = ["nn-a", "deb_loess_cfsv2", "deb_quantile_cfsv2", "abc_cfsv2"]
for gt_id in figure_gt_ids:
    if 'tmp2m' in gt_id:
            CB_colors_customized=['blue','dodgerblue','lightskyblue',"white",'red','firebrick','darkred']
            CB_minmax = (-4, 4)
    else:
            CB_colors_customized=['orangered','darkorange',"white",'forestgreen','darkgreen']
            CB_minmax = (-20, 20)
    plot_metric_maps(metric_dfs_rda, model_names=figure_model_names_p1,
                         gt_ids=[gt_id],
                         horizons=figure_horizons,
                         metric=figure_metric,
                         target_dates=figure_target_dates,
                         mean_metric_df=figure_mean_metric_df,
                         show=figure_show, 
                         scale_type='linear',
                         CB_colors_customized=CB_colors_customized,
                         CB_minmax = CB_minmax)
    
figure_model_names_p2 = ["deb_loess_ecmwf", "deb_quantile_ecmwf","abc_ecmwf"]
for gt_id in figure_gt_ids:
    if 'tmp2m' in gt_id:
            CB_colors_customized=['blue','dodgerblue','lightskyblue',"white",'red','firebrick','darkred']
            CB_minmax = (-4, 4)
    else:
            CB_colors_customized=['orangered','darkorange',"white",'forestgreen','darkgreen']
            CB_minmax = (-20, 20)
    plot_metric_maps(metric_dfs_rda, model_names=figure_model_names_p2,
                         gt_ids=[gt_id],
                         horizons=figure_horizons,
                         metric=figure_metric,
                         target_dates=figure_target_dates,
                         mean_metric_df=figure_mean_metric_df,
                         show=figure_show, 
                         scale_type='linear',
                         CB_colors_customized=CB_colors_customized,
                         CB_minmax = CB_minmax)

In [None]:
figure_model_names = figure_model_names_p1 + figure_model_names_p2
#Save Figure source data                                      
fig_dir = os.path.join("subseasonal_toolkit","viz","abc_figures_source_data")
if os.path.isdir(fig_dir) is False:
    make_directories(fig_dir)
fig_filename = os.path.join(fig_dir, "fig_a10-spatial_bias_baselines.xlsx")    
if os.path.isfile(fig_filename):
    os.remove(fig_filename)

for gt_id, horizon, metric in product(figure_gt_ids, figure_horizons, [figure_metric]):
    task = f"{gt_id}_{horizon}"
    printf(f"Processing {task} {metric}")
#     display(metric_dfs_rda[task][metric][figure_model_names])
    if os.path.isfile(fig_filename):
        with pd.ExcelWriter(fig_filename, mode='a') as writer:  
            metric_dfs_rda[task][metric][figure_model_names].to_excel(writer, sheet_name=task, na_rep="NaN") 
    else:
        with pd.ExcelWriter(fig_filename) as writer:  
            metric_dfs_rda[task][metric][figure_model_names].to_excel(writer, sheet_name=task, na_rep="NaN") 


## Figure A11: Debiased vs. ABC models
### Bias map (lat_lon_error)
This code produces maps to analyze bias for different models. 

In [None]:
# Figure parameter values
figure_gt_ids = us_1_5_gt_ids
figure_horizons = horizons
figure_metric = 'lat_lon_error'
figure_mean_metric_df = None
figure_show = True

figure_model_names_p1 = ["deb_cfsv2", "abc_cfsv2"]
for gt_id in figure_gt_ids:
    if 'tmp2m' in gt_id:
            CB_colors_customized=['blue','dodgerblue','lightskyblue',"white",'red','firebrick','darkred']
            CB_minmax = (-1, 1)
    else:
            CB_colors_customized=['orangered','darkorange',"white",'forestgreen','darkgreen']
            CB_minmax = (-5, 5)
    plot_metric_maps(metric_dfs_rda, model_names=figure_model_names_p1,
                         gt_ids=[gt_id],
                         horizons=figure_horizons,
                         metric=figure_metric,
                         target_dates=figure_target_dates,
                         mean_metric_df=figure_mean_metric_df,
                         show=figure_show, 
                         scale_type='linear',
                         CB_colors_customized=CB_colors_customized,
                         CB_minmax = CB_minmax)
    
figure_model_names_p2 = ["deb_ecmwf","abc_ecmwf"]
for gt_id in figure_gt_ids:
    if 'tmp2m' in gt_id:
            CB_colors_customized=['blue','dodgerblue','lightskyblue',"white",'red','firebrick','darkred']
            CB_minmax = (-1, 1)
    else:
            CB_colors_customized=['orangered','darkorange',"white",'forestgreen','darkgreen']
            CB_minmax = (-5, 5)
    plot_metric_maps(metric_dfs_rda, model_names=figure_model_names_p2,
                         gt_ids=[gt_id],
                         horizons=figure_horizons,
                         metric=figure_metric,
                         target_dates=figure_target_dates,
                         mean_metric_df=figure_mean_metric_df,
                         show=figure_show, 
                         scale_type='linear',
                         CB_colors_customized=CB_colors_customized,
                         CB_minmax = CB_minmax)

In [None]:
figure_model_names = figure_model_names_p1 + figure_model_names_p2
#Save Figure source data                                      
fig_dir = os.path.join("subseasonal_toolkit","viz","abc_figures_source_data")
if os.path.isdir(fig_dir) is False:
    make_directories(fig_dir)
fig_filename = os.path.join(fig_dir, "fig_a11-spatial_bias_deb.xlsx")    
if os.path.isfile(fig_filename):
    os.remove(fig_filename)

for gt_id, horizon, metric in product(figure_gt_ids, figure_horizons, [figure_metric]):
    task = f"{gt_id}_{horizon}"
    printf(f"Processing {task} {metric}")
#     display(metric_dfs_rda[task][metric][figure_model_names])
    if os.path.isfile(fig_filename):
        with pd.ExcelWriter(fig_filename, mode='a') as writer:  
            metric_dfs_rda[task][metric][figure_model_names].to_excel(writer, sheet_name=task, na_rep="NaN") 
    else:
        with pd.ExcelWriter(fig_filename) as writer:  
            metric_dfs_rda[task][metric][figure_model_names].to_excel(writer, sheet_name=task, na_rep="NaN") 

# A5. Overall variable importance

## Figure A12: Variable importance
See subseasonal_toolkit/viz/cohort_shapley_all.ipynb

# A6. ABC schematic 

## Figure A13: Schematic of ABC input/output
### Anomalies maps (lat_lon_anom)
This code produces maps to analyze anomalies for different models. 

In [None]:
figure_target_dates = '20201218'
figure_metrics = ['lat_lon_anom']
figure_gt_ids = ["us_precip_1.5x1.5"]
figure_horizons = horizons
figure_model_names = ["raw_cfsv2",
                 "deb_cfsv2",
                 "abc_cfsv2", 
                 "tuned_cfsv2pp", 
                 "perpp_cfsv2", 
                 "tuned_climpp", 
                 "gt",
                ]

#RDA: models for which Raw, Debiased and Abc versions are available
metric_dfs_rda = {}
for gt_id, horizon in product(figure_gt_ids, figure_horizons):
    task = f"{gt_id}_{horizon}"
    display(Markdown(f"#### Getting metrics for {gt_id} {horizon}"))
    metric_dfs_rda[task] = get_models_metric_lat_lon(gt_id=gt_id, horizon=horizon, target_dates=figure_target_dates, metrics = figure_metrics, model_names=figure_model_names)


In [None]:
# Figure parameter values
figure_gt_ids = ["us_precip_1.5x1.5"]
figure_horizons = horizons
figure_metric = 'lat_lon_anom'
figure_mean_metric_df = None
figure_show = True

figure_model_names = ["raw_cfsv2",
                         "deb_cfsv2",
                         "abc_cfsv2", 
                         "tuned_cfsv2pp", 
                         "perpp_cfsv2", 
                         "tuned_climpp", 
                         "gt",
                     ]
for gt_id in figure_gt_ids:
    if 'tmp2m' in gt_id:
            CB_colors_customized=['blue','dodgerblue','lightskyblue',"white",'red','firebrick','darkred']
            CB_minmax = (-4, 4)
    else:
            CB_colors_customized=['orangered','darkorange',"white",'forestgreen','darkgreen']
            CB_minmax = (-10, 10)
    plot_metric_maps(metric_dfs_rda, model_names=figure_model_names,
                         gt_ids=[gt_id],
                         horizons=figure_horizons,
                         metric=figure_metric,
                         target_dates=figure_target_dates,
                         mean_metric_df=figure_mean_metric_df,
                         show=figure_show, 
                         scale_type='linear',
                         CB_colors_customized=CB_colors_customized,
                         CB_minmax = CB_minmax)

In [None]:
#Save Figure source data                                      
fig_dir = os.path.join("subseasonal_toolkit","viz","abc_figures_source_data")
if os.path.isdir(fig_dir) is False:
    make_directories(fig_dir)
fig_filename = os.path.join(fig_dir, "fig_a13-schematic_abc_anoms.xlsx")    
if os.path.isfile(fig_filename):
    os.remove(fig_filename)

for gt_id, horizon, metric in product(figure_gt_ids, ["34w"], [figure_metric]):
    task = f"{gt_id}_{horizon}"
    printf(f"Processing {task} {metric}")
#     display(metric_dfs_rda[task][metric][figure_model_names])
    if os.path.isfile(fig_filename):
        with pd.ExcelWriter(fig_filename, mode='a') as writer:  
            metric_dfs_rda[task][metric][figure_model_names].to_excel(writer, sheet_name=task, na_rep="NaN") 
    else:
        with pd.ExcelWriter(fig_filename) as writer:  
            metric_dfs_rda[task][metric][figure_model_names].to_excel(writer, sheet_name=task, na_rep="NaN") 
