In [1]:
%load_ext autoreload
%autoreload 2
%reload_ext autoreload

import xarray as xr
import numpy as np
import os
# import pandas as pd
from glob import glob
#import climpredNEW.climpred 
#from climpredNEW.climpred.options import OPTIONS
from climpred.options import OPTIONS
import climpred
import pickle
from mpl_toolkits.basemap import Basemap
from numpy import meshgrid
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
import matplotlib.colors as mcolors
import cartopy.feature as cfeature
import itertools
import cartopy.crs as ccrs
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter, LatitudeLocator
import matplotlib.ticker as mticker
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, TwoSlopeNorm
from scipy.stats import rankdata
import bottleneck as bn
import scipy.stats as stats

from function import preprocessUtils as putils
from function import masks
from function import verifications
from function import funs as f
from function import conf
from function import loadbias
from function import dataLoad
from function import quikplot as qp


NameError: name 't' is not defined

# Save ACC values for the later plots. 


In [None]:
region_name = 'china' #or ['australia', 'CONUS', 'china']
obs_source = 'GLEAM' #['ERA5','GLEAM']

if obs_source == 'ERA5':
    soil_dir = conf.era_data
elif obs_source == 'GLEAM':
    soil_dir = conf.gleam_data



In [None]:
global obs_original,obs_raw
obs_original,obs_raw = dataLoad.load_rzsm_observations(soil_dir, region_name)
obs_original["time"] = obs_original["time"].dt.floor("D")
obs_raw["time"] = obs_raw["time"].dt.floor("D")

obs_anom_climp = verifications.rename_obs_for_climpred(obs_original)

mask, mask_anom = masks.load_mask_vals(region_name)

In [None]:


start_obs = '2000-01-01' #Beginning of observation period for analysis. We actually have data starting from 1999 so that we could have a 7-day rolling mean applied to the data and have up to 12 weeks lags for RZSM
end_obs = '2020-12-31' #end of observations for ERA5 and GLEAM. We actually needed data through 2020-02-15 since we have an initialization on 2019-12-25
start_testing = '2018-01-01' #Beginning of testing period
end_testing = '2019-12-31'
train_end_string = '2015-12-31' #last string date for training
train_end = 2015 #last year of training dates


global RZSM_or_Tmax_or_both
RZSM_or_Tmax_or_both = 'RZSM' # for getting the predictor from either RZSM and Tmax ('both') or only RZSM ('RZSM')

lead_select = [6,13,20,27]

In [None]:
def print_min_max(file,name):
    print(name)
    print(f'Maximum value in file is {np.nanmax(file[putils.xarray_varname(file)])}')
    print(f'Minimum value in file is {np.nanmin(file[putils.xarray_varname(file)])}')

In [None]:
global verification_var
verification_var = f'soilw_bgrnd_{obs_source}' #this is for what we are verifying with DL outputs

#Gleam observations
gleam_dir = f'{soil_dir}/{region_name}'

#Forecast predictions
gefsv12_fcst_dir = f'{conf.gefsv12_data}/{region_name}'

#ERA5 observations
era5_dir = f'{conf.era_data}/{region_name}'


In [None]:
#Load observation anomaly
gleam_anom = verifications.load_RZSM_anomaly_obs(region_name, soil_dir).load()
ecmwf_anom = verifications.load_ECMWF_baseline_anomaly(region_name).load()
gefs_anom = verifications.load_GEFSv12_baseline_anomaly(region_name).load()

'''Load a base file to serve as the xarray template to add our predictions from UNET into.'''
base_file_testing = gefs_anom.copy(deep=True).sel(S=slice(start_testing, None)).load()

global lat, lon
lat = base_file_testing.Y.values  # for plotting later
lon = base_file_testing.X.values


In [None]:
# #Load the percentile files from 
# ecmwf_perc = verifications.load_ECMWF_percentile_anomaly(region_name).load()
# gefs_perc = verifications.load_GEFSv12_percentile_anomaly(region_name).load()

In [None]:
#bias corrected data
gef_bc_crpss, ecm_bc_crpss = loadbias.load_additive_bias_corrected_data_CRPSS(lead_select,region_name,obs_source)


# Now create baseline anomalies of files to not have to re-compute later

In [None]:
#First
#Create seasonal anomaly from observations
print(f'Creating the seasonal anomalies for observational data and then subsetting for everything after {start_testing} date.')
save_anomaly_dir = f'Data/anomaly/{region_name}'
os.system(f'mkdir -p {save_anomaly_dir}')

obs_RZSM_save = f'{save_anomaly_dir}/obs_RZSM_anomaly_testing_{obs_source}.nc'
ref_RZSM_save = f'{save_anomaly_dir}/reforecast_RZSM_anomaly_testing_{obs_source}.nc'


In [None]:
def open_file_create_seasonal_anomaly(path,train_end):
    #Must subset by lead first (because we actually have data previously from past lag weeks)
    return(create_seasonal_anomaly(xr.open_mfdataset(path).sel(L=slice(0,34)).rolling(L=7, min_periods=7,center=False).mean(),train_end=train_end,source='reforecast'))

def check_values_in_file(file,lead):
    '''Just print some values to verify the files aren't identical when comparing with other results'''
    name_file = list(file.keys())[0]
    return(print(file[name_file].isel(L=lead).isel(M=10).isel(S=0).values))
    

In [None]:

def load_experiment_predictions_and_observations(lead,experiment,region_name,obs_source):
    # #Test
    # experiment='EX0'
    day_num = (lead*7)-1
    min_max_dir = f'Data/min_max_values/{region_name}'
    verification_directory = f'Data/model_npy_input_data/{region_name}/Verification_data' #For observation verification
    # bias_correction_dir = f'Data/bias_mean_values/Wk_{lead}'

    ex_name = experiment


    #Load the actual observations (used for the Mean Absolute Error calculation)
    obs_final_train,obs_final_validation,obs_final_testing = f.load_verification_observations_updated(lead,verification_directory,obs_source)
    obs_RZSM = np.array(obs_final_testing) #anomaly

    #Convert observations 0 values to nan (only for the RZSM observations). These values had a zero where there is no land soil moisture
    obs_RZSM = np.where(obs_RZSM == 0,np.nan,obs_RZSM)
    
    obs_RZSM =verifications.reverse_min_max_scaling(obs_RZSM,region_name,day_num,obs_source,2019)

    predictions_directory = f'predictions/{region_name}/Wk{lead}_testing'

    cont=False
    ecmwf_present=False
    if obs_source == 'GLEAM':
        prediction_GEFS = np.load(f'{predictions_directory}/Wk{lead}_testing_{ex_name}_regular_RZSM.npy')
        try:
            prediction_ECMWF = np.load(f'{predictions_directory}/Wk{lead}_testing_{ex_name}_ECMWF_regular_RZSM.npy')
            ecmwf_present=True
        except FileNotFoundError:
            prediction_ECMWF = np.empty_like(prediction_GEFS)
            
        cont = True
    elif obs_source == 'ERA5':
        prediction_GEFS = np.load(f'{predictions_directory}/Wk{lead}_testing_{ex_name}_regular_ERA5_RZSM.npy')
        try:
            prediction_ECMWF = np.load(f'{predictions_directory}/Wk{lead}_testing_{ex_name}_ECMWF_regular_ERA5_RZSM.npy')
            ecmwf_present=True
        except:
            prediction_ECMWF = np.empty_like(prediction_GEFS)
            
        cont = True


    if cont:
        print(f'Test prediction shape: {prediction_GEFS.shape}')
        prediction_RZSM_GEFS = verifications.reverse_min_max_scaling(prediction_GEFS[-1,:,:,:],region_name, day_num, 'GEFSv12',2019)
        if ecmwf_present:
            prediction_RZSM_ECMWF = verifications.reverse_min_max_scaling(prediction_ECMWF[-1,:,:,:],region_name, day_num, 'ECMWF',2019)
        else:
            prediction_RZSM_ECMWF = np.zeros_like(prediction_RZSM_GEFS)
            prediction_RZSM_ECMWF[:,:,:,:] = np.nan
        print(f'Shape of prediction RZSM: {prediction_RZSM_GEFS.shape}')
    
        #Convert back to np.nan values for the ocean and other water bodies
        prediction_RZSM_GEFS = np.where(np.isnan(obs_RZSM),np.nan,prediction_RZSM_GEFS.squeeze())
        if ecmwf_present:
            prediction_RZSM_ECMWF = np.where(np.isnan(obs_RZSM),np.nan,prediction_RZSM_ECMWF.squeeze())
    
        return(prediction_RZSM_GEFS, prediction_RZSM_ECMWF, obs_RZSM)
    else:
        return(np.zeros(shape=obs_RZSM.shape),np.zeros(shape=obs_RZSM.shape), obs_RZSM)


In [None]:
def setup_binary_for_hit_rate_with_ensemble_mean(week_lead, region_name, test_start, test_end, unet_file):

    #Test 
    # week_lead=1
    # percentile_eval = 20

    #Save dir
    save_dir = f'Data/correct_anomaly_percentile_statistics/{region_name}'
    os.system(f'mkdir -p {save_dir}')

    save_ecmwf = f'{save_dir}/Wk{week_lead}_ecmwf_stats_TP_FP_ensemble_mean.npy'
    save_gefs = f'{save_dir}/Wk{week_lead}_gefs_stats_TP_FP_ensemble_mean.npy'
    save_xg = f'{save_dir}/Wk{week_lead}_xgboost_stats_TP_FP_ensemble_mean.npy'
    save_obs_binary = f'{save_dir}/Wk{week_lead}_obs_stats_TP_FP_ensemble_mean.npy'

    day_num = (week_lead*7) -1
        
    print('Loading observation and baseline anomaly files')
    obs, gefs, ecmwf, var_OUT_overwrite, template_testing_only_by_lead= select_data_by_lead(obs_anomaly_SubX_format, baseline_anomaly, baseline_ecmwf, var_OUT, template_testing_only, day_num)
    
    obs_percent = obs_anom_percentile.sel(L=day_num).sel(M=0)
    obs_percent['95th_percentile'].shape #(104, 48, 96)


    file = baseline_anomaly
    file.RZSM.shape
    out_check_gefs_base = np.zeros(shape=(104,11,48,96,8)) #Adding 8 channels for the different anomaly spreads
    
    out_check_gefs_base[:,:,:,:] = np.nan
    out_check_ecmwf_base = out_check_gefs_base.copy()
    out_check_unet = out_check_gefs_base.copy()
    
    
    obs_binary_out =np.zeros(shape=(104,48,96,8)) #Adding 8 channels for the different anomaly spreads

    final_perc_gefs = np.zeros(shape=(104,48,96,8)) #Adding 8 channels for the different anomaly spreads
    final_perc_gefs[:,:,:] = np.nan
    final_perc_ecmwf = final_perc_gefs.copy()
    final_perc_unet = final_perc_gefs.copy()
    
    

    test_name = unet_file.split('testing_')[-1].split('.npy')[0].split('ensemble_')[-1]
    save_unet = f'{save_dir}/Wk{week_lead}_unet_stats_{test_name}_TP_FP_ensemble_mean.npy'
    
    test =  verifications.reverse_min_max_scaling(np.load(unet_file), region_name, day_num)[2,:,:,:,0] #We only want the last channel
    test = np.reshape(test,(test.shape[0]//11,11,test.shape[1],test.shape[2]))
    test.shape

    #Now mask the input
    test_unet = np.where(mask_anom == 1,test,np.nan)

    #XGBoost
    #Load the XGBoost data
    if region_name == 'CONUS':
        out_check_xg = out_check_gefs_base.copy()
        final_perc_xg = final_perc_gefs.copy()
        
        xgboost_files = sorted(glob(f'predictions_XGBOOST/{region_name}/Wk{week_lead}_testing/*EX28*'))[0]
    
        # break
        #Still working here
        test_name = xgboost_files.split('testing_')[-1].split('.npy')[0]
        load_ = np.expand_dims(np.load(xgboost_files),-1)
        load_.shape
        load_ = np.where(load_ == 0,np.nan,load_)
        load_ =  verifications.reverse_min_max_scaling(load_, region_name, day_num)#We only want the last channel
        
        xg = np.empty(shape=(load_.shape[0],11,load_.shape[1],load_.shape[2],load_.shape[3])) #This will help with climpred functions
        for j in range(11):
            xg[:,j,:,:,:] = load_
    
        xg = xg.squeeze()
        xg.shape
        x_vals = np.nanmean(xg,axis=1)
    else:
        x_vals = out_check_gefs_base.copy()
    #Check if the predicted anomaly is below each threshold


    
    #Test
    # idx = 0
    # mx = 0
    # ix = 10
    # iy =10 #NEGATIVE ANOMALY VALUE
    # iy =5 #POSITIVE ANOMALY VALUE

    #Use np.where to find the values of the percentile

    #Take ensemble mean
    o_vals = np.nanmean(obs.RZSM[:,:,0,:,:].values,axis=1)
    g_vals =  np.nanmean(gefs.RZSM[:,:,0,:,:].values,axis=1)
    e_vals =  np.nanmean(ecmwf.RZSM[:,:,0,:,:].values,axis=1)
    u_vals = np.nanmean(test_unet,axis=1)
    

    #Keep all models the same
    # o_vals = obs.RZSM[:,:,0,:,:].values
    # g_vals =  gefs.RZSM[:,:,0,:,:].values
    # e_vals =  ecmwf.RZSM[:,:,0,:,:].values
    # u_vals = test_unet
    # x_vals = xg

    u_vals.shape
    x_vals.shape

    def check_if_below_percentile(obs_percent, percentile_num, o_vals, g_vals, e_vals, u_vals, x_vals):
        perc_= obs_percent[f'{percentile_num}th_percentile'].values
        
        def find_percentage(perc_,o_vals,fcst):
            fcst =  np.where(fcst<perc_,1,0)
            obs_binary = np.where((o_vals<perc_),1,0)
            '''Take the mean to find the probability of correct'''
            fcst.shape #(104, 48, 96)
            #Now mask the input of CONUS/region
            fcst = np.where(mask_anom == 1,fcst,np.nan)
            obs_binary = np.where(mask_anom == 1,obs_binary,np.nan)

            return(fcst,obs_binary)

        g_perc,obs_binary = find_percentage(perc_,o_vals,g_vals)
        g_perc.shape
        e_perc,obs_binary = find_percentage(perc_,o_vals,e_vals)
        u_perc,obs_binary  = find_percentage(perc_,o_vals,u_vals)
        x_perc,obs_binary  = find_percentage(perc_,o_vals,x_vals)
        return(g_perc,e_perc,u_perc,x_perc,obs_binary )

    for idx,percentile_num in enumerate([5,10,20,33]):
        final_perc_gefs[:,:,:,idx], final_perc_ecmwf[:,:,:,idx], final_perc_unet[:,:,:,idx],final_perc_xg[:,:,:,idx],obs_binary_out[:,:,:,idx] = check_if_below_percentile(obs_percent, percentile_num, o_vals, g_vals, e_vals, u_vals, x_vals)



    
    def check_if_above_percentile(obs_percent, percentile_num, o_vals, g_vals, e_vals, u_vals, x_vals):
        perc_= obs_percent[f'{percentile_num}th_percentile'].values
       
        def find_percentage(perc_,o_vals,fcst):
            fcst =  np.where(fcst>perc_,1,0)
            obs_binary = np.where((o_vals>perc_),1,0)
            '''Take the mean to find the probability of correct'''
            fcst.shape #(104, 48, 96)
            #Now mask the input of CONUS/region
            fcst = np.where(mask_anom == 1,fcst,np.nan)
            obs_binary = np.where(mask_anom == 1,obs_binary,np.nan)

            return(fcst,obs_binary)

        g_perc,obs_binary  = find_percentage(perc_,o_vals,g_vals)
        g_perc.shape
        e_perc,obs_binary  = find_percentage(perc_,o_vals,e_vals)
        u_perc,obs_binary  = find_percentage(perc_,o_vals,u_vals)
        x_perc,obs_binary  = find_percentage(perc_,o_vals,x_vals)
        
        return(g_perc,e_perc,u_perc,x_perc,obs_binary )
    
    for idx,percentile_num in enumerate([66,80,90,95]):
        '''We are adding 4 to make sure that we get the indices correct, we already added data from below percentiles'''
        final_perc_gefs[:,:,:,idx+4], final_perc_ecmwf[:,:,:,idx+4], final_perc_unet[:,:,:,idx+4], final_perc_xg[:,:,:,idx+4],obs_binary_out[:,:,:,idx+4] = check_if_above_percentile(obs_percent, percentile_num, o_vals, g_vals, e_vals, u_vals, x_vals)

    
    #Save files for later use
    np.save(save_ecmwf,final_perc_ecmwf)
    np.save(save_gefs, final_perc_gefs)
    np.save(save_unet, final_perc_unet)
    np.save(save_xg, final_perc_xg)
    np.save(save_obs_binary, obs_binary_out)


In [None]:
def return_non_post_processed_forecasts(lead,dim_order):
    '''We are selecting a single lead time, so use this code'''
    if lead == 0:
        index_sel = 0
    else:
        index_sel = (lead*7)-1
    

    if region_name == 'CONUS':
        RZSM_base_reforecast_climpred_GEF = f.restrict_to_CONUS_bounding_box(gefs_anom,mask).sel(L=(lead*7)-1).expand_dims({'L': 1}).transpose(*dim_order)
        RZSM_base_reforecast_climpred_ECM = f.restrict_to_CONUS_bounding_box(ecmwf_anom,mask).sel(L=(lead*7)-1).expand_dims({'L': 1}).transpose(*dim_order)
    else:
        RZSM_base_reforecast_climpred_GEF = gefs_anom.sel(L=(lead*7)-1).expand_dims({'L': 1}).transpose(*dim_order)
        RZSM_base_reforecast_climpred_ECM = ecmwf_anom.sel(L=(lead*7)-1).expand_dims({'L': 1}).transpose(*dim_order)
        
    print_min_max(RZSM_base_reforecast_climpred_GEF,'GEFS RZSM baseline value from reforecast (training, validation, testing) (no pre-processing other than anomaly computed.)')
    print_min_max(RZSM_base_reforecast_climpred_ECM,'ECMWF RZSM baseline value from reforecast (training, validation, testing) (no pre-processing other than anomaly computed.)')

    if RZSM_or_Tmax_or_both == 'both':
        return(RZSM_base_reforecast_climpred_GEF, RZSM_base_reforecast_climpred_ECM,tmax_base_reforecast_climpred)
    else:
        return(RZSM_base_reforecast_climpred_GEF, RZSM_base_reforecast_climpred_ECM)

    '''Then load the actual UNET predictions and keep the observations that are already in a good format for comparison
    This data has been loaded, converted back to anomalies, and masked for RZSM for non-land regions'''
    
    

In [None]:
def convert_prediction_to_SubX_format(file,lead,dim_order):

    if region_name == 'CONUS':
        cp_base = f.restrict_to_CONUS_bounding_box(base_file_testing.copy(deep=True).sel(L=(lead*7)-1),mask).expand_dims({'L': 1})
    else:
        cp_base = base_file_testing.copy(deep=True).sel(L=(lead*7)-1).expand_dims({'L': 1})

    
    #reshape data back to original format (testing shape)
    
    cp_base = cp_base.transpose(*dim_order)
    
    #Reshape prediction file
    file = file.reshape((104,11,1,48,96))
    
    var_OUT = xr.Dataset(
            data_vars = dict(
                file_name = (['S','M','L','Y','X'],  file[:,:,:,:,:]),
            ),
            coords = dict(
                S = cp_base.S.values,
                X = cp_base.X.values,
                Y = cp_base.Y.values,
                L = cp_base.L.values,
                M = cp_base.M.values,

            ),
            attrs = dict(
                Description = 'New data added to file'),
        )  

    return(var_OUT)


In [None]:
def crpss_(lead,experiment,ACC_dictionary,obs_anom_climp,gef_bc_acc,ecm_bc_acc,obs_source):
    output_dictionary = {}
    out_name = ''
    
    #Must re-add back L as a date
    dim_order = ['S','M','L','Y','X']

    RZSM_base_reforecast_climpred_GEF, RZSM_base_reforecast_climpred_ECM = return_non_post_processed_forecasts(lead,dim_order) #Returns the original reforecasts
    
    prediction_RZSM_GEFS, prediction_RZSM_ECMWF, obs_RZSM = load_experiment_predictions_and_observations(lead,experiment,region_name,obs_source) #Returns the UNET prediction and observations


    #Change name for climpred processing
    #Reforecast prediction
    prediction_RZSM_climpred_GEF = verifications.rename_subx_for_climpred(convert_prediction_to_SubX_format(file=prediction_RZSM_GEFS,lead=lead,dim_order = dim_order))
    prediction_RZSM_climpred_ECM = verifications.rename_subx_for_climpred(convert_prediction_to_SubX_format(file=prediction_RZSM_ECMWF,lead=lead,dim_order = dim_order))

    prediction_RZSM_climpred_GEF  = prediction_RZSM_climpred_GEF.rename({'file_name':'RZSM'})
    prediction_RZSM_climpred_ECM  = prediction_RZSM_climpred_ECM.rename({'file_name':'RZSM'})
    
    print_min_max(prediction_RZSM_climpred_ECM,'ECMWF RZSM anomaly prediction value from UNET')
    print_min_max(prediction_RZSM_climpred_GEF,'GEFSv12 RZSM anomaly prediction value from UNET')
    
    unet_acc_gef = verifications.create_climpred_CRPSS_no_chunk(prediction_RZSM_climpred_GEF, obs_anom_climp)
    unet_acc_ecm = verifications.create_climpred_CRPSS_no_chunk(prediction_RZSM_climpred_ECM, obs_anom_climp)
    ACC_dictionary[f'Wk{lead}_{experiment}_MEM_RZSM_CRPSS'] = np.nanmedian(unet_acc_gef[putils.xarray_varname(unet_acc_gef)].values)
    ACC_dictionary[f'Wk{lead}_{experiment}_ECMWF_MEM_RZSM_CRPSS'] = np.nanmedian(unet_acc_ecm[putils.xarray_varname(unet_acc_ecm)].values)

    #Base reforecast (before post-processing)
    base_RZSM_climpred_GEFS  = verifications.rename_subx_for_climpred(RZSM_base_reforecast_climpred_GEF).sel(init=slice(start_testing, None))
    print_min_max(base_RZSM_climpred_GEFS,'GEFS RZSM baseline anomaly from reforecast. No post-processing. \n')

    base_RZSM_climpred_ECMWF  = verifications.rename_subx_for_climpred(RZSM_base_reforecast_climpred_ECM).sel(init=slice(start_testing, None))
    print_min_max(base_RZSM_climpred_ECMWF,'ECMWF RZSM baseline anomaly from reforecast. No post-processing. \n')
    
    gefs_acc = verifications.create_climpred_CRPSS_no_chunk(base_RZSM_climpred_GEFS, obs_anom_climp)
    ecm_acc = verifications.create_climpred_CRPSS_no_chunk(base_RZSM_climpred_ECMWF, obs_anom_climp)
    

    ACC_dictionary[f'Wk{lead}_GEFS_MEM_baseline_RZSM_CRPSS'] = np.nanmedian(gefs_acc[putils.xarray_varname(gefs_acc)].values)
    ACC_dictionary[f'Wk{lead}_ECMWF_MEM_baseline_RZSM_CRPSS'] = np.nanmedian(ecm_acc[putils.xarray_varname(ecm_acc)].values)
    ACC_dictionary[f'Wk{lead}_GEFS_MEM_BC_baseline_RZSM_CRPSS'] = np.nanmedian(gef_bc_acc[putils.xarray_varname(gef_bc_acc)].values[lead-1,:,:])
    ACC_dictionary[f'Wk{lead}_ECMWF_MEM_BC_baseline_RZSM_CRPSS'] = np.nanmedian(ecm_bc_acc[putils.xarray_varname(ecm_bc_acc)].values[lead-1,:,:])
    

    return(ACC_dictionary)
    


In [None]:
def crpss_by_season(lead, experiment, ACC_dictionary, obs_anom_climp, gef_bc_acc, ecm_bc_acc, obs_source):
    output_dictionary = {}
    out_name = ''
    
    # Must re-add back L as a date
    dim_order = ['S','M','L','Y','X']

    # Get the original reforecasts
    RZSM_base_reforecast_climpred_GEF, RZSM_base_reforecast_climpred_ECM = return_non_post_processed_forecasts(lead, dim_order)
    
    # Get the UNET predictions and observations
    prediction_RZSM_GEFS, prediction_RZSM_ECMWF, obs_RZSM = load_experiment_predictions_and_observations(lead, experiment, region_name, obs_source)

    # Convert predictions to climpred format
    prediction_RZSM_climpred_GEF = verifications.rename_subx_for_climpred(
        convert_prediction_to_SubX_format(file=prediction_RZSM_GEFS, lead=lead, dim_order=dim_order)
    ).rename({'file_name':'RZSM'})

    prediction_RZSM_climpred_ECM = verifications.rename_subx_for_climpred(
        convert_prediction_to_SubX_format(file=prediction_RZSM_ECMWF, lead=lead, dim_order=dim_order)
    ).rename({'file_name':'RZSM'})
    
    # Print stats
    print_min_max(prediction_RZSM_climpred_ECM, 'ECMWF RZSM anomaly prediction value from UNET')
    print_min_max(prediction_RZSM_climpred_GEF, 'GEFSv12 RZSM anomaly prediction value from UNET')

    def add_single_month(months, season):
        vals = months[season]
        #This is to ensure that we have all of the data correctly forecasted within the distribution
        next_number = vals[-1] + 1  # Get the last number and add 1
        vals.append(next_number)
        return vals
        
    # Function to filter data by season
    def filter_by_season(data, season, obs_fcst):
        months = {
            'DJF': [12, 1, 2],
            'MAM': [3, 4, 5],
            'JJA': [6, 7, 8],
            'SON': [9, 10, 11]
        }
        if obs_fcst=='forecast':
            return data.sel(init=data['init'].dt.month.isin(months[season]))
        else:
            return data.sel(time=data['time'].dt.month.isin(months[season]))

    

    seasons = ['DJF', 'MAM', 'JJA', 'SON']

    for season in seasons:
        # Filter by season
        pred_season_GEF = filter_by_season(prediction_RZSM_climpred_GEF, season, 'forecast')
        pred_season_ECM = filter_by_season(prediction_RZSM_climpred_ECM, season, 'forecast')
        obs_season = filter_by_season(obs_anom_climp, season, 'obs')
        
        # Compute ACC for each season
        unet_acc_gef = verifications.create_climpred_CRPSS_no_chunk(pred_season_GEF, obs_season)
        unet_acc_ecm = verifications.create_climpred_CRPSS_no_chunk(pred_season_ECM, obs_season)
        
        ACC_dictionary[f'{season}_Wk{lead}_{experiment}_MEM_RZSM_CRPSS'] = np.nanmedian(unet_acc_gef[putils.xarray_varname(unet_acc_gef)].values)
        ACC_dictionary[f'{season}_Wk{lead}_{experiment}_ECMWF_MEM_RZSM_CRPSS'] = np.nanmedian(unet_acc_ecm[putils.xarray_varname(unet_acc_ecm)].values)
        
        # Baseline reforecasts
        base_RZSM_climpred_GEFS = verifications.rename_subx_for_climpred(RZSM_base_reforecast_climpred_GEF).sel(init=slice(start_testing, None))
        base_RZSM_climpred_ECMWF = verifications.rename_subx_for_climpred(RZSM_base_reforecast_climpred_ECM).sel(init=slice(start_testing, None))
        
        # Filter baseline forecasts by season
        base_season_GEF = filter_by_season(base_RZSM_climpred_GEFS, season, 'forecast')
        base_season_ECM = filter_by_season(base_RZSM_climpred_ECMWF, season, 'forecast')
        
        gefs_acc = verifications.create_climpred_CRPSS_no_chunk(base_season_GEF, obs_season)
        ecm_acc = verifications.create_climpred_CRPSS_no_chunk(base_season_ECM, obs_season)
        
        # Store baseline ACC values for each season
        ACC_dictionary[f'{season}_Wk{lead}_GEFS_MEM_baseline_RZSM_CRPSS'] = np.nanmedian(gefs_acc[putils.xarray_varname(gefs_acc)].values)
        ACC_dictionary[f'{season}_Wk{lead}_ECMWF_MEM_baseline_RZSM_CRPSS'] = np.nanmedian(ecm_acc[putils.xarray_varname(ecm_acc)].values)
        
        # Bias-corrected ACC values for each season
        ACC_dictionary[f'{season}_Wk{lead}_GEFS_MEM_BC_baseline_RZSM_CRPSS'] = np.nanmedian(gef_bc_acc[putils.xarray_varname(gef_bc_acc)].values[lead-1,:,:])
        ACC_dictionary[f'{season}_Wk{lead}_ECMWF_MEM_BC_baseline_RZSM_CRPSS'] = np.nanmedian(ecm_bc_acc[putils.xarray_varname(ecm_bc_acc)].values[lead-1,:,:])

    return ACC_dictionary


In [None]:
def run_CRPSS(lead,region_name,obs_source):
    print(f'Working on lead {lead}')
    # lead=1

    # save_dict_dir = f'Outputs/crps_mae/{region_name}/Wk_{lead}'
    # os.system(f'mkdir -p {save_dict_dir}')
    
    if lead <=4:
        if region_name == 'CONUS':
            if obs_source == 'GLEAM':
                
                experiment_list = [f'EX{i}' for i in range(0,30)]
                experiment_list.remove('EX26')
                if lead <=2:
                    experiment_list.remove('EX18')
                    experiment_list.remove('EX19')
                    experiment_list.remove('EX20')
                    experiment_list.remove('EX21')
            elif obs_source == 'ERA5':
                experiment_list = ['EX29']
        else:
            experiment_list = ['EX29']
    elif lead ==5:
        experiment_list = ['EX26']

    
    ACC_dictionary = {}
    ACC_season = {}
    for experiment in experiment_list:
        ACC_dictionary.update(crpss_(lead=lead,experiment=experiment,
                                                              ACC_dictionary=ACC_dictionary, 
                                                              obs_anom_climp=obs_anom_climp,
                                                              gef_bc_acc=gef_bc_crpss, ecm_bc_acc=ecm_bc_crpss,
                                                     obs_source=obs_source))
        ACC_season.update(crpss_by_season(lead=lead,experiment=experiment,
                                                      ACC_dictionary=ACC_season, 
                                                      obs_anom_climp=obs_anom_climp,
                                                      gef_bc_acc=gef_bc_crpss, ecm_bc_acc=ecm_bc_crpss,
                                                     obs_source=obs_source))

        
    return(ACC_dictionary,ACC_season)


In [None]:
# print(ACC_dictionary.keys())

def save_CRPSS_tests(var, ACC_dictionary, region_name,obs_source,season):

    t1 = grab_ACC_from_dict(dict_ = ACC_dictionary, var = var)
    # print(acc.keys()) 

    ############################### SINGLE PREDICTION, NO BIAS CORRECTION ##################################################################
    # acc = grab_ACC_from_dict(dict_ = ACC_dictionary, var = var)
    # t1 = subset_delete(dict_ = acc, subset = 'bias_corrected')
    # print(t1.keys())

    #Save the average ACC values to a dictionary for later plotting
    # t_base = subset_keep(dict_ = t1, subset = 'baseline')
    # t_unet= subset_delete(dict_ = t1, subset = 'baseline')
    # print(t_base.keys())
    # print(t_unet.keys())
    
    file_path = f'Outputs/permutation_tests/{region_name}/Wk_{lead}'
    if season:
        file_save = f'{file_path}/CRPSS_vals_{obs_source}.pkl'
    else:
        file_save = f'{file_path}/CRPSS_vals_{obs_source}_season.pkl'
    
    os.system(f'mkdir -p {file_path}')
    
    with open(file_save, 'wb') as file:
        pickle.dump(t1, file)

    # plot_files_ACC(test_file = t1, var = var, name_of_test = f'{var} Single prediction ACC - No bias correction')

 
    return(0)



In [None]:
def grab_ACC_from_dict(dict_,var):
    crpss = {key: value for key, value in dict_.items() if f'{var}_CRPSS' in key}
    return(crpss)

In [None]:
for lead in [1,2,3,4]:
    CRPSS_dictionary,CRPSS_season = run_CRPSS(lead,region_name,obs_source)
    print(CRPSS_dictionary)
    '''As a note, week 1 doesn't ever have any experimental runs for EX18-EX21, also EX26 is not a model that I ran'''
    save_CRPSS_tests(var = 'RZSM', ACC_dictionary=CRPSS_dictionary, region_name = region_name,obs_source=obs_source, season=False)
    save_CRPSS_tests(var = 'RZSM', ACC_dictionary=CRPSS_season, region_name = region_name,obs_source=obs_source, season=True)

stop

In [None]:
ACC_dictionary

In [None]:
#Testing
# many_testing_predictions = True
# experiment = 'EX0'
# lead=1
# bias_correction = True

In [None]:
def select_lead(file,dim_order,lead):
    return(file.sel(L=(lead*7)-1).expand_dims({'L': 1}).transpose(*dim_order))



# Plot data
## Have not completed this yet with new ERA5 data

In [None]:
def get_min_max_of_files(file):
    min_ =  []
    max_ = []
    try:
        for f in list(file.keys()):
            min_.append(np.nanmin(file[f]))
            max_.append(np.nanmax(file[f]))
    except TypeError:
         for f in list(file.keys()):
            min_.append(np.nanmin(file[f].crps))
            max_.append(np.nanmax(file[f].crps))
            
    return(min(min_),max(max_))


In [None]:


   
# cmap = 'coolwarm'

def plot_files_ACC(test_file, var, name_of_test):
    cmap = plt.get_cmap('bwr')    
    
    save_dir = f'Outputs/crps_mae/Wk_{lead}/{var}_ACC'
    os.system(f'mkdir -p {save_dir}')
    
    if lead == 0:
        row=3
        column=5
        width = 20
        height=10
        ex_size = 12
    else:
        row=6
        column=5
        width = 30
        height=25
        ex_size = 16
        
    fig, axs = plt.subplots(
        row, column, subplot_kw={'projection': ccrs.PlateCarree()}, figsize=(width, height))
    axs = axs.flatten()
    fig.tight_layout()
    
    min_,max_ = get_min_max_of_files(test_file)
    # test_file = mae_rzsm_keys
    # for Subx original data
    
    #For the ACC values
    if lead == 0:
        text_x = -84  # You may need to adjust this value based on your data
        text_y = 27  # You may need to adjust this value based on your data
        font_size = 12
    else:
        text_x = -84  # You may need to adjust this value based on your data
        text_y = 27  # You may need to adjust this value based on your data
        font_size = 16
    
    for idx,experiment in enumerate(experiment_list):
        data = {key: value for key, value in test_file.items() if experiment in key}
        data = list(data.values())[0]
        
        
        mean_vals = round(np.nanmean(data),4)
        print(f'Mean value: {mean_vals}')
        
        v = np.linspace(min_, max_, 30, endpoint=True)

        map = Basemap(projection='cyl', llcrnrlat=25, urcrnrlat=50,
                      llcrnrlon=-128, urcrnrlon=-60, resolution='l')
        x, y = map(*np.meshgrid(lon, lat))
        # Adjust the text coordinates based on the actual data coordinates


        
        # ax.drawmeridians()
        try:
            norm = TwoSlopeNorm(vmin=min_, vcenter=0, vmax=max_)
        except ValueError:
            norm = None
        
        try:
            im = axs[idx].contourf(x, y, data, levels=v, extend='both',
                                  transform=ccrs.PlateCarree(), cmap=cmap,norm=norm)
        except TypeError:
            data = np.nanmean(data.crps.values,axis=0)
            im = axs[idx].contourf(x, y, data, levels=v, extend='both',
                      transform=ccrs.PlateCarree(), cmap=cmap,norm=norm)
            
    
        # axs[idx].title.set_text(f'SubX Lead {lead*7}')
        gl = axs[idx].gridlines(crs=ccrs.PlateCarree(), draw_labels=True,
                                   linewidth=0.7, color='gray', alpha=0.5, linestyle='--')
        gl.xlabels_top = False
        gl.ylabels_right = False
        if lead != 1:
            gl.ylabels_left = False
        gl.xformatter = LongitudeFormatter()
        gl.yformatter = LatitudeFormatter()
        axs[idx].coastlines()
        # plt.colorbar(im)
        # axs[idx].set_aspect('auto', adjustable=None)
        axs[idx].set_aspect('equal')  # this makes the plots better
        axs[idx].set_title(experiment,fontsize=ex_size)
        axs[idx].text(text_x, text_y, mean_vals, ha='right', va='bottom', fontsize=font_size, color='blue', weight = 'bold')

        # Add a colorbar axis at the bottom of the graph
        # left, bottom, width, height
    
    data_non_EX = {key: value for key, value in test_file.items() if 'EX' not in key}
    #Don't worry about additive bias right now. I can't figure out why it doesn't work
    data_non_EX = {key: value for key, value in data_non_EX.items() if 'no_BC' in key}
    
    for non_used,data_key in enumerate(data_non_EX):
        idx+=1
        # break
        data = {key: value for key, value in test_file.items() if data_key in key}
        data = list(data.values())[0]
        mean_vals = round(np.nanmean(data),4)
        
        v = np.linspace(min_, max_, 30, endpoint=True)

        map = Basemap(projection='cyl', llcrnrlat=25, urcrnrlat=50,
                      llcrnrlon=-128, urcrnrlon=-60, resolution='l')
        x, y = map(*np.meshgrid(lon, lat))


        # ax.drawmeridians()
        try:
            norm = TwoSlopeNorm(vmin=min_, vcenter=0, vmax=max_)
        except ValueError:
            norm = None
        
        try:
            im = axs[idx].contourf(x, y, data, levels=v, extend='both',
                                  transform=ccrs.PlateCarree(), cmap=cmap,norm=norm)
        except TypeError:
            data = np.nanmean(data.crps.values,axis=0)
            im = axs[idx].contourf(x, y, data, levels=v, extend='both',
                      transform=ccrs.PlateCarree(), cmap=cmap,norm=norm)
        # axs[idx].title.set_text(f'SubX Lead {lead*7}')
        gl = axs[idx].gridlines(crs=ccrs.PlateCarree(), draw_labels=True,
                                   linewidth=0.7, color='gray', alpha=0.5, linestyle='--')
        gl.xlabels_top = False
        gl.ylabels_right = False
        if lead != 1:
            gl.ylabels_left = False
        gl.xformatter = LongitudeFormatter()
        gl.yformatter = LatitudeFormatter()
        axs[idx].coastlines()
        # plt.colorbar(im)
        # axs[idx].set_aspect('auto', adjustable=None)
        axs[idx].set_aspect('equal')  # this makes the plots better
        axs[idx].set_title(data_key)
        axs[idx].text(text_x, text_y, mean_vals, ha='right', va='bottom', fontsize=font_size, color='blue', weight = 'bold')
        
        # Add a colorbar axis at the bottom of the graph
        # left, bottom, width, height
    cbar_ax = fig.add_axes([0.05, -0.05, .9, .04])


    
    # Draw the colorbar
    cbar = fig.colorbar(im, cax=cbar_ax, orientation='horizontal')
    plt.tight_layout()
    fig.suptitle(name_of_test, fontsize=30)
    plt.savefig(f'{save_dir}/{name_of_test}.png',bbox_inches='tight')
    plt.show()
    

In [None]:

def plot_rank_histogram(baseline_file,prediction_file, name_of_file):
    # prediction_file = rzsm_Rank_histogram_predictions
    # baseline_file = rzsm_Rank_histogram_baseline
    
    # baseline_file = baseline_file.values()
    baseline_file = list(baseline_file.values())[0]
    
    fig, axs = plt.subplots(3,5, figsize=(20, 7))

#     fig, axs = plt.subplots(
#         3, 5, subplot_kw={'projection': ccrs.PlateCarree()}, figsize=(20, 7), gridspec_kw={'height_ratios': [2,2,2]})
    axs = axs.flatten()
    

    # rank_file=rank_histogram_unet
    to_df = baseline_file.rank_histogram[:].to_dataframe()
    to_df['rank_histogram'] = to_df['rank_histogram'] / \
        to_df['rank_histogram'].sum()
    to_df['rank'] = to_df.index
    to_df['rank'] = to_df['rank'].astype(int)
    to_df.index = to_df['rank']
    del to_df['lead']
    del to_df['skill']
    del to_df['rank']
    
    print(f'Shape of to_df : {to_df.rank().shape[0]}')
    # axs[ax].plot(to_df)
    axs[0].bar(np.arange(1,to_df.rank().shape[0]+1),to_df.rank_histogram)
    axs[0].set_xlim(1, 12)

    # Optionally, adjust tick marks
    axs[0].set_xticks(np.arange(1, 13))
    # to_df.bar(ax=axs[0], kind='bar', width=1.4)
    axs[0].set_title(f'Baseline Rank Histogram')
    axs[0].set_xticklabels(axs[0].get_xticklabels(), rotation=0)
    axs[0].set_ylabel('Relative Frequency', rotation=90)

    
    for ax,experiment in enumerate(experiment_list):
        ax+=1
        data = {key: value for key, value in prediction_file.items() if experiment in key}
        data = list(data.values())[0]

        # rank_file=rank_histogram_unet
        to_df = data.rank_histogram[:].to_dataframe()
        to_df['rank_histogram'] = to_df['rank_histogram'] / \
            to_df['rank_histogram'].sum()
        to_df['rank'] = to_df.index
        to_df['rank'] = to_df['rank'].astype(int)
        to_df.index = to_df['rank']
        del to_df['lead']
        del to_df['skill']
        del to_df['rank']
        
        print(f'Shape of to_df : {to_df.rank().shape[0]}')
        # axs[ax].plot(to_df)
        axs[ax].bar(np.arange(1,to_df.rank().shape[0]+1),to_df.rank_histogram)
        axs[ax].set_xlim(1, 12)

        # Optionally, adjust tick marks
        axs[ax].set_xticks(np.arange(1, 13))
        axs[ax].set_title(experiment)
        axs[ax].set_xticklabels(axs[ax].get_xticklabels(), rotation=0)
        axs[ax].set_ylabel('Relative Frequency', rotation=90)
    plt.suptitle(f'{name_of_file} Rank Histogram', fontsize=30)
    plt.tight_layout()
    out_dir_save = f'{save_dir}/{name_of_file}.png'
    plt.savefig(out_dir_save, dpi=300)