In [7]:
import gc, os
import pickle
import cfgrib
import pygrib

import pandas as pd
import numpy as np
import xarray as xr

import matplotlib.pyplot as plt
import multiprocessing as mp

from glob import glob
from functools import reduce
from datetime import datetime
from sklearn.preprocessing import RobustScaler

os.environ['OMP_NUM_THREADS'] = '1'
mp_use_cores = 32
use_era_scaler = False

In [8]:
model = 'gfs0p25'
archive = '/uufs/chpc.utah.edu/common/home/steenburgh-group10/mewessler/archive/'
mlmodel_dir = '/uufs/chpc.utah.edu/common/home/steenburgh-group10/mewessler/output/slr_models/all_dev/'

In [9]:
init = datetime(2020, 1, 1, 0, 0)
date_fmt = '%Y%m%d'
datetime_fmt = '%Y%m%d%H'

## Data Ingest

In [15]:
def ingest_gfs(f):
    
    # print('Reading %s'%os.path.basename(f))

    datasets = cfgrib.open_datasets(f)

    keep_keys = ['tp', 'q', 't', 'u', 'v', 'absv', 'w', 'gh', 'r', 'd', 
                  'u10', 'v10', 'u100', 'v100', 't2m', 'd2m', 
                  'cape', 'prmsl', 'sp', 'orog', 'hpbl']

    sfc, iso = [], []

    for ds in datasets:

        key_match = np.array(list(ds.data_vars))[np.isin(list(ds.data_vars), keep_keys)]

        if len(key_match) > 0:

            dims = ds.dims.keys()
            coords = ds[key_match].coords

            if ('heightAboveGround' in coords) & ('heightAboveGround' not in dims):
                sfc.append(ds[key_match].drop('heightAboveGround'))

            elif 'isobaricInhPa' in coords:
                iso.append(ds[key_match])

            elif (('surface' in coords)|('meanSea' in coords)):
                sfc.append(ds[key_match])

            elif 'prmsl' in list(ds.data_vars):
                sfc.append(ds['prmsl'])

            else:
                pass

        else:
            pass

    sfc = xr.merge(sfc).drop('t')
    iso = xr.merge(iso).rename({'isobaricInhPa':'level'})
    iso = iso.sel(level=iso.level[::-1])

    sfc['longitude'] = sfc['longitude'] - 360
    iso['longitude'] = iso['longitude'] - 360
    
    return [sfc.drop(['surface', 'meanSea', 'step']), 
            iso.drop('step')]

In [16]:
flist = glob(archive + init.strftime(date_fmt) + '/models/%s/*%s*.grib2'%(model, init.strftime(datetime_fmt)))[1:]

with mp.get_context('fork').Pool(mp_use_cores) as p:
    returns = p.map(ingest_gfs, flist, chunksize=1)
    p.close()
    p.join()
    
returns = np.array(returns, dtype=object)
sfc, iso = returns[:, 0], returns[:, 1]
    
iso = xr.concat(iso, dim='valid_time').drop('time').rename({'valid_time':'time'}).sortby('time')
sfc = xr.concat(sfc, dim='valid_time').drop('time').rename({'valid_time':'time'}).sortby('time')

In [17]:
u, v = iso['u'], iso['v']
wdir = 90 - np.degrees(np.arctan2(-v, -u))
wdir = xr.where(wdir <= 0, wdir+360, wdir)
wdir = xr.where(((u == 0) & (v == 0)), 0, wdir)

iso['dir'] = wdir
iso['spd'] = np.sqrt(u**2 + v**2)

for hgt in [10, 100]:
    
    u, v = sfc['u%d'%hgt], sfc['v%d'%hgt]
    wdir = 90 - np.degrees(np.arctan2(-v, -u))
    wdir = xr.where(wdir <= 0, wdir+360, wdir)
    wdir = xr.where(((u == 0) & (v == 0)), 0, wdir)
    
    sfc['dir%dm'%hgt] = wdir
    sfc['spd%dm'%hgt] = np.sqrt(u**2 + v**2)

In [18]:
# sfc['day_of_year'] = (('latitude', 'longitude'), 
#                       np.full(sfc.orog.shape, 
#                               fill_value=pd.to_datetime(
#                                   sfc.time.values).strftime('%j')).astype(int))
# sfc

## Transform to levels above ground

In [19]:
orog = sfc.orog
gh = iso.gh

lowest_level = np.full(orog.shape, fill_value=np.nan)
lowest_level_index = np.full(orog.shape, fill_value=np.nan)

for i, level in enumerate(iso['level']):
    
    lev_gh = gh.sel(level=level)
    lowest_level = xr.where(orog >= lev_gh, level.values, lowest_level)
    lowest_level_index = xr.where(orog >= lev_gh, i, lowest_level_index)
    
lowest_level_index = xr.where(np.isnan(lowest_level), 0, lowest_level_index)
lowest_level = xr.where(np.isnan(lowest_level), 1000, lowest_level)

In [20]:
df = []
match_rename = {'absv':'vo', 'gh':'z', 'hpbl':'blh', 'prmsl':'msl', 'tp':'swe_mm',
               'u10':'u10m', 'v10':'v10m', 'u100':'u100m', 'v100':'v100m'}

# Loop over each variable in the xarray
for ds in [iso, sfc.drop('orog')]:
    
    for var_name in ds.data_vars:
        
        new_var_name = match_rename[var_name] if var_name in match_rename.keys() else var_name
        print('Reducing (%s) to %s index level AGL'%(var_name, new_var_name))

        var = ds[var_name]

        if 'level' in var.coords:

            for i in np.arange(10):

                var_agl = np.full(shape=(orog.shape), fill_value=np.nan)

                for j, level in enumerate(iso['level']):

                    var_agl = xr.where(lowest_level_index+i == j, var.isel(level=j), var_agl)

                    # Record the levels used, should match lowest_level array, sanity check
                    # var_agl[i, :, :] = xr.where(lowest_level_index+i == j, level, var_agl[i, :, :])

                # We could ho ahead and append to the pandas dataframe here 
                # at the completion of each level (_01agl, _02agl...)
                # We will have to use [(time), lat, lon] as a multiindex
                var_agl = xr.DataArray(var_agl, 
                     dims=['time', 'latitude', 'longitude'], 
                     coords={'time':ds['time'],
                             'latitude':ds['latitude'], 
                             'longitude':ds['longitude']})

                df.append(var_agl.to_dataframe(name='%s_%02dagl'%(new_var_name.upper(), i+1)))

                del var_agl
                gc.collect()

        else:

            var_agl = xr.DataArray(var.values, 
                dims=['time', 'latitude', 'longitude'], 
                coords={'time':ds['time'],
                    'latitude':ds['latitude'], 
                     'longitude':ds['longitude']})

            df.append(var_agl.to_dataframe(name='%s'%new_var_name.upper()))

# SLOW!!! Is there anything we can do here??
df = reduce(lambda left, right: pd.merge(left, right, on=['time', 'latitude', 'longitude']), df)
df = df.rename(columns={'SWE_MM':'swe_mm'})

Reducing (t) to t index level AGL
Reducing (gh) to z index level AGL
Reducing (u) to u index level AGL
Reducing (v) to v index level AGL
Reducing (r) to r index level AGL
Reducing (w) to w index level AGL
Reducing (absv) to vo index level AGL
Reducing (dir) to dir index level AGL
Reducing (spd) to spd index level AGL
Reducing (u10) to u10m index level AGL
Reducing (v10) to v10m index level AGL
Reducing (t2m) to t2m index level AGL
Reducing (d2m) to d2m index level AGL
Reducing (u100) to u100m index level AGL
Reducing (v100) to v100m index level AGL
Reducing (prmsl) to msl index level AGL
Reducing (cape) to cape index level AGL
Reducing (sp) to sp index level AGL
Reducing (tp) to swe_mm index level AGL
Reducing (hpbl) to blh index level AGL
Reducing (dir10m) to dir10m index level AGL
Reducing (spd10m) to spd10m index level AGL
Reducing (dir100m) to dir100m index level AGL
Reducing (spd100m) to spd100m index level AGL


In [21]:
scaler_file = glob(mlmodel_dir + '*scaler*')[-1]
stats_file = glob(mlmodel_dir + '*train_stats*')[-1]
model_file = glob(mlmodel_dir + '*SLRmodel*')[-1]

if use_era_scaler == True:
    with open(scaler_file, 'rb') as rfp:
        scaler = pickle.load(rfp)
else:
    scaler = RobustScaler(quantile_range=(25, 75))

with open(stats_file, 'rb') as rfp:
    train_stats, train_stats_norm = pickle.load(rfp)
    model_keys = train_stats.keys()
    
with open(model_file, 'rb') as rfp:
    SLRmodel = pickle.load(rfp)

In [22]:
print('\ncheck: missing from model', [k for k in df.keys() if k not in model_keys])
print('\ncheck: missing from input', [k for k in model_keys if k not in df.keys()])
print()

df = df.loc[:, model_keys]
scaler = scaler.fit(df)


check: missing from model []

check: missing from input []



In [23]:
df_norm = pd.DataFrame(scaler.transform(df), index=df.index, columns=df.keys())
df_norm

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,swe_mm,T_01agl,U_01agl,V_01agl,VO_01agl,W_01agl,Z_01agl,R_01agl,SPD_01agl,DIR_01agl,...,MSL,SP,U10M,V10M,U100M,V100M,SPD10M,DIR10M,SPD100M,DIR100M
time,latitude,longitude,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1
2020-01-01 03:00:00,30.0,-130.00,0.0000,0.810550,-1.578271,-1.883023,-0.156490,-0.161989,-0.521525,0.088398,1.180398,-1.512315,...,0.286250,0.704357,-1.533032,-1.848136,-1.233992,-1.380398,1.159144,-1.514494,0.718654,-1.638382
2020-01-01 03:00:00,30.0,-129.75,0.0000,0.810550,-1.522281,-1.904771,-0.116318,-0.138513,-0.523108,0.118785,1.172742,-1.522722,...,0.274667,0.703124,-1.475959,-1.871487,-1.191504,-1.397411,1.152551,-1.525535,0.712262,-1.650020
2020-01-01 03:00:00,30.0,-129.50,0.0000,0.801828,-1.470022,-1.921082,-0.102928,-0.190159,-0.524818,0.146409,1.163797,-1.532157,...,0.263224,0.701701,-1.427667,-1.888470,-1.157514,-1.408753,1.145939,-1.534686,0.706124,-1.659113
2020-01-01 03:00:00,30.0,-129.25,0.0000,0.793110,-1.428962,-1.946455,-0.129709,-0.063393,-0.526444,0.162983,1.166655,-1.540690,...,0.251914,0.700467,-1.388156,-1.907575,-1.126356,-1.427183,1.145190,-1.542663,0.707577,-1.668470
2020-01-01 03:00:00,30.0,-129.00,0.0000,0.793110,-1.376704,-2.004450,-0.116318,-0.039917,-0.527673,0.168508,1.190551,-1.553535,...,0.240332,0.699424,-1.335473,-1.962768,-1.088117,-1.468298,1.168259,-1.555683,0.725493,-1.681990
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2020-01-08 12:00:00,50.0,-101.00,2.8750,-0.959301,-1.284737,0.696174,1.205198,-0.706215,-0.360181,0.798343,0.191396,-0.846830,...,-1.937280,0.211109,-1.377313,0.667643,-1.237626,0.815126,0.240745,-0.864722,0.208051,-0.886929
2020-01-08 12:00:00,50.0,-100.75,3.0000,-1.072674,-1.490037,0.551186,1.017729,-0.443291,-0.356054,0.784530,0.280408,-0.925948,...,-1.888062,0.218320,-1.596822,0.540274,-1.464228,0.659174,0.357281,-0.934388,0.304211,-0.979794
2020-01-08 12:00:00,50.0,-100.50,3.0625,-1.168606,-1.602018,0.442445,1.031120,-0.602923,-0.351983,0.773481,0.335135,-0.971952,...,-1.838671,0.221736,-1.684625,0.412904,-1.591691,0.545754,0.393120,-0.980561,0.369459,-1.032442
2020-01-08 12:00:00,50.0,-100.25,3.0000,-1.220932,-1.639346,0.388074,0.857041,-0.391646,-0.348605,0.767956,0.352453,-0.991760,...,-1.788941,0.222210,-1.728527,0.349220,-1.634179,0.489044,0.416121,-1.002460,0.389548,-1.054514


In [24]:
# We're going to need to bin these out and process in parallel
# By time is likely easiest

slr = pd.DataFrame(SLRmodel.predict(df_norm), 
                   index=df_norm.index, columns=['slr']
                  ).to_xarray()['slr']

slr = xr.where(slr < 0, 0, slr)

slr

In [None]:
import cartopy.crs as ccrs
import cartopy.feature as cfeature

fig, axs = plt.subplots(10, 6, figsize=(40, 50))
axs = axs.flatten()

for i, t in enumerate(slr.time):

#     plt.axes(projection=ccrs.PlateCarree())

    slr.sel(time=t).plot.contourf(ax=axs[i], vmin=0, vmax=30, cmap='viridis_r')
    
#     slr.sel(time=t).plot.hist(ax=axs[i])
#     plt.xlim([0, 30])

#     ax.coastlines()

#     states_provinces = cfeature.NaturalEarthFeature(
#             category='cultural',
#             name='admin_1_states_provinces_lines',
#             scale='110m',
#             facecolor='none')

#     ax.add_feature(states_provinces, edgecolor='k')
 
#     plt.title(os.path.basename(flist[1]))

plt.show()