# ABC Plots

In [None]:
""" 
Generates ABC result figures for 

Adaptive Bias Correction for Improved Subseasonal Forecasting

Soukayna Mouatadid, Paulo Orenstein, Genevieve Flaspohler, 
Judah Cohen, Miruna Oprescu, Ernest Fraenkel, and Lester Mackey. 
"""
# Ensure notebook is being run from base repository directory
import os, sys
try:
    os.chdir("/home/{}/forecast_rodeo_ii/".format(os.environ["USER"]))
except Exception as err:
    print(f"Warning: unable to change directory; {repr(err)}")
    
%load_ext autoreload
%autoreload 2
%matplotlib inline    
    
import itertools
import importlib
import subprocess
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from itertools import product
from functools import partial
from datetime import datetime
from IPython.display import Markdown, display

import copy
import pdb
import calendar 
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib   
from matplotlib.gridspec import GridSpec
from matplotlib.colors import LinearSegmentedColormap

from subseasonal_toolkit.utils.experiments_util import pandas2hdf
from subseasonal_toolkit.utils.general_util import printf
from subseasonal_toolkit.utils.eval_util import get_target_dates, score_to_mean_rmse, contest_quarter_start_dates, contest_quarter
from subseasonal_toolkit.utils.models_util import get_selected_submodel_name

from viz_util_abc import get_metrics_df, plot_metric_maps, get_models_metric_lat_lon, plot_metric_maps, add_groupby_cols, all_model_names


# set figure and font sizes for seaborn plots
sns.set(rc={'figure.figsize':(8,6)}, font_scale=1)

#
# Directory for saving output
#
out_dir = "/home/{}/forecast_rodeo_ii/subseasonal_toolkit/viz".format(os.environ["USER"])


In [None]:
# 
# Full set of regions, times, and tasks to evaluate
#
metrics = ["rmse", "skill", "score"]

contest_gt_ids = ["contest_tmp2m", "contest_precip"]
us_gt_ids = ["us_tmp2m", "us_precip"]
east_gt_ids = ["east_tmp2m", "east_precip"]
us_1_5_gt_ids = ["us_tmp2m_1.5x1.5", "us_precip_1.5x1.5"]

# All ground truth ids
gt_ids = contest_gt_ids + us_gt_ids 

horizons = ["12w", "34w", "56w"]
target_eval_dates = ["std_paper", "std_contest"]

# The full set of models we to evaluate in some
# experiment 
all_models = [
    # Raw Baselines
    'raw_cfsv2', 
    # Baselines
    "climatology",   
    'deb_cfsv2',
    'persistence',
    # ECMWF
    'ecmwf'
    # Toolkit 
    'tuned_climpp',
    'tuned_cfsv2pp',
    'perpp',
    #Learning
    'autoknn',
    'informer',
    'tuned_localboosting',
    'multillr',
    'nbeats',
    'prophet',
    'salient',
    'tuned_salient2',
    #Ensembles
    'linear_ensemble',  
    'online_learning'
]

# Main experiment model names
main_experiment_models = [
    # Baselines
    "climatology",   
    'deb_cfsv2',
    'persistence',
    # Toolkit 
    'tuned_climpp',
    'tuned_cfsv2pp',
    'perpp',
    #Learning
    'autoknn',
    'tuned_localboosting',
    'multillr',
    'nbeats',
    'informer',
    'prophet',
    'tuned_salient2',
    #Ensembles
    'linear_ensemble',  
    'online_learning'
]

# Rodeo experiment model names
rodeo_experiment_models = [
    # Baselines
    "climatology",   
    'deb_cfsv2',
    'persistence',
    # Toolkit 
    'tuned_climpp',
    'tuned_cfsv2pp',
    'perpp',
    #Learning
    'autoknn',
    'tuned_localboosting',
    'multillr',
    'prophet',
    'tuned_salient2',
    #Ensembles
    'linear_ensemble_localFalse_dynamicFalse_stepFalse_LtCtD',
    'linear_ensemble_localFalse_dynamicFalse_stepFalse_AMLPtCtDtKtS',  
    'online_learning-ah_rpNone_R1_recent_g_SC_LtCtD',
    'online_learning-ah_rpNone_R1_recent_g_SC_AMLPtCtDtKtS'
]

# Salient experiment model names
salient_experiment_models = [
    # Baselines   
    'deb_cfsv2',
    # Toolkit 
    'tuned_cfsv2pp',
    #Learning
    'tuned_salient2',
]


# ECMWF experiment model names
ecmwf_experiment_models = [
    # Baselines
    "climatology", 
    'deb_cfsv2',
    'persistence',
    # Toolkit 
    'tuned_climpp',
    'tuned_cfsv2pp',
    'perpp',
    # ECMWF
    'ecmwf-years20_leads15-15_lossmse_forecastc_debiasp+c',
    'ecmwf-years20_leads15-15_lossmse_forecastp_debiasp+c',
    # Ensembles
    "online_learning", 
    "linear_ensemble" 
]

# De-biasing experiment model names
debias_experiment_models = [
    # Baselines
    "raw_ccsm4", 
    "raw_cfsv2",
    "raw_geos_v2p1",
    "raw_nesm",
    "raw_fimr1p1",
    "raw_gefs",
    "raw_gem",
    "raw_ecmwf",
    # Ensembles 
    "abc_ccsm4",
    "abc_cfsv2",
    "abc_geos_v2p1",
    "abc_nesm",
    "abc_fimr1p1",
    "abc_gefs",
    "abc_gem", 
    "abc_ecmwf",
]

In [None]:
#
# Dictionaries mapping all model names and tasks to their display names
#

east_tasks = {
    "east_tmp2m_34w": "Temp. weeks 3-4",
    "east_tmp2m_56w": "Temp. weeks 5-6",
    "east_precip_34w": "Precip. weeks 3-4",
    "east_precip_56w": "Precip. weeks 5-6"
}

contest_tasks = {
    "contest_tmp2m_34w": "Temp. weeks 3-4",
    "contest_tmp2m_56w": "Temp. weeks 5-6",
    "contest_precip_34w": "Precip. weeks 3-4",
    "contest_precip_56w": "Precip. weeks 5-6"
}
us_tasks = {
    "us_tmp2m_34w": "Temp. weeks 3-4",
    "us_tmp2m_56w": "Temp. weeks 5-6",
    "us_precip_34w": "Precip. weeks 3-4",
    "us_precip_56w": "Precip. weeks 5-6",   
}


## Read in all metrics for all tasks and all models
Reads metrics, generates a summary of missing data, and produces the `all_metrics` dictionary to be used in further analysis. 

In [None]:
"""
Generate a dictionary with metric values for all models and every combination of gt_id, 
horizon, and target dates
"""

all_metrics = {}

# Get metrics for main experiment, rodeo experiment, salient experiment and ecmwf experiment
for metric, gt_id, horizon, target_dates in \
        [x for x in product(['rmse', 'skill'], us_1_5_gt_ids, horizons, ['std_paper_forecast'])]:
        #[x for x in product(['rmse', 'skill'], us_gt_ids, horizons, ['std_paper'])] \
        #+[x for x in product(['rmse'], contest_gt_ids, horizons, ['std_contest'])] \
        #+[x for x in product(['rmse'], contest_gt_ids, horizons, ['std_paper'])] \
        #+[x for x in product(['rmse', 'skill'], us_1_5_gt_ids, horizons, ['std_ecmwf'])]: 
   
    
    #Set model names   
    if 'us' in gt_id:
        model_names = ecmwf_experiment_models if '1.5x1.5' in gt_id else debias_experiment_models
        model_names_str = 'ecmwf_experiment_models' if '1.5x1.5' in gt_id else 'debias_experiment_models'
    elif 'contest' in gt_id:
        model_names = rodeo_experiment_models if 'contest' in target_dates else salient_experiment_models
        model_names_str = 'rodeo_experiment_models' if 'contest' in target_dates else 'salient_experiment_models'
    else:
        model_names = all_models
        model_names_str = 'all_models'
        

    
    model_names = debias_experiment_models
    # Get task
    task = f"{gt_id}_{horizon}"

    #display(Markdown(f"### Loading metric {metric} for task {task} and dates {target_dates}"))
    display(Markdown(f"### {model_names_str}: {metric}, {task}, {target_dates}"))

    # Get all metrics
    df = get_metrics_df(gt_id, horizon, metric, target_dates, model_names=model_names)
    # No models exist for this task    
    if df is None: 
        continue

    # Add yearly and quarterly columns to the dataframe
    df = add_groupby_cols(df, horizon=horizon)

    all_metrics[(metric, task, target_dates)] = copy.copy(df)
    #print(all_metrics)

    if metric in ['rmse', 'skill']:
        key = (metric, task, target_dates)
        try:        
            missing_df = all_metrics[key].loc[(all_metrics[key][model_names].isnull().any(axis=1)),:]
        except:        
            missing_df = all_metrics[key].loc[(all_metrics[key].isnull().any(axis=1)),:]            
        if missing_df.shape[0] != 0:
            True
            display(Markdown(f"#### Missing metrics"))
            display(missing_df)    
        else:
            printf("All metrics present.")

## Figure 1: Raw vs. ABC models
### Bar plots 
This code produces maps to compare models' skill over the evaluation period 2018-2021. 

In [None]:
sns.set_context("notebook", font_scale=2.5, rc={"lines.linewidth": 0.5})
sns.set_theme(style="whitegrid")
sns.set_palette("Paired")
sns.set(font_scale = 1.5, rc={'font.weight': 'bold', 'figure.facecolor':'white', "lines.linewidth": 0.75})
sns.set_style("whitegrid")#, {'legend.frameon':True})

def barplot_rawabc(model_names, gt_id, horizon, metric, target_dates, show=True): 
    target_dates_objs = get_target_dates(target_dates)
    target_dates_start = datetime.strftime(target_dates_objs[0], '%Y-%m-%d')
    target_dates_end = datetime.strftime(target_dates_objs[-1], '%Y-%m-%d')
    target_dates_str = target_dates.replace('cold_','Cold wave, ').replace('texas','Texas').replace('gl','Great Lakes').replace('ne','New England')
    figure_models_missing56w = [
    "raw_fimr1p1",
    "raw_gefs",
    "raw_gem",
    "abc_fimr1p1",
    "abc_gefs",
    "abc_gem",    
    ]
    task = f'{gt_id}_{horizon}'
    if horizon == '56w':
        model_names = [m for m in model_names if m not in figure_models_missing56w]
    df_barplot = pd.DataFrame(columns=['start_date', metric, 'debias_method', 'model'])
    for i, m in enumerate(model_names):
    #     print(m)
        sn = get_selected_submodel_name(m, gt_id, horizon)
        f = os.path.join('eval', 'metrics', m, 'submodel_forecasts', sn, task, f'{metric}-{task}-{target_dates}.h5')
        if os.path.isfile(f):
            df = pd.read_hdf(f)
            df['debias_method'] = 'Dynamical' if m.split('_')[0]=='raw' else 'ABC'
            df['model'] = all_model_names[f"raw_{'_'.join(m.split('_')[1:])}"]
            df_barplot = df_barplot.append(df)
        else:
            printf(f"Metrics file missing for {metric} {m} {task}")
    df_barplot
    ax = sns.barplot(x="model", y=metric, hue="debias_method", data=df_barplot, ci=95, capsize=0.1, palette={
    'Dynamical': 'red',
    'ABC': 'skyblue'
})
    fig_title = f"{task.replace('_','').replace('precip',' Precipitation').replace('tmp2m',' Temperature').replace('us','U.S.')}"
    fig_title = fig_title.replace('56w', ', weeks 5-6').replace('34w', ', weeks 3-4').replace('12w', ', weeks 1-2').replace('1.5x1.5', '')
    fig_title = f"{fig_title}\n"#{target_dates_str}: {target_dates_start} to {target_dates_end}"
    ax.set_title(fig_title, weight='bold')
    if '56w' in horizon:
        ax.set_xticklabels(ax.get_xticklabels(), fontdict={'size': 16})#, rotation = 90)
    else:
        ax.set_xticklabels(ax.get_xticklabels(), fontdict={'size': 11})#, rotation = 90)
    ax.set(xlabel=None)
    ax.set_ylabel('Skill', fontdict={'weight': 'bold'})
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles=handles[0:], labels=labels[0:], frameon=True, edgecolor='white', framealpha=1)
    if target_dates.startswith('std_'):
        if 'precip' in gt_id and '12' not in horizon:
            ax. set(ylim=(-0.025, 0.3))
        elif 'precip' in gt_id and '12' in horizon:
            ax. set(ylim=(-0.025, 0.65))
        elif 'tmp2m' in gt_id and '12' not in horizon:
            ax. set(ylim=(-0.03, 0.5))
        elif 'tmp2m' in gt_id and '12' in horizon:
            ax. set(ylim=(-0.03, 0.9))
    fig = ax.get_figure()
    out_file = f"subseasonal_toolkit/viz/barplot_{metric}_{task}_{target_dates}.pdf"
    fig.savefig(out_file) 
    subprocess.call("chmod a+w "+out_file, shell=True)
    subprocess.call("chown $USER:sched_mit_hill "+out_file, shell=True)
    print(f"\nFigure saved: {out_file}\n")
    if show is False:
        fig.clear()
        plt.close(fig)  

figure_models = [
    # Baselines
    "raw_ccsm4", 
    "raw_cfsv2",
    "raw_geos_v2p1",
    "raw_nesm",
    "raw_fimr1p1",
    "raw_gefs",
    "raw_gem",
    "raw_ecmwf",
    # Ensembles 
    "abc_ccsm4",
    "abc_cfsv2",
    "abc_geos_v2p1",
    "abc_nesm",
    "abc_fimr1p1",
    "abc_gefs",
    "abc_gem", 
    "abc_ecmwf",
]

gt_id = 'us_precip_1.5x1.5'
horizon = '56w'
metric = 'skill'
target_dates = 'std_paper_forecast'

barplot_rawabc(model_names=figure_models, gt_id=gt_id, horizon=horizon, metric=metric, target_dates=target_dates)

## Figure 2: Raw vs. Debiased and ABC models
### lat_lon_skill map
This code produces maps to analyze skill for different models. 

In [None]:
figure_target_dates = 'std_paper_forecast'
figure_models = [
    # Baselines
    "raw_cfsv2",
    "raw_ecmwf",
    # Standard de-biasing
    "deb_cfsv2",
    "deb_ecmwf",
    # Ensembles 
    "abc_cfsv2",
    "abc_ecmwf",
]

#RDA: models for which Raw, Debiased and Abc versions are available
metric_dfs_rda = {}
for gt_id, horizon in product(us_1_5_gt_ids, horizons):
    task = f"{gt_id}_{horizon}"
    display(Markdown(f"#### Getting metrics for {gt_id} {horizon}"))
    metric_dfs_rda[task] = get_models_metric_lat_lon(gt_id=gt_id, horizon=horizon, target_dates=figure_target_dates, metrics = ['lat_lon_skill','skill'], model_names=figure_models)

In [None]:
figure_model_names = ["raw_cfsv2", "deb_cfsv2", "abc_cfsv2"]

# Figure parameter values
figure_gt_ids = us_1_5_gt_ids
figure_horizons = horizons
figure_metric = 'lat_lon_skill'
figure_mean_metric_df = metric_dfs_rda #None
figure_show = True

#for gt_id in figure_gt_ids:
for gt_id in figure_gt_ids:
    plot_metric_maps(metric_dfs_rda, model_names=figure_model_names,
                         gt_ids=[gt_id],
                         horizons=figure_horizons,
                         metric=figure_metric,
                         target_dates=figure_target_dates,
                         mean_metric_df=figure_mean_metric_df,
                         show=figure_show, 
                         scale_type='linear',
                         CB_colors_customized=["white", "#dede00", "#ff7f00", "blueviolet", "indigo", "yellowgreen", "lightgreen", "darkgreen"],
                         CB_minmax = (0, 85))

In [None]:
figure_model_names = ["raw_ecmwf", "deb_ecmwf", "abc_ecmwf"]

# Figure parameter values
figure_gt_ids = us_1_5_gt_ids
figure_horizons = horizons
figure_metric = 'lat_lon_skill'
figure_mean_metric_df = metric_dfs_rda #None
figure_show = True

#for gt_id in figure_gt_ids:
for gt_id in figure_gt_ids:
    plot_metric_maps(metric_dfs_rda, model_names=figure_model_names,
                         gt_ids=[gt_id],
                         horizons=figure_horizons,
                         metric=figure_metric,
                         target_dates=figure_target_dates,
                         mean_metric_df=figure_mean_metric_df,
                         show=figure_show, 
                         scale_type='linear',
                         CB_colors_customized=["white", "#dede00", "#ff7f00", "blueviolet", "indigo", "yellowgreen", "lightgreen", "darkgreen"],
                         CB_minmax = (0, 85))

## Figure 3: Raw vs. ABC models
### Bias map (lat_lon_error)
This code produces maps to analyze bias for different models. 

In [None]:
figure_target_dates = 'std_paper_forecast'
figure_models = [
    # Baselines
    "raw_cfsv2",
#     "raw_ecmwf",
#     # Standard de-biasing
#     "deb_cfsv2",
#     "deb_ecmwf",
    # Ensembles 
    "abc_cfsv2",
#     "abc_ecmwf",
]

#RDA: models for which Raw, Debiased and Abc versions are available
metric_dfs_rda = {}
for gt_id, horizon in product(us_1_5_gt_ids, horizons):
    task = f"{gt_id}_{horizon}"
    display(Markdown(f"#### Getting metrics for {gt_id} {horizon}"))
    metric_dfs_rda[task] = get_models_metric_lat_lon(gt_id=gt_id, horizon=horizon, target_dates=figure_target_dates, metrics = ['lat_lon_error'], model_names=figure_models)

In [None]:
figure_model_names = ["raw_cfsv2", "abc_cfsv2"]

# Figure parameter values
figure_gt_ids = us_1_5_gt_ids
figure_horizons = horizons
figure_metric = 'lat_lon_error'
figure_mean_metric_df = None# metric_dfs_rda#None
figure_show = True

#for gt_id in figure_gt_ids:
for gt_id in figure_gt_ids:
    if 'tmp2m' in gt_id:
            CB_colors_customized=['blue','dodgerblue','lightskyblue',"white",'red','firebrick','darkred']
            CB_minmax = (-4, 4)
    else:
            CB_colors_customized=['saddlebrown','peru',"white",'yellowgreen','green']
            CB_minmax = (-15, 15)
    plot_metric_maps(metric_dfs_rda, model_names=figure_model_names,
                         gt_ids=[gt_id],
                         horizons=figure_horizons,
                         metric=figure_metric,
                         target_dates=figure_target_dates,
                         mean_metric_df=figure_mean_metric_df,
                         show=figure_show, 
                         scale_type='linear',
                         CB_colors_customized=CB_colors_customized,
                         CB_minmax = CB_minmax)

In [None]:
if True:
    figure_target_dates = 'std_paper_forecast'#'cold_gl'
    figure_models = [
        # Baselines
    #     "raw_cfsv2",
        "raw_ecmwf",
    #     # Standard de-biasing
    #     "deb_cfsv2",
    #     "deb_ecmwf",
        # Ensembles 
    #     "abc_cfsv2",
        "abc_ecmwf",
    ]

    #RDA: models for which Raw, Debiased and Abc versions are available
    metric_dfs_rda = {}
    for gt_id, horizon in product(us_1_5_gt_ids, horizons):
        task = f"{gt_id}_{horizon}"
        display(Markdown(f"#### Getting metrics for {gt_id} {horizon}"))
        metric_dfs_rda[task] = get_models_metric_lat_lon(gt_id=gt_id, horizon=horizon, target_dates=figure_target_dates, metrics = ['lat_lon_error'], model_names=figure_models)
        
        
figure_model_names = ["raw_ecmwf", "abc_ecmwf"]

# Figure parameter values
figure_gt_ids = us_1_5_gt_ids
figure_horizons = horizons
figure_metric = 'lat_lon_error'
figure_mean_metric_df = None
figure_show = True

#for gt_id in figure_gt_ids:
for gt_id in figure_gt_ids:
    if 'tmp2m' in gt_id:
            CB_colors_customized=['blue','dodgerblue','lightskyblue',"white",'red','firebrick','darkred']
            CB_minmax = (-4, 4)
    else:
            CB_colors_customized=['saddlebrown','peru',"white",'yellowgreen','green']
            CB_minmax = (-15, 15)
    plot_metric_maps(metric_dfs_rda, model_names=figure_model_names,
                         gt_ids=[gt_id],
                         horizons=figure_horizons,
                         metric=figure_metric,
                         target_dates=figure_target_dates,
                         mean_metric_df=figure_mean_metric_df,
                         show=figure_show, 
                         scale_type='linear',
                         CB_colors_customized=CB_colors_customized,
                         CB_minmax = CB_minmax)
