In [1]:
import sys
import cartopy
import numpy as np
import netCDF4 as nc
from netCDF4 import Dataset
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
from matplotlib.patches import Polygon
import matplotlib.gridspec as gridspec
import matplotlib.ticker as mticker
from scipy.interpolate import griddata
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from cartopy.feature import NaturalEarthFeature
from datetime import datetime, timedelta
from wrf import (getvar, interplevel, get_cartopy, cartopy_xlim,
                 cartopy_ylim, to_np, latlon_coords, ALL_TIMES)

In [2]:
def time_mask(time, time_s, time_e, seconds=None):

    '''
    Checked on 14 Dec 2021, no problem was identified
    '''

    # print("In time_mask")

    Time_s = time_s - datetime(2000,1,1,0,0,0)
    Time_e = time_e - datetime(2000,1,1,0,0,0)

    if seconds == None:
        time_cood = (time>=Time_s) & (time<Time_e)
    else:
        time_cood = []
        for j in np.arange(len(time)):
            if seconds[0] >= seconds[1]:
                if_seconds = (time[j].seconds >= seconds[0]) | (time[j].seconds < seconds[1])
            else:
                if_seconds = (time[j].seconds >= seconds[0]) & (time[j].seconds < seconds[1])
            time_cood.append( (time[j]>=Time_s) & (time[j]<Time_e) & if_seconds)

    return time_cood

In [3]:
def UTC_to_AEST(time):

    '''
    change from UTC to AEST
    '''
    Time = time + timedelta(hours=10)

    return Time

In [4]:
def mask_by_lat_lon(file_path, loc_lat, loc_lon, lat_name, lon_name):

    '''
    make mask for the selected region
    '''

    file = nc.Dataset(file_path, mode='r')
    if len(np.shape(file.variables[lat_name][:])) == 3:
        # print("len(np.shape(file.variables[lat_name][:])) == 3")
        # print(lat_name)
        lat  = file.variables[lat_name][0,:,:]
        lon  = file.variables[lon_name][0,:,:]
    else:
        lat  = file.variables[lat_name][:]
        lon  = file.variables[lon_name][:]

    # print(lat)
    # print(lon)

    if len(np.shape(lat)) == 1:
        # print("len(np.shape(lat)) == 1")
        lat_spc = lat[1] - lat[0]
        lon_spc = lon[1] - lon[0]
        lons, lats = np.meshgrid(lon, lat)
        mask  = (lats > (loc_lat[0] - lat_spc/2)) & (lats < (loc_lat[1] + lat_spc/2)) & (lons > (loc_lon[0] - lon_spc/2)) & (lons < (loc_lon[1] + lon_spc/2))
    elif len(np.shape(lat)) == 2:
        # print("len(np.shape(lat)) == 2")
        ### caution: lat=100, lon=100 is a random pixel, lis run over a small domain may not have such a point
        lat_spc = lat[150,150] - lat[149,150]
        lon_spc = lon[150,150] - lon[150,149]
        print("print lat_spc  & lon_spc in def mask_by_lat_lon ", lat_spc, lon_spc)
        # print(lat_spc)
        # print(lon_spc)
        ### caution: due to irregular space in lis, using lat/lon +lat/lon_spc/2 may includes more than 1 pixel.
        ### I therefore let the space divied by 2.1 rather than 2
        mask  = (lat > (loc_lat[0] - lat_spc/2.1)) & (lat < (loc_lat[1] + lat_spc/2.1)) & (lon > (loc_lon[0] - lon_spc/2.1)) & (lon < (loc_lon[1] + lon_spc/2.1))
    # print(np.shape(mask))
    return mask

In [5]:
def spital_var(time, Var, time_s, time_e, seconds=None):

    # time should be AEST

    time_cood = time_mask(time, time_s, time_e, seconds)
    var       = np.nanmean(Var[time_cood],axis=0)

    # np.savetxt("test_var.txt",var,delimiter=",")
    return var

In [6]:
def read_var(file_path, var_name, loc_lat=None, loc_lon=None, lat_name=None, lon_name=None):

    '''
    Read observation data, output time coordinate and variable array
    Output: AEST time
    '''

    print(var_name)

    obs_file   = Dataset(file_path, mode='r')
    time_tmp   = nc.num2date(obs_file.variables['time'][:],obs_file.variables['time'].units,
                 only_use_cftime_datetimes=False, only_use_python_datetimes=True)
    if 'AWAP' in file_path or 'cable_out' in file_path:
        time   = time_tmp - datetime(2000,1,1,0,0,0)
    else:
        time   = UTC_to_AEST(time_tmp) - datetime(2000,1,1,0,0,0)
    ntime      = len(time)

    if loc_lat == None:
        Var_tmp = obs_file.variables[var_name][:]
        if hasattr(obs_file.variables[var_name], '_FillValue'):
            # hasattr(a,"b"): check whether object a has attribute 'b'
            def_val = obs_file.variables[var_name]._FillValue
            Var = np.where(Var_tmp == def_val, np.nan, Var_tmp)
        elif hasattr(obs_file.variables[var_name], '_fillvalue'):
            def_val = obs_file.variables[var_name]._fillvalue
            Var = np.where(Var_tmp == def_val, np.nan, Var_tmp)
        else:
            Var = Var_tmp
    else:
        # selected region
        if var_name == lat_name or var_name == lon_name:
            # read lat or lon
            mask = mask_by_lat_lon(file_path, loc_lat, loc_lon, lat_name, lon_name)
            lat  = obs_file.variables[lat_name]
            lon  = obs_file.variables[lon_name]
            if len(np.shape(lat)) == 1:
                lons, lats = np.meshgrid(lon, lat)
                if var_name == lat_name:
                    Var = np.where(mask,lats,np.nan)
                if var_name == lon_name:
                    Var = np.where(mask,lons,np.nan)
                # print(np.shape(Var))
            elif len(np.shape(lat)) == 2:
                Var = np.where(mask, obs_file.variables[var_name][:], np.nan)
                # print(np.shape(Var))
            elif len(np.shape(lat)) == 3:
                Var = np.where(mask, obs_file.variables[var_name][0,:,:], np.nan)
                # print(np.shape(Var))
        else:
            # read var except lat or lat
            mask = mask_by_lat_lon(file_path, loc_lat, loc_lon, lat_name, lon_name)
            #print("print mask in def read_var: ", mask)
            mask_multi = [ mask ] * ntime
                        
            if var_name in ['E','Ei','Es','Et']:
                # change GLEAM's coordinates from (time, lon, lat) to (time, lat, lon)
                tmp = np.moveaxis(obs_file.variables[var_name], -1, 1)
            else:
                tmp = obs_file.variables[var_name][:]
                
            if var_name in ["SoilMoist_inst","SoilTemp_inst", "SoilMoist", "SoilTemp"]:
                nlat    = len(mask[:,0])
                nlon    = len(mask[0,:])
                Var_tmp = np.zeros((ntime,6,nlat,nlon))
                for j in np.arange(6):
                    Var_tmp[:,j,:,:] = np.where(mask_multi,tmp[:,j,:,:],np.nan)
            else:
                Var_tmp = np.where(mask_multi,tmp,np.nan)            
                  
            #print("print Var_tmp in def read_var: ", Var_tmp)
            # print(np.shape(Var_tmp))
            if hasattr(obs_file.variables[var_name], '_FillValue'):
                def_val = obs_file.variables[var_name]._FillValue
                Var = np.where(Var_tmp == def_val, np.nan, Var_tmp)
            elif hasattr(obs_file.variables[var_name], '_fillvalue'):
                def_val = obs_file.variables[var_name]._fillvalue
                Var = np.where(Var_tmp == def_val, np.nan, Var_tmp)
            else:
                Var = Var_tmp
    return time,Var


In [7]:
def regrid_data(lat_in, lon_in, lat_out, lon_out, input_data):
    '''
    resample spitial data
    '''
    print("regrid_data")
    if len(np.shape(lat_in)) == 1:
        lon_in_2D, lat_in_2D = np.meshgrid(lon_in,lat_in)
        lon_in_1D            = np.reshape(lon_in_2D,-1)
        lat_in_1D            = np.reshape(lat_in_2D,-1)
    elif len(np.shape(lat_in)) == 2:
        lon_in_1D            = np.reshape(lon_in,-1)
        lat_in_1D            = np.reshape(lat_in,-1)
    else:
        print("ERROR: lon_in has ", len(np.shape(lat_in)), "dimensions")

    if len(np.shape(lat_out)) == 1:
        lon_out_2D, lat_out_2D = np.meshgrid(lon_out,lat_out)
    elif len(np.shape(lat_out)) == 2:
        lon_out_2D            = lon_out
        lat_out_2D            = lat_out
    else:
        print("ERROR: lon_out has ", len(np.shape(lat_in)), "dimensions")

    # Check NaN - input array shouldn't have NaN
    value_tmp = np.reshape(input_data,-1)
    value     = value_tmp[~np.isnan(value_tmp)]
    
    # ======= CAUTION =======
    lat_in_1D = lat_in_1D[~np.isnan(value_tmp)]  # here I make nan in values as the standard
    lon_in_1D = lon_in_1D[~np.isnan(value_tmp)]
    # print("shape value = ", np.shape(value))
    # print("shape lat_in_1D = ", np.shape(lat_in_1D))
    # print("shape lon_in_1D = ", np.shape(lon_in_1D))
    # =======================
    
    Value = griddata((lon_in_1D, lat_in_1D), value, (lon_out_2D, lat_out_2D), method="linear")

    return Value

In [8]:
def read_off_wb(file_path,offline_path):
    '''
    read off wb
    '''
    print("read_off_wb")
    #"SE Aus":
    loc_lat         = [-40,-25]
    loc_lon         = [135,155]
     
    time_s          = datetime(2016,12,31,0,0,0,0)
    time_e          = datetime(2016,12,31,23,59,0,0)

    time, wb_tmp    = read_var(offline_path, 'SoilMoist', loc_lat, loc_lon, 'latitude', 'longitude')
    wb              = spital_var(time,wb_tmp,time_s,time_e)
    print(time)
    time, GWwb_tmp  = read_var(offline_path, 'GWMoist', loc_lat, loc_lon, 'latitude', 'longitude')
    GWwb            = spital_var(time,GWwb_tmp,time_s,time_e)    
    
    # read lat lon in
    offline         = Dataset(offline_path, mode='r')
    lat_in          = offline.variables["latitude"][:]  
    lon_in          = offline.variables["longitude"][:]  
    
    # read lat lon out
    file            = Dataset(file_path, mode='r')
    lat_out         = file.variables["lat"][:,:]
    lon_out         = file.variables["lon"][:,:]
    nlat            = len(lat_out[:,0])
    nlon            = len(lon_out[0,:])
    nsoil           = 6
    
    # define
    wb_regrid       = np.zeros((nsoil,nlat,nlon))
    GWwb_regrid     = np.zeros((nlat,nlon))
    
    for l in np.arange(6):
        wb_regrid[l,:,:] = regrid_data(lat_in, lon_in, lat_out, lon_out, wb[l,:,:]) 
    
    GWwb_regrid = regrid_data(lat_in, lon_in, lat_out, lon_out, GWwb)
        
    return (wb_regrid,GWwb_regrid)

In [9]:
def read_rst_wb(file_path, lis_rst_path):
    
    '''
    read rst wb
    '''
    print("read_rst_wb")
    
    den_rat     = 0.921
    
    lis_rst     = Dataset(lis_rst_path, mode='r')
    wbliq       = lis_rst.variables["WB"][:,:]
    wbice       = lis_rst.variables["WBICE"][:,:]
    wb          = wbliq + wbice*den_rat
    GWwb        = lis_rst.variables["GWWB"][:]
    lat_in      = lis_rst.variables["lat"][:]  
    lon_in      = lis_rst.variables["lon"][:]  
    
    file        = Dataset(file_path, mode='r')
    lat_out     = file.variables["lat"][:,:]
    lon_out     = file.variables["lon"][:,:]
    nlat        = len(lat_out[:,0])
    nlon        = len(lon_out[0,:])
    nsoil       = 6
    
    wb_regrid   = np.zeros((nsoil,nlat,nlon))
    
    for l in np.arange(6):
        wb_regrid[l,:,:] = griddata((lon_in, lat_in), wb[l,:], (lon_out, lat_out), method="linear")
        
    GWwb_regrid = griddata((lon_in, lat_in), GWwb, (lon_out, lat_out), method="linear")
        
    return (wb_regrid, GWwb_regrid)

In [10]:
def spatial_map_single_plot_diff(file_path, lis_rst_path, offline_path, wrf_path):

    wb_rst, GWwb_rst = read_rst_wb(file_path, lis_rst_path)
    wb_off, GWwb_off = read_off_wb(file_path, offline_path)
    
    wb_diff          = wb_rst - wb_off
    GWwb_diff        = GWwb_rst - GWwb_off
    
    # read lat and lon outs
    wrf              = Dataset(wrf_path,  mode='r')
    lons             = wrf.variables['XLONG'][0,:,:]
    lats             = wrf.variables['XLAT'][0,:,:]

    # =========================== plot wb ===========================
    for l in np.arange(6):
        fig = plt.figure(figsize=(6,5))
        ax  = plt.axes(projection=ccrs.PlateCarree())

        plt.rcParams['text.usetex']     = False
        plt.rcParams['font.family']     = "sans-serif"
        plt.rcParams['font.serif']      = "Helvetica"
        plt.rcParams['axes.linewidth']  = 1.5
        plt.rcParams['axes.labelsize']  = 14
        plt.rcParams['font.size']       = 14
        plt.rcParams['legend.fontsize'] = 14
        plt.rcParams['xtick.labelsize'] = 14
        plt.rcParams['ytick.labelsize'] = 14

        almost_black                    = '#262626'
        # change the tick colors also to the almost black
        plt.rcParams['ytick.color']     = almost_black
        plt.rcParams['xtick.color']     = almost_black

        # change the text colors also to the almost black
        plt.rcParams['text.color']      = almost_black

        # Change the default axis colors from black to a slightly lighter black,
        # and a little thinner (0.5 instead of 1)
        plt.rcParams['axes.edgecolor']  = almost_black
        plt.rcParams['axes.labelcolor'] = almost_black

        # set the box type of sequence number
        props = dict(boxstyle="round", facecolor='white', alpha=0.0, ec='white')

        # color bar
        cmap  = plt.cm.seismic

        # start plotting
        ax.set_extent([135,155,-40,-25])
        ax.coastlines(resolution="50m",linewidth=1)

        # Add gridlines
        gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=True,linewidth=1, color='black', linestyle='--')
        gl.xlabels_top   = False
        gl.ylabels_right = False
        gl.xlines        = True

        gl.xlocator      = mticker.FixedLocator([135,140,145,150,155])
        gl.ylocator      = mticker.FixedLocator([-40,-35,-30,-25])

        gl.xformatter    = LONGITUDE_FORMATTER
        gl.yformatter    = LATITUDE_FORMATTER
        gl.xlabel_style  = {'size':10, 'color':'black'}
        gl.ylabel_style  = {'size':10, 'color':'black'}

        clevs = [-0.3,-0.25,-0.2,-0.15,-0.1,-0.05,0.05,0.1,0.15,0.2,0.25,0.3]

        plt.contourf(lons, lats, wb_diff[l,:,:], clevs, transform=ccrs.PlateCarree(), cmap=cmap, extend='both') # 

        cb = plt.colorbar(ax=ax, orientation="vertical", pad=0.02, aspect=16, shrink=0.8)
        cb.ax.tick_params(labelsize=10)
        plt.title("wb_diff_lyr="+str(l), size=16)

        message = "rst-off_wb_lyr="+str(l)

        plt.savefig('/g/data/w97/mm3972/scripts/Drought/drght_2017-2019/plots/WTD_sudden_change/spatial_map_'+message+'.png',dpi=300)
        
    # ========================== plot GWwb ==========================
    fig = plt.figure(figsize=(6,5))
    ax  = plt.axes(projection=ccrs.PlateCarree())

    plt.rcParams['text.usetex']     = False
    plt.rcParams['font.family']     = "sans-serif"
    plt.rcParams['font.serif']      = "Helvetica"
    plt.rcParams['axes.linewidth']  = 1.5
    plt.rcParams['axes.labelsize']  = 14
    plt.rcParams['font.size']       = 14
    plt.rcParams['legend.fontsize'] = 14
    plt.rcParams['xtick.labelsize'] = 14
    plt.rcParams['ytick.labelsize'] = 14

    almost_black                    = '#262626'
    # change the tick colors also to the almost black
    plt.rcParams['ytick.color']     = almost_black
    plt.rcParams['xtick.color']     = almost_black

    # change the text colors also to the almost black
    plt.rcParams['text.color']      = almost_black

    # Change the default axis colors from black to a slightly lighter black,
    # and a little thinner (0.5 instead of 1)
    plt.rcParams['axes.edgecolor']  = almost_black
    plt.rcParams['axes.labelcolor'] = almost_black

    # set the box type of sequence number
    props = dict(boxstyle="round", facecolor='white', alpha=0.0, ec='white')

    # color bar
    cmap  = plt.cm.seismic

    # start plotting
    ax.set_extent([135,155,-40,-25])
    ax.coastlines(resolution="50m",linewidth=1)

    # Add gridlines
    gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=True,linewidth=1, color='black', linestyle='--')
    gl.xlabels_top   = False
    gl.ylabels_right = False
    gl.xlines        = True

    gl.xlocator      = mticker.FixedLocator([135,140,145,150,155])
    gl.ylocator      = mticker.FixedLocator([-40,-35,-30,-25])

    gl.xformatter    = LONGITUDE_FORMATTER
    gl.yformatter    = LATITUDE_FORMATTER
    gl.xlabel_style  = {'size':10, 'color':'black'}
    gl.ylabel_style  = {'size':10, 'color':'black'}

    clevs = [-0.05,-0.04,-0.03,-0.02,-0.01,-0.005,0.005,0.01,0.02,0.03,0.04,0.05]

    plt.contourf(lons, lats, GWwb_diff, clevs, transform=ccrs.PlateCarree(), cmap=cmap, extend='both') # 

    cb = plt.colorbar(ax=ax, orientation="vertical", pad=0.02, aspect=16, shrink=0.8)
    cb.ax.tick_params(labelsize=10)
    plt.title("GWwb_diff", size=16)

    message = "rst-off_GWwb"

    plt.savefig('/g/data/w97/mm3972/scripts/Drought/drght_2017-2019/plots/WTD_sudden_change/spatial_map_'+message+'.png',dpi=300)        

        

In [None]:
if __name__ == "__main__":

    file_path      = "/g/data/w97/mm3972/model/wrf/NUWRF/LISWRF_configs/drght_2017_2019_bl_pbl2_mp4_sf_sfclay2/LIS_output/LIS_HIST_201701011200_depth_varying.d01.nc"
    gridinfo_path  = "/g/data/w97/mm3972/model/cable/src/CABLE-AUX/offline/gridinfo_AWAP_OpenLandMap_ELEV_DLCM_fix_10km.nc"
    wrf_path       = "/g/data/w97/mm3972/model/wrf/NUWRF/LISWRF_configs/uniform_soil_param/drght_2017_2019/run_Jan2017/WRF_output/wrfout_d01_2017-01-01_11:00:00"
    lis_input_path = "/g/data/w97/mm3972/model/wrf/NUWRF/LISWRF_configs/drght1719_bdy_data/bdy_data/lis_input.d01.nc_vec"
    lis_hist_path  = "/g/data/w97/mm3972/model/wrf/NUWRF/LISWRF_configs/uniform_soil_param/drght_2017_2019/run_Jan2017/LIS_output/LIS.CABLE.20170101110000.d01.nc"
    lis_rst_path   = "/g/data/w97/mm3972/model/wrf/NUWRF/LISWRF_configs/offline_rst_output/output_1719_drght/LIS_RST_CABLE_201701011100.d01.nc"
    offline_path   = "/g/data/w97/mm3972/model/cable/runs/runs_4_coupled/gw_after_sp30yrx3/outputs/cable_out_2000-2019.nc"
    
    spatial_map_single_plot_diff(file_path, lis_rst_path, offline_path, wrf_path)

read_rst_wb
read_off_wb
SoilMoist
print lat_spc  & lon_spc in def mask_by_lat_lon  0.10000038 0.099998474
