In [7]:
import os, gc
import pygrib
import numpy as np
import pandas as pd
import xarray as xr
import multiprocessing as mp
import matplotlib.pyplot as plt 

from glob import glob
from functools import partial
from matplotlib import gridspec
from datetime import datetime, timedelta

import regionmask
import cartopy
import cartopy.crs as ccrs
import geopandas as gpd

import warnings
warnings.filterwarnings('ignore')

os.environ['OMP_NUM_THREADS'] = '1'
n_cores = 64

In [8]:
nbm_dir = '/scratch/general/lustre/u1070830/nbm/'

urma_dir = '/scratch/general/lustre/u1070830/urma/'
tmp_dir = '/scratch/general/lustre/u1070830/tmp/'
os.makedirs(tmp_dir, exist_ok=True)

In [9]:
start_date = datetime(2020, 10, 1, 0)
end_date = datetime(2021, 5, 15, 23, 59)

In [10]:
urma = xr.open_dataset(urma_dir + 'agg/urma_agg.nc')
urma = urma['apcp24h_mm'].rename('apcp24h_in')
urma['lon'] = urma['lon'] - 360

#subset for only 0/12Z
urma = urma.sel(valid=np.array([t for t in urma.valid.values if pd.to_datetime(t).hour in [0, 12]]))

In [11]:
cwa = 'WESTUS'

geodir = '../forecast-zones/'
zones_shapefile = glob(geodir + '*.shp')[0]

# Read the shapefile
zones = gpd.read_file(zones_shapefile)

# Prune to Western Region using TZ
zones = zones.set_index('TIME_ZONE').loc[['M', 'Mm', 'm', 'MP', 'P']].reset_index()
cwas = zones.dissolve(by='CWA').reset_index()[['CWA', 'geometry']]
_cwas = cwas.copy()

if cwa == 'WESTUS':
    _cwas['CWA'] = 'WESTUS'
    _cwas = _cwas.dissolve(by='CWA').reset_index()
    bounds = _cwas.total_bounds
else:
    bounds = _cwas[_cwas['CWA'] == cwa].bounds.values[0]
    
print(bounds)
    
lons, lats = urma.lon, urma.lat
mask = regionmask.mask_3D_geopandas(_cwas, lons, lats).rename({'region':'cwa'})
mask['cwa'] = _cwas.iloc[mask.cwa]['CWA'].values.astype(str)
mask = mask.sel(cwa=cwa)
mask

idx = np.where(
    (urma.lat >= bounds[1]) & (urma.lat <= bounds[3]) &
    (urma.lon >= bounds[0]) & (urma.lon <= bounds[2]))

mask = mask.isel(y=slice(idx[0].min(), idx[0].max()), x=slice(idx[1].min(), idx[1].max()))
urma = urma.isel(y=slice(idx[0].min(), idx[0].max()), x=slice(idx[1].min(), idx[1].max()))

[-124.75392086   31.33426476 -104.04257965   49.00038889]


In [12]:
def extract_pqpf_verif_stats(_fhr, _urma, _mask):

    nbm_files = glob(nbm_dir + 'agg/*f%03d.WR.nc'%_fhr)
    
    # Subset the threshold value
    nbm = xr.open_mfdataset(nbm_files, concat_dim='valid')['probx'].isel(
        y=slice(idx[0].min(), idx[0].max()), 
        x=slice(idx[1].min(), idx[1].max()))
        
    # Subset the times
    nbm_time = nbm.valid
    urma_time = _urma.valid
    time_match = nbm_time[np.in1d(nbm_time, urma_time)].values
    time_match = np.array([t for t in time_match if pd.to_datetime(t) >= start_date])
    time_match = np.array([t for t in time_match if pd.to_datetime(t) <= end_date])

    _nbm = nbm.sel(valid=time_match)
    _urma = _urma.sel(valid=time_match)
    nbm_mask, _nbm = xr.broadcast(_mask, _nbm)
    urma_mask, _urma = xr.broadcast(_mask, _urma)

    _nbm_masked = xr.where(nbm_mask, _nbm, np.nan)
    _urma_masked = xr.where(urma_mask, _urma, np.nan)
    
    return _nbm_masked, _urma_masked

In [13]:
from sklearn.metrics import roc_curve, roc_auc_score, precision_recall_curve
from sklearn.calibration import calibration_curve

wpc_thresholds = [0.254, 2.54, 6.35, 12.7, 25.4, 50.8, 76.2, 101.6]
wpc_thresholds_in = [0.01, 0.10, 0.25, 0.50, 1.0]
wpc_colors = ['red', 'lime', 'blue', 'cyan', 'magenta']
nbins = 10

In [14]:
_nbm_masked, _urma_masked = extract_pqpf_verif_stats(24, urma, mask)

In [15]:
# threshold = 0.01*25.4

# n_bins= 10
# bins = np.linspace(0., 1. + 1e-8, n_bins + 1)

# fraction_of_positives, mean_predicted_value = [], []

# _urma_masked_binary = xr.where(_urma_masked > threshold, True, False)

# for _urma_masked_flat, _nbm_masked_flat in zip(
#     _urma_masked_binary.values.reshape(-1, _urma_masked.valid.size), 
#     _nbm_masked.sel(threshold=threshold).values.reshape(-1, _urma_masked.valid.size)):
    
#     # fp - fraction positive (ORF); mp - mean predicted (ThreshXprob)
#     fp, mp = calibration_curve(_urma_masked_flat, _nbm_masked_flat/100, n_bins=n_bins)
    
#     if len(fp) != n_bins:
#         fp = mp = np.full(n_bins, fill_value=np.nan, dtype=np.float)
        
#     fraction_of_positives.append(fp)
#     mean_predicted_value.append(mp)
    
# fraction_of_positives = np.array(fraction_of_positives) 
# mean_predicted_value = np.array(mean_predicted_value)

# spatial_frac_pos = xr.DataArray(
#     fraction_of_positives.reshape(np.append(_urma_masked_binary.shape[:-1], n_bins))
# ).rename({'dim_0':'y', 'dim_1':'x', 'dim_2':'bin'})

# spatial_frac_pos['lat'] = _urma_masked_binary.lat
# spatial_frac_pos['lon'] = _urma_masked_binary.lon
# spatial_frac_pos['bin'] = (bins[:-1] + bins[1:]) / 2

# spatial_frac_pos.to_netcdf('spatial_frac_pos.baseline.nc')

In [19]:
_nbm_masked

Unnamed: 0,Array,Chunk
Bytes,9.07 GB,1.46 GB
Shape,"(897, 823, 8, 384)","(897, 823, 8, 62)"
Count,72 Tasks,7 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 9.07 GB 1.46 GB Shape (897, 823, 8, 384) (897, 823, 8, 62) Count 72 Tasks 7 Chunks Type float32 numpy.ndarray",897  1  384  8  823,

Unnamed: 0,Array,Chunk
Bytes,9.07 GB,1.46 GB
Shape,"(897, 823, 8, 384)","(897, 823, 8, 62)"
Count,72 Tasks,7 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,3.07 kB,496 B
Shape,"(384,)","(62,)"
Count,34 Tasks,7 Chunks
Type,datetime64[ns],numpy.ndarray
"Array Chunk Bytes 3.07 kB 496 B Shape (384,) (62,) Count 34 Tasks 7 Chunks Type datetime64[ns] numpy.ndarray",384  1,

Unnamed: 0,Array,Chunk
Bytes,3.07 kB,496 B
Shape,"(384,)","(62,)"
Count,34 Tasks,7 Chunks
Type,datetime64[ns],numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,64 B,64 B
Shape,"(8,)","(8,)"
Count,40 Tasks,1 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 64 B 64 B Shape (8,) (8,) Count 40 Tasks 1 Chunks Type float64 numpy.ndarray",8  1,

Unnamed: 0,Array,Chunk
Bytes,64 B,64 B
Shape,"(8,)","(8,)"
Count,40 Tasks,1 Chunks
Type,float64,numpy.ndarray
