In [1]:
%load_ext autoreload
%autoreload 2
%reload_ext autoreload

import xarray as xr
import numpy as np
import os
# import pandas as pd
from glob import glob
import functions as f
#import climpredNEW.climpred 
#from climpredNEW.climpred.options import OPTIONS
from climpred.options import OPTIONS
import climpred
import pickle
from mpl_toolkits.basemap import Basemap
from numpy import meshgrid
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
import matplotlib.colors as mcolors
import cartopy.feature as cfeature
import itertools
import cartopy.crs as ccrs
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter, LatitudeLocator
import matplotlib.ticker as mticker
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, TwoSlopeNorm
from scipy.stats import rankdata
import bottleneck as bn
import scipy.stats as stats
import pickle
import masks
import verifications
import preprocessUtils as putils


2024-05-22 06:42:02.482835: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-05-22 06:42:02.626058: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  from .autonotebook import tqdm as notebook_tqdm


# Make ACC values for the later plots for permutation tests. 
## These are a bit further down in the script

In [2]:
global obs_original
obs_original = xr.open_dataset(f'Data/GLEAM/RZSM_anomaly.nc').rename({'SMsurf':'RZSM'}).drop('season').load()

In [3]:
#Set script parameters

region_name = 'CONUS' #or ['australia', 'CONUS', 'china']

mask=masks.load_mask(region_name)

start_obs = '2000-01-01' #Beginning of observation period for analysis. We actually have data starting from 1999 so that we could have a 7-day rolling mean applied to the data and have up to 12 weeks lags for RZSM
end_obs = '2020-12-31' #end of observations for ERA5 and GLEAM. We actually needed data through 2020-02-15 since we have an initialization on 2019-12-25
start_testing = '2018-01-01' #Beginning of testing period
end_testing = '2019-12-31'
train_end_string = '2015-12-30' #last string date for training
train_end = 2015 #last year of training dates


global RZSM_or_Tmax_or_both
RZSM_or_Tmax_or_both = 'RZSM' # for getting the predictor from either RZSM and Tmax ('both') or only RZSM ('RZSM')

lead_select = [6,13,20,27,34]

def select_data_by_lead(obs_anomaly_SubX_format, baseline_anomaly, baseline_ecmwf, template_testing_only, day_num):
    obs_lead = expand_dims_by_lead(obs_anomaly_SubX_format.sel(L=day_num))
    baseline_gefs =  expand_dims_by_lead(baseline_anomaly.sel(L=day_num))
    baseline_ecm =  expand_dims_by_lead(baseline_ecmwf.sel(L=day_num))
    template =  expand_dims_by_lead(template_testing_only.sel(L=day_num))
    return(obs_lead,baseline_gefs,baseline_ecm,template)

In [4]:
def print_min_max(file,name):
    print(name)
    print(f'Maximum value in file is {np.nanmax(file[putils.xarray_varname(file)])}')
    print(f'Minimum value in file is {np.nanmin(file[putils.xarray_varname(file)])}')

In [5]:
global verification_var
verification_var = 'soilw_bgrnd_GLEAM' #this is for what we are verifying with DL outputs

#Gleam observations
gleam_dir = 'Data/GLEAM'

#Forecast predictions
gefsv12_fcst_dir = 'Data/GEFSv12_reforecast'

#ERA5 observations
era5_dir = 'Data/ERA5'

if region_name != 'CONUS':
    #Gleam observations
    gleam_dir = f'Data_{region_name}/GLEAM'
    
    #Forecast predictions
    gefsv12_fcst_dir = f'Data_{region_name}/GEFSv12_reforecast'
    
    #ERA5 observations
    era5_dir = f'Data_{region_name}/ERA5'

In [6]:
#Load observation anomaly
print(f'Loading original observations and reforecasts (no anomalies). Selecting dates for observations between {start_obs} and {end_obs}. Applying a 7-day rolling mean to be consistent.')

gleam_anom = verifications.load_RZSM_anomaly_obs(region_name).load()
ecmwf_anom = verifications.load_ECMWF_baseline_anomaly(region_name).load()
gefs_anom = verifications.load_GEFSv12_baseline_anomaly(region_name).load()

'''Load a base file to serve as the xarray template to add our predictions from UNET into.'''
base_file_testing = gefs_anom.copy(deep=True).sel(S=slice(start_testing, None)).load()

global lat, lon
lat = base_file_testing.Y.values  # for plotting later
lon = base_file_testing.X.values


Loading original observations and reforecasts (no anomalies). Selecting dates for observations between 2000-01-01 and 2020-12-31. Applying a 7-day rolling mean to be consistent.


# Now create baseline anomalies of files to not have to re-compute later

In [7]:
#First
#Create seasonal anomaly from observations
print(f'Creating the seasonal anomalies for observational data and then subsetting for everything after {start_testing} date.')
save_anomaly_dir = f'Data/anomaly/{region_name}'
os.system(f'mkdir -p {save_anomaly_dir}')

obs_RZSM_save = f'{save_anomaly_dir}/obs_RZSM_anomaly_testing.nc'
ref_RZSM_save = f'{save_anomaly_dir}/reforecast_RZSM_anomaly_testing.nc'


Creating the seasonal anomalies for observational data and then subsetting for everything after 2018-01-01 date.


# Bias correction (additive). Climpred won't work despite my data being in the exact same format. Let's just make our own functions

In [8]:
def open_file_create_seasonal_anomaly(path,train_end):
    #Must subset by lead first (because we actually have data previously from past lag weeks)
    return(create_seasonal_anomaly(xr.open_mfdataset(path).sel(L=slice(0,34)).rolling(L=7, min_periods=7,center=False).mean(),train_end=train_end,source='reforecast'))

def check_values_in_file(file,lead):
    '''Just print some values to verify the files aren't identical when comparing with other results'''
    name_file = list(file.keys())[0]
    return(print(file[name_file].isel(L=lead).isel(M=10).isel(S=0).values))
    

In [9]:

def load_experiment_predictions_and_observations(lead,experiment,region_name):
    # #Test
    # experiment='EX0'
    day_num = (lead*7)-1
    min_max_dir = f'Data/min_max_values/{region_name}'
    verification_directory = f'Data/model_npy_inputs/{region_name}/Verification_data' #For observation verification
    bias_correction_dir = f'Data/bias_mean_values/Wk_{lead}'

    ex_name = experiment


    #Load the actual observations (used for the Mean Absolute Error calculation)
    obs_final_train,obs_final_validation,obs_final_testing = f.load_verification_observations_updated(lead,verification_directory)
    obs_RZSM = np.array(obs_final_testing) #anomaly

    #Convert observations 0 values to nan (only for the RZSM observations). These values had a zero where there is no land soil moisture
    obs_RZSM = np.where(obs_RZSM == 0,np.nan,obs_RZSM)
    
    obs_RZSM =verifications.reverse_min_max_scaling(obs_RZSM,region_name,day_num,'GLEAM',2019)

        
    predictions_directory = f'predictions/{region_name}/Wk{lead}_testing'
    try:
        prediction_ = np.load(f'{predictions_directory}/Wk{lead}_testing_{ex_name}_RZSM.npy')
        cont = True
    except FileNotFoundError:
        try:
            prediction_ = np.load(f'{predictions_directory}/Wk{lead}_testing_{ex_name}_regular_RZSM.npy')
            cont = True
        except FileNotFoundError:
            cont = False

    if cont:
        print(f'Test prediction shape: {prediction_.shape}')
        prediction_RZSM = verifications.reverse_min_max_scaling(prediction_[-1,:,:,:],region_name, day_num, 'GEFSv12',2019)
        print(f'Shape of prediction RZSM: {prediction_RZSM.shape}')
    
        #Convert back to np.nan values for the ocean and other water bodies
        prediction_RZSM = np.where(np.isnan(obs_RZSM),np.nan,prediction_RZSM.squeeze())
    
        return(prediction_RZSM, obs_RZSM)
    else:
        return(np.zeros(shape=obs_RZSM.shape), obs_RZSM)


In [10]:
def setup_binary_for_hit_rate_with_ensemble_mean(week_lead, region_name, test_start, test_end, unet_file):

    #Test 
    # week_lead=1
    # percentile_eval = 20
    # unet_file = 'EX14'

    #Save dir
    save_dir = f'Data/correct_anomaly_percentile_statistics/{region_name}'
    os.system(f'mkdir -p {save_dir}')

    save_ecmwf = f'{save_dir}/Wk{week_lead}_ecmwf_stats_TP_FP_ensemble_mean_t2.npy'
    save_gefs = f'{save_dir}/Wk{week_lead}_gefs_stats_TP_FP_ensemble_mean_t2.npy'
    save_xg = f'{save_dir}/Wk{week_lead}_xgboost_stats_TP_FP_ensemble_mean_t2.npy'
    save_obs_binary = f'{save_dir}/Wk{week_lead}_obs_stats_TP_FP_ensemble_mean_t2.npy'

    day_num = (week_lead*7) -1
        
    print('Loading observation and baseline anomaly files')
    obs, gefs, ecmwf, template_testing_only_by_lead= select_data_by_lead(gleam_anom, gefs_anom, ecmwf_anom,  base_file_testing, day_num)
    
    obs_percent = obs_anom_percentile.sel(L=day_num).sel(M=0)
    obs_percent['95th_percentile'].shape #(104, 48, 96)


    file = baseline_anomaly
    file.RZSM.shape
    out_check_gefs_base = np.zeros(shape=(104,11,48,96,8)) #Adding 8 channels for the different anomaly spreads
    
    out_check_gefs_base[:,:,:,:] = np.nan
    out_check_ecmwf_base = out_check_gefs_base.copy()
    out_check_unet = out_check_gefs_base.copy()
    
    
    obs_binary_out =np.zeros(shape=(104,48,96,8)) #Adding 8 channels for the different anomaly spreads

    final_perc_gefs = np.zeros(shape=(104,48,96,8)) #Adding 8 channels for the different anomaly spreads
    final_perc_gefs[:,:,:] = np.nan
    final_perc_ecmwf = final_perc_gefs.copy()
    final_perc_unet = final_perc_gefs.copy()
    
    

    test_name = unet_file.split('testing_')[-1].split('.npy')[0].split('ensemble_')[-1]
    
    save_unet = f'{save_dir}/Wk{week_lead}_unet_stats_{test_name}_TP_FP_ensemble_mean_t2.npy'
    
    test =  verifications.reverse_min_max_scaling(np.load(unet_file), region_name, day_num)[2,:,:,:,0] #We only want the last channel
    test = np.reshape(test,(test.shape[0]//11,11,test.shape[1],test.shape[2]))
    test.shape

    #Now mask the input
    test_unet = np.where(mask_anom == 1,test,np.nan)

    #XGBoost
    #Load the XGBoost data
    if region_name == 'CONUS':
        out_check_xg = out_check_gefs_base.copy()
        final_perc_xg = final_perc_gefs.copy()
        
        xgboost_files = sorted(glob(f'predictions_XGBOOST/{region_name}/Wk{week_lead}_testing/*EX28*'))[0]
    
        # break
        #Still working here
        test_name = xgboost_files.split('testing_')[-1].split('.npy')[0]
        load_ = np.expand_dims(np.load(xgboost_files),-1)
        load_.shape
        load_ = np.where(load_ == 0,np.nan,load_)
        load_ =  verifications.reverse_min_max_scaling(load_, region_name, day_num)#We only want the last channel
        
        xg = np.empty(shape=(load_.shape[0],11,load_.shape[1],load_.shape[2],load_.shape[3])) #This will help with climpred functions
        for j in range(11):
            xg[:,j,:,:,:] = load_
    
        xg = xg.squeeze()
        xg.shape
        x_vals = np.nanmean(xg,axis=1)
    else:
        x_vals = out_check_gefs_base.copy()
    #Check if the predicted anomaly is below each threshold


    
    #Test
    # idx = 0
    # mx = 0
    # ix = 10
    # iy =10 #NEGATIVE ANOMALY VALUE
    # iy =5 #POSITIVE ANOMALY VALUE

    #Use np.where to find the values of the percentile

    #Take ensemble mean
    o_vals = np.nanmean(obs.RZSM[:,:,0,:,:].values,axis=1)
    g_vals =  np.nanmean(gefs.RZSM[:,:,0,:,:].values,axis=1)
    e_vals =  np.nanmean(ecmwf.RZSM[:,:,0,:,:].values,axis=1)
    u_vals = np.nanmean(test_unet,axis=1)
    

    #Keep all models the same
    # o_vals = obs.RZSM[:,:,0,:,:].values
    # g_vals =  gefs.RZSM[:,:,0,:,:].values
    # e_vals =  ecmwf.RZSM[:,:,0,:,:].values
    # u_vals = test_unet
    # x_vals = xg

    u_vals.shape
    x_vals.shape

    def check_if_below_percentile_individual_distribution(obs_percent, percentile_num, o_vals, g_vals, e_vals, u_vals, x_vals):
        perc_= obs_percent[f'{percentile_num}th_percentile'].values
        
        def find_percentage(perc_,o_vals,fcst):
            fcst =  np.where(fcst<perc_,1,0)
            obs_binary = np.where((o_vals<perc_),1,0)
            '''Take the mean to find the probability of correct'''
            fcst.shape #(104, 48, 96)
            #Now mask the input of CONUS/region
            fcst = np.where(mask_anom == 1,fcst,np.nan)
            obs_binary = np.where(mask_anom == 1,obs_binary,np.nan)

            return(fcst,obs_binary)

        g_perc,obs_binary = find_percentage(perc_,o_vals,g_vals)
        g_perc.shape
        e_perc,obs_binary = find_percentage(perc_,o_vals,e_vals)
        u_perc,obs_binary  = find_percentage(perc_,o_vals,u_vals)
        x_perc,obs_binary  = find_percentage(perc_,o_vals,x_vals)
        return(g_perc,e_perc,u_perc,x_perc,obs_binary )

    for idx,percentile_num in enumerate([5,10,20,33]):
        final_perc_gefs[:,:,:,idx], final_perc_ecmwf[:,:,:,idx], final_perc_unet[:,:,:,idx],final_perc_xg[:,:,:,idx],obs_binary_out[:,:,:,idx] = check_if_below_percentile(obs_percent, percentile_num, o_vals, g_vals, e_vals, u_vals, x_vals)



    
    def check_if_above_percentile(obs_percent, percentile_num, o_vals, g_vals, e_vals, u_vals, x_vals):
        perc_= obs_percent[f'{percentile_num}th_percentile'].values
       
        def find_percentage(perc_,o_vals,fcst):
            fcst =  np.where(fcst>perc_,1,0)
            obs_binary = np.where((o_vals>perc_),1,0)
            '''Take the mean to find the probability of correct'''
            fcst.shape #(104, 48, 96)
            #Now mask the input of CONUS/region
            fcst = np.where(mask_anom == 1,fcst,np.nan)
            obs_binary = np.where(mask_anom == 1,obs_binary,np.nan)

            return(fcst,obs_binary)

        g_perc,obs_binary  = find_percentage(perc_,o_vals,g_vals)
        g_perc.shape
        e_perc,obs_binary  = find_percentage(perc_,o_vals,e_vals)
        u_perc,obs_binary  = find_percentage(perc_,o_vals,u_vals)
        x_perc,obs_binary  = find_percentage(perc_,o_vals,x_vals)
        
        return(g_perc,e_perc,u_perc,x_perc,obs_binary )
    
    for idx,percentile_num in enumerate([66,80,90,95]):
        '''We are adding 4 to make sure that we get the indices correct, we already added data from below percentiles'''
        final_perc_gefs[:,:,:,idx+4], final_perc_ecmwf[:,:,:,idx+4], final_perc_unet[:,:,:,idx+4], final_perc_xg[:,:,:,idx+4],obs_binary_out[:,:,:,idx+4] = check_if_above_percentile(obs_percent, percentile_num, o_vals, g_vals, e_vals, u_vals, x_vals)

    
    #Save files for later use
    np.save(save_ecmwf,final_perc_ecmwf)
    np.save(save_gefs, final_perc_gefs)
    np.save(save_unet, final_perc_unet)
    np.save(save_xg, final_perc_xg)
    np.save(save_obs_binary, obs_binary_out)


In [11]:
def return_anomaly_correlation_coefficient_compute(output_dictionary,fcst_RZSM,obs_RZSM,base_RZSM_climpred,
                                                    out_name,bias_correction,experiment,dim_order,many_testing_predictions):
    #Now create the anomaly correlation coefficient
    
    if many_testing_predictions == True:
        out_name = f'{out_name}_many_predictions'
    else:
        out_name = f'{out_name}_single_prediction'
    
    ACC_RZSM_prediction = anomaly_correlation_coefficient_function(var_OUT = np.empty(shape=(fcst_RZSM.RZSM.squeeze().shape)), forecast_converted=fcst_RZSM.RZSM.values.squeeze(), obs_converted = obs_RZSM.squeeze().reshape((104,11,48,96)))
    ACC_RZSM_baseline = anomaly_correlation_coefficient_function(np.empty(shape=(fcst_RZSM.RZSM.squeeze().shape)), base_RZSM_climpred.RZSM.values.squeeze(), obs_RZSM.squeeze().reshape((104,11,48,96)))

    output_dictionary[f'Wk{lead}_{experiment}_MEM_RZSM_ACC{out_name}'] =  ACC_RZSM_prediction  
    output_dictionary[f'Wk{lead}_MEM_baseline_RZSM_ACC_no_BC'] =  ACC_RZSM_baseline 
    output_dictionary[f'Wk{lead}_{experiment}_baseline_improvement_MEM_RZSM_ACC{out_name}'] =  ACC_RZSM_prediction - ACC_RZSM_baseline 

    
    #Additive bias correction
#     ACC_RZSM_additive_BC = anomaly_correlation_coefficient_function(var_OUT=np.empty(shape=(corrected_RZSM_additive.sel(L=(lead*7)-1).expand_dims({'L': 1}).transpose(*dim_order).squeeze().RZSM.shape)), forecast_converted=corrected_RZSM_additive.isel(L=lead).RZSM.values, obs_converted=obs_RZSM.squeeze().reshape((104,11,48,96)))
#     ACC_RZSM_additive_BC = np.where(np.isnan(obs_RZSM[0,:,:]),np.nan,ACC_RZSM_additive_BC)
#     ACC_tmax_additive_BC = anomaly_correlation_coefficient_function(np.empty(shape=(corrected_tmax_additive.sel(L=(lead*7)-1).expand_dims({'L': 1}).transpose(*dim_order).squeeze().tmax.shape)), corrected_tmax_additive.isel(L=lead).tmax.values, obs_tmax.squeeze().reshape((104,11,48,96)))
    
#     output_dictionary[f'Wk{lead}_baseline_improvement_MEM_RZSM_ACC_additive_BC'] =  ACC_RZSM_additive_BC - ACC_RZSM_baseline
#     output_dictionary[f'Wk{lead}_MEM_RZSM_ACC_additive_BC'] = ACC_RZSM_additive_BC
    
#     output_dictionary[f'Wk{lead}_baseline_improvement_MEM_Tmax_ACC_additive_BC'] =  ACC_tmax_additive_BC - ACC_Tmax_baseline
#     output_dictionary[f'Wk{lead}_MEM_Tmax_ACC_additive_BC'] = ACC_tmax_additive_BC
    
    return(output_dictionary)



In [12]:
def return_non_post_processed_forecasts(lead,dim_order):
    '''We are selecting a single lead time, so use this code'''
    if lead == 0:
        index_sel = 0
    else:
        index_sel = (lead*7)-1
    
    if lead ==0:
        RZSM_base_reforecast_climpred = f.restrict_to_CONUS_bounding_box(RZSM_baseline_reforecast_climpred_day0,CONUS_mask).isel(L=index_sel).expand_dims({'L': 1}).transpose(*dim_order)
        print_min_max(RZSM_base_reforecast_climpred,'RZSM baseline value from reforecast (no pre-processing other than anomaly computed.)')
    elif lead !=0:
        RZSM_base_reforecast_climpred = f.restrict_to_CONUS_bounding_box(gefs_anom,mask).sel(L=(lead*7)-1).expand_dims({'L': 1}).transpose(*dim_order)
        print_min_max(RZSM_base_reforecast_climpred,'RZSM baseline value from reforecast (no pre-processing other than anomaly computed.)')

    if RZSM_or_Tmax_or_both == 'both':
        return(RZSM_base_reforecast_climpred,tmax_base_reforecast_climpred)
    else:
        return(RZSM_base_reforecast_climpred)

    '''Then load the actual UNET predictions and keep the observations that are already in a good format for comparison
    This data has been loaded, converted back to anomalies, and masked for RZSM for non-land regions'''
    
    

In [13]:


def anomaly_correlation_coefficient(lead,experiment,ACC_dictionary):
    output_dictionary = {}
    out_name = ''
    
    #Must re-add back L as a date
    dim_order = ['S','M','Y','X','L']

    RZSM_base_reforecast_climpred = return_non_post_processed_forecasts(lead,dim_order) #Returns the original reforecasts
    prediction_RZSM,obs_RZSM = load_experiment_predictions_and_observations(lead,experiment,region_name) #Returns the UNET prediction and observations
    
    #Change name for climpred processing
    #Reforecast prediction
    prediction_RZSM_climpred = verifications.rename_subx_for_climpred(convert_prediction_to_SubX_format(file=prediction_RZSM,lead=lead,dim_order = dim_order))
    prediction_RZSM_climpred  = prediction_RZSM_climpred.rename({'file_name':'RZSM'})
    print_min_max(prediction_RZSM_climpred,'RZSM anomaly prediction value from UNET')

    
    unet_acc = verifications.create_climpred_ACC(prediction_RZSM_climpred, verifications.rename_obs_for_climpred(obs_original))
    ACC_dictionary[f'Wk{lead}_{experiment}_MEM_RZSM_ACC'] = np.nanmean(unet_acc.acc.values)

    #Base reforecast (before post-processing)
    base_RZSM_climpred  = verifications.rename_subx_for_climpred(RZSM_base_reforecast_climpred)
    base_RZSM_climpred = base_RZSM_climpred.sel(init=slice(start_testing, None))
    print_min_max(base_RZSM_climpred,'RZSM baseline anomaly from reforecast. No post-processing.')

    gefs_acc = verifications.create_climpred_ACC(base_RZSM_climpred, verifications.rename_obs_for_climpred(obs_original))

    ACC_dictionary[f'Wk{lead}_MEM_baseline_RZSM_ACC'] = np.nanmean(gefs_acc.acc.values)

    return(ACC_dictionary)
    


In [None]:
#Testing
# many_testing_predictions = True
# experiment = 'EX0'
# lead=1
# bias_correction = True

In [16]:
def convert_prediction_to_SubX_format(file,lead,dim_order):
    if lead == 0:
        cp_base = f.restrict_to_CONUS_bounding_box(base_file_testing_day0.copy(deep=True).isel(L=0),mask).expand_dims({'L': 1})
    else:                                    
        cp_base = f.restrict_to_CONUS_bounding_box(base_file_testing.copy(deep=True).sel(L=(lead*7)-1),mask).expand_dims({'L': 1})

    #reshape data back to original format (testing shape)
    
    cp_base = cp_base.transpose(*dim_order)
    
    #Reshape prediction file
    file = file.reshape((104,11,48,96,1))
    
    var_OUT = xr.Dataset(
            data_vars = dict(
                file_name = (['S','M','Y','X','L'],  file[:,:,:,:,:]),
            ),
            coords = dict(
                S = cp_base.S.values,
                X = cp_base.X.values,
                Y = cp_base.Y.values,
                L = cp_base.L.values,
                M = cp_base.M.values,

            ),
            attrs = dict(
                Description = 'New data added to file'),
        )  


    
    return(var_OUT)


In [17]:
def select_lead(file,dim_order,lead):
    return(file.sel(L=(lead*7)-1).expand_dims({'L': 1}).transpose(*dim_order))



In [None]:
## Testing
# lead=1
# experiment='EX0'
# train_end = 2015



    
# def run_climpred_for_Experiment(lead,experiment, train_end, 
#                                 obs_RZSM_file_for_climpred, obs_tmax_file_for_climpred,
#                                 RZSM_baseline_reforecast_climpred, tmax_baseline_reforecast_climpred,
#                                 base_file_testing,bias_correction,many_testing_predictions=False):
#     '''Returns the CRPS and MAE and ACC values for CONUS. Saves into a dictionary'''
        
#     out_dictionary = {}
    
#     #Must re-add back L as a date
#     dim_order = ['S','M','Y','X','L']
#     RZSM_base_reforecast_climpred, tmax_base_reforecast_climpred = return_original_forecasts(lead,dim_order)
    
#     if bias_correction  == True:
#         out_name = '_bias_corrected'
#     else:
#         out_name = ''
        
#     print(f'Prediction RZSM shape = {prediction_RZSM.shape}')
#     print(f'Prediction Tmax shape = {prediction_tmax.shape}')
#     print(f'Obs RZSM shape = {obs_RZSM_for_mae.shape}')
#     print(f'Obs Tmax shape = {obs_tmax_for_mae.shape}')
    

#     #Find the mean absolute error over all predictions from each ensemble member
#     # out_dictionary[f'Wk{lead}_{experiment}_RZSM_MAE'] = np.abs(np.nanmean(obs_RZSM_for_mae-prediction_RZSM,axis=0))
#     # out_dictionary[f'Wk{lead}_{experiment}_Tmax_MAE'] = np.abs(np.nanmean(obs_tmax_for_mae-prediction_tmax,axis=0))
    
#     prediction_MAE_RZSM = np.abs(np.nanmean(obs_RZSM_for_mae - prediction_RZSM,axis=0))
#     prediction_MAE_Tmax = np.abs(np.nanmean(obs_tmax_for_mae - prediction_tmax,axis=0))
    
#     convert_obs_to_compare_baseline_rzsm = np.abs(np.nanmean(np.subtract(obs_RZSM_for_mae.reshape((104,11,48,96,1)),RZSM_base_reforecast_climpred.RZSM.values).squeeze(),axis=(0,1)))
#     try:
#         convert_obs_to_compare_baseline_tmax = np.abs(np.nanmean(np.subtract(obs_tmax_for_mae.reshape((104,11,48,96,1)),tmax_base_reforecast_climpred.tasmax.values).squeeze(),axis=(0,1)))
#     except AttributeError:
#         convert_obs_to_compare_baseline_tmax = np.abs(np.nanmean(np.subtract(obs_tmax_for_mae.reshape((104,11,48,96,1)),tmax_base_reforecast_climpred.tmax.values).squeeze(),axis=(0,1)))
    
#     out_dictionary[f'Wk{lead}_{experiment}_baseline_improvement_RZSM_MAE{out_name}'] = np.subtract(convert_obs_to_compare_baseline_rzsm,prediction_MAE_RZSM)
#     out_dictionary[f'Wk{lead}_{experiment}_baseline_improvement_Tmax_MAE{out_name}'] =  np.subtract(convert_obs_to_compare_baseline_tmax, prediction_MAE_Tmax)
    
#     #Change name for climpred processing
#     #Reforecast prediction
#     prediction_RZSM_climpred = f.rename_subx_for_climpred(convert_prediction_to_SubX_format(prediction_RZSM,lead,dim_order))
#     print_min_max(prediction_RZSM_climpred,'RZSM anomaly prediction value from UNET')
    
#     # out_dictionary[f'Wk{lead}_baseline_RZSM_MAE'] = convert_obs_to_compare_baseline

    
#     prediction_tmax_climpred  = f.rename_subx_for_climpred(convert_prediction_to_SubX_format(prediction_tmax,lead,dim_order))
#     prediction_tmax_climpred  = prediction_tmax_climpred .rename({'RZSM':'tmax'})
#     print_min_max(prediction_tmax_climpred,'Tmax anomaly prediction value from UNET')
    
    
    
#     # out_dictionary[f'Wk{lead}_baseline_Tmax_MAE'] = convert_obs_to_compare_baseline_tmax
        
#     #Base reforecast (before post-processing)
#     base_RZSM_climpred  = f.rename_subx_for_climpred(RZSM_base_reforecast_climpred)
#     base_RZSM_climpred = base_RZSM_climpred.sel(init=slice(start_testing, None))
#     print_min_max(base_RZSM_climpred,'RZSM baseline anomaly from reforecast. No post-processing.')
    
#     base_tmax_climpred  = rename_subx_for_climpred(tmax_base_reforecast_climpred)
#     try:
#         base_tmax_climpred = base_tmax_climpred.rename({'tasmax':'tmax'})
#     except ValueError:
#         pass
#     base_tmax_climpred = base_tmax_climpred.sel(init=slice(start_testing, None))
#     print_min_max(base_tmax_climpred,'Tmax baseline anomaly from reforecast. No post-processing.')
    
#     #Observations
#     obs_RZSM_file_climpred = rename_obs_for_climpred(obs_RZSM_file_for_climpred)
#     obs_RZSM_file_climpred = obs_RZSM_file_climpred.rename({'SMsurf':'RZSM'})
#     print_min_max(obs_RZSM_file_climpred,'RZSM observations anomaly.')

#     obs_tmax_file_climpred = rename_obs_for_climpred(obs_tmax_file_for_climpred)
#     obs_tmax_file_climpred = obs_tmax_file_climpred.rename({'mx2t':'tmax'})
#     print_min_max(obs_tmax_file_climpred,'Tmax observations anomaly.')
    
#     #Need to set chunks for climpred functions
#     fcst_RZSM,fcst_tmax,base_RZSM,base_tmax = prediction_RZSM_climpred.chunk({'init': -1}), prediction_tmax_climpred.chunk({'init': -1}), base_RZSM_climpred.chunk({'init': -1}), base_tmax_climpred.chunk({'init': -1})

#     additive_RZSM_bias = rename_subx_for_climpred(corrected_RZSM_additive.isel(L=lead).expand_dims({'L': 1}).transpose(*dim_order)).chunk({'init': -1})
#     additive_tmax_bias = rename_subx_for_climpred(corrected_tmax_additive.isel(L=lead).expand_dims({'L': 1}).transpose(*dim_order)).chunk({'init': -1})
    
#     verif_RZSM,verif_tmax = obs_RZSM_file_climpred.chunk({'time': -1}), obs_tmax_file_climpred.chunk({'time': -1})
#     ######################################### ACC #########################################
#     output_dictionary,ACC_RZSM_prediction,ACC_RZSM_baseline,ACC_Tmax_prediction,ACC_Tmax_baseline,ACC_RZSM_additive_BC,ACC_tmax_additive_BC =  return_anomaly_correlation_coefficient_compute(output_dictionary,fcst_RZSM,obs_RZSM_for_mae,base_RZSM,fcst_tmax,obs_tmax_for_mae,base_tmax,corrected_RZSM_additive,corrected_tmax_additive,out_name)
    
#     #Now create climpred classes for CRPS
#     hindcast_prediction_RZSM = climpred.HindcastEnsemble(fcst_RZSM).add_observations(verif_RZSM)
#     hindcast_prediction_tmax = climpred.HindcastEnsemble(fcst_tmax).add_observations(verif_tmax)

#     hindcast_base_RZSM = climpred.HindcastEnsemble(base_RZSM).add_observations(verif_RZSM)
#     hindcast_base_tmax = climpred.HindcastEnsemble(base_tmax).add_observations(verif_tmax)
    
#     hindcast_base_RZSM_additive_BC = climpred.HindcastEnsemble(additive_RZSM_bias).add_observations(verif_RZSM)
#     hindcast_base_tmax_additive_BC = climpred.HindcastEnsemble(additive_tmax_bias).add_observations(verif_tmax)
    
#     ######################################### CRPS #########################################
#     crps_baseline_RZSM = hindcast_base_RZSM.verify(metric="crps", comparison="m2o", dim="member", alignment="maximize").rename(RZSM='crps').load()
#     prediction_RZSM = hindcast_prediction_RZSM.verify(metric="crps", comparison="m2o", dim="member", alignment="maximize").rename(RZSM='crps').load() 
    
#     addtive_bias_RZSM = hindcast_base_RZSM_additive_BC.verify(metric="crps", comparison="m2o", dim="member", alignment="maximize").rename(RZSM='crps').load() 
    
#     out_dictionary[f'Wk{lead}_{experiment}_baseline_improvement_RZSM_CRPS{out_name}'] = crps_baseline_RZSM -   prediction_RZSM
#     out_dictionary[f'Wk{lead}_baseline_improvement_RZSM_CRPS{out_name}_additive_BC'] = crps_baseline_RZSM -   addtive_bias_RZSM
#     # out_dictionary[f'Wk{lead}_{experiment}_Tmax_Baseline_CRPS']= hindcast_base_tmax.verify(metric="crps", comparison="m2o", dim="member", alignment="maximize").rename(tmax='crps').load()

    
#     crps_baseline_tmax = hindcast_base_tmax.verify(metric="crps", comparison="m2o", dim="member", alignment="maximize").rename(tmax='crps').load()
#     prediction_tmax =  hindcast_prediction_tmax.verify(metric="crps", comparison="m2o", dim="member", alignment="maximize").rename(tmax='crps').load()
#     addtive_bias_RZSM = hindcast_base_tmax_additive_BC.verify(metric="crps", comparison="m2o", dim="member", alignment="maximize").rename(tmax='crps').load() 
    
#     out_dictionary[f'Wk{lead}_{experiment}_baseline_improvement_Tmax_CRPS{out_name}'] = crps_baseline_tmax - prediction_tmax
#     out_dictionary[f'Wk{lead}_baseline_improvement_Tmax_CRPS{out_name}_additive_BC'] = crps_baseline_tmax - addtive_bias_RZSM
    
#     ######################################### RANK HISTOGRAM #########################################
#     #Now plot the rank histogram
#     rank_histogram_RZSM_baseline = hindcast_base_RZSM.verify(metric="rank_histogram", comparison="m2o", dim=["member", "init", "lat", "lon"], alignment="maximize").load().rename(RZSM='rank_histogram')
#     rank_histogram_RZSM_prediction = hindcast_prediction_RZSM.verify(metric="rank_histogram", comparison="m2o", dim=["member", "init", "lat", "lon"], alignment="maximize").load().rename(RZSM='rank_histogram')
    
#     rank_histogram_Tmax_baseline = hindcast_base_tmax.verify(metric="rank_histogram", comparison="m2o", dim=["member", "init", "lat", "lon"], alignment="maximize").load().rename(tmax='rank_histogram')
#     rank_histogram_Tmax_prediction = hindcast_prediction_tmax.verify(metric="rank_histogram", comparison="m2o", dim=["member", "init", "lat", "lon"], alignment="maximize").load().rename(tmax='rank_histogram')
    
#     out_dictionary[f'Wk{lead}_{experiment}_RZSM_rank_histogram{out_name}'] = rank_histogram_RZSM_prediction
#     out_dictionary[f'Wk{lead}_baseline_RZSM_rank_histogram'] = rank_histogram_RZSM_baseline
    
#     out_dictionary[f'Wk{lead}_{experiment}_Tmax_rank_histogram{out_name}'] = rank_histogram_Tmax_prediction
#     out_dictionary[f'Wk{lead}_baseline_Tmax_rank_histogram'] = rank_histogram_Tmax_baseline
    
#     return(out_dictionary)


# Plot data

In [None]:
def get_min_max_of_files(file):
    min_ =  []
    max_ = []
    try:
        for f in list(file.keys()):
            min_.append(np.nanmin(file[f]))
            max_.append(np.nanmax(file[f]))
    except TypeError:
         for f in list(file.keys()):
            min_.append(np.nanmin(file[f].crps))
            max_.append(np.nanmax(file[f].crps))
            
    return(min(min_),max(max_))


In [None]:


   
# cmap = 'coolwarm'

def plot_files_ACC(test_file, var, name_of_test):
    cmap = plt.get_cmap('bwr')    
    
    save_dir = f'Outputs/crps_mae/Wk_{lead}/{var}_ACC'
    os.system(f'mkdir -p {save_dir}')
    
    if lead == 0:
        row=3
        column=5
        width = 20
        height=10
        ex_size = 12
    else:
        row=6
        column=5
        width = 30
        height=25
        ex_size = 16
        
    fig, axs = plt.subplots(
        row, column, subplot_kw={'projection': ccrs.PlateCarree()}, figsize=(width, height))
    axs = axs.flatten()
    fig.tight_layout()
    
    min_,max_ = get_min_max_of_files(test_file)
    # test_file = mae_rzsm_keys
    # for Subx original data
    
    #For the ACC values
    if lead == 0:
        text_x = -84  # You may need to adjust this value based on your data
        text_y = 27  # You may need to adjust this value based on your data
        font_size = 12
    else:
        text_x = -84  # You may need to adjust this value based on your data
        text_y = 27  # You may need to adjust this value based on your data
        font_size = 16
    
    for idx,experiment in enumerate(experiment_list):
        data = {key: value for key, value in test_file.items() if experiment in key}
        data = list(data.values())[0]
        
        
        mean_vals = round(np.nanmean(data),4)
        print(f'Mean value: {mean_vals}')
        
        v = np.linspace(min_, max_, 30, endpoint=True)

        map = Basemap(projection='cyl', llcrnrlat=25, urcrnrlat=50,
                      llcrnrlon=-128, urcrnrlon=-60, resolution='l')
        x, y = map(*np.meshgrid(lon, lat))
        # Adjust the text coordinates based on the actual data coordinates


        
        # ax.drawmeridians()
        try:
            norm = TwoSlopeNorm(vmin=min_, vcenter=0, vmax=max_)
        except ValueError:
            norm = None
        
        try:
            im = axs[idx].contourf(x, y, data, levels=v, extend='both',
                                  transform=ccrs.PlateCarree(), cmap=cmap,norm=norm)
        except TypeError:
            data = np.nanmean(data.crps.values,axis=0)
            im = axs[idx].contourf(x, y, data, levels=v, extend='both',
                      transform=ccrs.PlateCarree(), cmap=cmap,norm=norm)
            
    
        # axs[idx].title.set_text(f'SubX Lead {lead*7}')
        gl = axs[idx].gridlines(crs=ccrs.PlateCarree(), draw_labels=True,
                                   linewidth=0.7, color='gray', alpha=0.5, linestyle='--')
        gl.xlabels_top = False
        gl.ylabels_right = False
        if lead != 1:
            gl.ylabels_left = False
        gl.xformatter = LongitudeFormatter()
        gl.yformatter = LatitudeFormatter()
        axs[idx].coastlines()
        # plt.colorbar(im)
        # axs[idx].set_aspect('auto', adjustable=None)
        axs[idx].set_aspect('equal')  # this makes the plots better
        axs[idx].set_title(experiment,fontsize=ex_size)
        axs[idx].text(text_x, text_y, mean_vals, ha='right', va='bottom', fontsize=font_size, color='blue', weight = 'bold')

        # Add a colorbar axis at the bottom of the graph
        # left, bottom, width, height
    
    data_non_EX = {key: value for key, value in test_file.items() if 'EX' not in key}
    #Don't worry about additive bias right now. I can't figure out why it doesn't work
    data_non_EX = {key: value for key, value in data_non_EX.items() if 'no_BC' in key}
    
    for non_used,data_key in enumerate(data_non_EX):
        idx+=1
        # break
        data = {key: value for key, value in test_file.items() if data_key in key}
        data = list(data.values())[0]
        mean_vals = round(np.nanmean(data),4)
        
        v = np.linspace(min_, max_, 30, endpoint=True)

        map = Basemap(projection='cyl', llcrnrlat=25, urcrnrlat=50,
                      llcrnrlon=-128, urcrnrlon=-60, resolution='l')
        x, y = map(*np.meshgrid(lon, lat))


        # ax.drawmeridians()
        try:
            norm = TwoSlopeNorm(vmin=min_, vcenter=0, vmax=max_)
        except ValueError:
            norm = None
        
        try:
            im = axs[idx].contourf(x, y, data, levels=v, extend='both',
                                  transform=ccrs.PlateCarree(), cmap=cmap,norm=norm)
        except TypeError:
            data = np.nanmean(data.crps.values,axis=0)
            im = axs[idx].contourf(x, y, data, levels=v, extend='both',
                      transform=ccrs.PlateCarree(), cmap=cmap,norm=norm)
        # axs[idx].title.set_text(f'SubX Lead {lead*7}')
        gl = axs[idx].gridlines(crs=ccrs.PlateCarree(), draw_labels=True,
                                   linewidth=0.7, color='gray', alpha=0.5, linestyle='--')
        gl.xlabels_top = False
        gl.ylabels_right = False
        if lead != 1:
            gl.ylabels_left = False
        gl.xformatter = LongitudeFormatter()
        gl.yformatter = LatitudeFormatter()
        axs[idx].coastlines()
        # plt.colorbar(im)
        # axs[idx].set_aspect('auto', adjustable=None)
        axs[idx].set_aspect('equal')  # this makes the plots better
        axs[idx].set_title(data_key)
        axs[idx].text(text_x, text_y, mean_vals, ha='right', va='bottom', fontsize=font_size, color='blue', weight = 'bold')
        
        # Add a colorbar axis at the bottom of the graph
        # left, bottom, width, height
    cbar_ax = fig.add_axes([0.05, -0.05, .9, .04])


    
    # Draw the colorbar
    cbar = fig.colorbar(im, cax=cbar_ax, orientation='horizontal')
    plt.tight_layout()
    fig.suptitle(name_of_test, fontsize=30)
    plt.savefig(f'{save_dir}/{name_of_test}.png',bbox_inches='tight')
    plt.show()
    

In [None]:

def plot_rank_histogram(baseline_file,prediction_file, name_of_file):
    # prediction_file = rzsm_Rank_histogram_predictions
    # baseline_file = rzsm_Rank_histogram_baseline
    
    # baseline_file = baseline_file.values()
    baseline_file = list(baseline_file.values())[0]
    
    fig, axs = plt.subplots(3,5, figsize=(20, 7))

#     fig, axs = plt.subplots(
#         3, 5, subplot_kw={'projection': ccrs.PlateCarree()}, figsize=(20, 7), gridspec_kw={'height_ratios': [2,2,2]})
    axs = axs.flatten()
    

    # rank_file=rank_histogram_unet
    to_df = baseline_file.rank_histogram[:].to_dataframe()
    to_df['rank_histogram'] = to_df['rank_histogram'] / \
        to_df['rank_histogram'].sum()
    to_df['rank'] = to_df.index
    to_df['rank'] = to_df['rank'].astype(int)
    to_df.index = to_df['rank']
    del to_df['lead']
    del to_df['skill']
    del to_df['rank']
    
    print(f'Shape of to_df : {to_df.rank().shape[0]}')
    # axs[ax].plot(to_df)
    axs[0].bar(np.arange(1,to_df.rank().shape[0]+1),to_df.rank_histogram)
    axs[0].set_xlim(1, 12)

    # Optionally, adjust tick marks
    axs[0].set_xticks(np.arange(1, 13))
    # to_df.bar(ax=axs[0], kind='bar', width=1.4)
    axs[0].set_title(f'Baseline Rank Histogram')
    axs[0].set_xticklabels(axs[0].get_xticklabels(), rotation=0)
    axs[0].set_ylabel('Relative Frequency', rotation=90)

    
    for ax,experiment in enumerate(experiment_list):
        ax+=1
        data = {key: value for key, value in prediction_file.items() if experiment in key}
        data = list(data.values())[0]

        # rank_file=rank_histogram_unet
        to_df = data.rank_histogram[:].to_dataframe()
        to_df['rank_histogram'] = to_df['rank_histogram'] / \
            to_df['rank_histogram'].sum()
        to_df['rank'] = to_df.index
        to_df['rank'] = to_df['rank'].astype(int)
        to_df.index = to_df['rank']
        del to_df['lead']
        del to_df['skill']
        del to_df['rank']
        
        print(f'Shape of to_df : {to_df.rank().shape[0]}')
        # axs[ax].plot(to_df)
        axs[ax].bar(np.arange(1,to_df.rank().shape[0]+1),to_df.rank_histogram)
        axs[ax].set_xlim(1, 12)

        # Optionally, adjust tick marks
        axs[ax].set_xticks(np.arange(1, 13))
        axs[ax].set_title(experiment)
        axs[ax].set_xticklabels(axs[ax].get_xticklabels(), rotation=0)
        axs[ax].set_ylabel('Relative Frequency', rotation=90)
    plt.suptitle(f'{name_of_file} Rank Histogram', fontsize=30)
    plt.tight_layout()
    out_dir_save = f'{save_dir}/{name_of_file}.png'
    plt.savefig(out_dir_save, dpi=300)

# Plot ACC

In [None]:
#Testing
# experiment='EX0'
# bias_correction=True
# many_testing_predictions=True


In [20]:
def grab_ACC_from_dict(dict_,var):
    acc = {key: value for key, value in dict_.items() if f'{var}_ACC' in key}
    return(acc)

def subset_delete(dict_,subset):
    keys_to_delete = [key for key in dict_.keys() if subset in key]
    for key in keys_to_delete:
        del dict_[key]
    return(dict_)

def subset_keep(dict_,subset):
    keys_to_keep = [key for key in dict_.keys() if subset in key]
    new_dict = {key: dict_[key] for key in keys_to_keep}
    return(new_dict)

def subset_update(dict_in, background_dict, subset):
    keys_to_keep = [key for key in background_dict.keys() if subset in key]
    new_dict = {key: background_dict[key] for key in keys_to_keep}
    dict_in.update(new_dict)
    return(dict_in)

In [None]:
#This is for all different tests
# experiment_list = [f'EX{i}' for i in range(0,13)]

# ACC_dictionary = {}
# for experiment in experiment_list:
#     for many_testing_predictions in [True,False]:
#         for bias_correction in [True,False]:
#             ACC_dictionary.update(anomaly_correlation_coefficient(lead,corrected_RZSM_additive,corrected_tmax_additive,bias_correction,many_testing_predictions,experiment))



In [21]:
def run_ACC(lead):
    print(f'Working on lead {lead}')
    # lead=1

    save_dict_dir = f'Outputs/crps_mae/Wk_{lead}'
    os.system(f'mkdir -p {save_dict_dir}')
    
    if lead ==0:
        experiment_list = [f'EX{i}' for i in range(0,13)]
    elif lead <=4:
        experiment_list = [f'EX{i}' for i in range(0,12)]
        experiment_list2 = [f'EX{i}' for i in range(13,30)] #Couldn't run EX12 due to memory issues
        experiment_list = experiment_list + experiment_list2
    elif lead ==5:
        experiment_list = ['EX24','EX27','EX29']

    ACC_dictionary = {}
    for experiment in experiment_list:
        ACC_dictionary.update(anomaly_correlation_coefficient(lead=lead,experiment=experiment,ACC_dictionary=ACC_dictionary ))
        
    return(ACC_dictionary)


In [22]:
# print(ACC_dictionary.keys())

def save_acc_tests(var, ACC_dictionary):

    acc = grab_ACC_from_dict(dict_ = ACC_dictionary, var = var)
    # print(acc.keys())

    ############################### SINGLE PREDICTION, NO BIAS CORRECTION ##################################################################
    acc = grab_ACC_from_dict(dict_ = ACC_dictionary, var = var)
    t1 = subset_delete(dict_ = acc, subset = 'bias_corrected')
    print(t1.keys())

    #Save the average ACC values to a dictionary for later plotting
    # t_base = subset_keep(dict_ = t1, subset = 'baseline')
    # t_unet= subset_delete(dict_ = t1, subset = 'baseline')
    # print(t_base.keys())
    # print(t_unet.keys())
    
    file_path = f'Outputs/permutation_tests/Wk_{lead}'
    file_save = f'{file_path}/ACC_vals.pkl'
    
    os.system(f'mkdir -p {file_path}')
    
    with open(file_save, 'wb') as file:
        pickle.dump(t1, file)

    # plot_files_ACC(test_file = t1, var = var, name_of_test = f'{var} Single prediction ACC - No bias correction')

 
    return(0)

for lead in [1,2,3,4]:
    save_acc_tests(var = 'RZSM', ACC_dictionary = run_ACC(lead))


stop

Working on lead 1
RZSM baseline value from reforecast (no pre-processing other than anomaly computed.)
Maximum value in file is 0.2466011941432953
Minimum value in file is -0.20214346051216125

Now loading Verification Data from Observations.

Test prediction shape: (3, 1144, 48, 96, 1)
Shape of prediction RZSM: (1144, 48, 96, 1)
RZSM anomaly prediction value from UNET
Maximum value in file is 0.1511690616607666
Minimum value in file is -0.20214346051216125
RZSM baseline anomaly from reforecast. No post-processing.
Maximum value in file is 0.2466011941432953
Minimum value in file is -0.18918319046497345
RZSM baseline value from reforecast (no pre-processing other than anomaly computed.)
Maximum value in file is 0.2466011941432953
Minimum value in file is -0.20214346051216125

Now loading Verification Data from Observations.

Test prediction shape: (3, 1144, 48, 96, 1)
Shape of prediction RZSM: (1144, 48, 96, 1)
RZSM anomaly prediction value from UNET
Maximum value in file is 0.19910082

  r = r_num / r_den
  r = r_num / r_den
  ACC_dictionary[f'Wk{lead}_{experiment}_MEM_RZSM_ACC'] = np.nanmean(unet_acc.acc.values)


RZSM baseline anomaly from reforecast. No post-processing.
Maximum value in file is 0.2466011941432953
Minimum value in file is -0.18918319046497345
RZSM baseline value from reforecast (no pre-processing other than anomaly computed.)
Maximum value in file is 0.2466011941432953
Minimum value in file is -0.20214346051216125

Now loading Verification Data from Observations.

RZSM anomaly prediction value from UNET
Maximum value in file is 0.0
Minimum value in file is 0.0


  r = r_num / r_den
  r = r_num / r_den
  ACC_dictionary[f'Wk{lead}_{experiment}_MEM_RZSM_ACC'] = np.nanmean(unet_acc.acc.values)


RZSM baseline anomaly from reforecast. No post-processing.
Maximum value in file is 0.2466011941432953
Minimum value in file is -0.18918319046497345
RZSM baseline value from reforecast (no pre-processing other than anomaly computed.)
Maximum value in file is 0.2466011941432953
Minimum value in file is -0.20214346051216125

Now loading Verification Data from Observations.

RZSM anomaly prediction value from UNET
Maximum value in file is 0.0
Minimum value in file is 0.0


  r = r_num / r_den
  r = r_num / r_den
  ACC_dictionary[f'Wk{lead}_{experiment}_MEM_RZSM_ACC'] = np.nanmean(unet_acc.acc.values)


RZSM baseline anomaly from reforecast. No post-processing.
Maximum value in file is 0.2466011941432953
Minimum value in file is -0.18918319046497345
RZSM baseline value from reforecast (no pre-processing other than anomaly computed.)
Maximum value in file is 0.2466011941432953
Minimum value in file is -0.20214346051216125

Now loading Verification Data from Observations.

RZSM anomaly prediction value from UNET
Maximum value in file is 0.0
Minimum value in file is 0.0


  r = r_num / r_den
  r = r_num / r_den
  ACC_dictionary[f'Wk{lead}_{experiment}_MEM_RZSM_ACC'] = np.nanmean(unet_acc.acc.values)


RZSM baseline anomaly from reforecast. No post-processing.
Maximum value in file is 0.2466011941432953
Minimum value in file is -0.18918319046497345
RZSM baseline value from reforecast (no pre-processing other than anomaly computed.)
Maximum value in file is 0.2466011941432953
Minimum value in file is -0.20214346051216125

Now loading Verification Data from Observations.

Test prediction shape: (3, 1144, 48, 96, 1)
Shape of prediction RZSM: (1144, 48, 96, 1)
RZSM anomaly prediction value from UNET
Maximum value in file is 0.24399486184120178
Minimum value in file is -0.20214346051216125
RZSM baseline anomaly from reforecast. No post-processing.
Maximum value in file is 0.2466011941432953
Minimum value in file is -0.18918319046497345
RZSM baseline value from reforecast (no pre-processing other than anomaly computed.)
Maximum value in file is 0.2466011941432953
Minimum value in file is -0.20214346051216125

Now loading Verification Data from Observations.

Test prediction shape: (3, 1144

  r = r_num / r_den
  r = r_num / r_den
  ACC_dictionary[f'Wk{lead}_{experiment}_MEM_RZSM_ACC'] = np.nanmean(unet_acc.acc.values)


RZSM baseline anomaly from reforecast. No post-processing.
Maximum value in file is 0.2466011941432953
Minimum value in file is -0.18918319046497345
RZSM baseline value from reforecast (no pre-processing other than anomaly computed.)
Maximum value in file is 0.2466011941432953
Minimum value in file is -0.20214346051216125

Now loading Verification Data from Observations.

Test prediction shape: (3, 1144, 48, 96, 1)
Shape of prediction RZSM: (1144, 48, 96, 1)
RZSM anomaly prediction value from UNET
Maximum value in file is 0.19206494092941284
Minimum value in file is -0.20214346051216125
RZSM baseline anomaly from reforecast. No post-processing.
Maximum value in file is 0.2466011941432953
Minimum value in file is -0.18918319046497345
RZSM baseline value from reforecast (no pre-processing other than anomaly computed.)
Maximum value in file is 0.2466011941432953
Minimum value in file is -0.20214346051216125

Now loading Verification Data from Observations.

Test prediction shape: (3, 1144

  r = r_num / r_den
  r = r_num / r_den
  ACC_dictionary[f'Wk{lead}_{experiment}_MEM_RZSM_ACC'] = np.nanmean(unet_acc.acc.values)


RZSM baseline anomaly from reforecast. No post-processing.
Maximum value in file is 0.2385759800672531
Minimum value in file is -0.1945461928844452
RZSM baseline value from reforecast (no pre-processing other than anomaly computed.)
Maximum value in file is 0.24183440208435059
Minimum value in file is -0.2101595252752304

Now loading Verification Data from Observations.

Test prediction shape: (3, 1144, 48, 96, 1)
Shape of prediction RZSM: (1144, 48, 96, 1)
RZSM anomaly prediction value from UNET
Maximum value in file is 0.22241319715976715
Minimum value in file is -0.2046499401330948
RZSM baseline anomaly from reforecast. No post-processing.
Maximum value in file is 0.2385759800672531
Minimum value in file is -0.1945461928844452
RZSM baseline value from reforecast (no pre-processing other than anomaly computed.)
Maximum value in file is 0.24183440208435059
Minimum value in file is -0.2101595252752304

Now loading Verification Data from Observations.

Test prediction shape: (3, 1144, 4

  r = r_num / r_den
  r = r_num / r_den
  ACC_dictionary[f'Wk{lead}_{experiment}_MEM_RZSM_ACC'] = np.nanmean(unet_acc.acc.values)


RZSM baseline anomaly from reforecast. No post-processing.
Maximum value in file is 0.2493564486503601
Minimum value in file is -0.1958891600370407
RZSM baseline value from reforecast (no pre-processing other than anomaly computed.)
Maximum value in file is 0.2493564486503601
Minimum value in file is -0.21714745461940765

Now loading Verification Data from Observations.

Test prediction shape: (3, 1144, 48, 96, 1)
Shape of prediction RZSM: (1144, 48, 96, 1)
RZSM anomaly prediction value from UNET
Maximum value in file is 0.20019745826721191
Minimum value in file is -0.2090436816215515
RZSM baseline anomaly from reforecast. No post-processing.
Maximum value in file is 0.2493564486503601
Minimum value in file is -0.1958891600370407
RZSM baseline value from reforecast (no pre-processing other than anomaly computed.)
Maximum value in file is 0.2493564486503601
Minimum value in file is -0.21714745461940765

Now loading Verification Data from Observations.

Test prediction shape: (3, 1144, 4

  r = r_num / r_den
  r = r_num / r_den
  ACC_dictionary[f'Wk{lead}_{experiment}_MEM_RZSM_ACC'] = np.nanmean(unet_acc.acc.values)


RZSM baseline anomaly from reforecast. No post-processing.
Maximum value in file is 0.2505298852920532
Minimum value in file is -0.19464196264743805
RZSM baseline value from reforecast (no pre-processing other than anomaly computed.)
Maximum value in file is 0.26432961225509644
Minimum value in file is -0.22833965718746185

Now loading Verification Data from Observations.

Test prediction shape: (3, 1144, 48, 96, 1)
Shape of prediction RZSM: (1144, 48, 96, 1)
RZSM anomaly prediction value from UNET
Maximum value in file is 0.3771704435348511
Minimum value in file is -0.22833965718746185
RZSM baseline anomaly from reforecast. No post-processing.
Maximum value in file is 0.2505298852920532
Minimum value in file is -0.19464196264743805
RZSM baseline value from reforecast (no pre-processing other than anomaly computed.)
Maximum value in file is 0.26432961225509644
Minimum value in file is -0.22833965718746185

Now loading Verification Data from Observations.

RZSM anomaly prediction value 

  r = r_num / r_den
  r = r_num / r_den
  ACC_dictionary[f'Wk{lead}_{experiment}_MEM_RZSM_ACC'] = np.nanmean(unet_acc.acc.values)


RZSM baseline anomaly from reforecast. No post-processing.
Maximum value in file is 0.2505298852920532
Minimum value in file is -0.19464196264743805
RZSM baseline value from reforecast (no pre-processing other than anomaly computed.)
Maximum value in file is 0.26432961225509644
Minimum value in file is -0.22833965718746185

Now loading Verification Data from Observations.

Test prediction shape: (3, 1144, 48, 96, 1)
Shape of prediction RZSM: (1144, 48, 96, 1)
RZSM anomaly prediction value from UNET
Maximum value in file is 0.17095153033733368
Minimum value in file is -0.22833965718746185
RZSM baseline anomaly from reforecast. No post-processing.
Maximum value in file is 0.2505298852920532
Minimum value in file is -0.19464196264743805
dict_keys(['Wk4_EX0_MEM_RZSM_ACC', 'Wk4_MEM_baseline_RZSM_ACC', 'Wk4_EX1_MEM_RZSM_ACC', 'Wk4_EX2_MEM_RZSM_ACC', 'Wk4_EX3_MEM_RZSM_ACC', 'Wk4_EX4_MEM_RZSM_ACC', 'Wk4_EX5_MEM_RZSM_ACC', 'Wk4_EX6_MEM_RZSM_ACC', 'Wk4_EX7_MEM_RZSM_ACC', 'Wk4_EX8_MEM_RZSM_ACC', 

NameError: name 'stop' is not defined

# Plot CRPS

In [26]:
def crps_analysis(lead,bias_correction,many_testing_predictions,experiment):
    print(f'Working on Experiment: {experiment}. For bias_correction: {bias_correction}. For many_testing_predictions: {many_testing_predictions}')
    
    def create_climpred_CRPS(fcst,obs):
        fcst_name = list(fcst.keys())[0]
        object_ =  climpred.HindcastEnsemble(fcst).add_observations(obs)
        return(object_.verify(metric="crps", comparison="m2o", dim="member", alignment="same_inits").rename({fcst_name:'crps'}).load())

        
    dim_order = ['S','M','Y','X','L']
    
    out_dictionary = {}
    
    if bias_correction  == True:
        if many_testing_predictions == True:
            out_name = '_bias_corrected_many_predictions'
        else:
            out_name = '_bias_corrected_single_prediction'
    else:
        if many_testing_predictions == True:
            out_name = '_many_predictions'
        else:
            out_name = '_single_prediction'
    
    #Must re-add back L as a date
    dim_order = ['S','M','Y','X','L']
    
    if RZSM_or_Tmax_or_both == 'both':
        RZSM_base_reforecast_climpred, tmax_base_reforecast_climpred = return_non_post_processed_forecasts(lead,dim_order) #Returns the original reforecasts
        prediction_RZSM,prediction_tmax,obs_RZSM,obs_tmax = load_experiment_predictions_and_observations(lead,experiment,many_testing_predictions,bias_correction) #Returns the UNET prediction and observations
    else:
        RZSM_base_reforecast_climpred = return_non_post_processed_forecasts(lead,dim_order) #Returns the original reforecasts
        prediction_RZSM,obs_RZSM = load_experiment_predictions_and_observations(lead,experiment,many_testing_predictions,bias_correction) #Returns the UNET prediction and observations
    ######################################### OBSERVATIONS #########################################
    obs_RZSM_climpred = f.rename_obs_for_climpred(obs_RZSM_file_for_climpred).chunk({'time': -1}).rename({'SMsurf':'RZSM'})
    print_min_max(obs_RZSM_climpred,'RZSM observations anomaly.')

        
    ######################################### Reforecast (No post-processing) raw forecasts BASELINE #########################################
    #Base reforecast (before post-processing)
    if lead ==0:
        RZSM_base_reforecast_climpred = f.restrict_to_CONUS_bounding_box(RZSM_baseline_reforecast_climpred_day0,CONUS_mask).isel(L=0).expand_dims({'L': 1}).transpose(*dim_order)
        print_min_max(RZSM_base_reforecast_climpred,'RZSM baseline value from reforecast (no pre-processing other than anomaly computed.)')
    elif lead !=0:
        RZSM_base_reforecast_climpred = f.restrict_to_CONUS_bounding_box(RZSM_baseline_reforecast_climpred,CONUS_mask).sel(L=(lead*7)-1).expand_dims({'L': 1}).transpose(*dim_order)
        print_min_max(RZSM_base_reforecast_climpred,'RZSM baseline value from reforecast (no pre-processing other than anomaly computed.)')
    


    #Reforecast prediction
    prediction_RZSM_climpred = f.rename_subx_for_climpred(convert_prediction_to_SubX_format(file=prediction_RZSM,lead=lead,dim_order = dim_order))
    prediction_RZSM_climpred  = prediction_RZSM_climpred.rename({'file_name':'RZSM'})
    print_min_max(prediction_RZSM_climpred,'RZSM anomaly prediction value from UNET')
    
    if RZSM_or_Tmax_or_both == 'both':
        prediction_tmax_climpred  = f.rename_subx_for_climpred(convert_prediction_to_SubX_format(prediction_tmax,lead,dim_order))
        prediction_tmax_climpred  = prediction_tmax_climpred.rename({'file_name':'tmax'})
        print_min_max(prediction_tmax_climpred,'Tmax anomaly prediction value from UNET')

    #Base reforecast (before post-processing)
    RZSM_base_reforecast_climpred = f.rename_subx_for_climpred(RZSM_base_reforecast_climpred)
    crps_baseline_RZSM = create_climpred_CRPS(RZSM_base_reforecast_climpred,obs_RZSM_climpred)
    print_min_max(crps_baseline_RZSM,'RZSM baseline anomaly from reforecast. No post-processing.')

    ######################################### Reforecast prediction #########################################
    fcst_RZSM = f.rename_subx_for_climpred(convert_prediction_to_SubX_format(prediction_RZSM,lead,dim_order)).chunk({'init': -1}).rename({'file_name':'RZSM'})
    print_min_max(fcst_RZSM,'RZSM anomaly prediction value from UNET')
    prediction_RZSM = create_climpred_CRPS(fcst_RZSM,obs_RZSM_climpred)
    out_dictionary[f'Wk{lead}_{experiment}_baseline_improvement_RZSM_CRPS{out_name}'] = crps_baseline_RZSM -   prediction_RZSM
    out_dictionary[f'Wk{lead}_{experiment}_MEM_RZSM_CRPS{out_name}'] =  prediction_RZSM  
    out_dictionary[f'Wk{lead}_MEM_baseline_RZSM_CRPS_no_BC'] =  crps_baseline_RZSM 
    
    
    ######################################### Reforecast Additive Bias #########################################
    # additive_BC_fcst_RZSM = f.rename_subx_for_climpred(select_lead(corrected_RZSM_additive,dim_order,lead)).chunk({'init': -1})
    # print_min_max(additive_BC_fcst_RZSM,'RZSM additive bias from raw forecasts (no post-processing with UNET).')
    # addtive_bias_RZSM = create_climpred_CRPS(additive_BC_fcst_RZSM,obs_RZSM_climpred)
    # out_dictionary[f'Wk{lead}_baseline_improvement_RZSM_CRPS{out_name}_additive_BC'] = crps_baseline_RZSM -   addtive_bias_RZSM

    
    
    ######################################### CRPS #########################################


#     out_dictionary[f'Wk{lead}_{experiment}_baseline_improvement_RZSM_CRPS{out_name}'] = crps_baseline_RZSM -   prediction_RZSM
#     out_dictionary[f'Wk{lead}_baseline_improvement_RZSM_CRPS{out_name}_additive_BC'] = crps_baseline_RZSM -   addtive_bias_RZSM
#     # out_dictionary[f'Wk{lead}_{experiment}_Tmax_Baseline_CRPS']= hindcast_base_tmax.verify(metric="crps", comparison="m2o", dim="member", alignment="maximize").rename(tmax='crps').load()


    
#     addtive_bias_RZSM = hindcast_base_tmax_additive_BC.verify(metric="crps", comparison="m2o", dim="member", alignment="maximize").rename(tmax='crps').load() 
    
#     out_dictionary[f'Wk{lead}_{experiment}_baseline_improvement_Tmax_CRPS{out_name}'] = crps_baseline_tmax - prediction_tmax
#     out_dictionary[f'Wk{lead}_baseline_improvement_Tmax_CRPS{out_name}_additive_BC'] = crps_baseline_tmax - addtive_bias_RZSM
    
    
    return(out_dictionary)
    


In [27]:
lead=4

save_dict_dir = f'Outputs/crps_mae/Wk_{lead}'
os.system(f'mkdir -p {save_dict_dir}')

0

In [30]:
if lead ==0:
    experiment_list = [f'EX{i}' for i in range(0,13)]
elif lead ==4:
    experiment_list = [f'EX{i}' for i in range(0,12)]
    experiment_list2 = [f'EX{i}' for i in range(13,30)] #Couldn't run EX12 due to memory issues
    experiment_list = experiment_list + experiment_list2
elif lead ==5:
    experiment_list = experiment_list = ['EX24','EX27','EX29']

crps_dictionary = {}

for experiment in experiment_list:
    for lead in [3,4,]:
        crps_dictionary.update(crps_analysis(lead,experiment,many_testing_predictions,experiment))

TypeError: crps_analysis() missing 1 required positional argument: 'experiment'

In [None]:
load_experiment_predictions_and_observations

In [None]:


def plot_files_CRPS(test_file, var, name_of_test):
    # cmap = plt.get_cmap('bwr')    
    cmap = 'coolwarm'
    save_dir = f'Outputs/crps_mae/Wk_{lead}/{var}_CRPS'
    os.system(f'mkdir -p {save_dir}')
    
    if lead == 0:
        row=3
        column=5
        width = 20
        height=10
        ex_size=12
    else:
        row=6
        column=5
        width = 30
        height=25
        ex_size=16
    fig, axs = plt.subplots(
        row, column, subplot_kw={'projection': ccrs.PlateCarree()}, figsize=(width, height))
    axs = axs.flatten()
    fig.tight_layout()
    
    min_,max_ = get_min_max_of_files(test_file)
    # test_file = mae_rzsm_keys
    # for Subx original data
    
    #For the ACC values
    if lead == 0:
        text_x = -84  # You may need to adjust this value based on your data
        text_y = 27  # You may need to adjust this value based on your data
        font_size = 12
    else:
        text_x = -84  # You may need to adjust this value based on your data
        text_y = 27  # You may need to adjust this value based on your data
        font_size = 16
    
    min_,max_ = get_min_max_of_files(test_file)
    # test_file = mae_rzsm_keys
    # for Subx original data
    for idx,experiment in enumerate(experiment_list):
        data = {key: value for key, value in test_file.items() if experiment in key}
        data = list(data.values())[0]
        data = data.mean(dim='init').crps.values
        
        mean_vals = round(np.nanmean(data),4)
        print(f'Mean value: {mean_vals}')
        
        v = np.linspace(min_, max_, 30, endpoint=True)

        map = Basemap(projection='cyl', llcrnrlat=25, urcrnrlat=50,
                      llcrnrlon=-128, urcrnrlon=-60, resolution='l')
        x, y = map(*np.meshgrid(lon, lat))

        
        # ax.drawmeridians()

        norm = None
        
        try:
            im = axs[idx].contourf(x, y, data, levels=v, extend='both',
                                  transform=ccrs.PlateCarree(), cmap=cmap,norm=norm)
        except TypeError:
            data = np.nanmean(data.crps.values,axis=0)
            im = axs[idx].contourf(x, y, data, levels=v, extend='both',
                      transform=ccrs.PlateCarree(), cmap=cmap,norm=norm)
            
    
        # axs[idx].title.set_text(f'SubX Lead {lead*7}')
        gl = axs[idx].gridlines(crs=ccrs.PlateCarree(), draw_labels=True,
                                   linewidth=0.7, color='gray', alpha=0.5, linestyle='--')
        gl.xlabels_top = False
        gl.ylabels_right = False
        if lead != 1:
            gl.ylabels_left = False
        gl.xformatter = LongitudeFormatter()
        gl.yformatter = LatitudeFormatter()
        axs[idx].coastlines()
        # plt.colorbar(im)
        # axs[idx].set_aspect('auto', adjustable=None)
        axs[idx].set_aspect('equal')  # this makes the plots better
        axs[idx].set_title(experiment,fontsize=ex_size)
        axs[idx].text(text_x, text_y, mean_vals, ha='right', va='bottom', fontsize=font_size, color='blue', weight = 'bold')

        # Add a colorbar axis at the bottom of the graph
        # left, bottom, width, height
    
    data_non_EX = {key: value for key, value in test_file.items() if 'EX' not in key}
    #Don't worry about additive bias right now. I can't figure out why it doesn't work
    data_non_EX = {key: value for key, value in data_non_EX.items() if 'no_BC' in key}
    
    for non_used,data_key in enumerate(data_non_EX):
        idx+=1
        # break
        data = {key: value for key, value in test_file.items() if data_key in key}
        data = list(data.values())[0]
        data = data.mean(dim='init').crps.values
        mean_vals = round(np.nanmean(data),4)
        
        v = np.linspace(min_, max_, 30, endpoint=True)

        map = Basemap(projection='cyl', llcrnrlat=25, urcrnrlat=50,
                      llcrnrlon=-128, urcrnrlon=-60, resolution='l')
        x, y = map(*np.meshgrid(lon, lat))
        # Adjust the text coordinates based on the actual data coordinates
        text_x = -84  # You may need to adjust this value based on your data
        text_y = 27  # You may need to adjust this value based on your data

        # ax.drawmeridians()
        norm = None
        
        try:
            im = axs[idx].contourf(x, y, data, levels=v, extend='both',
                                  transform=ccrs.PlateCarree(), cmap=cmap,norm=norm)
        except TypeError:
            data = np.nanmean(data.crps.values,axis=0)
            im = axs[idx].contourf(x, y, data, levels=v, extend='both',
                      transform=ccrs.PlateCarree(), cmap=cmap,norm=norm)
        # axs[idx].title.set_text(f'SubX Lead {lead*7}')
        gl = axs[idx].gridlines(crs=ccrs.PlateCarree(), draw_labels=True,
                                   linewidth=0.7, color='gray', alpha=0.5, linestyle='--')
        gl.xlabels_top = False
        gl.ylabels_right = False
        if lead != 1:
            gl.ylabels_left = False
        gl.xformatter = LongitudeFormatter()
        gl.yformatter = LatitudeFormatter()
        axs[idx].coastlines()
        # plt.colorbar(im)
        # axs[idx].set_aspect('auto', adjustable=None)
        axs[idx].set_aspect('equal')  # this makes the plots better
        axs[idx].set_title(data_key)
        axs[idx].text(text_x, text_y, mean_vals, ha='right', va='bottom', fontsize=font_size, color='blue', weight = 'bold')
        
        # Add a colorbar axis at the bottom of the graph
        # left, bottom, width, height
    cbar_ax = fig.add_axes([0.05, -0.05, .9, .04])


    
    # Draw the colorbar
    cbar = fig.colorbar(im, cax=cbar_ax, orientation='horizontal')
    plt.tight_layout()
    fig.suptitle(name_of_test, fontsize=30)
    plt.savefig(f'{save_dir}/{name_of_test}.png',bbox_inches='tight')
    plt.show()
    

In [None]:
def grab_crps_from_dict(dict_,var):
    acc = {key: value for key, value in dict_.items() if f'{var}_CRPS' in key}
    return(acc)

def subset_delete(dict_,subset):
    keys_to_delete = [key for key in dict_.keys() if subset in key]
    for key in keys_to_delete:
        del dict_[key]
    return(dict_)

def subset_keep(dict_,subset):
    keys_to_keep = [key for key in dict_.keys() if subset in key]
    new_dict = {key: dict_[key] for key in keys_to_keep}
    return(new_dict)

def subset_update(dict_in, background_dict, subset):
    keys_to_keep = [key for key in background_dict.keys() if subset in key]
    new_dict = {key: background_dict[key] for key in keys_to_keep}
    dict_in.update(new_dict)
    return(dict_in)

In [None]:
# print(ACC_dictionary.keys())

def plot_crps_tests(var):
    
    acc = grab_crps_from_dict(dict_ = crps_dictionary, var = var)
    # print(acc.keys())

    ############################### SINGLE PREDICTION, NO BIAS CORRECTION ##################################################################
    acc = grab_crps_from_dict(dict_ = crps_dictionary, var = var)
    t1 = subset_delete(dict_ = acc, subset = 'bias_corrected')
    # print(t1.keys())

    t1 = subset_delete(subset_update(subset_update(subset_delete(subset_keep(dict_ = t1, subset = 'single_prediction'), subset='baseline_improvement'),acc,'no_BC'),acc,'additive_BC'),'baseline_improvement')
    # print(t1.keys())
    # print(len(list(t1.keys())))

    #Save the average ACC values to a dictionary for later plotting
    tsave2 = subset_keep(dict_ = t1, subset = 'MEM_RZSM_CRPS_single_prediction')
    #save as a numpy array instead of xarray object
    tsave = {}
    
    for k,v in tsave2.items():
        # break
        tsave[k] = v.crps.mean(dim='init').values
    
    
    t_base2 = subset_keep(dict_ = t1, subset = 'MEM_baseline_RZSM_CRPS_no_BC')
    t_base = {}
    
    for k,v in t_base2.items():
        # break
        t_base[k] = v.crps.mean(dim='init').values
    
    tsave[list(t_base.keys())[0]] = t_base[list(t_base.keys())[0]]
    
    #Also include the baseline reforecast ACC value
    
    
    file_path = f'Outputs/permutation_tests/Wk_{lead}'
    file_save = f'{file_path}/CRPS_vals.pkl'
    
    os.system(f'mkdir -p {file_path}')
    
    with open(file_save, 'wb') as file:
        pickle.dump(tsave, file)

    plot_files_CRPS(test_file = t1, var = var, name_of_test = f'{var} Single prediction CRPS - No bias correction')
    
    
    ############################### SINGLE PREDICTION, NO BIAS CORRECTION --  IMPROVEMENT ##################################################################
    acc = grab_crps_from_dict(dict_ = crps_dictionary, var = var)
    t1 = subset_delete(dict_ = acc, subset = 'bias_corrected')
    # print(t1.keys())

    t1 = subset_keep(subset_update(subset_keep(subset_keep(dict_ = t1, subset = 'single_prediction'), subset='baseline_improvement'),acc,'additive_BC'),'baseline_improvement')
    # print(t1.keys())
    # print(len(list(t1.keys())))

    plot_files_CRPS(test_file = t1, var = var, name_of_test = f'{var} Single prediction CRPS Improvement - No bias correction')
    
    if lead < 3:
        
        ############################### MANY PREDICTION, NO BIAS CORRECTION ##################################################################
        acc = grab_crps_from_dict(dict_ = crps_dictionary, var = var)
        t1 = subset_delete(dict_ = acc, subset = 'bias_corrected')
        # print(t1.keys())

        t1 = subset_delete(subset_update(subset_update(subset_delete(subset_keep(dict_ = t1, subset = 'many_predictions'), subset='baseline_improvement'),acc,'no_BC'),acc,'additive_BC'),'baseline_improvement')
        # print(t1.keys())
        # print(len(list(t1.keys())))

        plot_files_CRPS(test_file = t1, var = var, name_of_test = f'{var} Many prediction CRPS - No bias correction')

        ############################### MANY PREDICTION, BIAS CORRECTION ##################################################################
        acc = grab_crps_from_dict(dict_ = crps_dictionary, var = var)
        t1 = subset_keep(dict_ = acc, subset = 'bias_corrected')
        # print(t1.keys())

        t1 = subset_delete(subset_update(subset_update(subset_delete(subset_keep(dict_ = t1, subset = 'many_predictions'), subset='baseline_improvement'),acc,'no_BC'),acc,'additive_BC'),'baseline_improvement')
        # print(t1.keys())
        # print(len(list(t1.keys())))

        plot_files_CRPS(test_file = t1, var = var, name_of_test = f'{var} Many prediction CRPS - Bias correction')

        ############################### SINGLE PREDICTION, BIAS CORRECTION ##################################################################
        acc = grab_crps_from_dict(dict_ = crps_dictionary, var = var)
        t1 = subset_keep(dict_ = acc, subset = 'bias_corrected')
        # print(t1.keys())

        t1 = subset_delete(subset_update(subset_update(subset_delete(subset_keep(dict_ = t1, subset = 'single_prediction'), subset='baseline_improvement'),acc,'no_BC'),acc,'additive_BC'),'baseline_improvement')
        # print(t1.keys())
        # print(len(list(t1.keys())))

        plot_files_CRPS(test_file = t1, var = var, name_of_test = f'{var} Single prediction CRPS - Bias correction')



        ############################### MANY PREDICTION, NO BIAS CORRECTION --  IMPROVEMENT ##################################################################
        acc = grab_crps_from_dict(dict_ = crps_dictionary, var = var)
        t1 = subset_delete(dict_ = acc, subset = 'bias_corrected')
        # print(t1.keys())

        t1 = subset_keep(subset_update(subset_keep(subset_keep(dict_ = t1, subset = 'many_predictions'), subset='baseline_improvement'),acc,'additive_BC'),'baseline_improvement')
        # print(t1.keys())
        # print(len(list(t1.keys())))

        plot_files_CRPS(test_file = t1, var = var, name_of_test = f'{var} Many prediction CRPS Improvement - No bias correction')

        ############################### MANY PREDICTION, BIAS CORRECTION -- Improvement ##################################################################
        acc = grab_crps_from_dict(dict_ = crps_dictionary, var = var)
        t1 = subset_keep(dict_ = acc, subset = 'bias_corrected')
        print(t1.keys())

        t1 = subset_keep(subset_update(subset_keep(subset_keep(dict_ = t1, subset = 'many_predictions'), subset='baseline_improvement'),acc,'additive_BC'),'baseline_improvement')
        # print(t1.keys())
        # print(len(list(t1.keys())))

        plot_files_CRPS(test_file = t1, var = var, name_of_test = f'{var} Many prediction CRPS Improvement - Bias correction')

        ############################### SINGLE PREDICTION, BIAS CORRECTION -- Improvement ##################################################################
        acc = grab_crps_from_dict(dict_ = crps_dictionary, var = var)
        t1 = subset_keep(dict_ = acc, subset = 'bias_corrected')
        # print(t1.keys())
        t1 = subset_keep(subset_update(subset_keep(subset_keep(dict_ = t1, subset = 'single_prediction'), subset='baseline_improvement'),acc,'additive_BC'),'baseline_improvement')
        # print(t1.keys())
        # print(len(list(t1.keys())))
        plot_files_CRPS(test_file = t1, var = var, name_of_test = f'{var} Single prediction CRPS Improvement - Bias correction')

    #     ############################### MANY PREDICTION, No BIAS CORRECTION -- Improvement ##################################################################
    #     acc = grab_ACC_from_dict(dict_ = ACC_dictionary, var = var)
    #     t1 = subset_delete(dict_ = acc, subset = 'bias_corrected')
    #     # print(t1.keys())

    #     t1 = subset_keep(subset_update(subset_keep(subset_keep(dict_ = t1, subset = 'many_predictions'), subset='baseline_improvement'),acc,'additive_BC'),'baseline_improvement')
    #     # print(t1.keys())
    #     # print(len(list(t1.keys())))

    #     plot_files_ACC(test_file = t1, var = var, name_of_test = f'{var} Many prediction ACC Improvement - No Bias correction')

    #     ############################### SINGLE PREDICTION, No BIAS CORRECTION -- Improvement ##################################################################
    #     acc = grab_ACC_from_dict(dict_ = ACC_dictionary, var = 'RZSM')
    #     t1 = subset_delete(dict_ = acc, subset = 'bias_corrected')
    #     # print(t1.keys())
    #     t1 = subset_keep(subset_update(subset_keep(subset_keep(dict_ = t1, subset = 'single_prediction'), subset='baseline_improvement'),acc,'additive_BC'),'baseline_improvement')
    #     # print(t1.keys())
    #     # print(len(list(t1.keys())))

    #     plot_files_ACC(test_file = t1, var = var, name_of_test = f'{var} Single prediction ACC Improvement - No Bias correction')

    return(0)


plot_crps_tests(var = 'RZSM')


