<a href="https://colab.research.google.com/github/m-wessler/nbm-verification/blob/main/NBM_Reliability_New.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#**1. Import Packages & Verify Environment**

Use PIP to install packages not already provided

In [None]:
!pip install boto3
!pip install pygrib
!pip install swifter

In [None]:
import gc
import os
import time
import json
import boto3
import pygrib
import swifter
import requests
import itertools

import numpy as np
import xarray as xr
import pandas as pd

from glob import glob
from functools import partial
from datetime import datetime, timedelta
from multiprocessing import Pool, cpu_count
from multiprocessing import set_start_method, get_context

## **2. Define functions and methods**


Global Variables

In [None]:
# Multiprocess settings
process_pool_size = cpu_count()*8
print(f'Process Pool Size: {process_pool_size}')

# Synoptic API token
user_token = 'a2386b75ecbc4c2784db1270695dde73'

# Backend APIs
metadata_api = "https://api.synopticdata.com/v2/stations/metadata?"
qc_api = "https://api.synopticdata.com/v2/stations/qcsegments?"

# Data Query APIs
timeseries_api = "https://api.synopticdata.com/v2/stations/timeseries?"
statistics_api = "https://api.synopticlabs.org/v2/stations/statistics?"
precipitation_api = "https://api.synopticdata.com/v2/stations/precipitation?"

# Assign API to element name
synoptic_apis = {
    'qpf':precipitation_api,
    'maxt':statistics_api,
    'mint':statistics_api}

synoptic_networks = {"NWS+RAWS+HADS":"1,2,106",
                     "NWS+RAWS":"1,2",
                     "NWS":"1",
                     "RAWS": "2",
                     "ALL":""}
                    #  "CUSTOM": "&network="+network_input,
                    #  "LIST": "&stid="+network_input}

# Assign synoptic variable to element name
synoptic_vars = {
    'qpf':None,
    'maxt':'air_temp',
    'mint':'air_temp'}

synoptic_vars_out = {
    'qpf':'OBSERVATIONS.precipitation',
    'maxt':'STATISTICS.air_temp_set_1.maximum',
    'mint':'STATISTICS.air_temp_set_1.minimum',}

# Assign stat type to element name
stat_type = {
    'qpf':'total',
    'maxt':'maximum',
    'mint':'minimum'}

ob_hours = {
    'qpf':['1200', '1200'],
    'maxt':['1200', '0600'],
    'mint':['0000', '1800']}

# NBM Globals
aws_bucket = 'noaa-nbm-grib2-pds'

# Where to place the grib file (subdirs can be added in local) (not used)
# output_dir = './'

# Which grib variables do each element correlate with
nbm_vars = {'qpf':'APCP',
                  'maxt':'TMP',
                  'mint':'TMP'}

# Which grib levels do each element correlate with
nbm_levs = {'qpf':'surface',
               'maxt':'2 m above ground',
               'mint':'2 m above ground'}

# If a grib message contains any of these, exclude
excludes = ['ens std dev', '% lev']

# Fix MDL's kelvin thresholds...
tk_fix = {233.0:233.15, 244.0:244.261, 249.0:249.817, 255.0:255.372,
    260:260.928, 270.0:270.928, 273.0:273.15, 299.0:299.817,
    305.0:305.372, 310.0:310.928, 316.0:316.483, 322.0:322.039}

General Methods

In [None]:
def mkdir_p(path):
    from pathlib import Path
    Path(path).mkdir(parents=True, exist_ok=True)
    return path

def cwa_list(input_region):

    input_region = input_region.upper()

    region_dict ={
        "WR":["BYZ", "BOI", "LKN", "EKA", "FGZ", "GGW", "TFX", "VEF", "LOX", "MFR",
            "MSO", "PDT", "PSR", "PIH", "PQR", "REV", "STO", "SLC", "SGX", "MTR",
            "HNX", "SEW", "OTX", "TWC"],

        "CR":["ABR", "BIS", "CYS", "LOT", "DVN", "BOU", "DMX", "DTX", "DDC", "DLH",
            "FGF", "GLD", "GJT", "GRR", "GRB", "GID", "IND", "JKL", "EAX", "ARX",
            "ILX", "LMK", "MQT", "MKX", "MPX", "LBF", "APX", "IWX", "OAX", "PAH",
            "PUB", "UNR", "RIW", "FSD", "SGF", "LSX", "TOP", "ICT"],

        "ER":["ALY", "LWX", "BGM", "BOX", "BUF", "BTV", "CAR", "CTP", "RLX", "CHS",
            "ILN", "CLE", "CAE", "GSP", "MHX", "OKX", "PHI", "PBZ", "GYX", "RAH",
            "RNK", "AKQ", "ILM"],

        "SR":["ABQ", "AMA", "FFC", "EWX", "BMX", "BRO", "CRP", "EPZ", "FWD", "HGX",
            "HUN", "JAN", "JAX", "KEY", "MRX", "LCH", "LZK", "LUB", "MLB", "MEG",
            "MAF", "MFL", "MOB", "MRX", "OHX", "LIX", "OUN", "SJT", "SHV", "TAE",
            "TBW", "TSA"]}

    if input_region == "CONUS":
        return np.hstack([region_dict[region] for region in region_dict.keys()])
    else:
        return region_dict[input_region]

Synoptic API Query Methods

In [None]:
def fetch_obs_from_API(date, cwa='', output_type='csv', use_saved=True, **req):

    valid = True
    cwa_filename = req['region'] if req['region'] else cwa

    output_dir = mkdir_p(f'./obs_{output_type}/')

    output_file = output_dir + f'obs.{req["element"]}.{req["ob_stat"]}' +\
                    f'.{date}.{cwa_filename}.{output_type}'

    if os.path.isfile(output_file) & use_saved:
        # print(f'Output file exists for:{iter_item}')
        return output_file

    else:
        json_dir = mkdir_p('./obs_json/')

        json_file = json_dir + f'obs.{req["element"]}.{req["ob_stat"]}' +\
                        f'.{date}.{cwa_filename}.json'


        adjusted_end_date = (datetime.strptime(date, '%Y%m%d') +
                            timedelta(days=req['days_offset'])
                            ).strftime('%Y%m%d')

        if os.path.isfile(json_file) & use_saved:
            # print(f'Polling archived JSON for: {iter_item}')

            with open(json_file, 'rb+') as rfp:
                response_dataframe = pd.json_normalize(json.load(rfp)['STATION'])

        else:
            api_query_args = {
                'api_token':f'&token={user_token}',
                'station_query':f'&cwa={cwa}',
                'network_query':f'&network={req["network_query"]}',
                'start_date_query':f'&start={date}{req["obs_start_hour"]}',
                'end_date_query':f'&end={adjusted_end_date}{req["obs_end_hour"]}',
                'vars_query':(f'&pmode=totals' if req["element"] == 'qpf'
                    else f'&vars={req["vars_query"]}'),
                'stats_query':f'&type={req["ob_stat"]}',
                'timezone_query':'&obtimezone=utc',
                'api_extras':'&fields=name,status,latitude,longitude,elevation'}

            api_query = req['api'] + ''.join(
                [api_query_args[k] for k in api_query_args.keys()])

            print(f'Polling API for: {iter_item}\n{api_query}')

            status_code, response_count = None, 0
            while (status_code != 200) & (response_count <= 10):
                print(f'{iter_item}, HTTP:{status_code}, #:{response_count}')

                # Don't sleep first try, sleep increasing amount for each retry
                time.sleep(2*response_count)

                response = requests.get(api_query)
                # response.raise_for_status()

                status_code = response.status_code
                response_count += 1

            try:
                response_dataframe = pd.json_normalize(
                    response.json()['STATION'])
            except:
                valid = False
            else:
                with open(json_file, 'wb+') as wfp:
                    wfp.write(response.content)

        if valid:
            # Check ACTIVE flag (Can disable in config above if desired)
            response_dataframe = response_dataframe[
                response_dataframe['STATUS'] == "ACTIVE"]

            # Un-nest the QPF totals
            if req['element'] == 'qpf':
                response_dataframe['TOTAL'] = [i[0]['total']
                    for i in response_dataframe['OBSERVATIONS.precipitation']]

            if output_type == 'pickle':
            # Save out df as pickle
                response_dataframe.to_pickle(output_file)

            elif output_type == 'csv':
            # Save out df as csv
                response_dataframe.to_csv(output_file)

            return None

        else:
            return iter_item

NBM Query Methods

In [None]:
def ll_to_index(loclat, loclon, datalats, datalons):
    # index, loclat, loclon = loclatlon
    abslat = np.abs(datalats-loclat)
    abslon = np.abs(datalons-loclon)
    c = np.maximum(abslon, abslat)
    latlon_idx_flat = np.argmin(c)
    latlon_idx = np.unravel_index(latlon_idx_flat, datalons.shape)
    return latlon_idx

def fetch_grib_from_AWS(iter_item, save_dir='./grib2/', **req):
    from botocore import UNSIGNED
    from botocore.client import Config

    yyyymmdd = iter_item

    nbm_sets = ['qmd'] #, 'core']

    mkdir_p(save_dir)

    output_file = (save_dir +
        f'{yyyymmdd}.t{req["hh"]:02d}z.fhr{req["lead_time_days"]*24:03d}.{req["element"]}.grib2')

    if os.path.isfile(output_file):
        return output_file

    else:
        for nbm_set in nbm_sets:

            bucket_dir = f'blend.{yyyymmdd}/{req["hh"]:02d}/{nbm_set}/'

            grib_file = f'{bucket_dir}blend.t{req["hh"]:02d}z.'+\
                        f'{nbm_set}.f{req["lead_time_days"]*24:03d}.{req["nbm_area"]}.grib2'

            index_file = f'{grib_file}.idx'

            client = boto3.client('s3', config=Config(signature_version=UNSIGNED))

            print(index_file)

            index_data_raw = client.get_object(
                Bucket=aws_bucket, Key=index_file)['Body'].read().decode().split('\n')

            cols = ['num', 'byte', 'date', 'var', 'level',
                'forecast', 'fthresh', 'ftype', '']

            n_data_cols = len(index_data_raw[0].split(':'))

            while len(cols) > n_data_cols:
                cols = cols[:-1]

            index_data = pd.DataFrame(
                [item.split(':') for item in index_data_raw],
                            columns=cols)

            # Clean up any ghost indicies, set the index
            index_data = index_data[index_data['num'] != '']
            index_data['num'] = index_data['num'].astype(int)
            index_data = index_data.set_index('num')

            # Allow byte ranging to '' (EOF)
            index_data.loc[index_data.shape[0]+1] = ['']*index_data.shape[1]

            index_subset = index_data[
                ((index_data['var'] == req['var']) &
                (index_data['level'] == req['level']))]

            # byte start >> byte range
            for i in index_subset.index:
                index_subset.loc[i]['byte'] = (
                    index_data.loc[i, 'byte'],
                    index_data.loc[int(i)+1, 'byte'])

            # Filter out excluded vars
            for ex in excludes:
                mask = np.column_stack([index_subset[col].str.contains(ex, na=False)
                                        for col in index_subset])

                index_subset = index_subset.loc[~mask.any(axis=1)]

            # Fetch the data by byte range, write from stream
            for index, item in index_subset.iterrows():
                byte_range = f"bytes={item['byte'][0]}-{item['byte'][1]}"

                output_bytes = client.get_object(
                    Bucket=aws_bucket, Key=grib_file, Range=byte_range)

                with open(output_file, 'ab') as wfp:
                    for chunk in output_bytes['Body'].iter_chunks(chunk_size=4096):
                        wfp.write(chunk)

    client.close()
    return output_file

def extract_nbm_value(grib_index, nbm_data):
    return nbm_data[grib_index]

#**3. Set Global Variables & User Configuration**

In [None]:
# Collect user inputs
element = 'qpf'
element = element.lower() # Failsafe

region_selection = 'WR'
cwa_selection = 'SLC'

start_date = '20231101'
end_date = '20231115'

interval_selection = 24 #6/12/24/48/72, if element==temp then False
init_hour_selection = 12 #int(input('Desired init hour int(HH)? '))
lead_days_selection = 1 #int(input('Desired forecast hour/lead time int(HHH)?'))

# # Immediately convert user input to datetime objects
# # start_date, end_date = [datetime.strptime(date+'0000', '%Y%m%d%H%M')
# #     for date in [start_date, end_date]]

# Convert user input to datetime objects
start_date, end_date = [datetime.strptime(date+'0000', '%Y%m%d%H%M')
    for date in [start_date, end_date]]

###**3a. Build arg dicts and clean up configs**


In [None]:
# Build arg dict
synoptic_api_args = {
    'obs_start_hour':ob_hours[element][0],
    'obs_end_hour':ob_hours[element][1],
    'ob_stat':stat_type[element],
    'api':synoptic_apis[element],
    'element':element,
    'region':region_selection,
    'network_query':synoptic_networks['NWS+RAWS'], # add config feature later
    'vars_query':None if element == 'qpf'
        else f'{synoptic_vars[element]}',
    'days_offset':1 if element != 'mint' else 0}

# Build an iterable date list from range
iter_date = start_date
date_selection_iterable = []
while iter_date <= end_date:
    date_selection_iterable.append(iter_date.strftime('%Y%m%d'))
    iter_date += timedelta(days=1)

# Assign the fixed kwargs to the function
cwa_query = ','.join(cwa_list(region_selection)
                    ) if region_selection is not None else cwa_selection

multiprocess_function = partial(fetch_obs_from_API,
                                cwa=cwa_query,
                                **synoptic_api_args)

In [None]:
# QPF 0/6/12/18, valid 0/6/12/18
# MaxT 6/18 valid 6
# MinT 6/18 valid 18

# Build arg dict
nbm_request_args = {
    #'yyyymmdd':yyyymmdd, #input('Desired init date (YYYYMMDD)? '),
    'interval':interval_selection,
    'hh':init_hour_selection,
    'lead_time_days':lead_days_selection,
    'nbm_area':'co',
    'element':element,
    'var':nbm_vars[element],
    'level':nbm_levs[element]}

if ((element == 'maxt') or (element == 'mint')):
    nbm_request_args['interval'] = False
    nbm_request_args['hh'] = 6 if element == 'maxt' else 18

# Fix offset of init time vs valid time to verify between chosen dates
valid_hours_advance = (
    nbm_request_args['hh'] + (nbm_request_args['lead_time_days']*24))

if (valid_hours_advance) >= 24:
    start_date -= timedelta(days=int(valid_hours_advance/24))
    end_date -= timedelta(days=int(valid_hours_advance/24))

print(start_date, end_date)

#**4. Acquire Observations**

In [None]:
# Multithreaded requests currently not supported by the Synoptic API
for iter_item in date_selection_iterable:
    multiprocess_function(iter_item)

# with Pool(process_pool_size) as pool:
#     print(f'Spooling up process pool for {len(multiprocess_iterable)} tasks '
#           f'across {process_pool_size} workers')

#     retry = pool.map(multiprocess_function, multiprocess_iterable)
#     pool.terminate()

#     print('Multiprocessing Complete')

# Glob together csv files
# Need to filter by variable/region in case of region change or re-run!
synoptic_varname = synoptic_vars_out[element]

searchstring = (f'*{element}*{region_selection}*.csv'
    if region_selection is not None else f'*{element}*{cwa_selection}*.csv')

df = pd.concat(map(pd.read_csv, glob(os.path.join('./obs_csv/', searchstring))),
               ignore_index=True)

if element == 'qpf':
    # Un-nest precipitation observations
    df_qpf = pd.concat([pd.DataFrame(json.loads(row.replace("'", '"')))
            for row in df[synoptic_varname]], ignore_index=True)

    df = df.drop(columns=synoptic_varname).join(df_qpf)

    # Rename the variable since we've changed the column name
    synoptic_varname = 'total'

# Identify the timestamp column (changes with variable)
for k in df.keys():
    if (('date_time' in k) or ('last_report' in k)):
        time_col = k

df.rename(columns={time_col:'timestamp'}, inplace=True)
time_col = 'timestamp'

# Convert read strings to datetime object
df[time_col] = pd.to_datetime(df['timestamp']).round('60min')

if element == 'maxt':
    # Attribute to the day prior if UTC < 06Z otherwise attribute as stamped
    df['timestamp'] = df['timestamp'].where(df['timestamp'].dt.hour <= 6,
                    df['timestamp']-pd.Timedelta(1, unit='D')).dt.date

elif element == 'mint':
    df['timestamp'] = df['timestamp'].dt.date

elif element == 'qpf':
    # Might need to do something different here so breaking into own elif...
    df['timestamp'] = df['timestamp'].dt.date

# Drop any NaNs and sort by date with station as secondary index
df.set_index(['timestamp'], inplace=True)
df = df[df.index.notnull()].reset_index().set_index(['timestamp', 'STID'])
df.sort_index(inplace=True)

df = df[['LATITUDE', 'LONGITUDE', 'ELEVATION', synoptic_varname]]
df = df.rename(columns={synoptic_varname:element.upper()})

df

#**5. Acquire NBM Data**

In [None]:
# Build an iterable date list from range
iter_date = start_date
date_selection_iterable = []
while iter_date <= end_date:
    date_selection_iterable.append(iter_date.strftime('%Y%m%d'))
    iter_date += timedelta(days=1)

# Assign the fixed kwargs to the function
multiprocess_function = partial(fetch_grib_from_AWS, **nbm_request_args)

# Set up this way for later additions (e.g. a 2D iterable)
# multiprocess_iterable = [item for item in itertools.product(
#     other_iterable, date_selection_iterable)]
multiprocess_iterable = date_selection_iterable

with get_context('fork').Pool(process_pool_size) as pool:
    print(f'Spooling up process pool for {len(multiprocess_iterable)} tasks '
          f'across {process_pool_size} workers')
    grib_output_files = pool.map(multiprocess_function, multiprocess_iterable)
    pool.terminate()
    print('Multiprocessing Complete')

#**6. Calculate Statistics**

**Start by matching obs with NBM values in order to streamline the bulk stats**


In [None]:
# Loop over dates in the DataFrame, open one NBM file at a time
for valid_date in df.index.get_level_values(0).unique():

    # We are looping over the observation dates... the filenames are stamped
    # with the INIT DATE. We need to offset the observation dates to work!
    init_date = valid_date - pd.Timedelta(
        nbm_request_args['lead_time_days'], 'day')

    print(f'i:{init_date}, v:{valid_date}')

    datestr = datetime.strftime(init_date, '%Y%m%d')
    nbm_file = f'./grib2/{datestr}.t{nbm_request_args["hh"]:02d}z' +\
            f'.fhr{nbm_request_args["lead_time_days"]*24:03d}.{element}.grib2'

    if os.path.isfile(nbm_file):
        nbm = pygrib.open(nbm_file)

        # If not yet indexed, go ahead and build the indexer
        if 'grib_index' not in df.columns:

            nbmlats, nbmlons = nbm.message(1).latlons()

            df_indexed = df.reset_index()[
                ['STID', 'LATITUDE', 'LONGITUDE', 'ELEVATION']].drop_duplicates()

            ll_to_index_mapped = partial(ll_to_index,
                                        datalats=nbmlats, datalons=nbmlons)

            print('\nFirst pass: creating y/x grib indicies from lat/lon\n')

            df_indexed['grib_index'] = df_indexed.swifter.apply(
                lambda x: ll_to_index_mapped(x.LATITUDE, x.LONGITUDE), axis=1)

            # Extract the grid latlon
            extract_nbm_lats_mapped = partial(extract_nbm_value,
                                nbm_data=nbmlats)

            extract_nbm_lons_mapped = partial(extract_nbm_value,
                                nbm_data=nbmlons)

            df_indexed['grib_lat'] = df_indexed['grib_index'].apply(
                extract_nbm_lats_mapped)

            df_indexed['grib_lon'] = df_indexed['grib_index'].apply(
                extract_nbm_lons_mapped)

            df_indexed.set_index('STID', inplace=True)

            df = df.join(
                df_indexed[['grib_index', 'grib_lat', 'grib_lon']]).sort_index()

        # Extract the data for that date and re-insert into DataFrame
        # Loop over each variable in the NBM file and store to DataFrame
        # May need a placeholder column of NaNs in df for each var to make this work...
        # Use .swifter.apply() as needed if this will speed up the process
        # Alternatively, can use multiprocess pool to thread out the work over each date
        # First pass this seems fast enough as it is...
        for msg in nbm:

            if 'Probability' in str(msg):
                # print(msg)

                # Deal with column names
                if 'Precipitation' in str(msg):

                    threshold_in = round(msg['upperLimit']*0.0393701, 2)

                    name = f"tp_ge_{str(threshold_in).replace('.','p')}"

                elif 'temperature' in str(msg):
                    gtlt = 'le' if 'below' in str(msg) else 'ge'
                    tk = (msg['lowerLimit'] if 'below'
                            in str(msg) else msg['upperLimit'])
                    tk = tk_fix[tk]
                    tc = tk-273
                    tf = (((tc)*(9/5))+32)
                    name = f"temp_{gtlt}_{tf:.0f}".replace('-', 'm')

                if name not in df.columns:
                    df[name] = np.nan

                extract_nbm_value_mapped = partial(extract_nbm_value,
                                                nbm_data=msg.values)

                df.loc[valid_date, name] = df.loc[valid_date]['grib_index'].apply(
                    extract_nbm_value_mapped).values
    else:
        print(f'{nbm_file} not found, skipping')

# Remove rows with missing data
df = df.dropna(how='any')
df

**Proceed to calculate the statistics from the pandas DataFrame**

In [None]:
# Sample query... Build from here
# df.query('QPF >= 1.0')['tp_ge_1p0'].hist()

#**7. Display Statistics**

#**8. Plot Visualizations**
