In [74]:
#!/usr/bin/env python
# coding: utf-8

# This script is used to compare ensemble outputs with NLDAS data
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import os
import pandas as pd
import xarray as xr
import datetime

def read_ens(out_forc_name_base, metric, start_yr, end_yr):
    for yr in range(start_yr, end_yr+1):        
        
        file = os.path.join(out_forc_name_base + '.' + str(yr) + '.'+metric+'.nc')
        f=xr.open_dataset(file)
        time = f['time'][:]
        pcp = f.variables['pcp'][:]
        tmean = f.variables['t_mean'][:]
        tmin = f.variables['t_min'][:]
        tmax = f.variables['t_max'][:]
        trange = f.variables['t_range'][:]
        
        if yr == start_yr:
            time_concat = time
            pcp_concat = pcp
            tmean_concat = tmean
            tmin_concat = tmin
            tmax_concat = tmax
            trange_concat = trange
        else:
            time_concat = np.concatenate((time_concat,time), axis=0) # (time)
            pcp_concat = np.concatenate((pcp_concat, pcp), axis=0) # (time,y,x)
            tmean_concat = np.concatenate((tmean_concat, tmean), axis=0)
            tmin_concat = np.concatenate((tmin_concat, tmin), axis=0)
            tmax_concat = np.concatenate((tmax_concat, tmax), axis=0)
            trange_concat = np.concatenate((trange_concat, trange), axis=0)
            
    time_concat = pd.DatetimeIndex(time_concat)
        
    return time_concat, pcp_concat, tmean_concat, tmin_concat, tmax_concat, trange_concat

#======================================================================================================
# main script
root_dir = '/glade/u/home/hongli/scratch/2020_04_21nldas_gmet'   
stn_ens_dir = os.path.join(root_dir,'data/stn_ens_summary')
nldas_dir = os.path.join(root_dir,'data/nldas_daily_utc_convert')
start_yr = 2015
end_yr = 2016

gridinfo_file = os.path.join(root_dir,'data/nldas_topo/conus_ens_grid_eighth.nc')

result_dir = os.path.join(root_dir,'test_uniform_perturb')
test_folders = [d for d in os.listdir(result_dir)]
test_folders = sorted(test_folders)
scenarios_ids = range(0,9) #[0,1,5,8] 
intervals =  range(10,1,-1) #[10,9,5,2]
scenario_num = len(scenarios_ids)

subforlder = 'gmet_ens_summary'
file_basename = 'ens_forc'

ens_num = 100
time_format = '%Y-%m-%d'

dpi_value = 600
plot_date_start = '2015-01-01'
plot_date_end = '2016-12-31'
plot_date_start_obj = datetime.datetime.strptime(plot_date_start, time_format)
plot_date_end_obj = datetime.datetime.strptime(plot_date_end, time_format)

output_dir=os.path.join(root_dir, 'scripts/step20_plot_nldas_IQR_DOY')
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
output_filename = 'step20_plot_nldas_IQR_DOY_temp.png'
   
# #======================================================================================================
# print('Read gridinfo mask')
# # get xy mask from gridinfo.nc
# f_gridinfo = xr.open_dataset(gridinfo_file)
# mask_xy = f_gridinfo['mask'].values[:] # (y, x). 1 is valid. 0 is invalid.
# #data_mask = f_gridinfo['data_mask'].values[:] # (y, x). 1 is valid. 0 is invalid.

# #======================================================================================================
# # read scenario ensemble results and save to dictionary
# print('Read nldas ens bounds')
# k=6-1
# test_folder = test_folders[scenarios_ids[k]]

# print(test_folder)
# test_dir = os.path.join(result_dir, test_folder)
# fig_title= test_folder

# print(' -- read spatial ensemble')
# # read ensemble mean    
# output_namebase = os.path.join(test_dir,subforlder, file_basename)
# metric = 'enspctl.5'
# time_enslb, pcp_enslb, tmean_enslb, tmin_enslb, tmax_enslb, trange_enslb = read_ens(output_namebase, metric, start_yr, end_yr)

# output_namebase = os.path.join(test_dir,subforlder, file_basename)
# metric = 'enspctl.95'
# time_ensub, pcp_ensub, tmean_ensub, tmin_ensub, tmax_ensub, trange_ensub = read_ens(output_namebase, metric, start_yr, end_yr)

# # define plot mask for nldas ensemble
# mask_ens_t = (time_enslb>=plot_date_start_obj) & (time_enslb<=plot_date_end_obj)

# print(' -- calculate IQR')
# # IQR = upper limit - lower limit
# (nt,ny,nx) = np.shape(pcp_ensub[mask_ens_t,:,:])
# pcp_ensiqr = pcp_ensub[mask_ens_t,:,:]-pcp_enslb[mask_ens_t,:,:]   
# tmean_ensiqr = tmean_ensub[mask_ens_t,:,:]-tmean_enslb[mask_ens_t,:,:]
# tmin_ensiqr = tmin_ensub[mask_ens_t,:,:]-tmin_enslb[mask_ens_t,:,:]
# tmax_ensiqr = tmax_ensub[mask_ens_t,:,:]-tmax_enslb[mask_ens_t,:,:]
# trange_ensiqr = trange_ensub[mask_ens_t,:,:]-trange_enslb[mask_ens_t,:,:]

# del pcp_enslb, tmean_enslb, tmin_enslb, tmax_enslb, trange_enslb  
# del pcp_ensub, tmean_ensub, tmin_ensub, tmax_ensub, trange_ensub 

# print(' -- extract unmasked values')
# # extract unmasked values
# mask_xy_3d = np.repeat(mask_xy[np.newaxis,:,:],nt,axis=0)
# pcp_ensiqr=pcp_ensiqr[mask_xy_3d!=0]    
# tmean_ensiqr=tmean_ensiqr[mask_xy_3d!=0] 
# tmin_ensiqr=tmin_ensiqr[mask_xy_3d!=0]  
# tmax_ensiqr=tmax_ensiqr[mask_xy_3d!=0]   
# trange_ensiqr=trange_ensiqr[mask_xy_3d!=0] 

# print(' -- reshape')
# # reshpae (nt,ny,nx) -> (nt,ny*nx)
# pcp_ensiqr = pcp_ensiqr.reshape((nt,-1))
# tmean_ensiqr = tmean_ensiqr.reshape((nt,-1))
# tmin_ensiqr = tmin_ensiqr.reshape((nt,-1))
# tmax_ensiqr = tmax_ensiqr.reshape((nt,-1))
# trange_ensiqr = trange_ensiqr.reshape((nt,-1))

# #======================================================================================================    
# # create a white-blue linear colormap
# print('create colormap')

# # reference: https://stackoverflow.com/questions/25408393/getting-individual-colors-from-a-color-map-in-matplotlib
# cmap = mpl.cm.get_cmap('jet') # get the blue color of jet 
# c0 = cmap(0.0)
# top = mpl.colors.LinearSegmentedColormap.from_list("", ["white",c0])

# # combine two liner colormaps to create a
# # reference: https://matplotlib.org/3.1.0/tutorials/colors/colormap-manipulation.html
# bottom = mpl.cm.get_cmap('jet')
# newcolors = np.vstack((top(np.linspace(0, 1, int(256*0.1))),bottom(np.linspace(0, 1, int(256*0.9)))))
# newcmp = mpl.colors.LinearSegmentedColormap.from_list("WhiteJet", newcolors)

##======================================================================================================    
# plot
print('Plot')
# var_list = ['Precp', 'Tmean']#, 'Tmin', 'Tmax']#, 'Trange']
# var_units = ['(mm/d)','($^\circ$C)']#,'($^\circ$C)','($^\circ$C)']#,'($^\circ$C)']
var_list = ['Tmean', 'Tmin', 'Tmax', 'Trange']
var_units = ['($^\circ$C)','($^\circ$C)','($^\circ$C)','($^\circ$C)']

# plot each varaiable seperately
nrow = len(var_list) # totally 9 sampling scenarios
ncol = 1           
fig, ax = plt.subplots(nrow, ncol, figsize=(5.5,5.5*0.75))

bins = 100
color_list = [ newcmp(x) for x in np.linspace(0, 1, bins) ]

for i in range(nrow):
    
    print(var_list[i])

    # data selection
#     if i == 0:
#         data = pcp_ensiqr 
    if i == 0:
        data = tmean_ensiqr
    elif i == 1:
        data = tmin_ensiqr
    elif i == 2:
        data = tmax_ensiqr
    elif i == 3:
        data = trange_ensiqr

#     # if pcp, remove obs=zero to avoid skew
#     if i == 0:
        #rain or no-rain (two scenarios, two plots)
    
    # vmin and vmax
    vmin = np.nanmin(data)
    vmax = np.percentile(data,95) #np.nanmax(data) #np.percentile(data,75)
    
    # calculate DOY (day of year) mean IQR    
    df = pd.DataFrame(data)    
    time_month = [t.month for t in time_ensub]
    time_day = [t.day for t in time_ensub]
    df['month']=time_month
    df['date']=time_day
#     df['datetime'] = time_ensub
#     df = df.set_index('datetime')    
    df2 = df.groupby(['month','date']).mean()
    
    # calculate grid count per IQR
    nt = len(df2)
    freq_arr = np.zeros((bins,nt))
    for d in range(nt):
        [hist,bin_edges] = np.histogram(df2.iloc[d,:], bins=bins, range=(vmin,vmax))
        freq_arr[:,d] = hist    

    # stackplot
    x = np.arange(1,1+nt)    
    y = freq_arr
#     h = ax[i].stackplot(x, y, colors=color_list)
    im = ax[i].imshow(y, cmap=newcmp, origin='lower')#, aspect = 'auto')

    # colorbar
    cbar = fig.colorbar(im,ax=ax[i],cmap=newcmp,orientation='vertical',pad=0.01, shrink=1)
    cbar.ax.tick_params(labelsize='xx-small', length=2, width=1)
    cbar.set_label(label='Number of grids',size='xx-small')    

    # limit
    ax[i].set_xlim(1,nt)
    ax[i].set_yticks(np.linspace(0, ax[i].get_yticks()[-1], 5))
    ylabels = np.linspace(bin_edges[0],bin_edges[-1],5)
    ylabels = [round(yy,1) for yy in ylabels]
    ax[i].set_yticklabels(ylabels)

    # label
    if i == nrow-1:
        xlabel = 'Day of Year (DOY) '
        ax[i].set_xlabel(xlabel, fontsize='xx-small')
    ylabel = 'IQR '+var_units[i]
    ax[i].set_ylabel(ylabel, fontsize='xx-small')
    
    # title
    alpha = chr(ord('a') + i)
    ax[i].set_title('('+alpha+') '+var_list[i], fontsize='xx-small', fontweight='semibold')

    # tick
    ax[i].tick_params(axis='both', direction='out',labelsize = 'xx-small', 
                        length=2, width=0.5, pad=1.5)

   # change subplot border width
    for axis in ['top','bottom','left','right']:
        ax[i].spines[axis].set_linewidth(0.5)

# # colorbar    
# fig.subplots_adjust(bottom=0.17, top=1, left = 0, right=1, wspace = 0.07, hspace = 0.25)
# cax = fig.add_axes([0.25, 0.05, 0.5, 0.02]) #[left, bottom, width, height]
# cbar = fig.colorbar(h[3], cax=cax, orientation='horizontal')

# tick1 = h[0].max()*0.5
# tick2 = h[0].max()
# cbar.set_ticks([0, tick1, tick2]) 
# cbar.set_ticklabels(['Low', 'Medium', 'High'])  
# cbar.ax.tick_params(labelsize='xx-small', length=2, width=1)

# # set the colorbar ticks and tick labels
# cbar.set_label(label='Number of grids per pixel',size='xx-small')    

# save plot
fig.tight_layout(pad=0.1, h_pad=0.5) 
fig.savefig(os.path.join(output_dir, output_filename), dpi=dpi_value,
            bbox_inches = 'tight', pad_inches = 0.05)
plt.close(fig)
print('Done')


Plot
Tmean
Tmin
Tmax
Trange
Done


In [33]:
bin_edges

array([ 0.       ,  0.4132919,  0.8265838,  1.2398757,  1.6531676,
        2.0664594,  2.4797513,  2.8930433,  3.3063352,  3.7196271,
        4.132919 ,  4.546211 ,  4.9595027,  5.3727946,  5.7860866,
        6.1993785,  6.6126704,  7.0259624,  7.4392543,  7.8525457,
        8.265838 ,  8.67913  ,  9.092422 ,  9.505713 ,  9.919005 ,
       10.332297 , 10.745589 , 11.158881 , 11.572173 , 11.985465 ,
       12.398757 , 12.812049 , 13.225341 , 13.638633 , 14.051925 ,
       14.465217 , 14.878509 , 15.2918005, 15.705091 , 16.118383 ,
       16.531675 , 16.944967 , 17.35826  , 17.771551 , 18.184843 ,
       18.598135 , 19.011427 , 19.424719 , 19.83801  , 20.251303 ,
       20.664595 , 21.077887 , 21.491179 , 21.90447  , 22.317762 ,
       22.731054 , 23.144346 , 23.557638 , 23.97093  , 24.384222 ,
       24.797514 , 25.210806 , 25.624098 , 26.03739  , 26.450682 ,
       26.863974 , 27.277266 , 27.690557 , 28.10385  , 28.517141 ,
       28.930433 , 29.343725 , 29.757017 , 30.17031  , 30.5836

In [58]:
locsx = ax[i].get_xticks()
locsy = ax[i].get_yticks()
locsx,locsy

(array([-100.,    0.,  100.,  200.,  300.,  400.]),
 array([-50.,   0.,  50., 100.]))

In [60]:
ax[i].get_xticks()[-1]

400.0

In [61]:
ax[i].set_yticks(np.linspace(0, ax[i].get_yticks()[-1], 5))
ax[i].set_xticklabels(np.linspace(bin_edges[0],bin_edges[-1],5))

(0.0, 3.601881980895996)

In [62]:
vmax

3.601881980895996