In [2]:
#!/usr/bin/env python
# coding: utf-8

import os
# import conda
# conda_file_dir = conda.__file__
# conda_dir = conda_file_dir.split('lib')[0]
# proj_lib = os.path.join(os.path.join(conda_dir, 'share'), 'proj')
# os.environ["PROJ_LIB"] = proj_lib
# reference :http://jtdz-solenoids.com/stackoverflow_/questions/54201946/how-can-i-avoid-proj-lib-error-in-importing-basemap
os.environ["PROJ_LIB"] = '/glade/u/home/hongli/tools/miniconda3/envs/conda_hongli/share/proj'

from mpl_toolkits.basemap import Basemap
import matplotlib.pyplot as plt
import os, shutil
import numpy as np
import xarray as xr
from itertools import chain
from osgeo import gdal
from pyproj import Proj

def plot_basemap(llcrnrlon,llcrnrlat,urcrnrlon,urcrnrlat,ax,nx,ny,lat_0,lon_0):

#     m = Basemap(llcrnrlon,llcrnrlat,urcrnrlon,urcrnrlat,resolution='l',projection='cyl', ax=ax)   
    m = Basemap(llcrnrlon,llcrnrlat,urcrnrlon,urcrnrlat,resolution='l',projection='tmerc', ax=ax,lat_0=lat_0,lon_0=lon_0)
    m.drawmapboundary(color='k', linewidth=1.)
   
    # lat and lon with lables
    m.drawparallels(np.arange(np.floor(llcrnrlat),np.ceil(urcrnrlat),0.2),labels=[True,False,False,False],dashes=[1,1], fontsize='small') # Draw parallels (latitude lines) for values (in degrees).
    m.drawmeridians(np.arange(np.floor(llcrnrlon),np.ceil(urcrnrlon),0.2),labels=[False,False,False,True],dashes=[1,1], fontsize='small') # Draw meridians (longitude lines). Label [left, right, top, bottom]

    # draw a shaded-relief image
    m.shadedrelief(scale=0.5)
    
    # lats and longs are returned as a dictionary
    lats = m.drawparallels(np.reshape(np.linspace(llcrnrlat,urcrnrlat,ny+1),(ny+1,)),labels=[False,False,False,False],dashes=[0.5,0.5]) 
    lons = m.drawmeridians(np.reshape(np.linspace(llcrnrlon,urcrnrlon,nx+1),(nx+1,)),labels=[False,False,False,False],dashes=[0.5,0.5]) 

    lat_lines = chain(*(tup[1][0] for tup in lats.items()))
    lon_lines = chain(*(tup[1][0] for tup in lons.items()))
    all_lines = chain(lat_lines, lon_lines)
    
    # cycle through these lines and set the desired style
    for line in all_lines:
        line.set(linestyle='-', alpha=0.3, color='grey')

    m.drawstates(linewidth=0.5, linestyle='solid', color='k')
    return m

# ===============================================================================
root_dir='/glade/u/home/hongli/scratch/2019_10_01gssha/ens_forc_wrf2'
grid_info_file = os.path.join(root_dir, 'dem/step4_create_gridinfo/gridinfo.nc')

asc_dir = '/glade/u/home/hongli/scratch/2019_10_01gssha/model/270m_Forward_Jan01_2018-Apr15_2018_WestWRF/input/hmet_ascii_data_1day_lead'
asc_files = [f for f in os.listdir(asc_dir) if '.asc' in f]
asc_files = sorted(asc_files)

outfolder = 'scripts/step4_sample_stnlist_perturb'
if os.path.exists(os.path.join(root_dir, outfolder)):
    shutil.rmtree(os.path.join(root_dir, outfolder))
os.makedirs(os.path.join(root_dir, outfolder))
ofile_name_base = 'stnlist'
dpi_value = 100

np.random.seed(seed=123455)

# ==========================================================================================
# read NLDAS grid info
f = xr.open_dataset(os.path.join(root_dir,grid_info_file))
mask = f['mask'].values[:] # 1 is valid. 0 is invalid. 
latitude = f['latitude'].values[:] 
longitude = f['longitude'].values[:] 
elev = f['elev'].values[:] 
gradient_n_s = f['gradient_n_s'].values[:] 
gradient_w_e = f['gradient_w_e'].values[:] 

(ny,nx)=np.shape(mask)
(y_ids,x_ids)=np.where(mask==1)
total_stn_num = len(y_ids)

# ==========================================================================================
# sampled grid interval
index_intervals=[1,2,3] #1, 1/4, 1/9.  

sample_num_previous = 0
for index_interval in index_intervals:    
    
    # uniform sample
    sample_indexes = np.where((y_ids%index_interval==0) & (x_ids%index_interval==0))[0]
    sample_num = len(sample_indexes)
    rnds=np.random.randint(low=0, high=8+1, size=np.shape(sample_indexes))
    record = []
    
    # perturb in eight directions
    if sample_num!=sample_num_previous:

        for i in range(sample_num):
            choice_index = sample_indexes[i]
            rnd = rnds[i]
            y_id_origin = y_ids[choice_index]
            x_id_origin = x_ids[choice_index]
                 
            if rnd in [1,2,8]:
                y_id=y_id_origin+1    
            elif rnd in [4,5,6]:
                y_id=y_id_origin-1
            else:
                y_id=y_id_origin
            if y_id<0 or y_id>=ny or mask[y_id,x_id_origin]!=1:
                y_id=y_id_origin
                           
            if rnd in [2,3,4]:
                x_id=x_id_origin+1
            elif rnd in [6,7,8]:
                x_id=x_id_origin-1  
            else:
                x_id=x_id_origin
            if x_id<0 or x_id>=nx or mask[y_id,x_id]!=1:
                x_id=x_id_origin
            
            if [y_id,x_id] not in record:
                record.append([y_id,x_id])
    
    # record the perturbed samples
    sample_num = len(record)        
#     ofile = ofile_name_base +'_'+str('%03d' %(sample_num))+'grids'+ '_interval'+str(index_interval)+'.txt'
    ofile = ofile_name_base +'_'+str('%03d' %(sample_num))+'grids'+'.txt'
    f_out = open(os.path.join(root_dir, outfolder, ofile), 'w') 
    f_out.write('NSITES\t'+str(sample_num)+'\n') # total number line
    f_out.write('STA_ID LAT LON ELEV SLP_N SLP_E STA_NAME\n') # title line
    
    print('index interval = '+str(index_interval)+', choice num = '+str(sample_num))

    
    for i in range(sample_num):
        y_id = record[i][0]
        x_id = record[i][1]
        sta_id = 'Row'+str('%03d' %(y_id))+'Col'+str('%03d' %(x_id))
        lat_i=latitude[y_id,x_id]
        lon_i=longitude[y_id,x_id]
        ele_i=elev[y_id,x_id]
        gradient_n_s_i=gradient_n_s[y_id,x_id]
        gradient_w_e_i=gradient_w_e[y_id,x_id]
        stn_name = '"'+sta_id+'"'
        f_out.write('%s, %f, %f, %f, %f, %f, %s\n' \
                    % (sta_id, lat_i, lon_i, ele_i, gradient_n_s_i, gradient_w_e_i, stn_name)) 

    f_out.close()
    sample_num_previous=sample_num        

# ==========================================================================================
print('plot distribution')
# lat/lon bounds and central lat/lon
with open(os.path.join(asc_dir, asc_files[0]), 'r') as f:
    content = f.readlines()
    for line in content:
        line = line.strip()        
        if line:
            if ('ncols' in line):
                ncols = int(line.split()[1])
                nx = ncols
            elif ('nrows' in line):
                nrows = int(line.split()[1])
                ny = nrows
            elif ('xllcorner' in line):
                xllcorner = float(line.split()[1])
            elif ('yllcorner' in line):
                yllcorner = float(line.split()[1])
            elif ('cellsize' in line):
                cellsize = float(line.split()[1])

p = Proj(proj='utm',zone=10,ellps='WGS84', preserve_units=False)
start_lon, start_lat = p(xllcorner, yllcorner, inverse=True)
end_lon, end_lat = p(xllcorner+cellsize*nx, yllcorner+cellsize*ny, inverse=True)

lat_0=0.5*(start_lat+end_lat)
lon_0=0.5*(start_lon+end_lon)

# plot
stnlist_files = [f for f in os.listdir(os.path.join(root_dir, outfolder)) if ofile_name_base in f]
stnlist_files = sorted(stnlist_files)
# data = np.loadtxt(os.path.join(root_dir,outfolder,stnlist_files[-1]), skiprows=1, delimiter=',', dtype='str')
# total_stn_num = len(data)
total_stn_num = 393

nrow = 1 #1
ncol = 3 #len(stnlist_files)
fig, ax = plt.subplots(nrow, ncol)
fig.set_figwidth(3.5*ncol) 
fig.set_figheight(4*nrow)

for j in range(ncol):

    k = j            
    if k<len(stnlist_files):  

        # read sampled stnlist.txt
        stnlist_file = os.path.join(root_dir, outfolder, stnlist_files[k])
        data = np.loadtxt(stnlist_file, skiprows=2, delimiter=',', dtype='str') #STA_ID[0], LAT[1], LON[2], ELEV[3], SLP_N[4], SLP_E[5], STA_NAME[6]
        stn_num = len(data)
        stn_lons = [float(data[i][2]) for i in range(stn_num)]
        stn_lats = [float(data[i][1]) for i in range(stn_num)]
        print(str(stn_num) +' Grids')

        m = plot_basemap(llcrnrlon=start_lon,llcrnrlat=start_lat,
                         urcrnrlon=end_lon,urcrnrlat=end_lat, ax=ax[j],
                         nx=nx,ny=ny,lat_0=lat_0,lon_0=lon_0) # plot Basemap                           

        x, y = m(stn_lons,stn_lats) # convert the lat/lon values to x/y projections.
        m.plot(x, y, 'bs', markersize=2) # plot sampeld grid points

        # set title
        perctl=round(stn_num/total_stn_num*100,0)
        title_str = '('+chr(ord('a') + k) +') ' + str(stn_num)  +' Sampled Grids ('+str('%d' %(perctl))+'%)'
        ax[j].set_title(title_str, fontsize='small', fontweight='semibold')

    else: # blank axis
        ax[j].axis('off')

# save plot
fig.tight_layout()
ofile = 'sample_grids_dist.png'
fig.savefig(os.path.join(root_dir, outfolder, ofile), dpi=dpi_value)
plt.close(fig)    

print('Done')


index interval = 1, choice num = 252
index interval = 2, choice num = 94
index interval = 3, choice num = 46
plot distribution
46 Grids
94 Grids
252 Grids
Done


In [30]:
data = np.loadtxt(stnlist_file, skiprows=2, delimiter=',', dtype='str')

In [29]:
stnlist_file

'/glade/u/home/hongli/scratch/2019_10_01gssha/ens_forc_wrf2/scripts/step4_sample_stnlist_perturb/stnlist_00011grids_interval6.txt'

In [6]:
rnds[0:10]

array([1, 2, 1, 8, 0, 7, 4, 8, 4, 2])

In [8]:
rnd=np.random.randint(low=0, high=8, size=np.shape(choice_index))
rnd.max()

7