In [1]:
%load_ext autoreload
%autoreload 2
%reload_ext autoreload

import xarray as xr
import numpy as np
import os
from glob import glob
import functions as f
#import climpredNEW.climpred 
#from climpredNEW.climpred.options import OPTIONS
from mpl_toolkits.basemap import Basemap
from numpy import meshgrid
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
import matplotlib.colors as mcolors
import cartopy.feature as cfeature
import itertools
import cartopy.crs as ccrs
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter, LatitudeLocator
import matplotlib.ticker as mticker
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, TwoSlopeNorm
import pandas as pd
import math
from scipy.stats import percentileofscore as pos
from datetime import datetime
import datetime as dt
from multiprocessing import Pool
from sklearn.metrics import confusion_matrix as CM
import masks
import preprocessUtils as putils
import verifications
import warnings



# Ignore specific RuntimeWarning
warnings.filterwarnings("ignore", message="invalid value encountered in sqrt")

2024-03-02 10:17:20.329615: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-03-02 10:17:22.911102: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#Set script parameters
global region_name
region_name = 'CONUS' #['CONUS','australia', 'china']

if region_name == 'CONUS':
    source = 'Data'
else:
    source = f'{Data}_{region_name}'

mask = masks.load_mask(region_name)


#Mask with np.nan for non-CONUS land values
mask_anom = mask[putils.xarray_varname(mask)][0,:,:].values


#leads to select
leads_ = [6,13,20,27,34]

In [3]:

#For RCI week differencing
global week_differencing
week_differencing = 4

# Data

In [4]:

obs_anomaly = xr.open_dataset(f'{source}/GLEAM/RZSM_anomaly.nc')

obs_anomaly = obs_anomaly.rename({'longitude':'X','latitude':'Y'})

print(f'Loading GLEAM RZSM anomalies')
obs_anomaly_SubX_format = xr.open_mfdataset(f'Data/GLEAM/RZSM_anomaly_reformat_SubX_format/{region_name}/RZSM_anomaly*.n*').sel(L=leads_).load()

init_date_list = [pd.to_datetime(i) for i in obs_anomaly_SubX_format.S.values] #Use for later reformatting the files



Loading GLEAM RZSM anomalies


In [5]:
#Reforecast baseline files
#######################################   Reforecast baseline files   ###########################################################################
baseline_anomaly_file_list = sorted(glob(f'{source}/GEFSv12_reforecast/soilw_bgrnd/baseline_RZSM_anomaly/soil*.n*'))
gefs_anom = xr.open_mfdataset(baseline_anomaly_file_list).sel(L=leads_).load()

ecmwf_anom = verifications.load_ECMWF_baseline_anomaly(region_name).sel(L=leads_).load()

In [6]:


# baseline_percentile_file_list = sorted(glob(f'{source}/GEFSv12_reforecast/soilw_bgrnd/percentiles_baseline/RZSM_percentiles_2*.nc'))
# baseline_percentile_MEM_file_list = sorted(glob(f'{source}/GEFSv12_reforecast/soilw_bgrnd/percentiles_baseline/RZSM_percentiles_MEM_2*.nc'))

# baseline_percentile = xr.open_mfdataset(baseline_percentile_file_list,combine='nested',concat_dim=['S']).sel(L=[0,6,13,20,27,34]).astype(np.float32).load()
# baseline_percentile_MEM = xr.open_mfdataset(baseline_percentile_MEM_file_list,combine='nested',concat_dim=['S']).sel(L=[0,6,13,20,27,34]).astype(np.float32).load()

# #########################################   Prediction (UNET) files   ######################################################################################
# unet_anomaly_file_list = sorted(glob(f'predictions/no_julian_dates/{experiment_name}_*.nc'))
# unet_percentile_file_list = sorted(glob(f'predictions/UNET/percentiles/{experiment_name}/RZSM_percentiles_2*.nc'))
# unet_percentile_MEM_file_list = sorted(glob(f'predictions/UNET/percentiles/{experiment_name}/RZSM_percentiles_MEM_2*.nc'))

# unet_percentile = xr.open_mfdataset(unet_percentile_file_list,combine='nested',concat_dim=['S']).sel(L=[0,6,13,20,27,34]).astype(np.float32).load()
# unet_percentile_MEM = xr.open_mfdataset(unet_percentile_MEM_file_list,combine='nested',concat_dim=['S']).sel(L=[0,6,13,20,27,34]).astype(np.float32).load()

# #Test
# # anomaly_file_list=baseline_anomaly_file_list
# # percentile_file_list = baseline_percentile_file_list
# # percentile_file_list_MEM=baseline_percentile_MEM_file_list
# # obs_anomaly=obs_anomaly
# # save_dir='Data/GEFSv12_reforecast/soilw_bgrnd/SMVI'
# # MEM_or_by_model='MEM'



# Rapid change index

In [7]:
# Create a time series to later iterate through for dates
time_index_short = pd.to_datetime(obs_anomaly.sel(time=slice('2000-01-01','2000-12-31')).time.values)
time_index_full = pd.to_datetime(obs_anomaly.time.values)

# Step 1, compute the mean weekly difference across all years for the same dates (OBSERVATIONS ONLY). 

## We are first applying a rolling mean equal to $Week_differencing (3 weeks right now)

## Also compute the standard deviation of the weekly differences

In [8]:
#Can change the difference weeks based on if we want to look at longer (or shorter) differenceing intervals


def return_std_and_mean_diff_across_years():

    std_daily_dict = {}
    mean_diff_dict = {}
    
    #First compute the rolling average based on the week differencing
    rolling_average = obs_anomaly.rolling(time=7*week_differencing, min_periods=7*week_differencing,center=False).mean()
    
    for idx,date in enumerate(time_index_short):
        # break
        #Grab all the same month and day values across all years
        
        #Need to add this because the leap year dates don't have enough values
        if date == pd.to_datetime('2000-02-29'):
            new_date = pd.to_datetime('2000-02-28')
        else:
            new_date = date
        
        #Select all the days across all years with the same month and day
        mask_current_week = (time_index_full.month == new_date.month) & (time_index_full.day == new_date.day)
        selected_data = obs_anomaly.isel(time=mask_current_week)

        #Now find the data from the previous week for each of those days
        if new_date == pd.to_datetime('2000-03-07'):
            leap_date = pd.to_datetime('2000-02-28')
        else:
            leap_date = new_date

        previous_week = leap_date - np.timedelta64(week_differencing,'W')
        mask_previous_week = (time_index_full.month == previous_week.month) & (time_index_full.day == previous_week.day)
        
        selected_data_previous = obs_anomaly.isel(time=mask_previous_week)

        #Sometimes we have a mis-match between years (specifically the number of data points, they must be equal!), so this fixes it
        if len(selected_data_previous.time.values) > len(selected_data.time.values):
            selected_data_previous = selected_data_previous.isel(time = slice(0,len(selected_data.time.values)))
        elif len(selected_data_previous.time.values) < len(selected_data.time.values):
            selected_data = selected_data.isel(time = slice(0,len(selected_data_previous.time.values)))


        #Now find the mean difference across all years and average
        mean_diff_across_years = np.nanmean(selected_data[putils.xarray_varname(selected_data)][:,:,:].values - selected_data_previous[putils.xarray_varname(selected_data)][:,:,:].values,axis=0)
        mean_diff_dict[f'{pd.to_datetime(date).year}-{pd.to_datetime(date).month:02}-{pd.to_datetime(date).day:02}'] = mean_diff_across_years
        
        # rv1_data = selected_data.RZSM[:,:,:].values
        # rv2_data = selected_data_previous.RZSM[:,:,:].values
        # std_combined = xr.concat([selected_data, selected_data_previous], dim='RZSM_2').std(dim='time')
        
        # Calculate the covariance between rv1 and rv2 along the third dimension
        # covariance_rv1_rv2 = np.nancov(rv1_data.reshape(-1, rv1_data.shape[-1]), rv2_data.reshape(-1, rv2_data.shape[-1]))[0, 1]
        # Calculate the standard deviation of the difference between rv1 and rv2
        # std_dev_diff = np.sqrt(std_dev_rv1**2 + std_dev_rv2**2 - 2 * covariance_rv1_rv2)
        
        diff_across_years = selected_data[putils.xarray_varname(selected_data)][:,:,:].values - selected_data_previous[putils.xarray_varname(selected_data)][:,:,:].values
        
        std_ = np.nanstd(diff_across_years,axis=0)
                
        std_daily_dict[f'{pd.to_datetime(date).year}-{pd.to_datetime(date).month:02}-{pd.to_datetime(date).day:02}'] = std_
        
    return(std_daily_dict,mean_diff_dict)


std_daily_dict,mean_diff_dict =return_std_and_mean_diff_across_years()

  mean_diff_across_years = np.nanmean(selected_data[putils.xarray_varname(selected_data)][:,:,:].values - selected_data_previous[putils.xarray_varname(selected_data)][:,:,:].values,axis=0)
  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,


In [9]:


def get_current_and_previous_values_difference(date,file,week_differencing):
    #test 
    # file = obs_anomaly

    selected_data = obs_anomaly.sel(time=date)
    
    #Now find the data from the previous week for each of those days
    if (pd.to_datetime(date).month == 3) and (pd.to_datetime(date).day == 7):
        leap_date = pd.to_datetime(f'{pd.to_datetime(date).year}-02-28')
    else:
        leap_date = date

    previous_week = leap_date - np.timedelta64(week_differencing,'W')
    selected_data_previous = obs_anomaly.sel(time=previous_week)
    
    return(selected_data - selected_data_previous)

def rci_function_OBS_only(obs_anomaly,week_differencing):
    #Now we need to calculate the weekly difference betweeen weeks, then substract the mean, and divide by standard deviation
    rci = xr.zeros_like(obs_anomaly)

    for idx,date in enumerate(obs_anomaly.time.values):
        # break
        
        #We must begin only at MARCH
        if (pd.to_datetime(date).month == 3) and (pd.to_datetime(date).day <=7):
            # print(f'Working on date {date}')
            start_of_year = True
            #We don't have any weeks to work with before the 7th of MARCH
            rci[putils.xarray_varname(rci)][idx,:,:] = 0
        elif (pd.to_datetime(date).month in [3, 4, 5, 6, 7, 8, 9, 10, 11]) and (pd.to_datetime(date).year >=2000):
            # print(f'Working on date {date}')
            start_of_year = False
            # break

            diff_weeks = get_current_and_previous_values_difference(date,obs_anomaly,week_differencing)
            #Now standardize
            rci_standardized = (diff_weeks - mean_diff_dict[f'2000-{pd.to_datetime(date).month:02}-{pd.to_datetime(date).day:02}']) / std_daily_dict[f'2000-{pd.to_datetime(date).month:02}-{pd.to_datetime(date).day:02}']
            # plt.hist(rci_standardized.RZSM.values.flatten(),bins=30)
            # plt.show()
            
            #Now update RCI value
            subtract_ = xr.where(rci_standardized < -0.75,1,0)
            add_ = xr.where(rci_standardized > 0.75,1,0)
            
            #If the signs switch between postive and negative with RCI, then we reset rci to 0
            switch1 = np.where((rci_standardized[putils.xarray_varname(rci_standardized)] > 0) & (rci[putils.xarray_varname(rci)][idx - 7,:,:] < 0),2,0)
            switch2 = np.where((rci_standardized[putils.xarray_varname(rci_standardized)] < 0) & (rci[putils.xarray_varname(rci)][idx - 7,:,:] > 0),2,0)
            
            sub = np.where(subtract_[putils.xarray_varname(subtract_)] == 1, rci[putils.xarray_varname(rci)][idx-7,:,:] - np.sqrt(np.abs(rci_standardized[putils.xarray_varname(rci_standardized)])-0.75),0)
            add = np.where(add_[putils.xarray_varname(add_)] == 1, rci[putils.xarray_varname(rci)][idx-7,:,:] + np.sqrt(rci_standardized[putils.xarray_varname(rci_standardized)]-0.75),0)
            
            final = sub + add
            
            #Now switch back the data if the signs are oppositve
            final = np.where(switch1 != 2, final,0)
            final = np.where(switch2 != 2, final,0)
            
            rci[putils.xarray_varname(rci)][idx,:,:] = final
            
            # plt.hist(rci.RZSM[idx,:,:].values.flatten(),bins=30)
            # plt.show()
            
    return(rci)




In [10]:
#RCI save 
save_rci = f'{source}/GLEAM/RCI_index_{week_differencing}_week.nc'

if os.path.exists(save_rci):
    rci_output = xr.open_dataset(save_rci).load()
else:
    rci_output = rci_function_OBS_only(obs_anomaly,week_differencing)
    rci_output.to_netcdf(save_rci)


#Plot the distribution for the RCI file
rci_output

In [None]:


def convert_OBS_to_SubX_format(_date):  
# for _date in init_date_list:
    # var='RZSM_weighted'
    # _date=init_date_list[0]
    
    '''We are going to create new leads that are different than reforecast. The reasoning for this is that we want the actual weekly lags (and 1 day lag) and this will
    assist with future predictions within the deep learning model'''
    
    save_dir = f'{source}/GLEAM/RCI_reformat_SubX_format'
    os.system(f'mkdir -p {save_dir}')
    
    # for var in ['geopotential']:
    ref_dir = f'{source}/GEFSv12_reforecast/soilw_bgrnd' #Just use a single reference directory to serve as the template for file creation
  
    #Grab a single SubX to use as the template. Doesn't matter if it is the same variable or not or the same date
    fcst_file = glob(f'{ref_dir}/*soil*2000-01-05*')[0]
    op = xr.open_dataset(fcst_file)
    
    if region_name == 'CONUS':
        new_X_coords = [i+360 if i < 0 else i for i in op.X.values]
        op = op.assign_coords({'X':new_X_coords})
    
    open_date_SubX = putils.restrict_to_bounding_box(op,mask)
    out_file = xr.zeros_like(open_date_SubX)
    
    '''We are going to create a new lead day that represents the previous day before the forecast was initialized
    #New shape will be (1x11x48x48x96)
    This will include the day lag 1, and weekly lags 1-12'''
    
    file_shape = out_file[list(out_file.keys())[0]].shape

    save_date = f'{_date.year}-{_date.month:02}-{_date.day:02}'
    
    obs_file_name = f'RCI_{week_differencing}week_reformat_{save_date}.nc4'
    save_file = f'{save_dir}/{obs_file_name}'
    
    # if os.path.exists(save_file):
    if os.path.exists('this.out'):
        pass
    else:
        # os.system(f'rm {save_file}') #Just to avoid getting random duplicates
        print(f'Working on initialized day {_date} to find values integrating with SubX models, leads, & coordinates and saving data into {save_dir}.')
        
        for idx,i_lead in enumerate(out_file.L.values):
            # break

            date_val = pd.to_datetime(pd.to_datetime(_date) + dt.timedelta(days=int(i_lead)+0)) #Adding +1 may be suitable for other forecasts which predict the next day. But GEFSv12 predicts lead 0 as 12 UTC on the same date it is initialized
            #But be careful if you adapt this code to a new script. We are looking backwards in time from the first date.
                
            date_val = f'{date_val.year}-{date_val.month:02}-{date_val.day:02}'

            out_file[putils.xarray_varname(out_file)][0,:, idx, :, :] = \
                rci_output[putils.xarray_varname(rci_output)].sel(time = date_val).values

        var_OUT = xr.Dataset(
            data_vars = dict(
                rci = (['S','M','L','Y','X'],    out_file[list(out_file.keys())[0]].values),
            ),
            coords = dict(
                S = np.atleast_1d(_date),
                X = open_date_SubX.X.values,
                Y = open_date_SubX.Y.values,
                L = out_file.L.values,
                M = open_date_SubX.M.values,

            ),
            attrs = dict(
                Description = f'RCI values on the exact same date and grid \
                cell as EMC reforecast data. '),
        )                    

        var_OUT = var_OUT.astype(np.float32)
        
        var_OUT.to_netcdf(save_file)

    return(0)





####### RUN FUNCTION #######
for date in init_date_list:
    convert_OBS_to_SubX_format(date)


# Now re-open the RCI files to analyze by lead

In [12]:

#open the RCI file in subx format
print('Loading the Observation RCI index file already computed.')
obs_rci = xr.open_mfdataset(f'{source}/GLEAM/RCI_reformat_SubX_format/RCI*{week_differencing}week*').sel(L=leads_).load()


Loading the Observation RCI index file already computed.


# Now construct the RCI values for each of the models

In [13]:
# Data

#Unet final experiment name (week 5)
experiment_name='EX27_regular_RZSM'


#######################################   UNET prediction   ###########################################################################
unet_anomaly_file_list = sorted(glob(f'predictions/{region_name}/anomaly_no_julian_dates/{experiment_name}_*.nc'))
unet_anomaly = xr.open_mfdataset(unet_anomaly_file_list).sel(L=leads_).load()


# Now calculate the RCI value based on previous observation differencing of standard deviation and day of year


In [14]:

def rci_reforecast_MEM(std_daily_dict, mean_diff_dict,week_differencing, reforecast_anomaly_MEM, obs_file, save_dir):

    os.system(f'mkdir -p {save_dir}')
    
    #testing
    # reforecast_anomaly_MEM = unet_anomaly
                        # Lead   0   ,    1,    ,   2      ,      3    ,     4     ,      5
    # output shape array([       nan, 0.0111007 , 0.01243984, 0.00984243, 0.01684886, 0.0191455 ])
    
    # obs_file = obs_rci.mean(dim='M')
    # save_dir = f'predictions/UNET/RCI/{experiment_name}'

    #Now go through each init day 

    #Take the mean
    
    
    # reforecast_anomaly_MEM = reforecast_anomaly_MEM.mean(dim='M').rolling(L=2,min_periods=2,center=False).mean() #this is for the 1-week RCI calculation
    # reforecast_anomaly_MEM.RZSM[0,:,10,10].values
    
    for idx,date in enumerate(reforecast_anomaly_MEM.S.values):
        # break
                
        datetime_dt = pd.to_datetime(date)
        save_date = f'{datetime_dt.year}-{datetime_dt.month:02}-{datetime_dt.day:02}'

        date_within_mean_std_dicts = f'2000-{datetime_dt.month:02}-{datetime_dt.day:02}'

        save_name_out = f'{save_dir}/RCI_MEM_{week_differencing}week_{save_date}.nc'
        #We need the very first lead as our beginning point with the RCI values from the observations
        rci_obs_to_save_over = obs_file.sel(S=date).copy(deep=True)
        rci_obs_to_save_over.rci[1:,:,:] = 0
        
        #Index only starts in March of every year and we must have it after the 7th bevcause that's when accumulation begins
        if (pd.to_datetime(date).month == 3) and (pd.to_datetime(date).day <=7):
            rci_obs_to_save_over.rci[0,:,:] = 0
            rci_obs_to_save_over = rci_obs_to_save_over.expand_dims({'S':1})
            rci_obs_to_save_over.to_netcdf(f'{save_name_out}')
            
        elif (datetime_dt.month in [3, 4, 5, 6, 7, 8, 9, 10, 11]):
            # break
            
            for idx2,lead in enumerate([6,13,20,27,34]):
                # break
                len_week_diff = np.arange(week_differencing)
                
                if idx2 in len_week_diff:
                    #We want to grab N weeks of the observations before the init date
                    obs_date_select = date - np.timedelta64(lead+1,'D')
                    week_diff = (reforecast_anomaly_MEM.sel(S=date, L=lead) - obs_anomaly.rename({'SMsurf':'RZSM'}).drop('season').sel(time=obs_date_select))
                    
                else:
                    #Now find the difference between the 2 weeks. The difference between the current week's forecast and the observation
                                   #Current day                                #Previous week
                    week_diff = reforecast_anomaly_MEM.sel(S=date, L=lead) - rci_obs_to_save_over.isel(L=idx2-week_differencing).rename({'rci':'RZSM'})
                    
                    #Now compute RCI
                    rci_standardized = (week_diff - mean_diff_dict[date_within_mean_std_dicts])/ std_daily_dict[date_within_mean_std_dicts]
                    rci_standardized.min()
                    rci_standardized.max()
                    rci_standardized.mean()
                    
                    #Now update RCI value
                    subtract_ = xr.where(rci_standardized < -0.75,1,0)
                    add_ = xr.where(rci_standardized > 0.75,1,0)
                    
                    #If the signs switch between postive and negative with RCI, then we reset rci to 0
                    switch1 = np.where((rci_standardized.RZSM > 0) & (rci_obs_to_save_over.isel(L=idx2-1).rci.values < 0),2,0)
                    switch2 = np.where((rci_standardized.RZSM < 0) & (rci_obs_to_save_over.isel(L=idx2-1).rci.values  > 0),2,0)
                    
                    sub = np.where(subtract_.RZSM.values == 1, rci_obs_to_save_over.isel(L=idx2-1).rci.values - (np.sqrt(np.abs(rci_standardized.RZSM.values) - 0.75)),0)
                    add = np.where(add_.RZSM.values == 1,  rci_obs_to_save_over.isel(L=idx2-1).rci.values + (np.sqrt(rci_standardized.RZSM.values-0.75)),0)

                    final = sub + add
            
                    #Now switch back the data if the signs are oppositve
                    final = np.where(switch1 != 2, final,0)
                    final = np.where(switch2 != 2, final,0)
                    
                    rci_obs_to_save_over.rci[idx2,:,:] = final
                    
            rci_obs_to_save_over = rci_obs_to_save_over.expand_dims({'S':1})
            rci_obs_to_save_over.to_netcdf(f'{save_name_out}')
                
        else:
            #don't need to update any dates that aren't between March through November
            rci_obs_to_save_over.rci[0,:,:] = 0
            rci_obs_to_save_over = rci_obs_to_save_over.expand_dims({'S':1})
            rci_obs_to_save_over.to_netcdf(f'{save_name_out}')

    return('Completed')

In [15]:
# UNET
rci_reforecast_MEM(std_daily_dict=std_daily_dict, mean_diff_dict=mean_diff_dict, 
                   week_differencing = week_differencing, reforecast_anomaly_MEM=unet_anomaly.mean(dim='M'), obs_file=obs_rci.mean(dim='M'), save_dir= f'predictions/{region_name}/UNET/RCI/{experiment_name}')

# Baseline reforecast GEFSv12
rci_reforecast_MEM(std_daily_dict=std_daily_dict, mean_diff_dict=mean_diff_dict, 
                   week_differencing = week_differencing, reforecast_anomaly_MEM=gefs_anom.mean(dim='M'), obs_file=obs_rci.mean(dim='M'), save_dir= f'{source}/GEFSv12_reforecast/soilw_bgrnd/RCI')

# Baseline reforecast ECMWF
rci_reforecast_MEM(std_daily_dict=std_daily_dict, mean_diff_dict=mean_diff_dict, 
                   week_differencing = week_differencing, reforecast_anomaly_MEM=ecmwf_anom.mean(dim='M'), obs_file=obs_rci.mean(dim='M'), save_dir= f'Data/ECMWF/soilw_bgrnd_processed/{region_name}/RCI')


'Completed'

# Now to it for each model realization

In [16]:
# Now calculate the RCI value based on previous observation differencing of standard deviation and day of year

def rci_reforecast(std_daily_dict, mean_diff_dict,week_differencing, reforecast_anomaly, obs_file, save_dir):

    os.system(f'mkdir -p {save_dir}')
    #testing
    # reforecast_anomaly = unet_anomaly
                        # Lead   0   ,    1,    ,   2      ,      3    ,     4     ,      5
    # output shape array([       nan, 0.0111007 , 0.01243984, 0.00984243, 0.01684886, 0.0191455 ])
    
    # obs_file = obs_rci
    # save_dir = f'predictions/UNET/RCI/{experiment_name}'

    #Now go through each init day 

    # reforecast_anomaly = reforecast_anomaly.rolling(L=2,min_periods=2,center=False).mean() #this is for the 1-week RCI calculation
    # reforecast_anomaly_MEM.RZSM[0,:,10,10].values
    
    for idx,date in enumerate(reforecast_anomaly.S.values):
        # break
                
        datetime_dt = pd.to_datetime(date)
        save_date = f'{datetime_dt.year}-{datetime_dt.month:02}-{datetime_dt.day:02}'

        date_within_mean_std_dicts = f'2000-{datetime_dt.month:02}-{datetime_dt.day:02}'

        save_name_out = f'{save_dir}/RCI_{week_differencing}week_{save_date}.nc'
        #We need the very first lead as our beginning point with the RCI values from the observations
        rci_obs_to_save_over = obs_file.sel(S=date).copy(deep=True)
        
        rci_obs_to_save_over.rci[:,:,:,:] = 0
        
        #Index only starts in March of every year and we must have it after the 7th bevcause that's when accumulation begins
        if (pd.to_datetime(date).month == 3) and (pd.to_datetime(date).day <=7):
            rci_obs_to_save_over.rci[:,0,:,:] = 0
            rci_obs_to_save_over = rci_obs_to_save_over.expand_dims({'S':1})
            rci_obs_to_save_over.to_netcdf(f'{save_name_out}')
            
        elif (datetime_dt.month in [3, 4, 5, 6, 7, 8, 9, 10, 11]):
            # break
            
            for idx2,lead in enumerate([6,13,20,27,34]):
                # break
                len_week_diff = np.arange(week_differencing)
                
                if idx2 in len_week_diff:
                    #We want to grab N weeks of the observations before the init date
                    obs_date_select = date - np.timedelta64(lead+1,'D')
                    week_diff = (reforecast_anomaly.sel(S=date, L=lead) - obs_anomaly.rename({'SMsurf':'RZSM'}).drop('season').sel(time=obs_date_select))
                    
                else:
                    #Now find the difference between the 2 weeks. The difference between the current week's forecast and the observation
                                   #Current day                                #Previous week
                    week_diff = reforecast_anomaly.sel(S=date, L=lead) - rci_obs_to_save_over.rci.isel(L=idx2-week_differencing)
                    
                #Now compute RCI
                rci_standardized = (week_diff - mean_diff_dict[date_within_mean_std_dicts])/ std_daily_dict[date_within_mean_std_dicts]
                rci_standardized.min()
                rci_standardized.max()
                rci_standardized.mean()
                
                try:
                    rci_standardized = rci_standardized.drop(['L','S'])
                except ValueError:
                    pass
                    
                #Now update RCI value
                subtract_ = xr.where(rci_standardized < -0.75,1,0)
                add_ = xr.where(rci_standardized > 0.75,1,0)
                
                #If the signs switch between postive and negative with RCI, then we reset rci to 0
                switch1 = np.where((rci_standardized[putils.xarray_varname(rci_standardized)] > 0) & (rci_obs_to_save_over.isel(L=idx2-1).rci.values < 0),2,0)
                switch2 = np.where((rci_standardized[putils.xarray_varname(rci_standardized)]< 0) & (rci_obs_to_save_over.isel(L=idx2-1).rci.values  > 0),2,0)
                
                sub = np.where(subtract_.RZSM.values == 1, rci_obs_to_save_over.isel(L=idx2-1).rci.values - (np.sqrt(np.abs(rci_standardized.RZSM.values) - 0.75)),0)
                add = np.where(add_.RZSM.values == 1,  rci_obs_to_save_over.isel(L=idx2-1).rci.values + (np.sqrt(rci_standardized.RZSM.values-0.75)),0)

                final = sub + add
        
                #Now switch back the data if the signs are oppositve
                final = np.where(switch1 != 2, final,0)
                final = np.where(switch2 != 2, final,0)
                
                rci_obs_to_save_over.rci[:,idx2,:,:] = final
                
            rci_obs_to_save_over = rci_obs_to_save_over.expand_dims({'S':1})
            rci_obs_to_save_over.to_netcdf(f'{save_name_out}')
                    
        else:
            #don't need to update any dates that aren't between March through November
            rci_obs_to_save_over.rci[:,0,:,:] = 0
            rci_obs_to_save_over = rci_obs_to_save_over.expand_dims({'S':1})
            rci_obs_to_save_over.to_netcdf(f'{save_name_out}')

    return('Completed')

In [17]:
# UNET - Good for CONUS
rci_reforecast(std_daily_dict=std_daily_dict, mean_diff_dict=mean_diff_dict, 
                   week_differencing = week_differencing, reforecast_anomaly=unet_anomaly, obs_file=obs_rci, save_dir= f'predictions/{region_name}/UNET/RCI/{experiment_name}')


'Completed'

In [18]:
# Baseline reforecast GEFS - Good for CONUS
rci_reforecast(std_daily_dict=std_daily_dict, mean_diff_dict=mean_diff_dict, 
                   week_differencing = week_differencing, reforecast_anomaly=gefs_anom, obs_file=obs_rci, save_dir= f'{source}/GEFSv12_reforecast/soilw_bgrnd/RCI')


'Completed'

In [19]:

# Baseline reforecast ECMWF
rci_reforecast(std_daily_dict=std_daily_dict, mean_diff_dict=mean_diff_dict, 
                   week_differencing = week_differencing, reforecast_anomaly=ecmwf_anom, obs_file=obs_rci, save_dir= f'Data/ECMWF/soilw_bgrnd_processed/{region_name}/RCI')


'Completed'

# Now do a case study of 2019 Southeast Flash Drought (ensemble mean only)

In [20]:
obs_rci

unet_rci = xr.open_mfdataset(f'predictions/{region_name}/UNET/RCI/{experiment_name}/RCI_MEM_{week_differencing}*').sel(L=[20,27,34]).load()
baseline_rci = xr.open_mfdataset(f'{source}/GEFSv12_reforecast/soilw_bgrnd/RCI/RCI_MEM_{week_differencing}*').sel(L=[20,27,34]).load()
ecmwf_rci = xr.open_mfdataset(f'Data/ECMWF/soilw_bgrnd_processed/{region_name}/RCI/RCI_MEM_{week_differencing}*').sel(L=[20,27,34]).load()

In [21]:
#dates
start_ = '2019-08-01'
end_ = '2019-10-30'

obs = obs_rci.sel(S=slice(start_,end_)).mean(dim='M')
unet = unet_rci.sel(S=slice(start_,end_))
baseline = baseline_rci.sel(S=slice(start_,end_))
ecmwf = ecmwf_rci.sel(S=slice(start_,end_))

In [22]:
#Mask with np.nan for non-CONUS land values
mask_anom = mask['NCA-LDAS_mask'][0,:,:].values


In [23]:
if region_name == 'CONUS':
    obs = xr.where(mask_anom ==1, obs,np.nan).sel(L=[20,27,34])
    unet = xr.where(mask_anom ==1, unet,np.nan)
    baseline = xr.where(mask_anom ==1, baseline,np.nan)

In [24]:
def get_min_max_of_files(obs, unet, baseline, ecmwf, date):
    #test 
    # date = '2019-08-07'
    
    min_ = []
    max_ = []

    min_.append(np.nanmin(obs.sel(S=date).rci.values))
    min_.append(np.nanmin(unet.sel(S=date).rci.values))
    min_.append(np.nanmin(baseline.sel(S=date).rci.values))
    min_.append(np.nanmin(ecmwf.sel(S=date).rci.values))
    
    max_.append(np.nanmax(obs.sel(S=date).rci.values))
    max_.append(np.nanmax(unet.sel(S=date).rci.values))
    max_.append(np.nanmax(baseline.sel(S=date).rci.values))
    max_.append(np.nanmax(ecmwf.sel(S=date).rci.values))
    
    return(min(min_),max(max_))

In [25]:
def return_array(file,lead,date):
    return(file.isel(L=lead).sel(S=date).rci.values)

In [26]:


   
# cmap = 'coolwarm'
def plot_case_study_rci(obs, unet, baseline, ecmwf, init_date):
    cmap = plt.get_cmap('bwr')    
    
    save_dir = f'Outputs/Case_studies/Southeast_US/RCI'
    os.system(f'mkdir -p {save_dir}')
        
    fig, axs = plt.subplots(
        nrows = 3, ncols= 4, subplot_kw={'projection': ccrs.PlateCarree()}, figsize=(15, 10))
    axs = axs.flatten()
    
    init_date = pd.to_datetime(init_date)
    date = f'{init_date.year}-{init_date.month:02}-{init_date.day:02}'
    
    min_,max_ = get_min_max_of_files(obs, unet, baseline, ecmwf,date)
    # test_file = mae_rzsm_keys
    # for Subx original data
    v = np.linspace(-3, 3, 20, endpoint=True)

    pos = [i for i in v if i > 0]
    neg = [i for i in v if i < 0]

    neg.append(0)
    v = neg + pos
    
    lon = obs.X.values
    lat = obs.Y.values
    
    axs_start = 0
    for index_, lead in enumerate([20,27,34]):
        for data_to_plot,name in zip([obs, unet, baseline,ecmwf], ['GLEAM','UNET','GEFSv12','ECMWF']):
            # break
            data = return_array(file=data_to_plot,lead=index_, date=date)
        
            map = Basemap(projection='cyl', llcrnrlat=25, urcrnrlat=50,
                          llcrnrlon=-128, urcrnrlon=-60, resolution='l')
            x, y = map(*np.meshgrid(lon, lat))
            # Adjust the text coordinates based on the actual data coordinates
        
            norm = TwoSlopeNorm(vmin=v[0], vcenter=0, vmax=v[-1])
        
            im = axs[axs_start].contourf(x, y, data, levels=v, extend='both',
                                  transform=ccrs.PlateCarree(), cmap=cmap,norm=norm)
    
    
            # axs[idx].title.set_text(f'SubX Lead {lead*7}')
            gl = axs[axs_start].gridlines(crs=ccrs.PlateCarree(), draw_labels=True,
                                       linewidth=0.7, color='gray', alpha=0.5, linestyle='--')
            gl.xlabels_top = False
            gl.ylabels_right = False
            if lead != 1:
                gl.ylabels_left = False
            gl.xformatter = LongitudeFormatter()
            gl.yformatter = LatitudeFormatter()
            axs[axs_start].coastlines()
            # plt.colorbar(im)
            # axs[idx].set_aspect('auto', adjustable=None)
            axs[axs_start].set_aspect('equal')  # this makes the plots better
            axs[axs_start].set_title(f'{name} Lead {lead}',fontsize=15)
            axs_start+=1
    cbar_ax = fig.add_axes([0.05, -0.05, .9, .04])
    
    # Draw the colorbar
    cbar = fig.colorbar(im, cax=cbar_ax, orientation='horizontal')
    fig.suptitle(f'Init date: {date}', fontsize=30)
    fig.tight_layout()
    
    plt.savefig(f'{save_dir}/Southeast_{week_differencing}week_RCI_init{date}.png',bbox_inches='tight')
    plt.show()


In [None]:
for init_date in obs.S.values:
    plot_case_study_rci(obs, unet, baseline, ecmwf, init_date)

In [28]:
End script here

SyntaxError: invalid syntax (1610507816.py, line 1)

# Plot anomaly for 2019

In [None]:
obs_anomaly_mf = xr.open_mfdataset('Data/GLEAM/RZSM_anomaly_reformat_SubX_format/RZSM_anomaly*.nc4').sel(L=[20,27,34]).load()

In [None]:
def get_min_max_of_files_anomaly(obs, unet, baseline, date):
    #test 
    # date = '2019-08-07'
    
    min_ = []
    max_ = []

    min_.append(obs.sel(S=date).min().RZSM.values)
    min_.append(unet.sel(S=date).min().RZSM.values)
    min_.append(baseline.sel(S=date).min().RZSM.values)

    max_.append(obs.sel(S=date).max().RZSM.values)
    max_.append(unet.sel(S=date).max().RZSM.values)
    max_.append(baseline.sel(S=date).max().RZSM.values)

    return(min(min_),max(max_))

def return_array_anomaly(file,lead,date):
    return(file.sel(L=lead,S=date).RZSM.values)

In [None]:


   
# cmap = 'coolwarm'
def plot_case_study_anomaly(obs, unet, baseline, init_date, year):

    text_x = -83.5
    text_y = 27
    font_size_corr = 12
    
    cmap = plt.get_cmap('bwr')    
    # Create a diverging color scale using RdBu colormap
    cmap = plt.get_cmap('RdBu')
    if year == 2019:
        save_dir = f'Outputs/Case_studies/Southeast_US/anomaly'
    elif year == 2017:
        save_dir = f'Outputs/Case_studies/High_Plains/anomaly'
    elif year == 2012:
        save_dir = f'Outputs/Case_studies/Central_US/anomaly'
        
    os.system(f'mkdir -p {save_dir}')
        
    fig, axs = plt.subplots(
        nrows = 3, ncols= 3, subplot_kw={'projection': ccrs.PlateCarree()}, figsize=(15, 10))
    axs = axs.flatten()
    
    init_date = pd.to_datetime(init_date)
    date = f'{init_date.year}-{init_date.month:02}-{init_date.day:02}'
    
    min_,max_ = get_min_max_of_files_anomaly(obs, unet, baseline, date)
    # test_file = mae_rzsm_keys
    # for Subx original data
    
    lon = obs.X.values
    lat = obs.Y.values
    
    axs_start = 0
    for lead in [20,27,34]:
        for data_to_plot,name in zip([obs, unet, baseline], ['GLEAM','UNET','Baseline']):
            # break
            data = return_array_anomaly(file=data_to_plot,lead=lead, date=date)
    
            v = np.linspace(min_, max_, 20, endpoint=True)
        
            map = Basemap(projection='cyl', llcrnrlat=25, urcrnrlat=50,
                          llcrnrlon=-128, urcrnrlon=-60, resolution='l')
            x, y = map(*np.meshgrid(lon, lat))
            # Adjust the text coordinates based on the actual data coordinates
        
            norm = TwoSlopeNorm(vmin=min_, vcenter=0, vmax=max_)
        
            im = axs[axs_start].contourf(x, y, data, levels=v, extend='both',
                                  transform=ccrs.PlateCarree(), cmap=cmap,norm=norm)
    
    
            # axs[idx].title.set_text(f'SubX Lead {lead*7}')
            gl = axs[axs_start].gridlines(crs=ccrs.PlateCarree(), draw_labels=True,
                                       linewidth=0.7, color='gray', alpha=0.5, linestyle='--')
            gl.xlabels_top = False
            gl.ylabels_right = False
            if lead != 1:
                gl.ylabels_left = False
            gl.xformatter = LongitudeFormatter()
            gl.yformatter = LatitudeFormatter()
            axs[axs_start].coastlines()
            # plt.colorbar(im)
            # axs[idx].set_aspect('auto', adjustable=None)
            axs[axs_start].set_aspect('equal')  # this makes the plots better
            axs[axs_start].set_title(f'{name} Lead {lead}',fontsize=15)

            if name in ['UNET','Baseline']:
                # Calculate the Pearson correlation coefficient
                obs_corr = return_array_anomaly(file=obs,lead=lead, date=date).flatten()
                data_corr = data.flatten()

                data_corr = data_corr[~np.isnan(obs_corr)]
                obs_corr = obs_corr[~np.isnan(obs_corr)]
                
                correlation_matrix = np.corrcoef(obs_corr, data_corr)
                # The correlation coefficient is in the top right corner of the correlation matrix
                correlation_coefficient = correlation_matrix[0, 1]
                correlation_coefficient = round(correlation_coefficient,4)
                #find the correlation coefficient across the dataset
                axs[axs_start].text(text_x, text_y, f'Corr: {correlation_coefficient}', ha='right', va='bottom', fontsize=font_size_corr, color='blue', weight = 'bold')
            
            
            axs_start+=1
            
    cbar_ax = fig.add_axes([0.05, -0.05, .9, .04])
    
    # Draw the colorbar
    cbar = fig.colorbar(im, cax=cbar_ax, orientation='horizontal')
    fig.suptitle(f'Init date: {date}', fontsize=30)
    fig.tight_layout()

    plt.savefig(f'{save_dir}/init_{date}_anomaly.png',bbox_inches='tight')
    
    plt.show()


In [None]:
#dates
start_ = '2019-08-01'
end_ = '2019-10-30'


obs_anom = obs_anomaly_mf.sel(S=slice(start_,end_)).mean(dim='M')
unet_anom = unet_anomaly.sel(S=slice(start_,end_)).mean(dim='M')
baseline_anom = baseline_anomaly.sel(S=slice(start_,end_)).mean(dim='M')

obs_anom = xr.where(mask_anom ==1, obs_anom,np.nan).sel(L=[20,27,34])
unet_anom = xr.where(mask_anom ==1, unet_anom,np.nan).sel(L=[20,27,34])
baseline_anom = xr.where(mask_anom ==1, baseline_anom,np.nan).sel(L=[20,27,34])

unet_anom = xr.where(~np.isnan(obs_anom), unet_anom,np.nan)
baseline_anom = xr.where(~np.isnan(obs_anom), baseline_anom,np.nan)

In [None]:
for init_date in obs_anom.S.values:
    plot_case_study_anomaly(obs=obs_anom, unet=unet_anom, baseline=baseline_anom, init_date=init_date, year=2019)

# 2017 Flash Drought

In [None]:
#dates
start_ = '2017-04-01'
end_ = '2017-08-30'


obs_anom = obs_anomaly_mf.sel(S=slice(start_,end_)).mean(dim='M')
unet_anom = unet_anomaly.sel(S=slice(start_,end_)).mean(dim='M')
baseline_anom = baseline_anomaly.sel(S=slice(start_,end_)).mean(dim='M')

#Masking on ocean areas
obs_anom = xr.where(mask_anom ==1, obs_anom,np.nan).sel(L=[20,27,34])
unet_anom = xr.where(mask_anom ==1, unet_anom,np.nan).sel(L=[20,27,34])
baseline_anom = xr.where(mask_anom ==1, baseline_anom,np.nan).sel(L=[20,27,34])

#Further masking of grid cells within CONUS that aren't in GlEAM
unet_anom = xr.where(~np.isnan(obs_anom), unet_anom,np.nan)
baseline_anom = xr.where(~np.isnan(obs_anom), baseline_anom,np.nan)

In [None]:
for init_date in obs_anom.S.values:
    plot_case_study_anomaly(obs=obs_anom, unet=unet_anom, baseline=baseline_anom, init_date=init_date, year = 2017)

# 2012 Flash Drought

In [None]:
#dates
start_ = '2012-04-01'
end_ = '2012-06-30'


obs_anom = obs_anomaly_mf.sel(S=slice(start_,end_)).mean(dim='M')
unet_anom = unet_anomaly.sel(S=slice(start_,end_)).mean(dim='M')
baseline_anom = baseline_anomaly.sel(S=slice(start_,end_)).mean(dim='M')

#Masking on ocean areas
obs_anom = xr.where(mask_anom ==1, obs_anom,np.nan).sel(L=[20,27,34])
unet_anom = xr.where(mask_anom ==1, unet_anom,np.nan).sel(L=[20,27,34])
baseline_anom = xr.where(mask_anom ==1, baseline_anom,np.nan).sel(L=[20,27,34])

#Further masking of grid cells within CONUS that aren't in GlEAM
unet_anom = xr.where(~np.isnan(obs_anom), unet_anom,np.nan)
baseline_anom = xr.where(~np.isnan(obs_anom), baseline_anom,np.nan)

In [None]:
for init_date in obs_anom.S.values:
    plot_case_study_anomaly(obs=obs_anom, unet=unet_anom, baseline=baseline_anom, init_date=init_date, year = 2012)