In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext autoreload

import xarray as xr
import numpy as np
import os
from glob import glob
import functions as f
#import climpredNEW.climpred 
#from climpredNEW.climpred.options import OPTIONS
from mpl_toolkits.basemap import Basemap
from numpy import meshgrid
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
import matplotlib.colors as mcolors
import cartopy.feature as cfeature
import itertools
import cartopy.crs as ccrs
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter, LatitudeLocator
import matplotlib.ticker as mticker
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, TwoSlopeNorm
import pandas as pd
import math
from scipy.stats import percentileofscore as pos
from datetime import datetime
import datetime as dt
from multiprocessing import Pool
from sklearn.metrics import confusion_matrix as CM
import preprocessUtils as putils
import masks



In [None]:
        
    fig, axs = plt.subplots(
        nrows = 5, ncols= 6, subplot_kw={'projection': ccrs.PlateCarree()}, figsize=(20, 15))
    axs = axs.flatten()
    
    init_date = pd.to_datetime(init_date)
    date = f'{init_date.year}-{init_date.month:02}-{init_date.day:02}'
    
    min_,max_ = get_min_max_of_files_anomaly_leadN_2(obs, baseline, date)
    # test_file = mae_rzsm_keys
    # for Subx original data
    
    lon = obs.X.values
    lat = obs.Y.values
    
    axs_start = 0
    for data_to_plot,name in zip([obs, baseline], ['GLEAM','Baseline']):
        # break
        data = return_array_anomaly_lead(file=data_to_plot, date=date)

        v = np.linspace(min_, max_, 20, endpoint=True)

        #Make sure it diverges at 0
        neg = [i for i in v if i <0]
        pos = [i for i in v if i >0]
        v = np.array(neg +[0] + pos)
        
    
        map = Basemap(projection='cyl', llcrnrlat=25, urcrnrlat=50,
                      llcrnrlon=-128, urcrnrlon=-60, resolution='l')
        x, y = map(*np.meshgrid(lon, lat))
        # Adjust the text coordinates based on the actual data coordinates
    
        norm = TwoSlopeNorm(vmin=min_, vcenter=0, vmax=max_)
    
        im = axs[axs_start].contourf(x, y, data, levels=v, extend='both',
                              transform=ccrs.PlateCarree(), cmap=cmap,norm=norm)


        # axs[idx].title.set_text(f'SubX Lead {lead*7}')
        gl = axs[axs_start].gridlines(crs=ccrs.PlateCarree(), draw_labels=True,
                                   linewidth=0.7, color='gray', alpha=0.5, linestyle='--')
        gl.xlabels_top = False
        gl.ylabels_right = False
        if lead != 1:
            gl.ylabels_left = False
        gl.xformatter = LongitudeFormatter()
        gl.yformatter = LatitudeFormatter()
        axs[axs_start].coastlines()
        # plt.colorbar(im)
        # axs[idx].set_aspect('auto', adjustable=None)
        axs[axs_start].set_aspect('equal')  # this makes the plots better
        axs[axs_start].set_title(f'{name} Lead {lead}',fontsize=15)

In [None]:
#For a single week and experiment
day_num =13 # [0,6,13,20,27,34]
week_lead = 2 #[0,1,2,3,4,5]

In [None]:
#Set script parameters
region_name = 'CONUS'
CONUS_mask = masks.load_CONUS_mask() #Mask of CONUS which serves as our bounding box. Can later change this to a larger file but then we would have to edit the data from the previous scripts. 

max_RZSM_reforecast, min_RZSM_reforecast = f.load_reforecast_min_max_RZSM()



# Data

In [None]:
#dates
start_ = '2019-08-01'
end_ = '2019-10-30'

southeast_lat_bottom  = 30
southeast_lat_top = 38

southeast_lon_left  = 267
southeast_lon_right = 282

#Mask with np.nan for non-CONUS land values
mask_anom = CONUS_mask['NCA-LDAS_mask'][0,:,:].values


In [None]:
obs_anomaly_SubX_format =xr.open_mfdataset('Data/GLEAM/RZSM_anomaly_reformat_SubX_format/RZSM_anomaly*.nc4').sel(L=[day_num]).astype(np.float32).load()

obs_anomaly_SubX_format_subset = obs_anomaly_SubX_format.sel(S=slice(start_,end_)).sel(X=slice(southeast_lon_left,southeast_lon_right)).sel(Y=slice(southeast_lat_top,southeast_lat_bottom)).mean(dim='M')


In [None]:
#######################################   Reforecast baseline files   ###########################################################################
# baseline_anomaly_file_list = sorted(glob('Data/GEFSv12_reforecast/soilw_bgrnd/baseline_RZSM_anomaly/RZSM*.nc'))
baseline_anomaly_file_list = sorted(glob('/glade/scratch/klesinger/FD_RZSM_deep_learning_data/Data/GEFSv12_reforecast/soilw_bgrnd/baseline_RZSM_anomaly/RZSM*.nc'))

baseline_anomaly = xr.open_mfdataset(baseline_anomaly_file_list).sel(L=[day_num]).astype(np.float32).load()



In [None]:
template = xr.open_mfdataset(f'Data/GLEAM/reformat_to_reforecast_shape/{region_name}/soilw_bgrnd/*.nc4').sel(L=[day_num]).astype(np.float32).load()

template_testing_only = template.sel(S=slice('2018-01-01','2019-12-31'))

In [None]:
def get_min_max_of_files(obs, unet, baseline, date):
    #test 
    # date = '2019-08-07'
    
    min_ = []
    max_ = []

    min_.append(obs.sel(S=date).min().rci.values)
    min_.append(unet.sel(S=date).min().rci.values)
    min_.append(baseline.sel(S=date).min().rci.values)

    max_.append(obs.sel(S=date).max().rci.values)
    max_.append(unet.sel(S=date).max().rci.values)
    max_.append(baseline.sel(S=date).max().rci.values)

    return(min(min_),max(max_))

def return_array(file,lead,date):
    return(file.sel(L=lead,S=date).rci.values)


def get_min_max_of_files_anomaly(obs, unet, baseline, init_dates_all):
    #test 
    # date = '2019-08-07'
    
    min_ = []
    max_ = []
    
    for date in init_dates_all:

    
        min_.append(obs.sel(S=date).min().RZSM.values)
        min_.append(unet.sel(S=date).min().RZSM.values)
        min_.append(baseline.sel(S=date).min().RZSM.values)
    
        max_.append(obs.sel(S=date).max().RZSM.values)
        max_.append(unet.sel(S=date).max().RZSM.values)
        max_.append(baseline.sel(S=date).max().RZSM.values)

    return(min(min_),max(max_))

def return_array_anomaly(file,date):
    return(file.sel(S=date).RZSM.values)

def create_reforecast_with_predictions_single_lead(day_num, week_lead, experiment_name):
    #Load previous predictions from experiments
    temp_cp = template_testing_only.copy(deep=True).sel(L=day_num)


    test = f.reverse_min_max_scaling(np.load(f'predictions/Wk_{week_lead}_testing/Wk{week_lead}_testing_{experiment_name}.npy')[2,:,:,:,0],max_RZSM_reforecast, min_RZSM_reforecast)
    test = np.reshape(test,(test.shape[0]//11,11,test.shape[1],test.shape[2]))

    #Apply CONUS mask 
    test = np.where(mask_anom == 1, test, np.nan)
    
    #Add data to file
    temp_cp.RZSM[:,:,:,:] = test

    #Mask the Southeast 
    # temp_cp = temp_cp.sel(X=slice(southeast_lon_left,southeast_lon_right)).sel(Y=slice(southeast_lat_top,southeast_lat_bottom)).mean(dim='M')
    temp_cp = temp_cp.sel(S=slice(start_,end_))
    
    return(temp_cp)

In [None]:
def make_plot_by_experiment_name(exp):
    
    #Look at a single plot
    unet_anomaly = create_reforecast_with_predictions_single_lead(day_num=day_num, week_lead=week_lead, experiment_name=exp)
    
    # Mask data to LAND
    obs = xr.where(mask_anom ==1, obs_anomaly_SubX_format.sel(L=day_num).mean(dim='M').sel(S=slice(start_,end_)),np.nan)
    unet = xr.where(mask_anom ==1, unet_anomaly.mean(dim='M'),np.nan)
    baseline = xr.where(mask_anom ==1, baseline_anomaly.mean(dim='M').sel(L=day_num).sel(S=slice(start_,end_)),np.nan)
    
    dates_ = obs.S.values
    unet = unet.assign_coords({'S':dates_})
    baseline = baseline.assign_coords({'S':dates_})
    
    
    
    
       
    # cmap = 'coolwarm'
    def plot_case_study_anomaly(obs, unet, baseline):
    
        text_x = -83.5
        text_y = 27
        font_size_corr = 12
        
        cmap = plt.get_cmap('bwr')    
        
        save_dir = f'Outputs/Case_studies/test_other_models'
        os.system(f'mkdir -p {save_dir}')
            
        fig, axs = plt.subplots(
            nrows = len(obs.S.values), ncols= 3, subplot_kw={'projection': ccrs.PlateCarree()}, figsize=(10,40))
        axs = axs.flatten()
        
        
        min_,max_ = get_min_max_of_files_anomaly(obs, unet, baseline, obs.S.values)
        # test_file = mae_rzsm_keys
        # for Subx original data
    
        # min_,max_ = -0.5,1
        
        lon = obs.X.values
        lat = obs.Y.values
        
        axs_start = 0
        for init_date in obs.S.values:
            init_date = pd.to_datetime(init_date)
            date = f'{init_date.year}-{init_date.month:02}-{init_date.day:02}'
            for lead in [day_num]:
                for data_to_plot,name in zip([obs, unet, baseline], ['GLEAM','UNET','Baseline']):
                    # break
                    data = return_array_anomaly(file=data_to_plot, date=init_date)
            
                    v = np.linspace(min_, max_, 20, endpoint=True)
                
                    map = Basemap(projection='cyl', llcrnrlat=25, urcrnrlat=50,
                                  llcrnrlon=-128, urcrnrlon=-60, resolution='l')
                    x, y = map(*np.meshgrid(lon, lat))
                    # Adjust the text coordinates based on the actual data coordinates
                
                    norm = TwoSlopeNorm(vmin=min_, vcenter=0, vmax=max_)
                
                    im = axs[axs_start].contourf(x, y, data, levels=v, extend='both',
                                          transform=ccrs.PlateCarree(), cmap=cmap,norm=norm)
            
            
                    # axs[idx].title.set_text(f'SubX Lead {lead*7}')
                    gl = axs[axs_start].gridlines(crs=ccrs.PlateCarree(), draw_labels=True,
                                               linewidth=0.7, color='gray', alpha=0.5, linestyle='--')
                    gl.xlabels_top = False
                    gl.ylabels_right = False
                    if lead != 1:
                        gl.ylabels_left = False
                    gl.xformatter = LongitudeFormatter()
                    gl.yformatter = LatitudeFormatter()
                    axs[axs_start].coastlines()
                    # plt.colorbar(im)
                    # axs[idx].set_aspect('auto', adjustable=None)
                    axs[axs_start].set_aspect('equal')  # this makes the plots better
                    axs[axs_start].set_title(f'{name} Lead {lead} Init: {date} ',fontsize=15)
        
                    if name in ['UNET','Baseline']:
                        # Calculate the Pearson correlation coefficient
                        obs_corr = return_array_anomaly(file=obs, date=init_date).flatten()
                        data_corr = data.flatten()
        
                        data_corr = data_corr[~np.isnan(obs_corr)]
                        obs_corr = obs_corr[~np.isnan(obs_corr)]
                        
                        correlation_matrix = np.corrcoef(obs_corr, data_corr)
                        # The correlation coefficient is in the top right corner of the correlation matrix
                        correlation_coefficient = correlation_matrix[0, 1]
                        correlation_coefficient = round(correlation_coefficient,4)
                        #find the correlation coefficient across the dataset
                        axs[axs_start].text(text_x, text_y, f'Corr: {correlation_coefficient}', ha='right', va='bottom', fontsize=font_size_corr, color='blue', weight = 'bold')
                    
                    
                    axs_start+=1
                
        cbar_ax = fig.add_axes([0.05, -0.05, .9, .04])
        
        # Draw the colorbar
        cbar = fig.colorbar(im, cax=cbar_ax, orientation='horizontal')
        fig.suptitle(f'Init date: {date}', fontsize=30)
        fig.tight_layout()
        
        plt.savefig(f'{save_dir}/{exp}_Wk_lead{week_lead}.png',bbox_inches='tight')
        plt.show()

    #Call plot function
    plot_case_study_anomaly(obs=obs, unet=unet, baseline=baseline)

In [None]:

make_plot_by_experiment_name(exp='EX19_RZSM')
make_plot_by_experiment_name(exp='EX11_denseLoss_RZSM')
make_plot_by_experiment_name(exp='EX11_RZSM')
make_plot_by_experiment_name(exp='EX18_denseLoss_RZSM')
make_plot_by_experiment_name(exp='EX18_RZSM')


# Loop over each experiment to find the best ones which represent only the Southeast