In [1]:
%load_ext autoreload
%autoreload 2
%reload_ext autoreload

import xarray as xr
import numpy as np
import os
from glob import glob
import functions as f
from mpl_toolkits.basemap import Basemap
from numpy import meshgrid
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
import matplotlib.colors as mcolors
import cartopy.feature as cfeature
import itertools
import cartopy.crs as ccrs
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter, LatitudeLocator
import matplotlib.ticker as mticker
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, TwoSlopeNorm
import pandas as pd
import math
from datetime import datetime
import datetime as dt
from multiprocessing import Pool
from sklearn.metrics import confusion_matrix as CM
import masks


2024-02-11 05:52:00.893226: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-02-11 05:52:01.022084: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  from .autonotebook import tqdm as notebook_tqdm


In [None]:
obs_anomaly_mf = xr.open_mfdataset('Data/GLEAM/RZSM_anomaly_reformat_SubX_format/RZSM_anomaly*.nc4').sel(L=[6,13,20,27,34]).load()

#######################################   Reforecast baseline files   ###########################################################################
baseline_anomaly_file_list = sorted(glob('Data/GEFSv12_reforecast/soilw_bgrnd/baseline_RZSM_anomaly/RZSM*.nc'))
baseline_anomaly = xr.open_mfdataset(baseline_anomaly_file_list).sel(L=[6,13,20,27,34]).load()

In [2]:
#Set script parameters
CONUS_mask = f.load_CONUS_mask() #Mask of CONUS which serves as our bounding box. Can later change this to a larger file but then we would have to edit the data from the previous scripts. 

#Used for later masking with np.nan
CONUS_array = CONUS_mask['NCA-LDAS_mask'].values.squeeze()
CONUS_array.shape

#Mask with np.nan for non-CONUS land values
mask_anom = CONUS_mask['NCA-LDAS_mask'][0,:,:].values



# Data (observation and baseline reforecast)

In [6]:
# We have run several different experiments
def load_unet_prediction(experiment_name, bias_correction_season_or_init_or_None):
    unet_anomaly_file_list = sorted(glob(f'predictions/anomaly_no_julian_dates/{experiment_name}_{bias_correction_season_or_init_or_None}*.nc'))
    unet_anomaly_conus_min_max = xr.open_mfdataset(unet_anomaly_file_list).sel(L=[6,13,20,27,34]).load()
    return(unet_anomaly_conus_min_max)



# Plot anomaly for 2019

In [4]:
def get_min_max_of_files_anomaly(obs, unet, baseline, date):
    #test 
    # date = '2019-08-07'
    
    min_ = []
    max_ = []

    min_.append(obs.sel(S=date).min().RZSM.values)
    min_.append(unet.sel(S=date).min().RZSM.values)
    min_.append(baseline.sel(S=date).min().RZSM.values)

    max_.append(obs.sel(S=date).max().RZSM.values)
    max_.append(unet.sel(S=date).max().RZSM.values)
    max_.append(baseline.sel(S=date).max().RZSM.values)

    #for some reason it is not allowing the max
    max_out = 0
    for i in max_:
        if i > max_out:
            max_out = i

    min_out = 0
    for i in min_:
        if i < min_out:
            min_out = i
    
    return(min_out,max_out)

def return_array_anomaly(file,lead,date):
    return(file.sel(L=lead,S=date).RZSM.values)

In [81]:


   
# cmap = 'coolwarm'
def plot_case_study_anomaly(obs, unet, baseline, init_date, year, file_name_out):

    text_x = -83.5
    text_y = 27
    font_size_corr = 12
    
    cmap = plt.get_cmap('bwr')    
    # Create a diverging color scale using RdBu colormap
    cmap = plt.get_cmap('RdBu')
    # Create a diverging color palette centered at 0
    # palette = sns.diverging_palette(220, 20, as_cmap=True)
    
    if year == 2019:
        save_dir = f'Outputs/Case_studies/Southeast_US/anomaly'
    elif year == 2017:
        save_dir = f'Outputs/Case_studies/High_Plains/anomaly'
    elif year == 2012:
        save_dir = f'Outputs/Case_studies/Central_US/anomaly'
        
    os.system(f'mkdir -p {save_dir}')
        
    fig, axs = plt.subplots(
        nrows = 5, ncols= 3, subplot_kw={'projection': ccrs.PlateCarree()}, figsize=(20, 15))
    axs = axs.flatten()
    
    init_date = pd.to_datetime(init_date)
    date = f'{init_date.year}-{init_date.month:02}-{init_date.day:02}'
    
    min_,max_ = get_min_max_of_files_anomaly(obs, unet, baseline, date)
    # test_file = mae_rzsm_keys
    # for Subx original data
    
    lon = obs.X.values
    lat = obs.Y.values
    
    axs_start = 0
    for lead in [6,13,20,27,34]:
        for data_to_plot,name in zip([obs, unet, baseline], ['GLEAM','UNET','Baseline']):
            # break
            data = return_array_anomaly(file=data_to_plot,lead=lead, date=date)
    
            v = np.linspace(min_, max_, 20, endpoint=True)

            #Make sure it diverges at 0
            neg = [i for i in v if i <0]
            pos = [i for i in v if i >0]
            v = np.array(neg +[0] + pos)
            
        
            map = Basemap(projection='cyl', llcrnrlat=25, urcrnrlat=50,
                          llcrnrlon=-128, urcrnrlon=-60, resolution='l')
            x, y = map(*np.meshgrid(lon, lat))
            # Adjust the text coordinates based on the actual data coordinates
        
            norm = TwoSlopeNorm(vmin=min_, vcenter=0, vmax=max_)
        
            im = axs[axs_start].contourf(x, y, data, levels=v, extend='both',
                                  transform=ccrs.PlateCarree(), cmap=cmap,norm=norm)
    
    
            # axs[idx].title.set_text(f'SubX Lead {lead*7}')
            gl = axs[axs_start].gridlines(crs=ccrs.PlateCarree(), draw_labels=True,
                                       linewidth=0.7, color='gray', alpha=0.5, linestyle='--')
            gl.xlabels_top = False
            gl.ylabels_right = False
            if lead != 1:
                gl.ylabels_left = False
            gl.xformatter = LongitudeFormatter()
            gl.yformatter = LatitudeFormatter()
            axs[axs_start].coastlines()
            # plt.colorbar(im)
            # axs[idx].set_aspect('auto', adjustable=None)
            axs[axs_start].set_aspect('equal')  # this makes the plots better
            axs[axs_start].set_title(f'{name} Lead {lead}',fontsize=15)

            if name in ['UNET','Baseline']:
                # Calculate the Pearson correlation coefficient
                obs_corr = return_array_anomaly(file=obs,lead=lead, date=date).flatten()
                data_corr = data.flatten()

                data_corr = data_corr[~np.isnan(obs_corr)]
                obs_corr = obs_corr[~np.isnan(obs_corr)]
                
                correlation_matrix = np.corrcoef(obs_corr, data_corr)
                # The correlation coefficient is in the top right corner of the correlation matrix
                correlation_coefficient = correlation_matrix[0, 1]
                correlation_coefficient = round(correlation_coefficient,4)
                #find the correlation coefficient across the dataset
                axs[axs_start].text(text_x, text_y, f'Corr: {correlation_coefficient}', ha='right', va='bottom', fontsize=font_size_corr, color='blue', weight = 'bold')
            
            
            axs_start+=1
            
    cbar_ax = fig.add_axes([0.05, -0.05, .9, .04])
    
    # Draw the colorbar
    cbar = fig.colorbar(im, cax=cbar_ax, orientation='horizontal')
    fig.suptitle(f'Init date: {date}\n{file_name_out}', fontsize=30)
    fig.tight_layout()

    plt.savefig(f'{save_dir}/init_{date}_{file_name_out}.png',bbox_inches='tight')
    
    plt.show()


In [None]:
#dates
start_ = '2019-07-25'
end_ =  '2019-09-14'

obs_anom = obs_anomaly_mf.sel(S=slice(start_,end_)).mean(dim='M').astype(np.float32)
baseline_anom = baseline_anomaly.sel(S=slice(start_,end_)).mean(dim='M').astype(np.float32)

for experiment_name in ['EX26_RZSM', 'EX10_grid_cell_standardization_RZSM']:
    for bias_correction_season_or_init_or_None in ['No','season','init']:
        # break
        unet_anomaly = load_unet_prediction(experiment_name, bias_correction_season_or_init_or_None)

        #Now plot the data
        unet_anom = unet_anomaly.sel(S=slice(start_,end_)).mean(dim='M')
        
        obs_anom = xr.where(mask_anom ==1, obs_anom,np.nan).sel(L=[6,13,20,27,34])
        unet_anom = xr.where(mask_anom ==1, unet_anom,np.nan).sel(L=[6,13,20,27,34])
        baseline_anom = xr.where(mask_anom ==1, baseline_anom,np.nan).sel(L=[6,13,20,27,34])
        
        unet_anom = xr.where(~np.isnan(obs_anom), unet_anom,np.nan)
        baseline_anom = xr.where(~np.isnan(obs_anom), baseline_anom,np.nan)
        
        for init_date in obs_anom.S.values:
            plot_case_study_anomaly(obs=obs_anom, unet=unet_anom, baseline=baseline_anom, init_date=init_date, 
                                    year=2019, file_name_out = f'{experiment_name}_{bias_correction_season_or_init_or_None}_bias_correction_anomaly')

In [20]:
def return_array_anomaly_lead(file,date):
    return(file.sel(S=date).RZSM.values)
    
def get_min_max_of_files_anomaly_lead(obs,  unet_select):
    #test 
    # date = '2019-08-07'
    
    min_ = []
    max_ = []

    min_.append(obs.min().RZSM.values)
    # min_.append(unet.sel(S=date).min().RZSM.values)
    min_.append(unet_select.min().RZSM.values)

    max_.append(obs.max().RZSM.values)
    # max_.append(unet.sel(S=date).max().RZSM.values)
    max_.append(unet_select.max().RZSM.values)

    #for some reason it is not allowing the max
    max_out = 0
    for i in max_:
        if i > max_out:
            max_out = i

    min_out = 0
    for i in min_:
        if i < min_out:
            min_out = i
    
    return(min_out,max_out)

def get_min_max_of_files_anomaly_leadN_2(obs, baseline, date):
    #test 
    # date = '2019-08-07'
    
    min_ = []
    max_ = []

    min_.append(obs.sel(S=date).min().RZSM.values)
    # min_.append(unet.sel(S=date).min().RZSM.values)
    min_.append(baseline.sel(S=date).min().RZSM.values)

    max_.append(obs.sel(S=date).max().RZSM.values)
    # max_.append(unet.sel(S=date).max().RZSM.values)
    max_.append(baseline.sel(S=date).max().RZSM.values)

    #for some reason it is not allowing the max
    max_out = 0
    for i in max_:
        if i > max_out:
            max_out = i

    min_out = 0
    for i in min_:
        if i < min_out:
            min_out = i
    
    return(min_out,max_out)

# cmap = 'coolwarm'

def plot_unet_experiments(obs, baseline, init_date, 
                            year, file_name_out , experiment_list, lead):

    text_x = -83.5
    text_y = 27
    font_size_corr = 12
    
    cmap = plt.get_cmap('bwr')    
    # Create a diverging color scale using RdBu colormap
    cmap = plt.get_cmap('RdBu')
    # Create a diverging color palette centered at 0
    # palette = sns.diverging_palette(220, 20, as_cmap=True)
    
    if year == 2019:
        save_dir = f'Outputs/Case_studies/Southeast_US/anomaly/specific_UNETs'
    elif year == 2017:
        save_dir = f'Outputs/Case_studies/High_Plains/anomaly/specific_UNETs'
    elif year == 2012:
        save_dir = f'Outputs/Case_studies/Central_US/anomaly/specific_UNETs'
        
    os.system(f'mkdir -p {save_dir}')
        
    fig, axs = plt.subplots(
        nrows = 5, ncols= 6, subplot_kw={'projection': ccrs.PlateCarree()}, figsize=(20, 15))
    axs = axs.flatten()
    
    init_date = pd.to_datetime(init_date)
    date = f'{init_date.year}-{init_date.month:02}-{init_date.day:02}'
    
    min_,max_ = get_min_max_of_files_anomaly_leadN_2(obs, baseline, date)
    # test_file = mae_rzsm_keys
    # for Subx original data
    
    lon = obs.X.values
    lat = obs.Y.values
    
    axs_start = 0
    for data_to_plot,name in zip([obs, baseline], ['GLEAM','Baseline']):
        # break
        data = return_array_anomaly_lead(file=data_to_plot, date=date)

        v = np.linspace(min_, max_, 20, endpoint=True)

        #Make sure it diverges at 0
        neg = [i for i in v if i <0]
        pos = [i for i in v if i >0]
        v = np.array(neg +[0] + pos)
        
    
        map = Basemap(projection='cyl', llcrnrlat=25, urcrnrlat=50,
                      llcrnrlon=-128, urcrnrlon=-60, resolution='l')
        x, y = map(*np.meshgrid(lon, lat))
        # Adjust the text coordinates based on the actual data coordinates
    
        norm = TwoSlopeNorm(vmin=min_, vcenter=0, vmax=max_)
    
        im = axs[axs_start].contourf(x, y, data, levels=v, extend='both',
                              transform=ccrs.PlateCarree(), cmap=cmap,norm=norm)


        # axs[idx].title.set_text(f'SubX Lead {lead*7}')
        gl = axs[axs_start].gridlines(crs=ccrs.PlateCarree(), draw_labels=True,
                                   linewidth=0.7, color='gray', alpha=0.5, linestyle='--')
        gl.xlabels_top = False
        gl.ylabels_right = False
        if lead != 1:
            gl.ylabels_left = False
        gl.xformatter = LongitudeFormatter()
        gl.yformatter = LatitudeFormatter()
        axs[axs_start].coastlines()
        # plt.colorbar(im)
        # axs[idx].set_aspect('auto', adjustable=None)
        axs[axs_start].set_aspect('equal')  # this makes the plots better
        axs[axs_start].set_title(f'{name} Lead {lead}',fontsize=15)

        if name in ['UNET','Baseline']:
            # Calculate the Pearson correlation coefficient
            obs_corr = return_array_anomaly_lead(file=obs, date=date).flatten()
            data_corr = data.flatten()

            data_corr = data_corr[~np.isnan(obs_corr)]
            obs_corr = obs_corr[~np.isnan(obs_corr)]
            
            correlation_matrix = np.corrcoef(obs_corr, data_corr)
            # The correlation coefficient is in the top right corner of the correlation matrix
            correlation_coefficient = correlation_matrix[0, 1]
            correlation_coefficient = round(correlation_coefficient,4)
            #find the correlation coefficient across the dataset
            axs[axs_start].text(text_x, text_y, f'Corr: {correlation_coefficient}', ha='right', va='bottom', fontsize=font_size_corr, color='blue', weight = 'bold')
        
            
        axs_start+=1

    for experiment_name in experiment_list: 
        # break
        data = np.nanmean(load_unet_prediction_by_lead(experiment_name, lead),axis=1)
        #mask data
        unet_select = obs.copy(deep = True)
        unet_select.RZSM[:,:,:] = data
        unet_select = xr.where(~np.isnan(obs),unet_select,np.nan)
        unet_select = unet_select.sel(S=init_date)
        unet_select.RZSM.shape
        min_,max_ = get_min_max_of_files_anomaly_lead(obs.sel(S=init_date),unet_select)
        v = np.linspace(min_, max_, 20, endpoint=True)
        
        #Make sure it diverges at 0
        neg = [i for i in v if i <0]
        pos = [i for i in v if i >0]
        v = np.array(neg +[0] + pos)
        
        
        map = Basemap(projection='cyl', llcrnrlat=25, urcrnrlat=50,
                      llcrnrlon=-128, urcrnrlon=-60, resolution='l')
        x, y = map(*np.meshgrid(lon, lat))
        # Adjust the text coordinates based on the actual data coordinates
        
        norm = TwoSlopeNorm(vmin=min_, vcenter=0, vmax=max_)
        
        im = axs[axs_start].contourf(x, y, unet_select.RZSM.values, levels=v, extend='both',
                              transform=ccrs.PlateCarree(), cmap=cmap,norm=norm)
        
        
        # axs[idx].title.set_text(f'SubX Lead {lead*7}')
        gl = axs[axs_start].gridlines(crs=ccrs.PlateCarree(), draw_labels=True,
                                   linewidth=0.7, color='gray', alpha=0.5, linestyle='--')
        gl.xlabels_top = False
        gl.ylabels_right = False
        if lead != 1:
            gl.ylabels_left = False
        gl.xformatter = LongitudeFormatter()
        gl.yformatter = LatitudeFormatter()
        axs[axs_start].coastlines()
        # plt.colorbar(im)
        # axs[idx].set_aspect('auto', adjustable=None)
        axs[axs_start].set_aspect('equal')  # this makes the plots better
        axs[axs_start].set_title(f'{experiment_name} Lead {lead}',fontsize=15)
        
        if name in ['UNET','Baseline']:
            # Calculate the Pearson correlation coefficient
            obs_corr = return_array_anomaly_lead(file=obs, date=date).flatten()
            data_corr = unet_select.RZSM.values.flatten()
        
            data_corr = data_corr[~np.isnan(obs_corr)]
            obs_corr = obs_corr[~np.isnan(obs_corr)]
            
            correlation_matrix = np.corrcoef(obs_corr, data_corr)
            # The correlation coefficient is in the top right corner of the correlation matrix
            correlation_coefficient = correlation_matrix[0, 1]
            correlation_coefficient = round(correlation_coefficient,4)
            #find the correlation coefficient across the dataset
            axs[axs_start].text(text_x, text_y, f'Corr: {correlation_coefficient}', ha='right', va='bottom', fontsize=font_size_corr, color='blue', weight = 'bold')
    
                
            axs_start+=1


                
    # unet_anomaly.shape
    # unet_anomaly = np.nanmean(unet_anomaly,axis=1)
    # unet_anom = xr.where(mask_anom ==1, unet_select,np.nan)
    # unet_anom = xr.where(~np.isnan(obs_anom), unet_select,np.nan)

    
    cbar_ax = fig.add_axes([0.05, -0.05, .9, .04])
    
    # Draw the colorbar
    cbar = fig.colorbar(im, cax=cbar_ax, orientation='horizontal')
    fig.suptitle(f'Init date: {date}\n{file_name_out}', fontsize=30)
    fig.tight_layout()

    plt.savefig(f'{save_dir}/init_{init_date}_{file_name_out}.png',bbox_inches='tight')
    
    plt.show()


# Plot UNET experiments for each weekly lead 3 -5

In [10]:
# We have run several different experiments
def load_unet_prediction_by_lead(experiment_name, lead):
    max_RZSM_reforecast, min_RZSM_reforecast = f.load_reforecast_min_max_RZSM()
    unet_anomaly_file_list = np.load(f'predictions/Wk_{lead}_testing/Wk{lead}_testing_{experiment_name}.npy')
    #Convert back to anomalies
    test = f.reverse_min_max_scaling(unet_anomaly_file_list[2,:,:,:,0],max_RZSM_reforecast, min_RZSM_reforecast)
    test = np.reshape(test,(test.shape[0]//11,11,test.shape[1],test.shape[2]))
    return(test)



In [11]:
Wk3 = [f'EX{i}_RZSM' for i in range(26)]
Wk4a = [f'EX{i}_RZSM' for i in range(12)]
Wk4b = [f'EX{i}_RZSM' for i in range(13,27)]

Wk4 = Wk4a+Wk4b

In [None]:
start_ = '2019-07-25'
end_ =  '2019-09-14'

# obs_anom = obs_anomaly_mf.sel(S=slice(start_,end_)).mean(dim='M').astype(np.float32)
# baseline_anom = baseline_anomaly.sel(S=slice(start_,end_)).mean(dim='M').astype(np.float32)


# break
#First get the testing dates for obs
start_test = '2018-01-01'
end_test = '2019-12-31'

obs_anom = obs_anomaly_mf.sel(S=slice(start_test,end_test)).mean(dim='M').astype(np.float32)
baseline_anom = baseline_anomaly.sel(S=slice(start_test,end_test)).mean(dim='M').astype(np.float32)


obs_select = obs_anom.sel(L=(lead*7)-1)
baseline_select = baseline_anom.sel(L=(lead*7)-1)




obs_anom = xr.where(mask_anom ==1, obs_select,np.nan)

baseline_anom = xr.where(mask_anom ==1, baseline_select,np.nan)

baseline_anom = xr.where(~np.isnan(obs_anom), baseline_select,np.nan)


obs_anom_dates = obs_anom.sel(S=slice(start_,end_))

for (experiment_list,lead) in zip([Wk3,Wk4], [3,4]):
    for init_date in obs_anom_dates.S.values:
        plot_unet_experiments(obs=obs_anom, baseline=baseline_anom, init_date=init_date, 
                                year=2019, file_name_out = f'ALL_unets_Lead_{lead}', experiment_list = experiment_list, lead=lead)

# 2017 Flash Drought

In [30]:
#dates
start_ = '2017-04-01'
end_ = '2017-08-30'


obs_anom = obs_anomaly_mf.sel(S=slice(start_,end_)).mean(dim='M')
unet_anom = unet_anomaly.sel(S=slice(start_,end_)).mean(dim='M')
baseline_anom = baseline_anomaly.sel(S=slice(start_,end_)).mean(dim='M')

#Masking on ocean areas
obs_anom = xr.where(mask_anom ==1, obs_anom,np.nan).sel(L=[20,27,34])
unet_anom = xr.where(mask_anom ==1, unet_anom,np.nan).sel(L=[20,27,34])
baseline_anom = xr.where(mask_anom ==1, baseline_anom,np.nan).sel(L=[20,27,34])

#Further masking of grid cells within CONUS that aren't in GlEAM
unet_anom = xr.where(~np.isnan(obs_anom), unet_anom,np.nan)
baseline_anom = xr.where(~np.isnan(obs_anom), baseline_anom,np.nan)

In [None]:
for init_date in obs_anom.S.values:
    plot_case_study_anomaly(obs=obs_anom, unet=unet_anom, baseline=baseline_anom, init_date=init_date, year = 2017)

# 2012 Flash Drought

In [32]:
#dates
start_ = '2012-04-01'
end_ = '2012-06-30'


obs_anom = obs_anomaly_mf.sel(S=slice(start_,end_)).mean(dim='M')
unet_anom = unet_anomaly.sel(S=slice(start_,end_)).mean(dim='M')
baseline_anom = baseline_anomaly.sel(S=slice(start_,end_)).mean(dim='M')

#Masking on ocean areas
obs_anom = xr.where(mask_anom ==1, obs_anom,np.nan).sel(L=[20,27,34])
unet_anom = xr.where(mask_anom ==1, unet_anom,np.nan).sel(L=[20,27,34])
baseline_anom = xr.where(mask_anom ==1, baseline_anom,np.nan).sel(L=[20,27,34])

#Further masking of grid cells within CONUS that aren't in GlEAM
unet_anom = xr.where(~np.isnan(obs_anom), unet_anom,np.nan)
baseline_anom = xr.where(~np.isnan(obs_anom), baseline_anom,np.nan)

In [None]:
for init_date in obs_anom.S.values:
    plot_case_study_anomaly(obs=obs_anom, unet=unet_anom, baseline=baseline_anom, init_date=init_date, year = 2012)