In [2]:
#!/usr/bin/env python
# coding: utf-8

# This script is used to compare two ensemble outputs (e.g., gauge-based GMET and NLDAS-based GMET)

from mpl_toolkits.basemap import Basemap
from pyproj import Proj
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import os
import pandas as pd
import xarray as xr
import datetime

startTime = datetime.datetime.now()

def read_ens(out_forc_name_base, start_yr, end_yr, ens_num):
    for yr in range(start_yr, end_yr+1):        
        for i in range(ens_num):
            ens_file = os.path.join(out_forc_name_base + '.' + str(yr) + '.' + str('%03d' % (i+1)) +'.nc')

            f=xr.open_dataset(ens_file)
            time = f['time'][:]
            pcp = f.variables['pcp'][:]
            tmean = f.variables['t_mean'][:]
            trange = f.variables['t_range'][:]

            if i == 0:
                lats_ens = f['latitude'].values[:] #shape (y,x)
                lons_ens = f['longitude'].values[:]                

                pcp_ens_mb = np.zeros((np.shape(pcp)[0], np.shape(pcp)[1], np.shape(pcp)[2], ens_num))# create ens array for one member
                tmean_ens_mb = np.zeros((np.shape(pcp)[0], np.shape(pcp)[1], np.shape(pcp)[2], ens_num))
                trange_ens_mb = np.zeros((np.shape(pcp)[0], np.shape(pcp)[1], np.shape(pcp)[2], ens_num))

            pcp_ens_mb[:,:,:,i] = pcp
            tmean_ens_mb[:,:,:,i] = tmean
            trange_ens_mb[:,:,:,i] = trange
        
        if yr == start_yr:
            time_ens = time
            pcp_ens = pcp_ens_mb
            tmean_ens = tmean_ens_mb
            trange_ens = trange_ens_mb
        else:
            time_ens = np.concatenate((time_ens,time), axis=0) # (time)
            pcp_ens = np.concatenate((pcp_ens, pcp_ens_mb), axis=0) # (time,y,x,ens_num)
            tmean_ens = np.concatenate((tmean_ens, tmean_ens_mb), axis=0)
            trange_ens = np.concatenate((trange_ens, trange_ens_mb), axis=0)
            
        time_ens = pd.DatetimeIndex(time_ens.dt.floor('D').to_pandas())
        
    return lats_ens, lons_ens, time_ens, pcp_ens, tmean_ens, trange_ens

def plot_basemap(llcrnrlon,llcrnrlat,urcrnrlon,urcrnrlat,ax,lat_0,lon_0,ny,nx):

    m = Basemap(llcrnrlon,llcrnrlat,urcrnrlon,urcrnrlat,resolution='l',projection='cyl', ax=ax)   
#     m = Basemap(llcrnrlon,llcrnrlat,urcrnrlon,urcrnrlat,resolution='l',projection='tmerc', ax=ax,lat_0=lat_0,lon_0=lon_0)

    m.drawstates(linewidth=1, linestyle='solid', color='grey')
    m.drawcountries(linewidth=1, linestyle='solid', color='k')
    m.drawcoastlines(linewidth=.75, linestyle='solid', color='k')
    return m

# set the colormap and centre the colorbar
class MidpointNormalize(mpl.colors.Normalize):
    """Normalise the colorbar.
    source: http://chris35wills.github.io/matplotlib_diverging_colorbar/
    e.g. im=ax1.imshow(array, norm=MidpointNormalize(midpoint=0.,vmin=-300, vmax=1000))    
    """
    def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False):
        self.midpoint = midpoint
        mpl.colors.Normalize.__init__(self, vmin, vmax, clip)

    def __call__(self, value, clip=None):
        x, y = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1]
        return np.ma.masked_array(np.interp(value, x, y), np.isnan(value))

#======================================================================================================
# main script
root_dir = '/glade/u/home/hongli/scratch/2020_04_21nldas_gmet'   
nldas_dir = os.path.join(root_dir,'data/nldas_daily_utc')
start_yr = 2015
end_yr = 2016

gridinfo_file = os.path.join(root_dir,'scripts/conus_ens_grid_eighth_deg_v1p1.nc')

result_dir = os.path.join(root_dir,'test_uniform')
test_folders = [d for d in os.listdir(result_dir)]
test_folders = sorted(test_folders)

ens_num = 100
time_format = '%Y-%m-%d'

dpi_value = 150
plot_date_start = '2015-01-01'
plot_date_end = '2016-12-31'
plot_date_start_obj = datetime.datetime.strptime(plot_date_start, time_format)
plot_date_end_obj = datetime.datetime.strptime(plot_date_end, time_format)

output_dir=os.path.join(root_dir, 'scripts/step7_plot_spatial_NLDAS_ens')
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
of_namebase = 'spatial_'
    
#======================================================================================================
print('Read gridinfo mask')
# get xy mask from gridinfo.nc
f_gridinfo = xr.open_dataset(gridinfo_file)
mask_xy = f_gridinfo['mask'].values[:] # (y, x). 1 is valid. 0 is invalid.
data_mask = f_gridinfo['data_mask'].values[:] # (y, x). 1 is valid. 0 is invalid.

#======================================================================================================
# read historical nldas data
print('Read nldas data')
for yr in range(start_yr, end_yr+1):
    
    nldas_file = 'NLDAS_'+str(yr)+'.nc'
    nldas_path = os.path.join(nldas_dir, nldas_file)
    
    f_nldas = xr.open_dataset(nldas_path)
    if yr == start_yr:
        prcp_avg = f_nldas['prcp_avg'].values[:] # (time, y, x). unit: kg/m^2 = mm
        tair_min = f_nldas['tair_min'].values[:] # (time, y, x). unit: K
        tair_max = f_nldas['tair_max'].values[:]
        time = pd.to_datetime(f_nldas['time'].values[:]).strftime(time_format)
    else:
        prcp_avg = np.concatenate((prcp_avg, f_nldas['prcp_avg'].values[:]), axis = 0)
        tair_min = np.concatenate((tair_min, f_nldas['tair_min'].values[:]), axis = 0)
        tair_max = np.concatenate((tair_max, f_nldas['tair_max'].values[:]), axis = 0)
        time = np.concatenate((time, pd.to_datetime(f_nldas['time'].values[:]).strftime(time_format)), axis = 0)

# get time mask from nldas data
time_obj = np.asarray([datetime.datetime.strptime(t, time_format) for t in time])
mask_t  = (time_obj >= plot_date_start_obj) & (time_obj <= plot_date_end_obj) 
time = time_obj[mask_t]
  
# convert unit and calculate mean values
prcp_sum = np.multiply(prcp_avg[mask_t,:,:], 24.0) #mm/hr to mm/day
tair_min = np.subtract(tair_min[mask_t,:,:], 273.15)
tair_max = np.subtract(tair_max[mask_t,:,:], 273.15)

prcp_mean = np.nanmean(prcp_sum, axis=0) #(y, x))
tmean_mean = np.nanmean(0.5*(tair_min+tair_max), axis=0) 
trange_mean = np.nanmean((tair_max-tair_min), axis=0)
del prcp_avg,tair_min,tair_max

time_nldas = datetime.datetime.now()
print('read NLDAS time:',time_nldas - startTime)

#======================================================================================================
print('Plot')
# loop through all uniform tests
for test_folder in test_folders[6:7]:
    
    print(test_folder)
    test_dir = os.path.join(result_dir, test_folder)
    fig_title= test_folder

    # read output ensemble
    time_ens0 = datetime.datetime.now()
    
    output_namebase = os.path.join(test_dir,'outputs', 'ens_forc')
    lats_ens, lons_ens, time_ens, pcp_ens, tmean_ens, trange_ens = read_ens(output_namebase, start_yr, end_yr, ens_num)
    
    time_ens1 = datetime.datetime.now()
    print('read ensemble time:',time_ens1 - time_ens0)

    # define plot mask for nldas ensemble
    mask_ens_t = (time_ens>=plot_date_start_obj) & (time_ens<=plot_date_end_obj)
    
    # calculate time-series mean values of mean and std. (y,x)
    pcp_ens_mean = np.nanmean(np.nanmean(pcp_ens[mask_ens_t,:,:,:], axis = 3),axis=0) 
    pcp_ens_std = np.nanmean(np.std(pcp_ens[mask_ens_t,:,:,:], axis = 3),axis=0)
    
    tmean_ens_mean = np.nanmean(np.nanmean(tmean_ens[mask_ens_t,:,:,:], axis = 3),axis=0)
    tmean_ens_std = np.nanmean(np.std(tmean_ens[mask_ens_t,:,:,:], axis = 3),axis=0)

    trange_ens_mean = np.nanmean(np.nanmean(trange_ens[mask_ens_t,:,:,:], axis = 3),axis=0)
    trange_ens_std = np.nanmean(np.std(trange_ens[mask_ens_t,:,:,:], axis = 3),axis=0)
    
    # convert masked values to nan
    pcp_ens_mean=np.where(mask_xy==0,np.nan,pcp_ens_mean)
    pcp_ens_std=np.where(mask_xy==0,np.nan,pcp_ens_std)
    
    tmean_ens_mean=np.where(mask_xy==0,np.nan,tmean_ens_mean)
    tmean_ens_std=np.where(mask_xy==0,np.nan,tmean_ens_std)
    
    trange_ens_mean=np.where(mask_xy==0,np.nan,trange_ens_mean)
    trange_ens_std=np.where(mask_xy==0,np.nan,trange_ens_std)
    
    # setup plot colorbar range for the plot_date
    vmin_prcp_mean=np.nanmin([np.nanmin(prcp_mean), np.nanmin(pcp_ens_mean)])
    vmax_prcp_mean=np.nanmax([np.nanmax(prcp_mean), np.nanmax(pcp_ens_mean)])
    vmin_prcp_std=np.nanmin(pcp_ens_std)
    vmax_prcp_std=np.nanmax(pcp_ens_std)
    
    vmin_t_mean_mean=np.nanmin([np.nanmin(tmean_mean), np.nanmin(tmean_ens_mean)])
    vmax_t_mean_mean=np.nanmax([np.nanmax(tmean_mean), np.nanmax(tmean_ens_mean)])
    vmin_t_mean_std=np.nanmin(tmean_ens_std)
    vmax_t_mean_std=np.nanmax(tmean_ens_std)
    
    vmin_t_range_mean=np.nanmin([np.nanmin(trange_mean), np.nanmin(trange_ens_mean)])
    vmax_t_range_mean=np.nanmax([np.nanmax(trange_mean), np.nanmax(trange_ens_mean)])
    vmin_t_range_std=np.nanmin(trange_ens_std)
    vmax_t_range_std=np.nanmax(trange_ens_std)

    # plot
    nrow = 3 # prcp, tmean, trange
    ncol = 3 # NLDAS, ens mean, ens std
    fig, ax = plt.subplots(nrow, ncol)#, constrained_layout=True)
    fig.set_figwidth(5.5*ncol) 
    fig.set_figheight(5.5*0.75*nrow)
    fig.suptitle(fig_title, fontsize='medium', fontweight='semibold', color='g')

    llcrnrlon = lons_ens[0,0]
    urcrnrlon = lons_ens[-1,-1]
    llcrnrlat = lats_ens[0,0]
    urcrnrlat = lats_ens[-1,-1]
    lat_0=0.5*(llcrnrlat+urcrnrlat)
    lon_0=0.5*(llcrnrlon+urcrnrlon)
    (ny,nx)=np.shape(lats_ens)
    
    for i in range(nrow):
        for j in range(ncol):
    
            # plot Basemap
            m = plot_basemap(llcrnrlon,llcrnrlat,urcrnrlon,urcrnrlat,ax[i,j],lat_0,lon_0,ny,nx) # plot Basemap 
            
            # select data for each subplot
            # PCP (first row)
            if i == 0 and j == 0:
                data=prcp_mean
                cmap=plt.cm.Blues
                vmin=vmin_prcp_mean
                vmax=vmax_prcp_mean
                title_str = '(a) NLDAS Daily Precip'
            elif i == 0 and j == 1:
                data=pcp_ens_mean
                cmap=plt.cm.Blues
                vmin=vmin_prcp_mean
                vmax=vmax_prcp_mean
                title_str = '(b) Ens Mean of Daily Precip'
            elif i == 0 and j == 2:
                data=pcp_ens_std
                cmap=plt.cm.Blues
                vmin=vmin_prcp_std
                vmax=vmax_prcp_std
                title_str = '(c) Ens Std of Daily Precip'
    
            # T_MEAN (second row)
            elif i == 1 and j == 0:
                data=tmean_mean
                cmap=plt.cm.Reds
                vmin=vmin_t_mean_mean
                vmax=vmax_t_mean_mean
                title_str = '(d) NLDAS Mean Temp'
            elif i == 1 and j == 1:
                data=tmean_ens_mean
                cmap=plt.cm.Reds
                vmin=vmin_t_mean_mean
                vmax=vmax_t_mean_mean
                title_str = '(e) Ens Mean of Mean Temp'
            elif i == 1 and j == 2:
                data=tmean_ens_std
                cmap=plt.cm.Reds
                vmin=vmin_t_mean_std
                vmax=vmax_t_mean_std
                title_str = '(f) Ens Std of Mean Temp'
    
            # T_RANGE (third row)
            elif i == 2 and j == 0:
                data=trange_mean
                cmap=plt.cm.Greens
                vmin=vmin_t_range_mean
                vmax=vmax_t_range_mean
                title_str = '(g) NLDAS Temp Range'
            elif i == 2 and j == 1:
                data=trange_ens_mean
                cmap=plt.cm.Greens
                vmin=vmin_t_range_mean
                vmax=vmax_t_range_mean
                title_str = '(h) Ens Mean of Temp Range'
            elif i == 2 and j == 2:
                data=trange_ens_std
                cmap=plt.cm.Greens #bwr
                vmin=vmin_t_range_std
                vmax=vmax_t_range_std
                title_str = '(i) Ens Std of Temp Range'

            # plot data
            im1 = m.pcolormesh(lons_ens, lats_ens,data,shading='flat',latlon=True,cmap=cmap,vmin=vmin,vmax=vmax)

            # set title
            ax[i,j].set_title(title_str, fontsize='small', fontweight='semibold')

            # set colorbar
            cbar = m.colorbar(im1, location='right')
            if i == 0:
                cbar.set_label(label='(mm/day)', size='small', rotation='horizontal', labelpad=-20, y=1.1) #y=1.04
            elif i >= 1:
                cbar.set_label(label='($^\circ$C)', size='small', rotation='horizontal', labelpad=-20, y=1.1)
            cbar.ax.tick_params(labelsize='small') 
    
    # save plot
    fig.tight_layout(pad=0.1, w_pad=0.05, h_pad=0.00)
    output_filename = of_namebase+test_folder+'.png'
    fig.savefig(os.path.join(output_dir, output_filename), dpi=dpi_value)
    plt.close(fig)
    #     plt.show()
    del pcp_ens, tmean_ens, trange_ens

print('Done')
print('Total time:', datetime.datetime.now() - startTime)


Read gridinfo mask
Read nldas data




read NLDAS time: 0:00:05.383655
Plot
Done
Total time: 0:00:05.385083




In [18]:
np.shape(prcp_mean)

(array([], dtype=int64),)