## Build the High-Resolution (1-km grid spacing) Reflectivity Dataset

This notebooks generates the dataset for the super resolution work. The code opens the 1-km WoFS summary files and extracts a random patch of composite reflectivity from a random ensemble member. The final dataset is saved as float16 to enable additional samples. 

In [1]:
import numpy as np
import xarray as xr 

from glob import glob
import itertools

# Adding wofs_super_res to the system path. 
from os.path import dirname, basename
import sys, os ; sys.path.insert(0, dirname(dirname(os.getcwd())))
path = os.path.dirname(os.getcwd())

from wofs_super_res.util.resample import resample
from wofs_super_res.util.filtering import SpatialFilter

In [2]:
BASE_PATH_1KM = '/work/brian.matilla/WOFS_2021/summary_files/WOFS_JET/WOFS_1KM/'
BASE_PATH_3KM = '/work/mflora/SummaryFiles/'
OUT_PATH = '/work/mflora/SUPER_RES/data'

In [3]:
# Loads the DBZ. 
def is_same_file(path_1km, path_3km):
    return basename(path_1km) == basename(path_3km)

def zip_files(paths_1km, paths_3km): 
    paths = []
    for path_1km, path_3km in zip(paths_1km, paths_3km):
        if is_same_file(path_1km, path_3km):
            paths.append((path_1km, path_3km))
    return paths 

def get_random_centroid(nx, delta):
    """Get random centroid"""
    rng = np.arange(delta, nx-delta)
    i,j = np.random.choice(rng, size=2, replace=False)
    return i,j 

def save_dataset(fname, dataset):
    """ saves xarray dataset to netcdf """
    comp = dict(zlib=True, complevel=5)
    encoding = {var: comp for var in dataset.data_vars}
    #os.makedirs(os.path.dirname(fname), exist_ok=True)
    dataset.to_netcdf( path = fname, encoding=encoding )
    dataset.close( )
    del dataset

def generate_data(paths, save_fname):
    
    # patch size radius 
    patch_size = 60 # 120 x 120 patches 
    NX = 402
    
    dbz_1km_flt_set = []
    dbz_1km_set = []
    dbz_3km_set = [] 
    for path_1km, path_3km in paths:
        
        # Load the 1-km and 3-km reflectivity. 
        ds_1km = xr.load_dataset(path_1km, decode_times=False)
        ds_3km = xr.load_dataset(path_3km, decode_times=False)
        #i = np.random.choice(range(18), replace=False)
        
        # Get the lat,lon grids from the 1-km and 3-km 
        xlat_1km, xlon_1km = ds_1km['xlat'].values, ds_1km['xlon'].values
        xlat_3km, xlon_3km = ds_3km['xlat'].values, ds_3km['xlon'].values
        
        # TODO: Grab a random ensemble member
        ns = np.random.choice(range(18), size=3, replace=False)
        for n in ns:
            i,j = get_random_centroid(NX, patch_size)
            dbz_1km = ds_1km['comp_dz'].values[n, i-patch_size:i+patch_size, j-patch_size:j+patch_size]
            dbz_3km = ds_3km['comp_dz'].values[n,:,:]
        
            # Resample 3-km to 1-km 
            dbz_3km_res = resample(target_grid=(xlat_1km, xlon_1km), 
                    original_grid=(xlat_3km, xlon_3km), 
                    variable=dbz_3km)
        
            # Get the same patch from the 3-km data.
            dbz_3km_res = dbz_3km_res[i-patch_size:i+patch_size, j-patch_size:j+patch_size]
        
            # Get the coordinates of the 1-km patch. 
            xlat_1km_res = xlat_1km[i-patch_size:i+patch_size, j-patch_size:j+patch_size]
            xlon_1km_res = xlon_1km[i-patch_size:i+patch_size, j-patch_size:j+patch_size]

            # Setting the minimal resolution as 6 km as that is the minimum resolution on
            # a 3-km grid, but the effective resolution is likely larger. 
            flt = SpatialFilter(grid_spacing=1000, min_resolution=6000, filter_order=4)
            dbz_1km_flt = flt.filter(dbz_1km)
        
            dbz_1km_flt_set.append(dbz_1km_flt[:-1, :-1])
            dbz_1km_set.append(dbz_1km)
            dbz_3km_set.append(dbz_3km_res)
        
    dbz_1km_flt_set = np.array(dbz_1km_flt_set, dtype=np.float32)
    dbz_1km_set = np.array(dbz_1km_set, dtype=np.float32)
    dbz_3km_set = np.array(dbz_3km_set, dtype=np.float32)
    
    # Convert to xarray dataset.
    data = {}
    
    data['REFL_1KM'] = (['n_samples', 'ny', 'nx'], dbz_1km_set)
    data['REFL_3KM'] = (['n_samples', 'ny', 'nx'], dbz_3km_set)
    data['REFL_1KM_FILT'] = (['n_samples', 'ny', 'nx'], dbz_1km_flt_set)
    data['xlat'] = (['ny', 'nx'], xlat_1km_res)
    data['xlon'] = (['ny', 'nx'], xlon_1km_res)
    
    ds = xr.Dataset(data)
    
    save_dataset(save_fname, ds)
    
    return save_fname 

In [4]:
dates = ['20210427',
 '20210518',
 '20210527',
 '20210523',
 '20210524',
 '20210514',
 '20210526',
 '20210519',
 '20210517',
 '20210504',
 '20210503'
]

init_times = ['0000',
 '0100',
 '0200',
 '0300',
 '1900',
 '2000',
 '2100',
 '2200',
 '2300',
]
# TODO: Add parallelization! 
# NOTE: If I use the 3-km as input, I'll have to manually resample it to a 3-km grid spacing
# as it currently at a 1-km grid spacing. 

for date, init_time in itertools.product(dates, init_times):
    paths_1km = glob(os.path.join(BASE_PATH_1KM, date, init_time, f'wofs_ENS_*'))
    paths_3km = glob(os.path.join(BASE_PATH_3KM, date, init_time, f'wofs_ENS_*'))

    paths_1km.sort()
    paths_3km.sort()

    # Only keep the first 15 min; forecasts errors are likely too large after that
    # for a fair comparison of 3km -> 1km. 
    paths = zip_files(paths_1km[:3], paths_3km[:3])

    save_fname = os.path.join(OUT_PATH, f'super_res_patches_{date}{init_time}.nc')
    
    ds_1km, ds_3km = generate_data(paths, save_fname)

NameError: name 'fname' is not defined