In [None]:
import wrf
import pandas as pd
import numpy as np
import xarray as xr
import metpy.calc as mpc
import multiprocessing as mp

from functools import partial
from metpy.units import units as mpu
from datetime import datetime

In [None]:
def calc_wbzh(tw, gh, orog):
    
    # WBZ parameter: temperature to search for 
    # 0.5 degC as in Western Region tech. attachment
    wbzparam = 0.5
        
    for i, level in enumerate(tw.level.values):
        
        if i > 0:

            level_top = tw.isel(level=i).level.values
            level_bot = tw.isel(level=i-1).level.values
            print('Searching for WBZ between %d and %d hPa'%(level_bot, level_top))

            gh_bot = gh.isel(level=i)
            gh_top = gh.isel(level=i-1)

            tw_bot = tw.isel(level=i)
            tw_top = tw.isel(level=i-1)

            # Linear interpolation of wbz height to use when/if it is between these two levels
            interp_wbzh = gh_bot + ((wbzparam - tw_bot)*((gh_top - gh_bot)/(tw_top - tw_bot)))

            if i == 1:
                # First iteration, establish the wbz height (wbzh) array
                # If WBZ between these two levels, use interpolated WBZH, else np.nan
                wbzh = xr.where( (tw_bot >= wbzparam) & (tw_top <= wbzparam), interp_wbzh, np.nan)

            else:
                # If does not exist:
                wbzh = xr.where( ((tw_bot >= wbzparam) & (tw_top <= wbzparam)) & (np.isnan(wbzh)), interp_wbzh, wbzh)
            
                # If exists and wbzh subterrainian
                #wbzh = xr.where( ((tw_bot >= wbzparam) & (tw_top <= wbzparam)) & (~np.isnan(wbzh) & (wbzh >= orog.min())), interp_wbzh, wbzh)

    # Where nans remain because entire column Tw < wbzparam, fill with 0 m AMSL
    wbzh = xr.where(np.isnan(wbzh) & (tw.max(dim='level') < wbzparam), 0, wbzh)
    
    return wbzh

In [None]:
def calc_tlayer(t, gh, orog):
    
    # Determine geopotential height relative to ground level
    # + 500 m buffer (see Alcott(?), I believe this may have been done as a bias correction)
    gh_agl = (gh - (orog + 500.0))

    # Where this is 0.0 m, set to 1.0 m
    gh_agl = xr.where(gh_agl == 0.0, 1.0, gh_agl)
    
    # If the 1000mb height is > 0, use the 1000 mb temperature to start
    # Otherwise assign t=0
    tvals = xr.where(gh_agl.sel(level=1000) > 0, t.sel(level=1000), 0)
    
    # Iterate through the vertical levels
    for i in range(t.level.size):
                
        # 'l' level
        # 'z' geopotential height
        # 'c/up/dn' current level/level above/level below
        
        # Current level
        lc = t.level.isel(level=i).values
        zc = gh_agl.isel(level=i)
        tc = t.isel(level=i)
                
        # Level above (corrected for 'wraparound' when iterating)
        up = i+1 if i+1 < t.level.size else 0
        lup = t.level.isel(level=up).values
        zup = gh_agl.isel(level=up)
        tup = t.isel(level=up)
        
        # Level below (corrected for 'wraparound' when iterating)
        ldn = t.level.isel(level=i-1).values
        zdn = gh_agl.isel(level=i-1)
        tdn = t.isel(level=i-1)
        
        # Print values for a sanity check while testing 
        # to ensure proper iteration/vertical wrap
        # print(i, lc, lup, ldn)
        
        # Where the geopotential height AGL is > 0 at this level 
        # and geopotential height AGL is < 0 at level below...
        tvals = xr.where(((zc > 0.0) & (zdn < 0.0)),
        
        # Determine a layer temperature
        (( zc / ( zc - zup ) ) * ( tup - tc ) + tc ),
        
        # Else use layer temperature already determined
        tvals)
    
    # In the strange exception case where 500 mb is below ground level
    # apply T500 as Tlayer (redundant failsafe - probably not needed)
    tlayer = xr.where(gh_agl.sel(level=500) < 0, t.sel(level=500), tvals)
        
    return tlayer

In [None]:
def calc_slr(tlayer, wbzh, orog):
    
    # Tunable transition layer parameters (m)
    all_snow_buffer = 0
    transition_layer = 200
    
    # Extend the snow level below the wet bulb zero parameter height if set
    snow_level = wbzh - all_snow_buffer
    snow_level = xr.where(snow_level < 0., 0., snow_level)

    # Curve fit to Alcott and Steenburgh (2010) SMLR results
    init_slr = xr.where(tlayer < 0., 5. - tlayer, 5.)
    init_slr = xr.where(tlayer < -15., 20. + (tlayer + 15.), init_slr)
    init_slr = xr.where(tlayer < -20., 15., init_slr)

    # Keep the initial SLR calculations above the snow level
    slr = xr.where(orog >= snow_level, init_slr, 0.)

    # Linear attenuation of the SLR in the transition layer
    slr = xr.where(
        ((orog < snow_level) & (orog > (snow_level - transition_layer))),
        (init_slr * (orog - (snow_level - transition_layer)) / transition_layer), slr)

    return slr

In [None]:
def wetbulb(ti, pres, tkel, qv, units):
    return wrf.wetbulb(pres.sel(time=ti), tkel.sel(time=ti), qv.sel(time=ti), units)

In [None]:
data = xr.open_dataset('./CLNX_12h_delay12_extract_gfs_ISO.nc').load()
data = data[['t', 'gh', 'r']].isel(level=slice(20, 33))
# data = data.sel(time=slice(datetime(2020, 1, 1, 0), datetime(2020, 1, 15, 12)))
data

In [None]:
orog = xr.open_dataset('./gfs_fv3_orog.nc').sel(latitude=data.latitude, longitude=data.longitude+360)
orog = orog['orog'].values
orog = xr.DataArray(np.full(data.time.shape, fill_value=orog), dims='time')
orog['time'] = data.time
data['orog'] = orog

In [None]:
# There is no need for a hi resolution wet bulb, the difference
# in order of calculate, downscale is negligible!
print('Calculating Tw...')

# Broadcast pressure levels to the dimensions of the data
p = data.level
_p = np.ones(data.t.shape)
_p = np.array([_p[:, i]*p[i].values
    for i in range(p.size)]).transpose(1, 0)
p = data.t.copy().rename('p')
p.values = _p

# Calculate the mixing ratio
qv = data.t.copy().rename('qv')

qv.values = np.array(mpc.mixing_ratio_from_relative_humidity(
    data.r.values/100, (data.t.values-273.15)*mpu.degC, p.values*mpu.millibar))

# Repair the dimensions after metpy messes with them
qv['time'] = data.time
qv['level'] = data.level
qv['lat'] = data.latitude
qv['lon'] = data.longitude

wetbulb_mp = partial(wetbulb, pres=p*100, tkel=data.t, qv=qv, units='degC')

tw = [wetbulb_mp(ti) for ti in data.time.values]
tw = xr.concat(tw, dim='time')

# Repair the dimensions after wrf messes with them
tw['time'] = data.time
tw['level'] = data.level
tw['lat'] = data.latitude
tw['lon'] = data.longitude

data['tw'] = tw
print('Done')

data['tw']

In [None]:
data.sel(level=data.level.values[::-1])
wbzh = calc_wbzh(data['tw']-273.15, data['gh'], data['orog'])
nn_times = wbzh[~np.isnan(wbzh)].time

data['wbzh'] = wbzh
data['wbzh']

In [None]:
tlayer = calc_tlayer(data['t']-273.15, data['gh'], data['orog'])
tlayer

In [None]:
slr = calc_slr(tlayer, wbzh, data['orog']).rename('slr')
slr

In [None]:
slr.plot.hist(bins=np.arange(0, 31, 1), edgecolor='k')

In [None]:
slr.to_netcdf('webslr.nc')