# Imports

In [None]:
proj_dir='/path/to/main_project_folder/' # edit this line

import matplotlib.pyplot as plt
import matplotlib.font_manager
import seaborn
import importlib
import xarray as xr
import pandas as pd
import numpy as np
import statsmodels.api as sm
import matplotlib.cm as cm
import numpy as np
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import seaborn as sns
from os.path import exists
import importlib
from matplotlib import colors
from matplotlib.ticker import MultipleLocator, FormatStrFormatter, AutoMinorLocator, MaxNLocator
from PIL import Image
from shapely.geometry.polygon import Polygon
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
from sklearn.neighbors import KernelDensity
import sys
sys.path.append(proj_dir)
from project_utils import parameters as param
from project_utils import load_region
from project_utils import prepare_inputs
importlib.reload(param)
importlib.reload(prepare_inputs)
importlib.reload(load_region)

# Set figure constants

In [None]:
fs = 15
w=200
ninputs = 2
SNOWFREE = True
nlag = 1

plt.rcParams['font.family'] = 'DeJavu Serif'
plt.rcParams['font.serif'] = ['Helvetica']

skill_cmap = "cividis"
pdp_cmap = "bone"
pnorm = 0.45
sublabel_bbox = dict(boxstyle="round", pad=0.2, fc="white", ec="k", lw=2)
tick_fs = fs-1
win = 400
min_per = 1 
clip = 15

c_vec = {} 
c_vec['southcentral_north_america'] = cm.cool(0.9)
c_vec['southeastern_north_america'] = cm.cividis(0.8)
c_vec['northcentral_north_america'] = 'darkslategray'
c_vec['southwestern_europe'] = cm.tab20(0)
c_vec['western_europe'] = 'maroon'
c_vec['central_europe'] = 'limegreen'
c_vec['eastern_europe'] = cm.tab20(12)
c_vec['northeastern_europe'] = 'k'
c_vec['northeastern_asia'] = 'tab:cyan'
c_vec['southeastern_asia'] = 'darkgreen'
c_vec['northsouthern_south_america'] = 'tan'
c_vec['southsouthern_south_america'] = 'darkviolet'
c_vec['southeastern_africa'] = 'orange'
c_vec['southwestern_africa'] = 'teal'
c_vec['southeastern_australia'] = 'tab:gray'
c_vec['southwestern_australia'] = 'firebrick'

month_doy = [0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334]
month_str = ['J', 'F', 'M', 'A', 'M', 'J', 'J', 'A', 'S', 'O', 'N', 'D']
month_str_long = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']

shifted_month_doy = [0, 31, 62, 92, 123, 153, 184, 215, 243, 274, 304, 334]
shifted_month_str = ['J', 'A', 'S', 'O', 'N', 'D', 'J', 'F', 'M', 'A', 'M', 'J']
shifted_month_str_long = ['Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec', 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun']

centered_month_doy = month_doy.copy()
centered_shifted_month_doy = shifted_month_doy.copy()
for ii in range(len(month_doy)):
    if ii == len(month_doy)-1:
        centered_month_doy[ii] = int((month_doy[ii] + 366)/2)
        centered_shifted_month_doy[ii] = int((shifted_month_doy[ii] + 366)/2)
    else:
        centered_month_doy[ii] = int((month_doy[ii] + month_doy[ii+1])/2)
        centered_shifted_month_doy[ii] = int((shifted_month_doy[ii] + shifted_month_doy[ii+1])/2)

hemispheres = ['north', 'south']

continents = {}
continents['north'] = ['north_america', 'europe', 'asia']
continents['south'] = ['south_america', 'africa', 'australia']

cont_bounds = {}
cont_bounds['north_america'] = [-134., -65., 15., 55.]
cont_bounds['south_america'] = [-99., -30., -60., -20.]
cont_bounds['europe'] = [-14., 55., 31., 71.]
cont_bounds['africa'] = [-9., 60., -43., -3.]
cont_bounds['asia'] = [71., 140., 14., 54.]
cont_bounds['australia'] = [101., 170., -47., -7.]
ext_adjust = [6, -6, 1.5, -6]

region_str_tpl = {}
region_str_tpl['north_america'] = ['southcentral_north_america', 'southeastern_north_america', 'northcentral_north_america']
region_str_tpl['europe'] = ['southwestern_europe', 'western_europe', 'central_europe', 'eastern_europe', 'northeastern_europe']
region_str_tpl['asia'] = ['northeastern_asia', 'southeastern_asia']
region_str_tpl['south_america'] = ['northsouthern_south_america', 'southsouthern_south_america']
region_str_tpl['africa'] = ['southwestern_africa', 'southeastern_africa']
region_str_tpl['australia'] = ['southwestern_australia', 'southeastern_australia']

reg_abbrev_labels = {}
reg_abbrev_labels['north_america'] = ['S. Plains', 'southeastern_north_america', 'northcentral_north_america']
reg_abbrev_labels['europe'] = ['southwestern_europe', 'western_europe', 'central_europe', 'eastern_europe', 'W. northeastern_europe']
reg_abbrev_labels['asia'] = ['N. China', 'S. China']
reg_abbrev_labels['south_america'] = ['C. Argentina', 'S. Argentina']
reg_abbrev_labels['africa'] = ['SW. Africa', 'SE. Africa']
reg_abbrev_labels['australia'] = ['W. Australia', 'E. Australia']

reg_labels_split = {}
reg_labels_split['southcentral_north_america'] = 'southcentral\nNorth America'
reg_labels_split['southeastern_north_america'] = 'southeastern\nNorth America'
reg_labels_split['northcentral_north_america'] = 'northcentral\nNorth America'
reg_labels_split['southwestern_europe'] = 'southwestern\nEurope'
reg_labels_split['western_europe'] = 'western\nEurope'
reg_labels_split['central_europe'] = 'central\nEurope'
reg_labels_split['eastern_europe'] = 'eastern\nEurope'
reg_labels_split['northeastern_europe'] = 'northeastern\nEurope'
reg_labels_split['northeastern_asia'] = 'northeastern\nAsia'
reg_labels_split['southeastern_asia'] = 'southeastern\nAsia'
reg_labels_split['northsouthern_south_america'] = 'north-southern\nSouth America'
reg_labels_split['southsouthern_south_america'] = 'south-southern\nSouth America'
reg_labels_split['southwestern_africa'] = 'southwestern\nAfrica'
reg_labels_split['southeastern_africa'] = 'southeastern\nAfrica'
reg_labels_split['southwestern_australia'] = 'southwestern\nAustralia'
reg_labels_split['southeastern_australia'] = 'southeastern\nAustralia'

reg_labels_split_padded = {}
reg_labels_split_padded['southcentral_north_america'] = ' southcentral \n North America '
reg_labels_split_padded['southeastern_north_america'] = ' southeastern \n North America '
reg_labels_split_padded['northcentral_north_america'] = ' northcentral \n North America '
reg_labels_split_padded['southwestern_europe'] = ' southwestern \n Europe '
reg_labels_split_padded['western_europe'] = ' western \n Europe '
reg_labels_split_padded['central_europe'] = ' central \n Europe '
reg_labels_split_padded['eastern_europe'] = ' eastern \n Europe '
reg_labels_split_padded['northeastern_europe'] = ' northeastern \n Europe '
reg_labels_split_padded['northeastern_asia'] = ' northeastern \n Asia '
reg_labels_split_padded['southeastern_asia'] = ' southeastern \n Asia '
reg_labels_split_padded['northsouthern_south_america'] = ' north-southern \n South America '
reg_labels_split_padded['southsouthern_south_america'] = ' south-southern \n South America '
reg_labels_split_padded['southwestern_africa'] = ' southwestern \n Africa '
reg_labels_split_padded['southeastern_africa'] = ' southeastern \n Africa '
reg_labels_split_padded['southwestern_australia'] = ' southwestern \n Australia '
reg_labels_split_padded['southeastern_australia'] = ' southeastern \n Australia '

reg_labels = {}
reg_labels['southcentral_north_america'] = 'southcentral North America'
reg_labels['southeastern_north_america'] = 'southeastern North America'
reg_labels['northcentral_north_america'] = 'northcentral North America'
reg_labels['southwestern_europe'] = 'southwestern Europe'
reg_labels['western_europe'] = 'western Europe'
reg_labels['central_europe'] = 'central Europe'
reg_labels['eastern_europe'] = 'eastern Europe'
reg_labels['northeastern_europe'] = 'northeastern Europe'
reg_labels['northeastern_asia'] = 'northeastern Asia'
reg_labels['southeastern_asia'] = 'southeastern Asia'
reg_labels['northsouthern_south_america'] = 'north-southern South America'
reg_labels['southsouthern_south_america'] = 'south-southern South America'
reg_labels['southwestern_africa'] = 'southwestern Africa'
reg_labels['southeastern_africa'] = 'southeastern Africa'
reg_labels['southwestern_australia'] = 'southwestern Australia'
reg_labels['southeastern_australia'] = 'southeastern Australia'

reg_labels_abbrev = {}
reg_labels_abbrev['southcentral_north_america'] = 'southcentral N. America'
reg_labels_abbrev['southeastern_north_america'] = 'southeastern N. America'
reg_labels_abbrev['northcentral_north_america'] = 'northcentral N. America'
reg_labels_abbrev['southwestern_europe'] = 'southwestern Europe'
reg_labels_abbrev['western_europe'] = 'western Europe'
reg_labels_abbrev['central_europe'] = 'central Europe'
reg_labels_abbrev['eastern_europe'] = 'eastern Europe'
reg_labels_abbrev['northeastern_europe'] = 'northeastern Europe'
reg_labels_abbrev['northeastern_asia'] = 'northeastern Asia'
reg_labels_abbrev['southeastern_asia'] = 'southeastern Asia'
reg_labels_abbrev['northsouthern_south_america'] = 'central southern S. America'
reg_labels_abbrev['southsouthern_south_america'] = 'southern southern S. America'
reg_labels_abbrev['southwestern_africa'] = 'southwestern Africa'
reg_labels_abbrev['southeastern_africa'] = 'southeastern Africa'
reg_labels_abbrev['southwestern_australia'] = 'southwestern Australia'
reg_labels_abbrev['southeastern_australia'] = 'southeastern Australia'

lon_adjust = {}
lat_adjust = {}
for hem in hemispheres:
    for cont in continents[hem]:
        for ll, region_str in enumerate(region_str_tpl[cont]):
            lon_adjust[region_str] = 0
            lat_adjust[region_str] = 0

lon_adjust['northsouthern_south_america'] = -1

lon_adjust['southwestern_europe'] = 4
lon_adjust['northeastern_europe'] = -3.4

lon_adjust['southeastern_africa'] = 3.75
lon_adjust['southwestern_africa'] = -2.85

lon_adjust['southcentral_north_america'] = -3.5
lat_adjust['southcentral_north_america'] = 2.5

lon_adjust['southeastern_asia'] = 0.5
lon_adjust['northeastern_asia'] = 0
lat_adjust['northeastern_asia'] = 1.25

lon_adjust['southeastern_north_america'] = 2
lat_adjust['southeastern_north_america'] = 1.8

lon_adjust['southwestern_australia'] = -3.3
lon_adjust['southeastern_australia'] = 3.75
lat_adjust['southwestern_australia'] = 1
lat_adjust['southeastern_australia'] = 1.2

nlag_list = [0, 1, 2, 3, 7, 14, 30]

ice_y = {}
x_vec = {}
x_vec_unseen_bool = {}
x_vec_train_bool = {}
years = {}
train_years = {}
for ax_row, hem in enumerate(hemispheres):
    for ax_col, cont in enumerate(continents[hem]):
        for ll, region_str in enumerate(region_str_tpl[cont]):
            pred_temps = pd.read_csv("../processed_data_NCEP/"+region_str+"/model_predictions"+"_lag"+str(nlag)+".csv")
            pred_temps['date'] = pd.to_datetime(pred_temps['date'])
            test_years = pred_temps[pred_temps.set == "unseen"].date.dt.year.unique()
            train_yrs = pred_temps[pred_temps.set == "train"].date.dt.year.unique()
            if hem == 'south':
                test_years = np.delete(test_years, np.argwhere(test_years==1979))
            years[region_str] = test_years
            train_years[region_str] = train_yrs
            x_vec[region_str] = {}
            x_vec_unseen_bool[region_str] = {}
            x_vec_train_bool[region_str] = {}
    
jan_1st_idx = {}
for curr_lag in nlag_list:
    for ax_row, hem in enumerate(hemispheres):
        for ax_col, cont in enumerate(continents[hem]):
            for ll, region_str in enumerate(region_str_tpl[cont]):
                soilw_df = pd.read_csv("../processed_data_NCEP/"+region_str+"/region_avg_soilw_cday_anomaly.csv")
                soilw_df['time'] = pd.to_datetime(soilw_df['time'])
                time_vec = soilw_df['time']
                if hem == 'south':
                    if SNOWFREE:
                        if curr_lag == 0:
                            soilw_df = soilw_df[(pd.to_datetime(time_vec).dt.month < 6) | (pd.to_datetime(time_vec).dt.month > 8)].reset_index(drop=True)
                            jan_1st_idx[region_str] = sorted_soilw_df.index.get_loc(0)
                        else:
                            soilw_df = soilw_df[(pd.to_datetime(time_vec).dt.month < 6) | (pd.to_datetime(time_vec).dt.month > 8)].reset_index(drop=True).iloc[:-curr_lag]
                else:
                    if SNOWFREE:
                        soilw_df = soilw_df[(pd.to_datetime(time_vec).dt.month >= 3) & (pd.to_datetime(time_vec).dt.month <= 11)].reset_index(drop=True)
                sorted_soilw_df = soilw_df.sort_values(by='soilw_daily_anom')
                x_vec[region_str][str(curr_lag)] = sorted_soilw_df.soilw_daily_anom.values
                x_vec_unseen_bool[region_str][str(curr_lag)] = sorted_soilw_df.time.dt.year.isin(years[region_str]).values
                x_vec_train_bool[region_str][str(curr_lag)] = sorted_soilw_df.time.dt.year.isin(train_years[region_str]).values

subplot_region_order = ['northcentral_north_america', 'central_europe', 'northeastern_europe', 'northeastern_asia',
                        'southcentral_north_america', 'western_europe', 'eastern_europe', 'southeastern_asia',
                        'southeastern_north_america', 'southwestern_europe', 'southwestern_africa', 'southeastern_africa', 
                        'northsouthern_south_america', 'southsouthern_south_america', 'southwestern_australia', 'southeastern_australia']

subplot_hem_order = ['north', 'north', 'north', 'north',
                   'north', 'north', 'north', 'north',
                   'north', 'north', 'south','south', 
                   'south', 'south', 'south', 'south']

region_bboxes = {}
for hem in hemispheres:
    for cont in continents[hem]:
        for region_str in region_str_tpl[cont]:
            region_abbrev, region_input_lat_bbox, region_input_lon_bbox, region_box_x, region_box_y, region_lat, region_lon, region_lon_EW, region_t62_lats, region_t62_lons, read = load_region.load_region_constants_modules(region_str)
            region_bboxes[region_str] = {"lon":[region_lon_EW.stop, region_lon_EW.start], "lat":[region_lat.stop, region_lat.start], "color": c_vec[region_str]}
            
def show_fig(fname, grayscale):
    fig, ax = plt.subplots(1,1,figsize=(20,20))
    if grayscale:
        image = Image.open(fname).convert("L")
        arr = np.asarray(image)
        ax.imshow(arr, cmap='gray', vmin=0, vmax=255)
    else:
        image = Image.open(fname)
        arr = np.asarray(image)
        ax.imshow(arr, vmin=0, vmax=255)
    ax.set_xticks([])
    ax.set_yticks([])
    plt.tight_layout()
    plt.show()


def calc_full_rolling_mean(vec):
    max_win = 400
    mid_idx = np.median(range(len(vec)))
    last_idx = len(vec)-1
    rolling_mean = np.zeros(np.shape(vec))
    for jj in range(len(vec)):
        if (jj == 0) | (jj == len(vec)-1):
            rolling_mean[jj] = vec[jj]
        elif jj <= mid_idx:
            win_size = jj
            win_size = min(win_size, max_win)
            rolling_mean[jj] = np.mean(vec[jj-win_size:jj+win_size+1])
        elif jj > mid_idx:
            win_size = last_idx - jj
            win_size = min(win_size, max_win)
            rolling_mean[jj] = np.mean(vec[jj-win_size:jj+win_size+1])

    return rolling_mean
    
    
print('done')

# Supplemental Figure 5: NCEP regions and climatologies

In [None]:
show_std = True
top_left_labels = False
center_labels = True
plt.rcParams['axes.linewidth'] = 4
fig = plt.figure(figsize=(25*0.9*1.5, 16*0.95*1.5))
fig.patch.set_facecolor('white')

col1_left = 0.00 
row1_bot = 0.83 
ax1_width = 0.5 
ax2_width = 0.18 
ax_hgt = 0.150  
lab_rot = 270

cont_num = 0
for hem in hemispheres:
    for cont in continents[hem]:
        if hem == 'south':
            vpad = 0.045
        elif hem == 'north':
            vpad = 0.0
        ax = fig.add_axes([col1_left, row1_bot-(ax_hgt)*cont_num-vpad, ax1_width, ax_hgt], projection=ccrs.PlateCarree())
        ax.set_extent(cont_bounds[cont], crs=ccrs.PlateCarree())
        if cont == 'north_america':
            ax.add_feature(cfeature.STATES, linewidth=0.5)
        ax.add_feature(cfeature.NaturalEarthFeature('physical', 'ocean', '50m', edgecolor='face', facecolor='azure'), zorder=-30)
        ax.add_feature(cfeature.BORDERS.with_scale('110m'), linewidth=1, zorder=-20)
        ax.add_feature(cfeature.COASTLINE.with_scale('110m'), linewidth=1, zorder=-10)
        ax1 = fig.add_axes([col1_left+0.75*ax1_width, row1_bot-(ax_hgt)*cont_num-vpad, ax2_width, ax_hgt])
        ax2 = fig.add_axes([col1_left+0.75*ax1_width+1.19*ax2_width, row1_bot-(ax_hgt)*cont_num-vpad, ax2_width, ax_hgt])
        
        for ll, region_str in enumerate(region_str_tpl[cont]):
            pgon = Polygon(((region_bboxes[region_str]["lon"][0], region_bboxes[region_str]["lat"][0]),
                    (region_bboxes[region_str]["lon"][0], region_bboxes[region_str]["lat"][1]),
                    (region_bboxes[region_str]["lon"][1], region_bboxes[region_str]["lat"][1]),
                    (region_bboxes[region_str]["lon"][1], region_bboxes[region_str]["lat"][0]),
                    (region_bboxes[region_str]["lon"][0], region_bboxes[region_str]["lat"][0])))
            reg_center = ((region_bboxes[region_str]["lon"][0]+region_bboxes[region_str]["lon"][1])/2 + lon_adjust[region_str], 
                          (region_bboxes[region_str]["lat"][0]+region_bboxes[region_str]["lat"][1])/2 + lat_adjust[region_str])
            top_left = (region_bboxes[region_str]["lon"][1], region_bboxes[region_str]["lat"][1])
            top_left_txt = (region_bboxes[region_str]["lon"][1]-4, region_bboxes[region_str]["lat"][1]+4)
            
            ax.add_geometries([pgon], crs=ccrs.PlateCarree(), facecolor=region_bboxes[region_str]["color"], edgecolor='k', alpha=0.7, linewidth=1.75, zorder=-40)
            if center_labels:
                if region_str in ['southwestern_africa', 'southeastern_africa', 'southcentral_north_america', 'northcentral_north_america', 'southeastern_north_america']:
                    ax.annotate(reg_labels_split[region_str],
                                xy=reg_center, xycoords='data',
                                bbox=dict(edgecolor='k', facecolor='white', boxstyle='round'),
                                horizontalalignment='center', verticalalignment='center', zorder=100, fontsize=fs-0.5)
                else:
                    ax.annotate(reg_labels[region_str],
                                xy=reg_center, xycoords='data',
                                bbox=dict(edgecolor='k', facecolor='white', boxstyle='round'),
                                horizontalalignment='center', verticalalignment='center', zorder=100, fontsize=fs-0.5)
            elif top_left_labels:
                ax.annotate(reg_labels[region_str],
                            xy=top_left, xycoords='data',
                            xytext=top_left_txt, textcoords='data',
                            arrowprops=dict(facecolor=region_bboxes[region_str]["color"], shrink=0.05),
                            bbox=dict(edgecolor='k', facecolor='white', boxstyle='round'),
                            horizontalalignment='right', verticalalignment='top', zorder=100, fontsize=fs)
            
            tmax_df = pd.read_csv("../processed_data_NCEP/"+region_str+"/region_avg_tmax.csv")
            tmax_df['time'] = pd.to_datetime(tmax_df['time'])
            soilw_df = pd.read_csv("../processed_data_NCEP/"+region_str+"/region_avg_soilw.csv")
            soilw_df['time'] = pd.to_datetime(soilw_df['time'])
            
            lw=6
            lw_std=1.8
            alph=0.9
            alph_std = 0.20
            fill_alph=0.0
            cp=3
            ebar_shift = {}
            ebar_shift['north_america'] = [-4, 0, 4]
            ebar_shift['europe'] = [-8, -4, 0, 4, 8]
            ebar_shift['asia'] = [-2, 2]
            ebar_shift['south_america'] = [-2, 2]
            ebar_shift['africa'] = [-2, 2]
            ebar_shift['australia'] = [-2, 2]
            
            if hem == 'north':
                ax1.plot(range(0,366), tmax_df.groupby([tmax_df['time'].dt.dayofyear]).mean().values-273.15, color=region_bboxes[region_str]["color"],  label='region_str', linewidth=lw, alpha=alph)
                ax2.plot(range(0,366), soilw_df.groupby([soilw_df['time'].dt.dayofyear]).mean().values, color=region_bboxes[region_str]["color"],  label='region_str', linewidth=lw, alpha=alph)
                if show_std:
                    tmax_mean = tmax_df.groupby([tmax_df['time'].dt.dayofyear]).mean().values[:,0]-273.15
                    tmax_std = tmax_df.groupby([tmax_df['time'].dt.dayofyear]).std().values[:,0]
                    soilw_mean = soilw_df.groupby([soilw_df['time'].dt.dayofyear]).mean().values[:,0]
                    soilw_std = soilw_df.groupby([soilw_df['time'].dt.dayofyear]).std().values[:,0]
                    from scipy.interpolate import make_interp_spline
                    X_Y_Spline = make_interp_spline(range(0,366), tmax_std)
                    X_ = range(0,366)
                    tmax_y = X_Y_Spline(X_)
                    X_Y_Spline = make_interp_spline(range(0,366), soilw_std)
                    soilw_y = X_Y_Spline(X_)
                    ax1.plot(X_, tmax_mean+tmax_y, color=region_bboxes[region_str]["color"],  label='region_str', linewidth=lw_std, alpha=alph_std)
                    ax2.plot(X_, soilw_mean+soilw_y, color=region_bboxes[region_str]["color"],  label='region_str', linewidth=lw_std, alpha=alph_std)
                    ax1.plot(X_, tmax_mean-tmax_y, color=region_bboxes[region_str]["color"],  label='region_str', linewidth=lw_std, alpha=alph_std)
                    ax2.plot(X_, soilw_mean-soilw_y, color=region_bboxes[region_str]["color"],  label='region_str', linewidth=lw_std, alpha=alph_std)
                    
            elif hem == 'south':
                ax1.plot(range(0,366), np.concatenate((tmax_df.groupby([tmax_df['time'].dt.dayofyear]).mean().values[month_doy[6]:], tmax_df.groupby([tmax_df['time'].dt.dayofyear]).mean().values[:month_doy[6]]))-273.15, color=region_bboxes[region_str]["color"],  label='region_str', linewidth=lw, alpha=alph)
                ax2.plot(range(0,366), np.concatenate((soilw_df.groupby([soilw_df['time'].dt.dayofyear]).mean().values[month_doy[6]:], soilw_df.groupby([soilw_df['time'].dt.dayofyear]).mean().values[:month_doy[6]])), color=region_bboxes[region_str]["color"],  label='region_str', linewidth=lw, alpha=alph)
                if show_std:
                    tmax_mean = np.concatenate((tmax_df.groupby([tmax_df['time'].dt.dayofyear]).mean().values[month_doy[6]:], tmax_df.groupby([tmax_df['time'].dt.dayofyear]).mean().values[:month_doy[6]]))[:,0]-273.15
                    tmax_std = np.concatenate((tmax_df.groupby([tmax_df['time'].dt.dayofyear]).std().values[month_doy[6]:], tmax_df.groupby([tmax_df['time'].dt.dayofyear]).std().values[:month_doy[6]]))[:,0]
                    soilw_mean = np.concatenate((soilw_df.groupby([soilw_df['time'].dt.dayofyear]).mean().values[month_doy[6]:], soilw_df.groupby([soilw_df['time'].dt.dayofyear]).mean().values[:month_doy[6]]))[:,0]
                    soilw_std = np.concatenate((soilw_df.groupby([soilw_df['time'].dt.dayofyear]).std().values[month_doy[6]:], soilw_df.groupby([soilw_df['time'].dt.dayofyear]).std().values[:month_doy[6]]))[:,0]
                    from scipy.interpolate import make_interp_spline
                    X_Y_Spline = make_interp_spline(range(0,366), tmax_std)
                    X_ = range(0, 366)
                    tmax_y = X_Y_Spline(X_)
                    X_Y_Spline = make_interp_spline(range(0,366), soilw_std)
                    soilw_y = X_Y_Spline(X_)
                    ax1.plot(X_, tmax_mean+tmax_y, color=region_bboxes[region_str]["color"],  label='region_str', linewidth=lw_std, alpha=alph_std)
                    ax2.plot(X_, soilw_mean+soilw_y, color=region_bboxes[region_str]["color"],  label='region_str', linewidth=lw_std, alpha=alph_std)
                    ax1.plot(X_, tmax_mean-tmax_y, color=region_bboxes[region_str]["color"],  label='region_str', linewidth=lw_std, alpha=alph_std)
                    ax2.plot(X_, soilw_mean-soilw_y, color=region_bboxes[region_str]["color"],  label='region_str', linewidth=lw_std, alpha=alph_std)
                    
            
        ax1_color = 'k'
        ax1.set_xlim([0,365])
        ax1.set_ylim([260-273.15, 310-273.15])
        ax1.tick_params(axis='y', labelcolor=ax1_color, labelsize=fs+2, length=10, width=1.5, direction='inout')
        ax1.set_yticks(ticks=[-10, 0, 10, 20, 30])
        ax2_color = 'k'
        ax2.set_xlim([0,365])
        ax2.set_ylim([0.05, 0.47])
        ax2.tick_params(axis='y', labelcolor=ax2_color, labelsize=fs+2, length=10, width=1.5, direction='inout')  
        ax2.set_yticks(ticks=[0.10, 0.20, 0.30, 0.40])
        
        al = 0.15
        if hem == 'north':
            ax1.axvspan(month_doy[5], month_doy[8], alpha=al, color='tab:red', zorder=-10)
            ax2.axvspan(month_doy[5], month_doy[8], alpha=al, color='tab:red', zorder=-10)
            ax1.axvspan(month_doy[0], month_doy[2], alpha=al, color='gray', zorder=-10)
            ax2.axvspan(month_doy[0], month_doy[2], alpha=al, color='gray', zorder=-10)
            ax1.axvspan(month_doy[11], month_doy[11]+31, alpha=al, color='gray', zorder=-10)
            ax2.axvspan(month_doy[11], month_doy[11]+31, alpha=al, color='gray', zorder=-10)
            ax1.vlines(month_doy, ax1.get_ylim()[0], ax1.get_ylim()[1], colors='k', linestyles='dotted', linewidths=0.5)
            ax2.vlines(month_doy, ax2.get_ylim()[0], ax2.get_ylim()[1], colors='k', linestyles='dotted', linewidths=0.5)
        elif hem == 'south':
            ax1.axvspan(shifted_month_doy[5], shifted_month_doy[8], alpha=al, color='tab:red', zorder=-10)
            ax2.axvspan(shifted_month_doy[5], shifted_month_doy[8], alpha=al, color='tab:red', zorder=-10)
            ax1.axvspan(shifted_month_doy[0], shifted_month_doy[2], alpha=al, color='gray', zorder=-10)
            ax2.axvspan(shifted_month_doy[0], shifted_month_doy[2], alpha=al, color='gray', zorder=-10)
            ax1.axvspan(shifted_month_doy[11], shifted_month_doy[11]+31, alpha=al, color='gray', zorder=-10)
            ax2.axvspan(shifted_month_doy[11], shifted_month_doy[11]+31, alpha=al, color='gray', zorder=-10)
            ax1.vlines(shifted_month_doy, ax1.get_ylim()[0], ax1.get_ylim()[1], colors='k', linestyles='dotted', linewidths=0.5)
            ax2.vlines(shifted_month_doy, ax2.get_ylim()[0], ax2.get_ylim()[1], colors='k', linestyles='dotted', linewidths=0.5)
            
        if cont_num == 0:
            ax.set_title('N. Hemisphere Regions', fontsize = fs+12, fontweight='bold')
            ax1.set_title('TMAX climatology', fontsize = fs+12, fontweight='bold')
            ax2.set_title('SM climatology', fontsize = fs+12, fontweight='bold')
            ax1.set_ylabel('degrees (\u00B0C)', color=ax1_color, fontsize = fs+8)
            ax2.set_ylabel('fraction (vol)', color=ax2_color, fontsize = fs+8)
        elif cont_num == 2:
            ax1.set_xticks(ticks=centered_month_doy, labels=month_str_long, fontsize = fs+8)
            ax2.set_xticks(ticks=centered_month_doy, labels=month_str_long, fontsize = fs+8)   
            ax1.tick_params(axis='x', length=8, width=1.5, direction='inout', labelrotation=lab_rot+30)
            ax2.tick_params(axis='x', length=8, width=1.5, direction='inout', labelrotation=lab_rot+30)
        elif cont_num == 3:
            ax.set_title('S. Hemisphere Regions', fontsize = fs+12, fontweight='bold')
            ax1.set_ylabel('degrees (\u00B0C)', color=ax1_color, fontsize = fs+8)
            ax2.set_ylabel('fraction (vol)', color=ax2_color, fontsize = fs+8)
        elif cont_num == 5:
            ax1.set_xticks(ticks=centered_shifted_month_doy, labels=shifted_month_str_long, fontsize = fs+8)
            ax2.set_xticks(ticks=centered_shifted_month_doy, labels=shifted_month_str_long, fontsize = fs+8)
            ax1.tick_params(axis='x', length=8, width=1.5, direction='inout', labelrotation=lab_rot+30)
            ax2.tick_params(axis='x', length=8, width=1.5, direction='inout', labelrotation=lab_rot+30)
        cont_num = cont_num+1
        
label_fs = fs+10

txt_kwargs={'ha':'left','va':'center','fontsize':label_fs,'fontweight':'bold','color':'k','zorder':10000,'bbox':sublabel_bbox}

fig.text(0.1685, 0.9650, 'a', txt_kwargs)
fig.text(0.3810, 0.9650, 'b', txt_kwargs)
fig.text(0.5960, 0.9650, 'c', txt_kwargs)
fig.text(0.1685, 0.4709-0.02+0.017, 'd', txt_kwargs)
fig.text(0.3810, 0.4709-0.02+0.017, 'e', txt_kwargs)
fig.text(0.5960, 0.4709-0.02+0.017, 'f', txt_kwargs)

fig.canvas.draw()
plt.savefig("../figures_NCEP/SupFig_05_NCEP_regions.png", transparent=False, dpi=800)

# Supplemental Figures 6-8: NCEP model skill

In [None]:
def plot_model_skill_compare(outer_fig, axes, regnum, cbar_ax):
    skill_RMSE = {}
    leg_s = 25
    leg_alph = 1.0
    train_col = 'gray'
    test_col = 'gray'
    unseen_col = 'gray'
    
    ax = axes[:3]
        
    predict_df_daily = pd.read_csv("../processed_data_NCEP/"+region_str+"/model_predictions"+"_lag"+str(nlag)+".csv")
    tmax_predictions_train = predict_df_daily[predict_df_daily.set == 'train'].predicted_tmax-273.15
    tmax_predictions_test = predict_df_daily[predict_df_daily.set == 'test'].predicted_tmax-273.15
    tmax_predictions_unseen = predict_df_daily[predict_df_daily.set == 'unseen'].predicted_tmax-273.15
    y_train = predict_df_daily[predict_df_daily.set == 'train'].true_y-273.15
    y_test = predict_df_daily[predict_df_daily.set == 'test'].true_y-273.15
    y_unseen = predict_df_daily[predict_df_daily.set == 'unseen'].true_y-273.15

    ax[0].scatter(y_train, tmax_predictions_train, s=0.5, alpha=0.5, color=train_col, label='_train')
    ax[1].scatter(y_test, tmax_predictions_test, s=0.5, alpha=0.5, color=test_col, label='_test')
    ax[2].scatter(y_unseen, tmax_predictions_unseen, s=0.5, alpha=0.5, color=unseen_col, label='_unseen')
    
    p0 = ax[0].hist2d(y_train, tmax_predictions_train, bins = nbins, cmin=10/len(y_train),
                 zorder = 5, cmap = skill_cmap, density=True)
    p1 = ax[1].hist2d(y_test, tmax_predictions_test, bins = nbins, cmin=10/len(y_test),
                 zorder = 5, cmap = skill_cmap, density=True)
    p2 = ax[2].hist2d(y_unseen, tmax_predictions_unseen, bins = nbins, cmin=10/len(y_unseen),
                 zorder = 5, cmap = skill_cmap, density=True)

    x_l0, x_r0 = ax[0].get_xlim()
    x_l1, x_r1 = ax[1].get_xlim()

    x_l = min(x_l0, x_l1)
    x_r = min(x_r0, x_r1)

    lw = 2 
    ax[0].plot(np.arange(x_l,x_r, 0.5), np.arange(x_l,x_r, 0.5), 'r', zorder = 10, linewidth=lw)
    ax[0].plot(np.arange(x_l,x_r, 0.5), np.arange(x_l,x_r, 0.5)+3, 'gray', zorder = 10, linestyle='dotted', linewidth=lw)
    ax[0].plot(np.arange(x_l,x_r, 0.5), np.arange(x_l,x_r, 0.5)-3, 'gray', zorder = 10, linestyle='dotted', linewidth=lw)
    ax[0].set_xlim([x_l,x_r])
    ax[0].set_ylim([x_l,x_r])
    ax[0].yaxis.set_major_locator(MaxNLocator(integer=True))
    ax[1].plot(np.arange(x_l,x_r, 0.5), np.arange(x_l,x_r, 0.5), 'r', zorder = 10, linewidth=lw)
    ax[1].plot(np.arange(x_l,x_r, 0.5), np.arange(x_l,x_r, 0.5)+3, 'gray', zorder = 10, linestyle='dotted', linewidth=lw)
    ax[1].plot(np.arange(x_l,x_r, 0.5), np.arange(x_l,x_r, 0.5)-3, 'gray', zorder = 10, linestyle='dotted', linewidth=lw)
    ax[1].set_xlim([x_l,x_r])
    ax[1].set_ylim([x_l,x_r])
    ax[2].plot(np.arange(x_l,x_r, 0.5), np.arange(x_l,x_r, 0.5), 'r', zorder = 10, linewidth=lw)
    ax[2].plot(np.arange(x_l,x_r, 0.5), np.arange(x_l,x_r, 0.5)+3, 'gray', zorder = 10, linestyle='dotted', linewidth=lw)
    ax[2].plot(np.arange(x_l,x_r, 0.5), np.arange(x_l,x_r, 0.5)-3, 'gray', zorder = 10, linestyle='dotted', linewidth=lw)
    ax[2].set_xlim([x_l,x_r])
    ax[2].set_ylim([x_l,x_r])
     
    afs = fs    
    ax[0].annotate("R2 = {:.2f}".format(r2_score(y_train, tmax_predictions_train))+"\n"+
                   "MAE = {:.2f}".format(mean_absolute_error(y_train, tmax_predictions_train))+"\n"+
                   "RMSE = {:.2f}".format(np.sqrt(mean_squared_error(y_train, tmax_predictions_train))), (0.99, 0.01), xycoords='axes fraction', fontsize=afs, ha='right', va='bottom')
    ax[1].annotate("R2 = {:.2f}".format(r2_score(y_test, tmax_predictions_test))+"\n"+
                   "MAE = {:.2f}".format(mean_absolute_error(y_test, tmax_predictions_test))+"\n"+
                   "RMSE = {:.2f}".format(np.sqrt(mean_squared_error(y_test, tmax_predictions_test))), (0.99, 0.01), xycoords='axes fraction', fontsize=afs, ha='right', va='bottom')
    ax[2].annotate("R2 = {:.2f}".format(r2_score(y_unseen, tmax_predictions_unseen))+"\n"+
                   "MAE = {:.2f}".format(mean_absolute_error(y_unseen, tmax_predictions_unseen))+"\n"+
                   "RMSE = {:.2f}".format(np.sqrt(mean_squared_error(y_unseen, tmax_predictions_unseen))), (0.99, 0.01), xycoords='axes fraction', fontsize=afs, ha='right', va='bottom')
    
    skill_RMSE['SM+GPH+calday'] = np.sqrt(mean_squared_error(y_unseen, tmax_predictions_unseen))
    
    vmin0, vmax0 = p0[3].get_clim()
    vmin1, vmax1 = p1[3].get_clim()
    vmin2, vmax2 = p2[3].get_clim()

    p0[3].set_clim(min(vmin0, vmin1, vmin2), max(vmax0, vmax1, vmax2))
    p1[3].set_clim(min(vmin0, vmin1, vmin2), max(vmax0, vmax1, vmax2))
    p2[3].set_clim(min(vmin0, vmin1, vmin2), max(vmax0, vmax1, vmax2))
    
    if regnum == 0:
        cbar = plt.colorbar(p0[3], cax=cbar_ax)
        cbar.set_ticks([])
        cbar.set_label('density', fontsize=fs+4)
        cbar.ax.text(1.35, max(vmax0, vmax1, vmax2), 'high', ha='left', va='center', fontsize=fs+4)
        cbar.ax.text(1.35, min(vmin0, vmin1, vmin2), 'low', ha='left', va='center', fontsize=fs+4)
        cbar.set_ticks([vmin0, vmax0])
        cbar.ax.set_yticklabels([])
    tick_fs = fs-1
    
    if region_str in ['northcentral_north_america', 'central_europe', 'northeastern_europe', 'northeastern_asia', 'western_europe', 'eastern_europe']:
        major = 15
        minor = 7.5
    elif region_str in ['southcentral_north_america', 'southeastern_asia', 'southeastern_north_america', 'southwestern_europe', 'southwestern_africa','southeastern_africa', 'northsouthern_south_america', 'southsouthern_south_america', 'southwestern_australia', 'southeastern_australia']:
        major = 10
        minor = 5
    
    for k in range(3):
        ax[k].yaxis.set_major_locator(MultipleLocator(major)) 
        ax[k].xaxis.set_major_locator(MultipleLocator(major))  
        ax[k].xaxis.set_major_formatter(FormatStrFormatter('%d'))
        ax[k].yaxis.set_major_formatter(FormatStrFormatter(''))
        ax[k].yaxis.set_minor_locator(MultipleLocator(minor))
        ax[k].xaxis.set_minor_locator(MultipleLocator(minor))
        ax[k].tick_params(axis='both', which='major', labelsize=tick_fs, length=6, width=1.5, direction='inout')
        ax[k].tick_params(axis='both', which='minor', labelsize=tick_fs, length=5, width=1, direction='inout')
    
    ax_title_loc = (0.10, 0.90)
    ax[0].annotate("train", ax_title_loc, xycoords='axes fraction', fontsize=fs+4, ha='left', va='center', bbox=dict(boxstyle="round", pad=0.2, fc="white", ec="k", lw=2))
    ax[1].annotate("validate", ax_title_loc, xycoords='axes fraction', fontsize=fs+4, ha='left', va='center', bbox=dict(boxstyle="round", pad=0.2, fc="white", ec="k", lw=2))
    ax[2].annotate("test", ax_title_loc, xycoords='axes fraction', fontsize=fs+4, ha='left', va='center', bbox=dict(boxstyle="round", pad=0.2, fc="white", ec="k", lw=2))
        
    #################
    ## shuffled SM ##
    #################
    ax = [axes[3]]
    
    ax[0].annotate("test", ax_title_loc, xycoords='axes fraction', fontsize=fs+4, ha='left', va='center', bbox=dict(boxstyle="round", pad=0.2, fc="white", ec="k", lw=2))
    predict_df_daily = pd.read_csv("../processed_data_NCEP/"+region_str+"/no_SM_01_model_predictions_lag1.csv")
    tmax_predictions_unseen = predict_df_daily[predict_df_daily.set == 'unseen'].predicted_tmax-273.15
    y_unseen = predict_df_daily[predict_df_daily.set == 'unseen'].true_y-273.15

    ax[0].scatter(y_unseen, tmax_predictions_unseen, s=0.5, alpha=0.5, color=unseen_col, label='_unseen')
    p0 = ax[0].hist2d(y_unseen, tmax_predictions_unseen, bins = nbins, cmin=10/len(y_unseen),
                 zorder = 5, cmap = skill_cmap, density=True)

    x_l0, x_r0 = ax[0].get_xlim()

    x_l = x_l0
    x_r = x_r0

    lw = 2 
    ax[0].plot(np.arange(x_l,x_r, 0.5), np.arange(x_l,x_r, 0.5), 'r', zorder = 10, linewidth=lw)
    ax[0].plot(np.arange(x_l,x_r, 0.5), np.arange(x_l,x_r, 0.5)+3, 'gray', zorder = 10, linestyle='dotted', linewidth=lw)
    ax[0].plot(np.arange(x_l,x_r, 0.5), np.arange(x_l,x_r, 0.5)-3, 'gray', zorder = 10, linestyle='dotted', linewidth=lw)
    ax[0].set_xlim([x_l,x_r])
    ax[0].set_ylim([x_l,x_r])
    ax[0].yaxis.set_major_locator(MaxNLocator(integer=True))
     
    afs = fs    
    ax[0].annotate("R2 = {:.2f}".format(r2_score(y_unseen, tmax_predictions_unseen))+"\n"+
                   "MAE = {:.2f}".format(mean_absolute_error(y_unseen, tmax_predictions_unseen))+"\n"+
                   "RMSE = {:.2f}".format(np.sqrt(mean_squared_error(y_unseen, tmax_predictions_unseen))), (0.99, 0.01), xycoords='axes fraction', fontsize=afs, ha='right', va='bottom')

    skill_RMSE['GPH+calday'] = np.sqrt(mean_squared_error(y_unseen, tmax_predictions_unseen))
    
    p0[3].set_clim(min(vmin0, vmin1, vmin2), max(vmax0, vmax1, vmax2))
    
    ax[0].yaxis.set_major_locator(MultipleLocator(major)) 
    ax[0].xaxis.set_major_locator(MultipleLocator(major))  
    ax[0].yaxis.set_major_formatter(FormatStrFormatter(''))
    ax[0].xaxis.set_major_formatter(FormatStrFormatter('%d'))
    ax[0].yaxis.set_minor_locator(MultipleLocator(minor))
    ax[0].xaxis.set_minor_locator(MultipleLocator(minor))
    ax[0].tick_params(axis='both', which='major', labelsize=tick_fs, length=6, width=1.5, direction='inout')
    ax[0].tick_params(axis='both', which='minor', labelsize=tick_fs, length=5, width=1, direction='inout')
    
    ##########################
    ## baseline model skill ##
    ##########################
    ax = [axes[4]]
    
    predict_df_daily = pd.read_csv("../processed_data_NCEP/"+region_str+"/model_predictions"+"_lag"+str(nlag)+".csv")
    predict_df_daily['date'] = pd.to_datetime(predict_df_daily['date'])
    model_seasonal_cycle = predict_df_daily.predicted_tmax.groupby(predict_df_daily.date.dt.dayofyear).mean()
    ncep_seasonal_cycle = predict_df_daily.true_y.groupby(predict_df_daily.date.dt.dayofyear).mean()
    full_doys = predict_df_daily['date'].dt.dayofyear.copy()
    seasonal_cycle_predictions = -100*np.ones((len(predict_df_daily),1))

    for dd in full_doys.unique():
        seasonal_cycle_predictions[full_doys == dd] = ncep_seasonal_cycle.loc[dd]
    
    predict_df_daily['predicted_seasonal_cycle'] = seasonal_cycle_predictions
    tmax_predictions = predict_df_daily.predicted_seasonal_cycle-273.15
    y_true = predict_df_daily.true_y-273.15

    ax[0].scatter(y_true, tmax_predictions, s=0.5, alpha=0.5, color='gray')
    p0 = ax[0].hist2d(y_true, tmax_predictions, bins = nbins, cmin=10/len(y_true),
                 zorder = 5, cmap = skill_cmap, density=True)

    x_l0, x_r0 = ax[0].get_xlim()

    x_l = min(x_l0, x_l1)
    x_r = min(x_r0, x_r1)

    lw = 2 
    ax[0].plot(np.arange(x_l,x_r, 0.5), np.arange(x_l,x_r, 0.5), 'r', zorder = 10, linewidth=lw)
    ax[0].plot(np.arange(x_l,x_r, 0.5), np.arange(x_l,x_r, 0.5)+3, 'gray', zorder = 10, linestyle='dotted', linewidth=lw)
    ax[0].plot(np.arange(x_l,x_r, 0.5), np.arange(x_l,x_r, 0.5)-3, 'gray', zorder = 10, linestyle='dotted', linewidth=lw)
    ax[0].set_xlim([x_l,x_r])
    ax[0].set_ylim([x_l,x_r])
    ax[0].yaxis.set_major_locator(MaxNLocator(integer=True))
     
    afs = fs    
    ax[0].annotate("R2 = {:.2f}".format(r2_score(y_true, tmax_predictions))+"\n"+
                   "MAE = {:.2f}".format(mean_absolute_error(y_true, tmax_predictions))+"\n"+
                   "RMSE = {:.2f}".format(np.sqrt(mean_squared_error(y_true, tmax_predictions))), (0.99, 0.01), xycoords='axes fraction', fontsize=afs, ha='right', va='bottom')
    
    skill_RMSE['calday'] = np.sqrt(mean_squared_error(y_true, tmax_predictions))
    
    p0[3].set_clim(min(vmin0, vmin1, vmin2), max(vmax0, vmax1, vmax2))
           
    ax[0].yaxis.set_major_locator(MultipleLocator(major)) 
    ax[0].xaxis.set_major_locator(MultipleLocator(major))  
    ax[0].xaxis.set_major_formatter(FormatStrFormatter('%d'))
    ax[0].yaxis.set_minor_locator(MultipleLocator(minor))
    ax[0].xaxis.set_minor_locator(MultipleLocator(minor))
    ax[0].tick_params(axis='both', which='major', labelsize=tick_fs, length=6, width=1.5, direction='inout')
    ax[0].tick_params(axis='both', which='minor', labelsize=tick_fs, length=5, width=1, direction='inout')
    ax[0].yaxis.set_major_formatter(FormatStrFormatter('%d'))
    
    print('RMSE')
    print(region_str, skill_RMSE['calday'], skill_RMSE['GPH+calday'], skill_RMSE['SM+GPH+calday'])
    print('calday-->shuffled', np.round(100*((skill_RMSE['calday']-skill_RMSE['GPH+calday'])/skill_RMSE['calday']), 4), '%')
    print('shuffled-->full', np.round(100*((skill_RMSE['GPH+calday']-skill_RMSE['SM+GPH+calday'])/skill_RMSE['GPH+calday']), 4), '%')

In [None]:
panel_1_regions = ['northcentral_north_america', 'southcentral_north_america', 'southeastern_north_america', 'northsouthern_south_america', 'southsouthern_south_america']
panel_2_regions = ['southwestern_europe', 'western_europe', 'central_europe', 'eastern_europe', 'northeastern_europe', 'southwestern_africa', 'southeastern_africa']
panel_3_regions = ['northeastern_asia', 'southeastern_asia', 'southwestern_australia', 'southeastern_australia']

for p_num, panel_regs in enumerate([panel_1_regions, panel_2_regions, panel_3_regions]):
    if p_num == 0:
        reglab_adj = 0.01
        xlab_yadj = 0
    elif p_num == 1:
        reglab_adj = -0.005
        xlab_yadj = 0.01
    elif p_num == 2:
        reglab_adj = -0.005
        xlab_yadj = -0.01
    
    plt.rcParams['axes.linewidth'] = 3
    nbins = 50
    if p_num == 2:
        outer_fig, [outer_axes, cbar_ax] = plt.subplots(1,2,
                                             gridspec_kw={'left':0.2, 'bottom':0.1, 'width_ratios': [0.96, 0.04]}, 
                                             subplot_kw={'frame_on':False, 'xticks': [], 'yticks':[]},
                                             figsize=(16, 15*(len(panel_regs)/5)), frameon=True)
    else:
        outer_fig, [outer_axes, cbar_ax] = plt.subplots(1,2,
                                             gridspec_kw={'left':0.2, 'bottom':0.1, 'width_ratios': [0.97, 0.03]}, 
                                             subplot_kw={'frame_on':False, 'xticks': [], 'yticks':[]},
                                             figsize=(16, 15*(len(panel_regs)/5)), frameon=True)
        
    cbar_ax.remove()
    cbar_ax = outer_fig.add_axes([0.93, 0.12, 0.02, 0.80], frameon=True, xticks=[], yticks=[])

    ax_l = 0.09
    ax_t = 0.99
    ax_h = 0.155*(5/len(panel_regs))
    ax_w = 0.27
    ax_wpad = 0.01
    ax_hpad = 0.03*(5/len(panel_regs))

    outer_fig.patch.set_facecolor('white')
    title_fs = fs+2.5
    
    outer_fig.text(x=ax_l+ax_w/2+0.05, y=0.973, s="(a) CNN with all input variables", ha='center', fontsize=title_fs, fontweight='bold')
    outer_fig.text(x=ax_l+1*(ax_w+ax_wpad)+ax_w/2+0.11, y=0.973, s="(b) CNN without SM input", ha='center', fontsize=title_fs, fontweight='bold')
    outer_fig.text(x=ax_l+2*(ax_w+ax_wpad)+ax_w/2+0.07, y=0.973, s="(c) seasonal climatology", ha='center', fontsize=title_fs, fontweight='bold')

    outer_fig.text(x=0.29, y=0.033+xlab_yadj, s="daily NCEP TMAX (\u00B0C)", ha='center', fontsize=fs-0.5)
    outer_fig.text(x=0.62, y=0.033+xlab_yadj, s="daily NCEP TMAX (\u00B0C)", ha='center', fontsize=fs-0.5)
    outer_fig.text(x=0.852, y=0.033+xlab_yadj, s="daily NCEP TMAX (\u00B0C)", ha='center', fontsize=fs-0.5)

    p_ax = outer_fig.subplots(len(panel_regs),5)
    axes = np.empty_like(p_ax)
    for ij in range(len(panel_regs)):
        axes[ij,0] = outer_fig.add_axes([ax_l, ax_t-(ij+1)*(ax_h+ax_hpad), ax_w/2, ax_h], frame_on=True)
        axes[ij,1] = outer_fig.add_axes([ax_l+ax_w/2, ax_t-(ij+1)*(ax_h+ax_hpad), ax_w/2, ax_h], frame_on=True)
        axes[ij,2] = outer_fig.add_axes([ax_l+ax_w/2+ax_w/2, ax_t-(ij+1)*(ax_h+ax_hpad), ax_w/2, ax_h], frame_on=True)
        axes[ij,0].set_ylabel('predicted TMAX (\u00B0C)', fontsize=fs)
        
        axes[ij,3] = outer_fig.add_axes([ax_l+1.5*(ax_w+ax_wpad)+ax_w/4-0.02, ax_t-(ij+1)*(ax_h+ax_hpad), ax_w/2, ax_h], frame_on=True)
        axes[ij,3].set_ylabel('predicted TMAX (\u00B0C)', fontsize=fs)
        
        axes[ij,4] = outer_fig.add_axes([ax_l+2*(ax_w+ax_wpad)+ax_w/4+ax_w/4, ax_t-(ij+1)*(ax_h+ax_hpad), ax_w/2, ax_h], frame_on=True)
        axes[ij,4].set_ylabel('calendar-day mean\nTMAX (\u00B0C)', fontsize=fs)
        
        p_ax[ij,0].remove()
        p_ax[ij,1].remove()
        p_ax[ij,2].remove()
        p_ax[ij,3].remove()
        p_ax[ij,4].remove()
        
    for regnum, region_str in enumerate(panel_regs):
        if p_num == 0:
            outer_fig.text(x=0.048-0.03+reglab_adj, y=ax_t-(regnum+1)*(ax_h+ax_hpad)+ax_h/2, s=reg_labels_split[region_str], rotation=90, ha='center', va='center', fontsize=fs+2, bbox={'fc':'w', 'boxstyle':'round,pad=0.15'}, zorder=1000, fontweight='bold')
        elif p_num == 1:
            outer_fig.text(x=0.0473+reglab_adj, y=ax_t-(regnum+1)*(ax_h+ax_hpad)+ax_h/2, s=reg_labels_split_padded[region_str], rotation=90, ha='center', va='center', fontsize=fs+4, bbox={'fc':'w', 'boxstyle':'round,pad=0.09'}, zorder=1000, fontweight='bold')
        elif p_num == 2:
            outer_fig.text(x=0.0485+reglab_adj, y=ax_t-(regnum+1)*(ax_h+ax_hpad)+ax_h/2, s=reg_labels_split[region_str], rotation=90, ha='center', va='center', fontsize=fs+2, bbox={'fc':'w', 'boxstyle':'round,pad=0.15'}, zorder=1000, fontweight='bold')
        plot_model_skill_compare(outer_fig, axes[regnum,:], regnum, cbar_ax)
    
    fignum = ["06", "07", "08"]
    
    plt.savefig("../figures_NCEP/SupFig_"+fignum[p_num]+"_NCEP_model_skill.png", transparent=False, dpi=800)

# Supplemental Figure 9: NCEP SM-T PDPs

In [None]:
outer_fig, [outer_axes, cbar_ax] = plt.subplots(1,2,
                                     gridspec_kw={'left':0, 'bottom':0, 'width_ratios': [0.99, 0.01]}, 
                                     subplot_kw={'frame_on':False, 'xticks': [], 'yticks':[]},
                                     figsize=(20, 15), frameon=True)
cbar_ax.remove()

outer_fig.patch.set_facecolor('white')
outer_fig.text(x=0.5, y=0.996, s="Summertime SM-T Partial Dependence Plots (NCEP)",
               fontsize=fs+8, fontweight='bold', ha='center', va='top')
outer_fig.text(x=0.485, y=0.004, s="local SM anomaly (S.D.)",
               fontsize=fs+8, fontweight='bold', ha='center', va='bottom')
outer_fig.text(x=0.004, y=0.5, s="TMAX - TMAX(SM=0) (\u00B0C)", rotation=90, ha='left', va='center',
               fontsize=fs+8, fontweight='bold')
outer_gridspec = outer_axes.get_subplotspec()
outer_subfig = outer_fig.add_subfigure(outer_gridspec, frameon=True)
axes = outer_subfig.subplots(4,4, gridspec_kw={'left':0.050, 'bottom':0.046, 'right':1.0, 'top':0.954, 'hspace':0.22, 'wspace':0.04}, 
                             subplot_kw={'frame_on':True, 'xticks': [], 'yticks':[]})
ax = axes.flatten()

x_min = min([min(x_vec[region_str][str(nlag)]) for region_str in subplot_region_order])
x_max = max([max(x_vec[region_str][str(nlag)]) for region_str in subplot_region_order])

pdp_5th = {}
pdp_95th = {}
pdp_mean = {}

cmap_list = ['bone', 'Greys', 'copper', 'pink',
             'twilight_shifted', 'twilight', 'summer', 'Greys_r',
             'viridis', 'plasma', 'inferno', 'cividis', 
             'viridis_r', 'plasma_r', 'inferno_r', 'cividis_r']

for regnum, region_str in enumerate(subplot_region_order):
    pdp_cmap = 'bone' 
    hem = subplot_hem_order[regnum]
    if hem == 'north':
        ndays = 92
    elif hem == 'south':
        ndays = 90
        
    unseen_bool = x_vec_unseen_bool[region_str][str(nlag)]
    train_bool = x_vec_train_bool[region_str][str(nlag)]
        
    predictions = np.ones((len(years[region_str]), ndays, len(x_vec[region_str][str(nlag)])))
    for kk_yr, yr in enumerate(years[region_str]):
            if hem == 'north':
                predictions[kk_yr,:,:] = np.load("../processed_data_NCEP/"+region_str+"/JJA_pdp/"+region_str+"_centered_predictions_pdp_data_JJA_"+str(yr)+"_w"+str(w)+"_nlag"+str(nlag)+".npy")
            elif hem == 'south':
                predictions[kk_yr,:,:] = np.load("../processed_data_NCEP/"+region_str+"/JJA_pdp/"+region_str+"_centered_predictions_pdp_data_JJA_"+str(yr)+"_w"+str(w)+"_nlag"+str(nlag)+".npy")[:90,:]

    predictions_grid = predictions.reshape(len(years[region_str])*ndays, len(x_vec[region_str][str(nlag)]))
    x_vec_grid = np.tile(x_vec[region_str][str(nlag)], [len(years[region_str])*ndays, 1])
    
    x_vec_grid = x_vec_grid[:,unseen_bool]
    predictions_grid = predictions_grid[:,unseen_bool]
    
    ## center TMAX predictions at TMAX(SM=0) ##
    origin_idx = np.abs(x_vec[region_str][str(nlag)][unseen_bool]).argmin()
    win_size = 100
    for ll in range(len(predictions_grid[:,0])):
        tmax_shift = np.mean(predictions_grid[ll,origin_idx-win_size:origin_idx+win_size])
        predictions_grid[ll,:] = predictions_grid[ll,:] - tmax_shift
        
    long_predictions = predictions_grid.flatten()
    long_x_vec = x_vec_grid.flatten()
    

    ax[regnum].scatter(long_x_vec, long_predictions,
                       marker='.', color = 'k', s = 3, linewidth=0.25, zorder = 0, alpha=0.01)

    pdp_mean[region_str] = calc_full_rolling_mean(np.mean(predictions_grid, axis=0))
    pdp_5th[region_str] = calc_full_rolling_mean(np.percentile(predictions_grid, q=5, axis=0))
    pdp_95th[region_str] = calc_full_rolling_mean(np.percentile(predictions_grid, q=95, axis=0))
    
    extreme_idx = np.floor(0.01*len(pdp_mean[region_str])).astype(int)
    
    ax[regnum].scatter(x_vec[region_str][str(nlag)][unseen_bool][clip:-clip], pdp_mean[region_str][clip:-clip], color='r', zorder=500, s=10, linewidth=0, label='_mean')
    ax[regnum].scatter(x_vec[region_str][str(nlag)][unseen_bool][clip:-clip], pdp_5th[region_str][clip:-clip], color='r', zorder=500, s=1, linewidth=0, label='_5th / 95th')
    ax[regnum].scatter(x_vec[region_str][str(nlag)][unseen_bool][clip:-clip], pdp_95th[region_str][clip:-clip], color='r', zorder=500, s=1, linewidth=0, label='_95th percentile')
    ax[regnum].scatter(100,100, color='k', label='test samples', zorder=2000, s=50)
    ax[regnum].scatter(0,0, color='gray', label='TMAX(SM=0)', zorder=2000, s=100)
    ax[regnum].plot([100, 101], [100, 101], color='r', zorder=500, linewidth=4, label='mean')
    ax[regnum].plot([100, 101], [100, 101], color='r', zorder=500, linewidth=0.5, label='5th / 95th')    
    
    
    ax[regnum].hlines(y=0, xmin=-5, xmax=5, color='k', linewidth=0.5, zorder=490, label='_y=0')
    ax[regnum].vlines(x=0, ymin=-10, ymax=10, color='k', linewidth=0.5, zorder=490, label='_x=0')
    
    if regnum == 3:
        leg = ax[regnum].legend(fontsize=fs+2, loc='upper right', 
                          markerscale=0.9, borderpad=0.2, labelspacing=0.25, 
                          handlelength=0.75, handleheight=0.9, handletextpad=0.30, 
                          borderaxespad=0.20, columnspacing=0.50, ncol=2, framealpha=1.0)
        leg.get_frame().set_facecolor('w')
        leg.get_frame().set_fill(True)
        leg.set_zorder(10000)
        ax[regnum].annotate('training samples', (0.51, 0.120), 
                        xycoords='axes fraction', fontsize=fs+2.5, ha='left', va='center', color=sns.color_palette()[0])
        

    ax[regnum].set_title(reg_labels[region_str], fontsize=fs+5)
    ax[regnum].set_xticks([-3, -2, -1, 0, 1, 2, 3])
    
    ax[regnum].set_xlim(min(x_vec[region_str][str(nlag)][unseen_bool][clip:-clip]), max(x_vec[region_str][str(nlag)][unseen_bool][clip:-clip]))
    ax[regnum].set_ylim(-6.3,4.1)
    
    sns.rugplot(data=x_vec[region_str][str(nlag)][train_bool], ax=ax[regnum], lw=1, alpha=.05, height=0.075, zorder=-1000)
        
    major = 2
    minor = 1
    xmajor=1
    xminor=0.5
    tick_fs = fs-1
    if regnum in [0, 4, 8, 12]:    
        ax[regnum].yaxis.set_major_locator(MultipleLocator(major)) 
        ax[regnum].yaxis.set_major_formatter(FormatStrFormatter('%d'))
        ax[regnum].yaxis.set_minor_locator(MultipleLocator(minor))
    else:
        ax[regnum].yaxis.set_major_locator(MultipleLocator(major)) 
        ax[regnum].yaxis.set_major_formatter(FormatStrFormatter(''))
        ax[regnum].yaxis.set_minor_locator(MultipleLocator(minor))
    
    ax[regnum].xaxis.set_major_locator(MultipleLocator(xmajor))
    ax[regnum].xaxis.set_major_formatter(FormatStrFormatter('%d'))
    ax[regnum].xaxis.set_minor_locator(MultipleLocator(xminor))
    ax[regnum].xaxis.set_minor_formatter(FormatStrFormatter(''))
    ax[regnum].tick_params(axis='both', which='major', labelsize=tick_fs, length=10, width=2, direction='inout')
    ax[regnum].tick_params(axis='both', which='minor', labelsize=tick_fs, length=8, width=1.5, direction='inout')
    
    ax[regnum].annotate(u'RANGE = {:.2f}\u00B0C'.format(pdp_mean[region_str][clip:-clip].max()-pdp_mean[region_str][clip:-clip].min()), (0.03, 0.15), 
                        xycoords='axes fraction', fontsize=fs+1, ha='left', va='center', bbox=dict(boxstyle="round", pad=0.2, fc="white", ec="k", lw=0.75, alpha=0.9), zorder=5000)
    
plt.savefig("../figures_NCEP/SupFig_09_NCEP_PDPs.png", transparent=False, dpi=800)

# Supplemental Figure 10: 100 baseline PDPs v 1 true PDP

In [None]:
outer_fig, outer_axes = plt.subplots(1,1,
                                     gridspec_kw={'left':0, 'bottom':0, 'right':0.99}, 
                                     subplot_kw={'frame_on':False, 'xticks': [], 'yticks':[]},
                                     figsize=(20, 15), frameon=True)
outer_fig.patch.set_facecolor('white')
outer_fig.text(x=0.5, y=0.996, s="Summertime SM-T Partial Dependence Plots (NCEP)",
               fontsize=fs+8, fontweight='bold', ha='center', va='top')
outer_fig.text(x=0.5, y=0.004, s="local SM anomaly (S.D.)",
               fontsize=fs+8, fontweight='bold', ha='center', va='bottom')
outer_fig.text(x=0.006, y=0.5, s="TMAX - TMAX(SM=0) (\u00B0C)", rotation=90, ha='left', va='center',
               fontsize=fs+8, fontweight='bold')
outer_gridspec = outer_axes.get_subplotspec()
outer_subfig = outer_fig.add_subfigure(outer_gridspec, frameon=True)
axes = outer_subfig.subplots(4,4, gridspec_kw={'left':0.050, 'bottom':0.046, 'right':0.99, 'top':0.954, 'hspace':0.22, 'wspace':0.04}, 
                             subplot_kw={'frame_on':True, 'xticks': [], 'yticks':[]})
ax = axes.flatten()

shuff_lag = 0
shuff_pdp_mean = {}
pdp_mean = {}
pdp_95th = {}
pdp_5th = {}
null_99th = {}
null_1st = {}
max_null = {}
min_null = {}

true_c = 'r'
null_c = 'k'
true_lw = 0 
null_lw = 1.5
true_alph = 1.0
null_alph = 0.1


for regnum, region_str in enumerate(subplot_region_order):
    unseen_bool_shuff = x_vec_unseen_bool[region_str][str(shuff_lag)]
    unseen_bool = x_vec_unseen_bool[region_str][str(nlag)]
    max_cpl =  0
    print(region_str)
    hem = subplot_hem_order[regnum]

    if hem == 'north':
        ndays = 92
    elif hem == 'south':
        ndays = 90
      
    predictions = np.ones((len(years[region_str]), ndays, len(x_vec[region_str][str(nlag)])))
    for kk_yr, yr in enumerate(years[region_str]):
            if hem == 'north':
                predictions[kk_yr,:,:] = np.load("../processed_data_NCEP/"+region_str+"/JJA_pdp/"+region_str+"_centered_predictions_pdp_data_JJA_"+str(yr)+"_w"+str(w)+"_nlag"+str(nlag)+".npy")
            elif hem == 'south':
                predictions[kk_yr,:,:] = np.load("../processed_data_NCEP/"+region_str+"/JJA_pdp/"+region_str+"_centered_predictions_pdp_data_JJA_"+str(yr)+"_w"+str(w)+"_nlag"+str(nlag)+".npy")[:90,:]

    predictions_grid = predictions.reshape(len(years[region_str])*ndays, len(x_vec[region_str][str(nlag)]))
    x_vec_grid = np.tile(x_vec[region_str][str(nlag)], [len(years[region_str])*ndays, 1])
    
    predictions_grid = predictions_grid[:,unseen_bool]
    x_vec_grid = x_vec_grid[:,unseen_bool]
    
    ## center TMAX predictions at TMAX(SM=0) ##
    origin_idx = np.abs(x_vec[region_str][str(nlag)][unseen_bool]).argmin()
    win_size = 100
    for ll in range(len(predictions_grid[:,0])):
        tmax_shift = np.mean(predictions_grid[ll,origin_idx-win_size:origin_idx+win_size])
        predictions_grid[ll,:] = predictions_grid[ll,:] - tmax_shift
            
    long_predictions = predictions_grid.flatten()
    long_x_vec = x_vec_grid.flatten()

    pdp_mean[region_str] = calc_full_rolling_mean(np.mean(predictions_grid, axis=0)) 
    pdp_5th[region_str] = calc_full_rolling_mean(np.percentile(predictions_grid, q=5, axis=0))
    pdp_95th[region_str] = calc_full_rolling_mean(np.percentile(predictions_grid, q=95, axis=0))
                
    ax[regnum].set_ylim(-6,6)
    
    for rep in range(1, 101):
        pred_temps = pd.read_csv("../processed_data_NCEP/"+region_str+"/SM_shuff_"+str(rep).zfill(2)+"_model_predictions"+"_lag0.csv")
        pred_temps['date'] = pd.to_datetime(pred_temps['date'])
        shuff_years = pred_temps[pred_temps.set == "unseen"].date.dt.year.unique()
        if hem == 'south':
            shuff_years = np.delete(shuff_years, np.argwhere(shuff_years==1979))
        if exists("../processed_data_NCEP/"+region_str+"/SM_shuff_JJA_pdp/"+region_str+"_"+str(rep).zfill(2)+"_centered_predictions_pdp_data_JJA_"+str(shuff_years[0])+"_w"+str(w)+"_nlag"+str(shuff_lag)+".npy"):
            shuff_predictions = np.ones((len(shuff_years), ndays, len(x_vec[region_str][str(shuff_lag)])))   
            for kk_yr, yr in enumerate(shuff_years):
                    if hem == 'north':
                        shuff_predictions[kk_yr,:,:] = np.load("../processed_data_NCEP/"+region_str+"/SM_shuff_JJA_pdp/"+region_str+"_"+str(rep).zfill(2)+"_centered_predictions_pdp_data_JJA_"+str(yr)+"_w"+str(w)+"_nlag"+str(shuff_lag)+".npy")
                    elif hem == 'south':
                        shuff_predictions[kk_yr,:,:] = np.load("../processed_data_NCEP/"+region_str+"/SM_shuff_JJA_pdp/"+region_str+"_"+str(rep).zfill(2)+"_centered_predictions_pdp_data_JJA_"+str(yr)+"_w"+str(w)+"_nlag"+str(shuff_lag)+".npy")[:90,:]

            shuff_predictions_grid = shuff_predictions.reshape(len(shuff_years)*ndays, len(x_vec[region_str][str(shuff_lag)]))
            shuff_x_vec_grid = np.tile(x_vec[region_str][str(shuff_lag)], [len(shuff_years)*ndays, 1])

            shuff_predictions_grid = shuff_predictions_grid[:,unseen_bool_shuff]
            shuff_x_vec_grid = shuff_x_vec_grid[:,unseen_bool_shuff]
            
            ## center TMAX predictions at TMAX(SM=0) ##
            origin_idx = np.abs(x_vec[region_str][str(shuff_lag)][unseen_bool_shuff]).argmin()
            win_size = 100
            for ll in range(len(shuff_predictions_grid[:,0])):
                tmax_shift = np.mean(shuff_predictions_grid[ll,origin_idx-win_size:origin_idx+win_size])
                shuff_predictions_grid[ll,:] = shuff_predictions_grid[ll,:] - tmax_shift
            
            shuff_long_predictions = shuff_predictions_grid.flatten()
            shuff_long_x_vec = shuff_x_vec_grid.flatten()

            shuff_pdp_mean[region_str] = calc_full_rolling_mean(np.mean(shuff_predictions_grid, axis=0))
            
            curr_cpl = np.max(shuff_pdp_mean[region_str][clip:-clip])-np.min(shuff_pdp_mean[region_str][clip:-clip])
            if curr_cpl > max_cpl:
                max_rep = rep
                max_cpl = curr_cpl
                print(max_rep, max_cpl)
            
            if rep == 1:
                max_null[region_str] = 0*shuff_pdp_mean[region_str][clip:-clip]
                min_null[region_str] = 0*shuff_pdp_mean[region_str][clip:-clip]
            max_null[region_str] = np.maximum(max_null[region_str], shuff_pdp_mean[region_str][clip:-clip])
            min_null[region_str] = np.minimum(min_null[region_str], shuff_pdp_mean[region_str][clip:-clip])

            if hem == 'south':
                x_vec1 = x_vec[region_str][str(shuff_lag)]
            elif hem == 'north':
                x_vec1 = x_vec[region_str][str(shuff_lag)]
                
            ax[regnum].plot(x_vec1[unseen_bool_shuff][clip:-clip], shuff_pdp_mean[region_str][clip:-clip], color=null_c, alpha=null_alph, zorder=400, linewidth=null_lw, label='_null')
    
    ax[regnum].scatter(x_vec[region_str][str(nlag)][unseen_bool][clip:-clip], pdp_mean[region_str][clip:-clip], color='r', zorder=500, s=18, linewidth=true_lw, label='_true')
    ax[regnum].scatter(x_vec[region_str][str(nlag)][unseen_bool][clip:-clip], pdp_5th[region_str][clip:-clip], color='r', zorder=500, s=1, linewidth=0, label='_5th / 95th')
    ax[regnum].scatter(x_vec[region_str][str(nlag)][unseen_bool][clip:-clip], pdp_95th[region_str][clip:-clip], color='r', zorder=500, s=1, linewidth=0, label='_95th percentile')
    
    ax[regnum].hlines(y=0, xmin=-5, xmax=5, color='k', linewidth=0.5, zorder=450, label='_y=0')
    ax[regnum].vlines(x=0, ymin=-10, ymax=10, color='k', linewidth=0.5, zorder=450, label='_x=0')
    
    ax[regnum].set_title(reg_labels[region_str], fontsize=fs+5)
    ax[regnum].set_xticks([-3, -2, -1, 0, 1, 2, 3])
    
    ax[regnum].set_xlim(min(x_vec[region_str][str(shuff_lag)][unseen_bool_shuff][clip:-clip]), max(x_vec[region_str][str(shuff_lag)][unseen_bool_shuff][clip:-clip]))
    ax[regnum].set_ylim(-5.9,2.9)
        
    major = 2
    minor = 1
    xmajor=1
    xminor=0.5
    if regnum in [0, 4, 8, 12]:    
        ax[regnum].yaxis.set_major_locator(MultipleLocator(major)) 
        ax[regnum].yaxis.set_major_formatter(FormatStrFormatter('%d'))
        ax[regnum].yaxis.set_minor_locator(MultipleLocator(minor))
        if regnum == 0:
            ax[regnum].plot([100, 101], [100, 101], color='r', zorder=500, linewidth=4, label='true')
            ax[regnum].plot([100, 101], [100, 101], color='r', zorder=500, linewidth=0.75, label='5th / 95th')
            ax[regnum].plot([0,0.01], [0,0], color='gray', alpha = 0.8, zorder=-100, linewidth=3, label='baseline') 
            leg = ax[regnum].legend(fontsize=fs+5, loc='lower left', 
                              markerscale=0.9, 
                                borderpad=0.4, labelspacing=0.25, 
                                    handlelength=0.75, handleheight=0.9,
                                    handletextpad=0.45, 
                              borderaxespad=0.50, columnspacing=0.75, facecolor='gainsboro')
            leg.get_frame().set_facecolor('w')
            leg.get_frame().set_fill(True)
            leg.set_zorder(10000)
    else:
        ax[regnum].yaxis.set_major_locator(MultipleLocator(major)) 
        ax[regnum].yaxis.set_major_formatter(FormatStrFormatter(''))
        ax[regnum].yaxis.set_minor_locator(MultipleLocator(minor))
    
    ax[regnum].xaxis.set_major_locator(MultipleLocator(xmajor))
    ax[regnum].xaxis.set_major_formatter(FormatStrFormatter('%d'))
    ax[regnum].xaxis.set_minor_locator(MultipleLocator(xminor))
    ax[regnum].xaxis.set_minor_formatter(FormatStrFormatter(''))
    ax[regnum].tick_params(axis='both', which='major', labelsize=tick_fs+3, length=10, width=2, direction='inout')
    ax[regnum].tick_params(axis='both', which='minor', labelsize=tick_fs+3, length=10, width=1, direction='inout')
    
plt.savefig("../figures_NCEP/SupFig_10_NCEP_PDPs_baseline.png", transparent=False, dpi=800)

# Supplemental Figure 11: NCEP lag sensitivity

In [None]:
shuff_lag = 0
plt.rcParams['axes.linewidth'] = 4
outer_fig, outer_axes = plt.subplots(1,1,constrained_layout=True,
                                     gridspec_kw={'left':0.1, 'bottom':0.1}, 
                                     subplot_kw={'frame_on':False, 'xticks': [], 'yticks':[]},
                                     figsize=(20, 15), frameon=True)
outer_fig.patch.set_facecolor('white')
outer_axes.set_title("Sensitivity of SM-T Partial Dependence Plots to SM input lag (NCEP)", fontsize=fs+8, fontweight='bold')
outer_axes.set_xlabel("local SM anomaly (S.D.)", fontsize=fs+8, fontweight='bold')
outer_axes.set_ylabel("TMAX - TMAX(SM=0) (\u00B0C)", fontsize=fs+8, fontweight='bold')
outer_gridspec = outer_axes.get_subplotspec()
outer_subfig = outer_fig.add_subfigure(outer_gridspec, frameon=True)
axes = outer_subfig.subplots(4,4, gridspec_kw={'hspace':0, 'wspace':0}, 
                             subplot_kw={'frame_on':True, 'xticks': [], 'yticks':[]})
ax = axes.flatten()

for regnum, region_str in enumerate(subplot_region_order):
    hem = subplot_hem_order[regnum]
    if hem == 'north':
        ndays = 92
    elif hem == 'south':
        ndays = 90
        
    for nn, curr_lag in enumerate(nlag_list):
        unseen_bool = x_vec_unseen_bool[region_str][str(curr_lag)]
        unseen_bool_lag0 = x_vec_unseen_bool[region_str][str(0)]
        
        pred_temps = pd.read_csv("../processed_data_NCEP/"+region_str+"/model_predictions"+"_lag"+str(curr_lag)+".csv")
        pred_temps['date'] = pd.to_datetime(pred_temps['date'])
        test_years = pred_temps[pred_temps.set == "unseen"].date.dt.year.unique()
        if hem == 'south':
            test_years = np.delete(test_years, np.argwhere(test_years==1979))
        
        lag_predictions = np.ones((len(test_years), ndays, len(x_vec[region_str][str(curr_lag)])))
        for kk_yr, yr in enumerate(test_years):
                if hem == 'north':
                    lag_predictions[kk_yr,:,:] = np.load("../processed_data_NCEP/"+region_str+"/JJA_pdp/"+region_str+"_centered_predictions_pdp_data_JJA_"+str(yr)+"_w"+str(w)+"_nlag"+str(curr_lag)+".npy")
                elif hem == 'south':
                    lag_predictions[kk_yr,:,:] = np.load("../processed_data_NCEP/"+region_str+"/JJA_pdp/"+region_str+"_centered_predictions_pdp_data_JJA_"+str(yr)+"_w"+str(w)+"_nlag"+str(curr_lag)+".npy")[:ndays,:]

        lag_predictions_grid = lag_predictions.reshape(len(test_years)*ndays, len(x_vec[region_str][str(curr_lag)]))

        lag_predictions_grid = lag_predictions_grid[:,unseen_bool]
        
        ## center TMAX predictions at TMAX(SM=0) ##
        origin_idx = np.abs(x_vec[region_str][str(curr_lag)][unseen_bool]).argmin()
        win_size = 100
        for ll in range(len(lag_predictions_grid[:,0])):
            tmax_shift = np.mean(lag_predictions_grid[ll,origin_idx-win_size:origin_idx+win_size])
            lag_predictions_grid[ll,:] = lag_predictions_grid[ll,:] - tmax_shift
        
        lag_pdp_mean = calc_full_rolling_mean(np.mean(lag_predictions_grid, axis=0))
            
        if curr_lag == 0:
            y_min, y_max = min(lag_pdp_mean), max(lag_pdp_mean)
            
        
        ax[regnum].plot(x_vec[region_str][str(curr_lag)][unseen_bool][clip:-clip], lag_pdp_mean[clip:-clip], color=cm.cool(1-nn/len(nlag_list)), zorder=50-nn, alpha=0.75, linewidth=3, label='_lag='+str(curr_lag))
            
        ax[regnum].scatter(-100, 0, color=cm.cool(1-nn/len(nlag_list)), zorder=-100, s=60, label='lag='+str(curr_lag))

    if hem == 'south':
        x_vec1 = x_vec[region_str][str(shuff_lag)]
    elif hem == 'north':
        x_vec1 = x_vec[region_str][str(shuff_lag)]

    ax[regnum].fill_between(x=x_vec1[unseen_bool_lag0][clip:-clip], y1=max_null[region_str], y2=min_null[region_str], color='k', hatch='///', alpha=0.25, zorder=1500, linewidth=0, label='_baseline')
    ax[regnum].scatter(100, 0, s=250, marker='s', lw=0, color='k', alpha=0.25, hatch='///', label='baseline')
        
    ax[regnum].hlines(y=0, xmin=-5, xmax=5, color='k', linewidth=0.5, zorder=2000, label='_y=0')
    ax[regnum].vlines(x=0, ymin=-10, ymax=10, color='k', linewidth=0.5, zorder=2000, label='_x=0')
    
    ax[regnum].set_title(reg_labels[region_str], fontsize=fs+5)
    ax[regnum].set_xticks([-3, -2, -1, 0, 1, 2, 3])
    
    ax[regnum].set_xlim(min(x_vec[region_str][str(0)][unseen_bool_lag0][clip:-clip]), max(x_vec[region_str][str(0)][unseen_bool_lag0][clip:-clip]))
    
    xmajor = 1
    xminor = 0.5
    ymajor = 2
    yminor = 1
   
    ax[regnum].set_ylim(-6.3,3.4)
    
    ax[regnum].yaxis.set_major_locator(MultipleLocator(ymajor)) 
    ax[regnum].yaxis.set_minor_locator(MultipleLocator(yminor)) 
    ax[regnum].xaxis.set_major_locator(MultipleLocator(xmajor))
    ax[regnum].xaxis.set_minor_locator(MultipleLocator(xminor)) 
    ax[regnum].yaxis.set_major_formatter(FormatStrFormatter('%d'))
    ax[regnum].xaxis.set_major_formatter(FormatStrFormatter('%d'))
    ax[regnum].tick_params(axis='both', which='major', labelsize=tick_fs+3, length=8, width=2, direction='inout')
    ax[regnum].tick_params(axis='both', which='minor', labelsize=tick_fs+3, length=6, width=1, direction='inout')
    
    if regnum in [0, 4, 8, 12]:    
        if regnum == 0:
            leg = ax[0].legend(fontsize=fs+3, loc='lower center', ncol=2, bbox_to_anchor=(0.36, 0.00), 
                         markerscale=1.0, borderpad=0.3, labelspacing=0.20, 
                         handletextpad=-0.05, borderaxespad=0.4, columnspacing=0.70, framealpha=1.0)
            leg.get_frame().set_facecolor('w')
            leg.get_frame().set_fill(True)
            leg.set_zorder(10000)
            for line in leg.get_lines():
                line.set_linewidth(3.0)
        ax[regnum].yaxis.set_major_locator(MultipleLocator(ymajor)) 
        ax[regnum].yaxis.set_major_formatter(FormatStrFormatter('%d'))
        ax[regnum].yaxis.set_minor_locator(MultipleLocator(yminor))
    else:
        ax[regnum].yaxis.set_major_locator(MultipleLocator(ymajor)) 
        ax[regnum].yaxis.set_major_formatter(FormatStrFormatter(''))
        ax[regnum].yaxis.set_minor_locator(MultipleLocator(yminor))

plt.savefig("../figures_NCEP/SupFig_11_NCEP_PDPs_lag.png", transparent=False, dpi=800)