In [9]:
#!/usr/bin/env python
# coding: utf-8

# This script is used to compare ensemble outputs with NLDAS data

from mpl_toolkits.basemap import Basemap
from pyproj import Proj
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import os
import pandas as pd
import xarray as xr
import datetime

startTime = datetime.datetime.now()

def read_ens(out_forc_name_base, metric, start_yr, end_yr):
    for yr in range(start_yr, end_yr+1):        
        
        file = os.path.join(out_forc_name_base + '.' + str(yr) + '.'+metric+'.nc')
        f=xr.open_dataset(file)
        time = f['time'][:]
        pcp = f.variables['pcp'][:]
        tmean = f.variables['t_mean'][:]
        tmin = f.variables['t_min'][:]
        tmax = f.variables['t_max'][:]
        trange = f.variables['t_range'][:]
        
        if yr == start_yr:
            time_concat = time
            pcp_concat = pcp
            tmean_concat = tmean
            tmin_concat = tmin
            tmax_concat = tmax
            trange_concat = trange
        else:
            time_concat = np.concatenate((time_concat,time), axis=0) # (time)
            pcp_concat = np.concatenate((pcp_concat, pcp), axis=0) # (time,y,x)
            tmean_concat = np.concatenate((tmean_concat, tmean), axis=0)
            tmin_concat = np.concatenate((tmin_concat, tmin), axis=0)
            tmax_concat = np.concatenate((tmax_concat, tmax), axis=0)
            trange_concat = np.concatenate((trange_concat, trange), axis=0)
            
    time_concat = pd.DatetimeIndex(time_concat)
        
    return time_concat, pcp_concat, tmean_concat, tmin_concat, tmax_concat, trange_concat

def plot_basemap(llcrnrlon,llcrnrlat,urcrnrlon,urcrnrlat,ax,lat_0,lon_0,ny,nx):

    m = Basemap(llcrnrlon,llcrnrlat,urcrnrlon,urcrnrlat,resolution='l',projection='cyl', ax=ax)   
#     m = Basemap(llcrnrlon,llcrnrlat,urcrnrlon,urcrnrlat,resolution='l',projection='tmerc', ax=ax,lat_0=lat_0,lon_0=lon_0)

    m.drawstates(linewidth=0.5, linestyle='solid', color='grey')
    m.drawcountries(linewidth=0.5, linestyle='solid', color='k')
    m.drawcoastlines(linewidth=.25, linestyle='solid', color='k')
    return m

# set the colormap and centre the colorbar
class MidpointNormalize(mpl.colors.Normalize):
    """Normalise the colorbar.
    source: http://chris35wills.github.io/matplotlib_diverging_colorbar/
    e.g. im=ax1.imshow(array, norm=MidpointNormalize(midpoint=0.,vmin=-300, vmax=1000))    
    """
    def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False):
        self.midpoint = midpoint
        mpl.colors.Normalize.__init__(self, vmin, vmax, clip)

    def __call__(self, value, clip=None):
        x, y = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1]
        return np.ma.masked_array(np.interp(value, x, y), np.isnan(value))

#======================================================================================================
# main script
root_dir = '/glade/u/home/hongli/scratch/2020_04_21nldas_gmet'   
nldas_dir = os.path.join(root_dir,'data/nldas_daily_utc_convert')
start_yr = 2015
end_yr = 2016

gridinfo_file = os.path.join(root_dir,'data/nldas_topo/conus_ens_grid_eighth.nc')

result_dir = os.path.join(root_dir,'test_uniform_perturb')
test_folders = [d for d in os.listdir(result_dir)]
test_folders = sorted(test_folders)
subforlder = 'gmet_ens_summary'
file_basename = 'ens_forc'

ens_num = 100
time_format = '%Y-%m-%d'

dpi_value = 150
plot_date_start = '2015-01-01'
plot_date_end = '2016-12-31'
plot_date_start_obj = datetime.datetime.strptime(plot_date_start, time_format)
plot_date_end_obj = datetime.datetime.strptime(plot_date_end, time_format)

output_dir=os.path.join(root_dir, 'scripts/step10_plot_spatial_NLDAS_ens')
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
    
# #======================================================================================================
# print('Read gridinfo mask')
# # get xy mask from gridinfo.nc
# f_gridinfo = xr.open_dataset(gridinfo_file)
# mask_xy = f_gridinfo['mask'].values[:] # (y, x). 1 is valid. 0 is invalid.
# #data_mask = f_gridinfo['data_mask'].values[:] # (y, x). 1 is valid. 0 is invalid.
# latitude = f_gridinfo['latitude'].values[:]
# longitude = f_gridinfo['longitude'].values[:]

# #======================================================================================================
# # read historical nldas data
# # read historical nldas data
# print('Read nldas data')
# for yr in range(start_yr, end_yr+1):
    
#     nldas_file = 'NLDAS_'+str(yr)+'.nc'
#     nldas_path = os.path.join(nldas_dir, nldas_file)
    
#     f_nldas = xr.open_dataset(nldas_path)
#     if yr == start_yr:
#         pcp = f_nldas['pcp'].values[:] # (time, y, x). unit: mm/day
#         t_mean = f_nldas['t_mean'].values[:] # (time, y, x). unit: degC
#         t_min = f_nldas['t_min'].values[:] 
#         t_max = f_nldas['t_max'].values[:]
#         t_range = f_nldas['t_range'].values[:]
#         time = f_nldas['time'].values[:]
#     else:
#         pcp = np.concatenate((pcp, f_nldas['pcp'].values[:]), axis = 0)
#         t_mean = np.concatenate((t_mean, f_nldas['t_mean'].values[:]), axis = 0)
#         t_min = np.concatenate((t_min, f_nldas['t_min'].values[:]), axis = 0)
#         t_max = np.concatenate((t_max, f_nldas['t_max'].values[:]), axis = 0)
#         t_range = np.concatenate((t_range, f_nldas['t_range'].values[:]), axis = 0)
#         time = np.concatenate((time, f_nldas['time'].values[:]), axis = 0)

# # get time mask from nldas data
# time_obj = pd.to_datetime(time)
# mask_t  = (time_obj >= plot_date_start_obj) & (time_obj <= plot_date_end_obj) 
# time = time_obj[mask_t]

# # time series mean
# prcp_mean = np.nanmean(pcp[mask_t,:,:], axis=0) #(y, x))
# tmean_mean = np.nanmean(t_mean[mask_t,:,:], axis=0) 
# tmin_mean = np.nanmean(t_min[mask_t,:,:], axis=0) 
# tmax_mean = np.nanmean(t_max[mask_t,:,:], axis=0) 
# trange_mean = np.nanmean(t_range[mask_t,:,:], axis=0)

# # convert masked values to nan
# prcp_mean=np.where(mask_xy==0,np.nan,prcp_mean)
# tmean_mean=np.where(mask_xy==0,np.nan,tmean_mean)
# tmin_mean=np.where(mask_xy==0,np.nan,tmin_mean)
# tmax_mean=np.where(mask_xy==0,np.nan,tmax_mean)
# trange_mean=np.where(mask_xy==0,np.nan,trange_mean)

# del pcp,t_mean,t_range

# time_nldas = datetime.datetime.now()
# print('read NLDAS time:',time_nldas - startTime)

#======================================================================================================
print('Plot')
# loop through all uniform tests
for test_folder in test_folders:
    
    print(test_folder)
    test_dir = os.path.join(result_dir, test_folder)
    fig_title= test_folder

    # read ensemble mean    
    output_namebase = os.path.join(test_dir,subforlder, file_basename)
    metric = 'ensmean'
    time_ensmean, pcp_ensmean, tmean_ensmean, tmin_ensmean, tmax_ensmean, trange_ensmean = read_ens(output_namebase, metric, start_yr, end_yr)

    # read ensemble std    
    output_namebase = os.path.join(test_dir,subforlder, file_basename)
    metric = 'ensstd'
    time_ensstd, pcp_ensstd, tmean_ensstd, tmin_ensstd, tmax_ensstd, trange_ensstd = read_ens(output_namebase, metric, start_yr, end_yr)   

    # define plot mask for nldas ensemble
    mask_ens_t = (time_ensmean>=plot_date_start_obj) & (time_ensmean<=plot_date_end_obj)
    
    # caluclate time series mean(ny,nx)
    pcp_ensmean = np.nanmean(pcp_ensmean[mask_ens_t,:,:],axis=0) 
    pcp_ensstd = np.nanmean(pcp_ensstd[mask_ens_t,:,:],axis=0)
    
    tmean_ensmean = np.nanmean(tmean_ensmean[mask_ens_t,:,:],axis=0)
    tmean_ensstd = np.nanmean(tmean_ensstd[mask_ens_t,:,:],axis=0)

    tmin_ensmean = np.nanmean(tmin_ensmean[mask_ens_t,:,:],axis=0)
    tmin_ensstd = np.nanmean(tmin_ensstd[mask_ens_t,:,:],axis=0)

    tmax_ensmean = np.nanmean(tmax_ensmean[mask_ens_t,:,:],axis=0)
    tmax_ensstd = np.nanmean(tmax_ensstd[mask_ens_t,:,:],axis=0)

    trange_ensmean = np.nanmean(trange_ensmean[mask_ens_t,:,:],axis=0)
    trange_ensstd = np.nanmean(trange_ensstd[mask_ens_t,:,:],axis=0)
    
    # convert masked values to nan
    pcp_ensmean=np.where(mask_xy==0,np.nan,pcp_ensmean)
    pcp_ensstd=np.where(mask_xy==0,np.nan,pcp_ensstd)
    
    tmean_ensmean=np.where(mask_xy==0,np.nan,tmean_ensmean)
    tmean_ensstd=np.where(mask_xy==0,np.nan,tmean_ensstd)
    
    tmin_ensmean=np.where(mask_xy==0,np.nan,tmin_ensmean)
    tmin_ensstd=np.where(mask_xy==0,np.nan,tmin_ensstd)
    
    tmax_ensmean=np.where(mask_xy==0,np.nan,tmax_ensmean)
    tmax_ensstd=np.where(mask_xy==0,np.nan,tmax_ensstd)
    
    trange_ensmean=np.where(mask_xy==0,np.nan,trange_ensmean)
    trange_ensstd=np.where(mask_xy==0,np.nan,trange_ensstd)
    
    # setup plot colorbar range for the plot_date
#     vmin_prcp_mean=np.nanmin([np.nanmin(prcp_mean), np.nanmin(pcp_ensmean)])
#     vmax_prcp_mean=np.nanmax([np.nanmax(prcp_mean), np.nanmax(pcp_ensmean)])
#     vmin_prcp_std=np.nanmin(pcp_ensstd)
#     vmax_prcp_std=np.nanmax(pcp_ensstd)
    
#     vmin_t_mean_mean=np.nanmin([np.nanmin(tmean_mean), np.nanmin(tmean_ensmean)])
#     vmax_t_mean_mean=np.nanmax([np.nanmax(tmean_mean), np.nanmax(tmean_ensmean)])
#     vmin_t_mean_std=np.nanmin(tmean_ensstd)
#     vmax_t_mean_std=np.nanmax(tmean_ensstd)
    
#     vmin_t_range_mean=np.nanmin([np.nanmin(trange_mean), np.nanmin(trange_ensmean)])
#     vmax_t_range_mean=np.nanmax([np.nanmax(trange_mean), np.nanmax(trange_ensmean)])
#     vmin_t_range_std=np.nanmin(trange_ensstd)
#     vmax_t_range_std=np.nanmax(trange_ensstd)
    vmin_prcp_mean=0
    vmax_prcp_mean=20
    vmin_prcp_std=0
    vmax_prcp_std=9
    
    vmin_t_mean_mean=-12
    vmax_t_mean_mean=35
    vmin_t_mean_std=0
    vmax_t_mean_std=6
    
    vmin_t_range_mean=0
    vmax_t_range_mean=35
    vmin_t_range_std=0
    vmax_t_range_std=3.5

    # plot
    nrow = 5 # prcp, tmean, tmin, tmax, trange
    ncol = 3 # NLDAS, ens mean, ens std
    fig, ax = plt.subplots(nrow, ncol, figsize=(8,8*0.85), constrained_layout=True)
#     fig.set_figwidth(5.5*ncol) 
#     fig.set_figheight(5.5*0.75*nrow)
    fig.suptitle(fig_title, fontsize='x-small', fontweight='semibold', color='g')

    llcrnrlon = longitude[0,0]
    urcrnrlon = longitude[-1,-1]
    llcrnrlat = latitude[0,0]
    urcrnrlat = latitude[-1,-1]
    lat_0=0.5*(llcrnrlat+urcrnrlat)
    lon_0=0.5*(llcrnrlon+urcrnrlon)
    (ny,nx)=np.shape(longitude)
    
    for i in range(nrow):
        for j in range(ncol):
    
            # plot Basemap
            m = plot_basemap(llcrnrlon,llcrnrlat,urcrnrlon,urcrnrlat,ax[i,j],lat_0,lon_0,ny,nx) # plot Basemap 
            
            # select data for each subplot
            # PCP (first row)
            if i == 0 and j == 0:
                data=prcp_mean
                cmap=plt.cm.Blues
                vmin=vmin_prcp_mean
                vmax=vmax_prcp_mean
                title_str = '(a) NLDAS Daily Precip'
            elif i == 0 and j == 1:
                data=pcp_ensmean
                cmap=plt.cm.Blues
                vmin=vmin_prcp_mean
                vmax=vmax_prcp_mean
                title_str = '(b) Ens Mean of Daily Precip'
            elif i == 0 and j == 2:
                data=pcp_ensstd
                cmap=plt.cm.Blues
                vmin=vmin_prcp_std
                vmax=vmax_prcp_std
                title_str = '(c) Ens Std of Daily Precip'
    
            # T_MEAN (second row)
            elif i == 1 and j == 0:
                data=tmean_mean
                cmap=plt.cm.Reds
                vmin=vmin_t_mean_mean
                vmax=vmax_t_mean_mean
                title_str = '(d) NLDAS Mean Temp'
            elif i == 1 and j == 1:
                data=tmean_ensmean
                cmap=plt.cm.Reds
                vmin=vmin_t_mean_mean
                vmax=vmax_t_mean_mean
                title_str = '(e) Ens Mean of Mean Temp'
            elif i == 1 and j == 2:
                data=tmean_ensstd
                cmap=plt.cm.Reds
                vmin=vmin_t_mean_std
                vmax=vmax_t_mean_std
                title_str = '(f) Ens Std of Mean Temp'
    
            # T_MIN (third row)
            elif i == 2 and j == 0:
                data=tmin_mean
                cmap=plt.cm.Reds
                vmin=vmin_t_mean_mean
                vmax=vmax_t_mean_mean
                title_str = '(g) NLDAS Min Temp'
            elif i == 2 and j == 1:
                data=tmin_ensmean
                cmap=plt.cm.Reds
                vmin=vmin_t_mean_mean
                vmax=vmax_t_mean_mean
                title_str = '(h) Ens Mean of Min Temp'
            elif i == 2 and j == 2:
                data=tmin_ensstd
                cmap=plt.cm.Reds
                vmin=vmin_t_mean_std
                vmax=vmax_t_mean_std
                title_str = '(i) Ens Std of Min Temp'
    
            # T_MAX (fourth row)
            elif i == 3 and j == 0:
                data=tmax_mean
                cmap=plt.cm.Reds
                vmin=vmin_t_mean_mean
                vmax=vmax_t_mean_mean
                title_str = '(j) NLDAS Max Temp'
            elif i == 3 and j == 1:
                data=tmax_ensmean
                cmap=plt.cm.Reds
                vmin=vmin_t_mean_mean
                vmax=vmax_t_mean_mean
                title_str = '(k) Ens Mean of Max Temp'
            elif i == 3 and j == 2:
                data=tmax_ensstd
                cmap=plt.cm.Reds
                vmin=vmin_t_mean_std
                vmax=vmax_t_mean_std
                title_str = '(l) Ens Std of Max Temp'
    
            # T_RANGE (fifth row)
            elif i == 4 and j == 0:
                data=trange_mean
                cmap=plt.cm.Greens
                vmin=vmin_t_range_mean
                vmax=vmax_t_range_mean
                title_str = '(m) NLDAS Temp Range'
            elif i == 4 and j == 1:
                data=trange_ensmean
                cmap=plt.cm.Greens
                vmin=vmin_t_range_mean
                vmax=vmax_t_range_mean
                title_str = '(n) Ens Mean of Temp Range'
            elif i == 4 and j == 2:
                data=trange_ensstd
                cmap=plt.cm.Greens #bwr
                vmin=vmin_t_range_std
                vmax=vmax_t_range_std
                title_str = '(o) Ens Std of Temp Range'

            # plot data
            im1 = m.pcolormesh(longitude,latitude,data,shading='flat',latlon=True,cmap=cmap,vmin=vmin,vmax=vmax)

            # set title
            ax[i,j].set_title(title_str, fontsize='xx-small', fontweight='semibold')

            # set colorbar
            cbar = m.colorbar(im1, location='right')
            if i == 0:
                cbar.set_label(label='(mm/day)', size='xx-small', rotation='horizontal', labelpad=-20, y=1.1) #y=1.04
            elif i >= 1:
                cbar.set_label(label='($^\circ$C)', size='xx-small', rotation='horizontal', labelpad=-20, y=1.1)
            cbar.ax.tick_params(labelsize='xx-small') 
    
    # save plot
#     fig.tight_layout #(pad=0.1, h_pad=0.5, w_pad=0.70) # pad and h_pad configurations have no effects.
    output_filename = test_folder+'.png'
    fig.savefig(os.path.join(output_dir, output_filename), dpi=dpi_value)
    plt.close(fig)
    #     plt.show()
    del time_ensmean, pcp_ensmean, tmean_ensmean, trange_ensmean
    del time_ensstd, pcp_ensstd, tmean_ensstd, trange_ensstd 

print('Done')
print('Total time:', datetime.datetime.now() - startTime)


Plot
00810grids


The dedent function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use inspect.cleandoc instead.
The dedent function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use inspect.cleandoc instead.
The dedent function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use inspect.cleandoc instead.


00974grids


The dedent function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use inspect.cleandoc instead.
The dedent function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use inspect.cleandoc instead.
The dedent function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use inspect.cleandoc instead.


01225grids


The dedent function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use inspect.cleandoc instead.
The dedent function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use inspect.cleandoc instead.
The dedent function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use inspect.cleandoc instead.


01610grids


The dedent function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use inspect.cleandoc instead.
The dedent function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use inspect.cleandoc instead.
The dedent function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use inspect.cleandoc instead.


02251grids


The dedent function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use inspect.cleandoc instead.
The dedent function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use inspect.cleandoc instead.
The dedent function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use inspect.cleandoc instead.


03186grids


The dedent function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use inspect.cleandoc instead.
The dedent function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use inspect.cleandoc instead.
The dedent function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use inspect.cleandoc instead.


04951grids


The dedent function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use inspect.cleandoc instead.
The dedent function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use inspect.cleandoc instead.
The dedent function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use inspect.cleandoc instead.


08884grids


The dedent function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use inspect.cleandoc instead.
The dedent function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use inspect.cleandoc instead.
The dedent function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use inspect.cleandoc instead.


18074grids


The dedent function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use inspect.cleandoc instead.
The dedent function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use inspect.cleandoc instead.
The dedent function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use inspect.cleandoc instead.


Done
Total time: 0:06:03.658332


In [11]:
vmax_t_mean_std,vmax_t_min_std,vmax_t_max_std,vmax_t_range_std

(4.8909454, 4.8909454, 4.951986169708093, 2.0169)

In [13]:
vmin_t_range_mean,vmax_t_range_mean

(0, 30)

In [14]:
vmin_t_range_std,vmax_t_range_std

(0, 3.2)