## Data Workflow 

In [None]:
# Conda requirement: python 3.8, xarray, numpy, scipy, numba, monte_python, scikit-explain

# Import required libraries 
from glob import glob 
from os.path import join

# Xarray is a great library for processing n-dimensional arrays! 
import xarray as xr 
import numpy as np 
import numba as nb 
from scipy.ndimage import maximum_filter
from skexplain.common.multiprocessing_utils import run_parallel, to_iterator 

# Download https://github.com/WarnOnForecast/wofs_ml_severe
# change your system path.
import sys
sys.path.append('/home/monte.flora/python_packages/wofs_ml_severe')
from wofs_ml_severe.data_pipeline.storm_report_loader import StormReportLoader

# To use MontePython, clone it to your local directory
# and then run `python setup.py install` 
import monte_python 
from monte_python.object_identification import quantize_probabilities

In [None]:
# Set the constant variables (like base paths).
BASE_DATA_PATH = '/work/mflora/SummaryFiles/'
OUTDIR = '/work/mflora/tmp'
ENSEMBLE_SIZE = 18 
PATCH_SIZE_RADIUS = 20 # on a 3-km grid, this would be 40 x 40 patch or 120 x 120 km grid. 
PROB_THRESH = 4/18 
N_JOBS = 1
SVR_VARIABLES = ['srh_0to1', 'cape_ml', 'cin_ml']

### Questions 

1. Do we limit up to a lead time of an hour? 

For a instant NMEP valid at t=x, we want the t-15, t-30, t-45, t-60



In [None]:
def load_data(date, init_time, time_index):
    """Load the appropriate WoFS summary files"""
    
    assert time_index>=12 
    
    # TODO: ENS has to be at a minimum of an hour. 
    # to prevent sampling too close to initialization. 

    file_path_svr = glob(join(BASE_DATA_PATH, date, init_time, f'wofs_SVR_{time_index}*'))[0]
    file_path_ens = glob(join(BASE_DATA_PATH, date, init_time, f'wofs_ENS_{time_index}*'))[0]

    ds_ens = xr.load_dataset(file_path_ens, decode_times=False)
    ds_svr = xr.load_dataset(file_path_svr, decode_times=False)
    
    return ds_ens, ds_svr 

In [None]:
def save_dataset(fname, dataset):
    """ saves xarray dataset to netcdf using compression """
    comp = dict(zlib=True, complevel=5)
    encoding = {var: comp for var in dataset.data_vars}
    #os.makedirs(os.path.dirname(fname), exist_ok=True)
    dataset.to_netcdf( path = fname, encoding=encoding )
    dataset.close( )
    del dataset

def compute_ensemble_mean(dataset, data_vars):
    return [np.mean(dataset[v], axis=0) for v in data_vars]  

@nb.jit(nopython=True)
def extract_patch(data, centers, delta=10):
    """Extract patches 
    
    data : shape of (v,y,x)
    centers  
    """
    # Ensure the centers do not conflict the boundaries. 
    centers = nb.typed.List(centers)
    
    patches = [ ]
    for obj_y, obj_x in centers:
        patches.append( data[:, obj_y-delta:obj_y+delta, obj_x-delta:obj_x+delta] )

    return patches

In [None]:
def compute_supercell_probability(dataset, max_size=5, classify_embedded=True):
    """Compute the neighborhood maximum ensemble probability (NMEP; Schwartz and Sobash 2017) 
    of a supercell. Supercells are identified using the storm mode classification 
    scheme from Potvin et al. (2022). 
    
    Parameters
    --------------------
    dataset : xarray.dataset 
        A WoFS ENS summary file containing `comp_dz` and `uh_2to5_instant`.
    
    max_size : int (default=5)
        The maximum filter diameter (in grid points)
    
    classify_embedded : True/False (default=True)
        Setting classify_embedded=False restricts the storm mode classification 
        to a 3-mode scheme, which runs much faster than the 7-mode scheme. 
    
    References: 
        Schwartz, C. S. & Sobash, R. A. (2017). Generating probabilistic forecasts from 
        convection-allowing ensembles using neighborhood approaches: 
        A review and recommendations. Monthly Weather Review. 
        https://doi.org/10.1175/mwr-d-16-0400.1
        
        Potvin, C. K., and co-authors (2022). An Iterative Storm Segmentation 
        and Classification Algorithm for Convection-Allowing Models and Gridded Radar Analyses,
        Journal of Atmospheric and Oceanic Technology, 39(7), 999-1013.
    """
    # Identify supercell regions per ensemble member. 
    supercell_per_mem = []

    for i in range(ENSEMBLE_SIZE):  
        dbz_vals = dataset['comp_dz'].values[i,:,:]
        rot_vals = dataset['uh_2to5_instant'].values[i,:,:]
        clf = monte_python.StormModeClassifier()
        # Setting classify_embedded=False, restricts the storm mode classification 
        # to a 3-mode scheme, which runs much faster than the 7-mode scheme. 
        storm_modes, labels, dbz_props = clf.classify(dbz_vals, rot_vals, 
                                                      classify_embedded=classify_embedded)

        # We want to isolate the supercells within the domain. 
        supercell_label = clf.MODES.index('SUPERCELL')+1
        supercell_binary = np.where(storm_modes==supercell_label,1,0)

        # Apply a maximum value filter to reduce phase errors between members
        supercell_binary = maximum_filter(supercell_binary, size=max_size)

        supercell_per_mem.append(supercell_binary)

    # Compute the ensemble probability of a supercell
    supercell_prob = np.mean(supercell_per_mem, axis=0)
    
    return supercell_prob 

In [None]:
def get_storm_patches(supercell_prob, ds_svr):
    """Using the supercell NMEP, identify object centers and then 
    extract the patches. """
    
    params= {'min_thresh': 5,
         'max_thresh': 18,
         'data_increment':1,
         'area_threshold': 500,
         'dist_btw_objects': 25} 
    
    input_data = quantize_probabilities(supercell_prob, ENSEMBLE_SIZE)
    
    sup_labels, sup_props = monte_python.label(input_data = input_data, 
                       method ='watershed', 
                       return_object_properties=True, 
                       params = params
                       )

    # Use those supercells to center the patches 
    centers = [region.centroid for region in sup_props] 
    data = compute_ensemble_mean(ds_svr, SVR_VARIABLES)
    data.append(supercell_prob)
    data = np.array(data)
    patches = np.array(extract_patch(data, centers, delta=PATCH_SIZE_RADIUS))

    variables = SVR_VARIABLES + ['supercell probs']
    
    data = {f'{v}_ens_mean' if v in SVR_VARIABLES else v
            : (['n_samples', 'ny', 'nx'], patches[:,i,:,:]) for i,v in enumerate(variables)}

    # Convert data to xarray.Dataset.
    dataset = xr.Dataset(data)
    
    return dataset 

## Example Workflow

In [None]:
# Load data for a single case. 
dates = ['20210524', '20210526']
init_times = ['2300']
time_indices = [12]

# TODO: add the storm reports!! 

def worker_fn(date, init_time, time_index):
    """A worker function for multiprocessing."""

    # Load the data. 
    try:
        ds_ens, ds_svr = load_data(date, init_time, time_index)
    except:
        print(f'Unable to load data for {date}, {init_time}, {time_index}')
        return None
        
    # Compute the probability of a supercell.
    supercell_prob = compute_supercell_probability(ds_ens, max_size=5, classify_embedded=True)

    # Check that supercell prob exceeds some threshold!
    if np.max(supercell_prob) > PROB_THRESH: 
        # Get data patches. 
        dataset = get_storm_patches(supercell_prob, ds_svr)
    
        # Save the data. 
        save_dataset(join(OUTDIR, f'wofs_data_{date}_{init_time}_{time_index}.nc'), dataset)
        
if N_JOBS == 1:
    args_iterator = to_iterator(dates, init_times, time_indices)   
    for date, init_time, time_idx in args_iterator:
        print(date, init_time, time_idx)
        worker_fn(date, init_time, time_idx)
else:
    args_iterator = to_iterator(dates, init_times, time_indices) 
    run_parallel(worker_fn, args_iterator, n_jobs=N_JOBS)

In [None]:
# You can plot data using built-in plot 
dataset['supercell probs'][0,:,:].plot()

### Optional Plotting code

In [None]:
# (Optional) Plot the storm modes and labels to get familiar with the data! 
import matplotlib.pyplot as plt 

dbz_vals = ds_ens['comp_dz'].values[i,:,:]
rot_vals = ds_ens['uh_2to5_instant'].values[i,:,:]
clf = monte_python.StormModeClassifier()
storm_modes, labels, dbz_props = clf.classify(dbz_vals, rot_vals, 
                                                      classify_embedded=False)

x,y = np.meshgrid(range(dbz_vals.shape[0]), range(dbz_vals.shape[1]))
fig, axes = plt.subplots(dpi=300, ncols=2, nrows=2, figsize=(8,8))

axes[0,0].contourf(x,y,dbz_vals, alpha=0.6, levels=np.arange(20,75,5), cmap='jet')
monte_python.plot_storm_labels(x, y, labels, dbz_props, ax=axes[0,1]) 
monte_python.plot_storm_modes(x, y, storm_modes, dbz_props, clf.converter, ax=axes[1,1]) 

axes[1,0].contourf(x,y,supercell_prob, alpha=0.6, levels=np.arange(0.1, 1.1, 0.1), cmap='rainbow')
monte_python.plot_storm_labels(x, y, sup_labels, sup_props, ax=axes[1,0], alpha=0.6) 

titles = ['WoFS dBZ', 'Storm Labels', 'Supercell Regions', 'Storm Modes']
for i, ax in enumerate(axes.flat):
    ax.set_title(titles[i])