In [None]:
import os, gc, sys
import pygrib
import regionmask
import cartopy
import cartopy.crs as ccrs
import numpy as np
import pandas as pd
import xarray as xr
import geopandas as gpd
import multiprocessing as mp
import matplotlib.pyplot as plt 

from glob import glob
from functools import partial
from matplotlib import gridspec
from datetime import datetime, timedelta
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib import colors

os.environ['OMP_NUM_THREADS'] = '1'

In [None]:
# CONFIG # # CONFIG # # CONFIG # # CONFIG # # CONFIG # 
cwa = 'SEW'# 'SEW'#sys.argv[1]
fhr_start, fhr_end, fhr_step = 24, 108, 6

# start_date = datetime(2020, 5, 18, 0)
# end_date = datetime(2020, 10, 1, 0)

start_date = datetime(2020, 10, 1, 0)
end_date = datetime(2020, 12, 1, 0)

produce_thresholds = [0.01, 0.1, 0.25, 0.50, 1.0]
bint, bins_custom = 5, None

cwa_bounds = {
    'WESTUS':[30, 50, -130, -100],
    'SEW':[46.0, 49.0, -125.0, -120.5],
    'SLC':[37.0, 42.0, -114.0, -110],
    'MSO':[44.25, 49.0, -116.75, -112.25],
    'MTR':[35.75, 38.75, -123.5, -120.25],}
# CONFIG # # CONFIG # # CONFIG # # CONFIG # # CONFIG # 

In [None]:
nbm_dir = '/scratch/general/lustre/u1070830/nbm/'
urma_dir = '/scratch/general/lustre/u1070830/urma/'
tmp_dir = '/scratch/general/lustre/u1070830/tmp/'
fig_dir = '/uufs/chpc.utah.edu/common/home/steenburgh-group10/mewessler/nbm/'
os.makedirs(tmp_dir, exist_ok=True)

In [None]:
def cmap_colors(_bins):
    if _bins in ((0, 20)):
        return ('#f5f5f5','#d8b365')
    elif _bins in ((20, 40)):
        return ('#5ab4ac','#f5f5f5','#f5f5f5','#d8b365', '#d8b365')
    elif _bins in ((40, 60), (60, 80)):
        return ('#5ab4ac','#5ab4ac','#f5f5f5','#f5f5f5','#d8b365')
    elif _bins in ((80, 100)):
        return ('#5ab4ac','#f5f5f5')

def resize_colobar(event):
    # Tell matplotlib to re-draw everything, so that we can get
    # the correct location from get_position.
    plt.draw()

    posn = ax.get_position()
    colorbar_ax.set_position([posn.x0 + posn.width + 0.01, posn.y0,
                             0.04, axpos.height])
    
def calc_pbin(pbin, _bint, _thresh, _data, _urma):

    p0, p1 = pbin-_bint/2, pbin+_bint/2
    N = xr.where((_data >= p0) & (_data < p1), 1, 0).sum(dim=['valid'])
    n = xr.where((_data >= p0) & (_data < p1) & (_urma > _thresh), 1, 0).sum(dim='valid')
    
    return pbin, n, N

def calc_pbin_fixed(pbin, _thresh, _data, _urma):

    p0, p1 = pbin
    N = xr.where((_data >= p0) & (_data <= p1), 1, 0).sum(dim=['valid'])
    n = xr.where((_data >= p0) & (_data <= p1) & (_urma > _thresh), 1, 0).sum(dim='valid')
    
    return pbin, n, N

In [None]:
extract_dir = nbm_dir + 'extract/'
extract_flist = sorted(glob(extract_dir + '*'))

if not os.path.isfile(urma_dir + 'agg/urma_agg.nc'):
    pass 
    #print('URMA aggregate not found')

else:
    #print('Getting URMA aggregate from file')
    urma = xr.open_dataset(urma_dir + 'agg/urma_agg.nc')['apcp24h_mm']

urma = urma/25.4
urma = urma.rename('apcp24h_in')
lons, lats = urma.lon, urma.lat

In [None]:
geodir = '../forecast-zones/'
zones_shapefile = glob(geodir + '*.shp')[0]

# Read the shapefile
zones = gpd.read_file(zones_shapefile)

# Prune to Western Region using TZ
zones = zones.set_index('TIME_ZONE').loc[['M', 'Mm', 'm', 'MP', 'P']].reset_index()
cwas = zones.dissolve(by='CWA').reset_index()[['CWA', 'geometry']]
_cwas = cwas.copy()

if cwa == 'WESTUS':
    _cwas['CWA'] = 'WESTUS'
    _cwas = _cwas.dissolve(by='CWA').reset_index()
    bounds = _cwas.total_bounds
else:
    bounds = _cwas[_cwas['CWA'] == cwa].bounds.values[0]
    
print(bounds)
    
lons, lats = urma.lon, urma.lat
mask = regionmask.mask_3D_geopandas(_cwas, lons, lats).rename({'region':'cwa'})
mask['cwa'] = _cwas.iloc[mask.cwa]['CWA'].values.astype(str)
mask = mask.sel(cwa=cwa)
mask

In [None]:
idx = np.where(
    (urma.lat >= bounds[1]) & (urma.lat <= bounds[3]) &
    (urma.lon >= bounds[0]) & (urma.lon <= bounds[2]))

mask = mask.isel(y=slice(idx[0].min(), idx[0].max()), x=slice(idx[1].min(), idx[1].max()))
urma = urma.isel(y=slice(idx[0].min(), idx[0].max()), x=slice(idx[1].min(), idx[1].max()))
urma = urma.transpose('valid', 'y', 'x')

In [None]:
# MUTLITHREAD THIS!

data = {k:{} for k in produce_thresholds}

for fhr in np.arange(fhr_start, fhr_end+1, fhr_step):

    open_file = [f for f in extract_flist if 'fhr%03d'%fhr in f][0]

    # Subset the threshold value
    nbm = xr.open_dataset(open_file)['probx'].sel(
    y=slice(idx[0].min(), idx[0].max()),
    x=slice(idx[1].min(), idx[1].max()))

    # Subset the times
    nbm_time = nbm.valid
    urma_time = urma.valid
    time_match = nbm_time[np.in1d(nbm_time, urma_time)].values
    time_match = np.array([t for t in time_match if pd.to_datetime(t) >= start_date])
    time_match = np.array([t for t in time_match if pd.to_datetime(t) <= end_date])
    date0 = pd.to_datetime(time_match[0]).strftime('%Y/%m/%d %H UTC')
    date1 = pd.to_datetime(time_match[-1]).strftime('%Y/%m/%d %H UTC')

    _nbm = nbm.sel(valid=time_match)
    _urma = urma.sel(valid=time_match)
    _mask, _nbm, _urma = xr.broadcast(mask, _nbm, _urma)

    _nbm_masked = xr.where(_mask, _nbm, np.nan)
    _urma_masked = xr.where(_mask, _urma, np.nan)

    for thresh in produce_thresholds:
        
        print('Processing: %s f%03d %.2f"'%(cwa, fhr, thresh))

        _nbm_masked_select = _nbm_masked.sel(threshold=thresh)

        _data = []
        for bins in zip(np.arange(0, 81, 20), np.arange(20, 101, 20)):
        #for bins in zip(np.arange(0, 91, 10), np.arange(10, 101, 10)):

            b0, b1 = bins
            levels = np.unique([0, b0, b1, 100])

            # The meat and potatoes of the thing
            N = xr.where(
                    (_nbm_masked_select > b0) & 
                    (_nbm_masked_select <= b1), 
                1, 0).sum(dim='valid')

            n = xr.where(
                (_nbm_masked_select > b0) & 
                (_nbm_masked_select <= b1) & 
                (_urma_masked > thresh), 
                1, 0).sum(dim='valid')

            obs_rel_freq = xr.where(n > 5, n/N, np.nan)*100
            _data.append([bins, n, N])

        data[thresh][fhr] = _data
        
    nbm.close()

In [None]:
            # Make the plot
#             fig = plt.figure(figsize=(12, 12), facecolor='w')
#             ax = fig.add_axes([0, 0, 1, 1], projection=ccrs.PlateCarree())
#             cmap = colors.ListedColormap(cmap_colors(bins), name='orf_cmap')

#             if cwa != 'WESTUS':
#                 zones.geometry.boundary.plot(color=None, linestyle='--', edgecolor='black', linewidth=0.75, ax=ax)

#             cwas.geometry.boundary.plot(color=None, edgecolor='black', linewidth=1.5, ax=ax)
#             ax.add_feature(cartopy.feature.OCEAN, zorder=100, color='w', edgecolor=None)
#             ax.coastlines(linewidth=2, zorder=101)

#             cbd = ax.contourf(obs_rel_freq.lon, obs_rel_freq.lat, obs_rel_freq,
#                              levels=levels, cmap=cmap)

#             nan_shade = xr.where(np.isnan(obs_rel_freq) & _mask.isel(valid=0), -1, np.nan)
#             ax.contourf(obs_rel_freq.lon, obs_rel_freq.lat, nan_shade, cmap='gray', alpha=0.5)

#             cbar_ax = fig.add_axes([1.01, .075, .05, .85])
#             cbar = plt.colorbar(cbd, cax=cbar_ax)
#             cbar.ax.tick_params(labelsize=16)
#             fig.canvas.mpl_connect('resize_event', resize_colobar)

#             ax.set_title('CWA: %s\nFHR: %03d\nThreshold: %.02f"\nBin: %d%% - %d%%'%(cwa, fhr, thresh, bins[0], bins[1]), fontsize=16)
#             cbar.set_label(label='\n[< Too Wet]        Observed Relative Frequency        [Too Dry >]', fontsize=16)

#             ax.set_ylim(bottom=cwa_bounds[cwa][0]+0.05, top=cwa_bounds[cwa][1]+0.085)
#             ax.set_xlim(left=cwa_bounds[cwa][2]-0.15, right=cwa_bounds[cwa][3]+0.15)

#             plt.show()
                
#         plotdata = []
#         for d in data:
#             bins, n, N = d
#             center = np.mean(bins)/100
#             obs_rel_freq = n.sum()/N.sum()
#             plotdata.append([center, obs_rel_freq, n.sum(), N.sum()])

#         plotdata = np.array(plotdata)

#         # Make the figure
#         fig = plt.figure(figsize=(9, 11), facecolor='w') 
#         axs = gridspec.GridSpec(2, 1, height_ratios=[4, 1]) 
#         ax = plt.subplot(axs[0])
#         ax1 = plt.subplot(axs[1])

#         ax.plot(plotdata[:, 0], plotdata[:, 1], linewidth=3, color='r',
#                 marker='+', markersize=15, label='ALL')

#         perfect = np.arange(0, 1.1, .1)
#         climo = xr.where(urma > thresh, 1, 0).sum().values/urma.size
#         skill = perfect - ((perfect - climo)/2)

#         ax.plot(perfect, perfect, 
#                 color='k')

#         ax.axhline(climo, 
#                 color='k', linestyle='--')

#         ax.plot(perfect, skill, 
#                 color='k', linestyle='--')

#         fillperf = np.arange(climo, 1, .001)
#         ax.fill_between(fillperf, fillperf - (fillperf - climo)/2, 1,
#                 color='gray', alpha=0.35)

#         fillperf = np.arange(0, climo, .001)
#         ax.fill_between(fillperf, 0, fillperf - (fillperf - climo)/2,
#                 color='gray', alpha=0.35)

#         ax.set_xlim([0, 1])
#         ax.set_ylim([0, 1])

#         ax.set_xticks(perfect)
#         ax.set_yticks(perfect)

#         ax.set_xlabel('Forecast Probability')
#         ax.set_ylabel('Observed Relative Frequency')
#         ax.grid(zorder=1)

#         ax.set_title((
#             'NBM Reliability | CWA: %s\n'%cwa +
#             '%s - %s\n'%(date0, date1) + 
#             '%02dh Acc QPF | %3dh Lead Time\n\n'%(nbm.interval, nbm.fhr) +
#             'Probability of Exceeding %.2f"\n\n'%thresh + 
#             'n forecast prob > 0: %2.1e | n observed > %.2f: %2.1e'%(
#                 plotdata[:, 2].sum(), thresh, plotdata[:, 3].sum())))

#         ax.legend(loc='upper left')

#         # # # # # # # # # # # # # # # # # # # # # # # #

#         ax1.bar(plotdata[:, 0], plotdata[:, 3], color='k', width=0.095, zorder=10)
#         ax1.bar(plotdata[:, 0], plotdata[:, 2], color='r', alpha=0.25, width=0.095, zorder=11)

#         ax1.set_xticks(np.arange(0, 1.1, .1))
#         # ax1.set_xticklabels(plotdata[:, 0][::2])
#         ax1.set_xlim([0, 1])

#         ax1.set_yscale('log')
#         ax1.set_yticks([1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7, 1e8])

#         ax1.set_xlabel('Forecast Probability')
#         ax1.set_ylabel('# Forecasts')
#         ax1.grid(zorder=-1)
        
#         plt.show()