<a href="https://colab.research.google.com/github/m-wessler/nbm-verification/blob/main/get_nbm_aws_streamline_multi.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports

In [None]:
!pip install boto3
!pip install pygrib

import os, gc
import boto3
import pygrib

import numpy as np
import pandas as pd
import xarray as xr

from functools import partial
from datetime import datetime, timedelta
from multiprocessing import cpu_count, get_context

from multiprocessing import set_start_method

# Globals

In [None]:
# Multiprocess settings
process_pool_size = cpu_count()*8
print(f'Process Pool Size: {process_pool_size}')

# Define Globals
aws_bucket = 'noaa-nbm-grib2-pds'

# Where to place the grib file (subdirs can be added in local) (not used)
# output_dir = './'

# Which grib variables do each element correlate with
element_var = {'qpf':'APCP',
                  'maxt':'TMP',
                  'mint':'TMP'}

# Which grib levels do each element correlate with
element_lev = {'qpf':'surface',
               'maxt':'2 m above ground',
               'mint':'2 m above ground'}

# If a grib message contains any of these, exclude
excludes = ['ens std dev', '% lev']

# Fix MDL's bad kelvin thresholds...
tk_fix = {233.0:233.15, 244.0:244.261, 249.0:249.817, 255.0:255.372,
    260:260.928, 270.0:270.928, 273.0:273.15, 299.0:299.817,
    305.0:305.372, 310.0:310.928, 316.0:316.483, 322.0:322.039}

# Methods

In [None]:
def mkdir_p(path):
    from pathlib import Path
    Path(path).mkdir(parents=True, exist_ok=True)
    return path

In [None]:
def download_unzip(url, save_dir='./shapefiles/', chunk_size=128):
    import requests
    import zipfile

    save_file = url.split('/')[-1]
    save_path = mkdir_p(save_dir) + save_file

    if not os.path.isfile(save_path):
        r = requests.get(url, stream=True)
        with open(save_path, 'wb') as fd:
            for chunk in r.iter_content(chunk_size=chunk_size):
                fd.write(chunk)

    unzipped_dir = mkdir_p(save_dir + save_file.replace('.zip', ''))

    if len(os.listdir(unzipped_dir)) == 0:
        with zipfile.ZipFile(save_path, 'r') as zip_ref:
            zip_ref.extractall(unzipped_dir)

    return (unzipped_dir)

In [None]:
def fetch_grib_from_AWS(iter_item, save_dir='./grib2/', **req):
    from botocore import UNSIGNED
    from botocore.client import Config

    yyyymmdd = iter_item

    nbm_sets = ['qmd'] #, 'core']

    mkdir_p(save_dir)

    output_file = (save_dir +
        f'{yyyymmdd}.t{req["hh"]:02d}z.fhr{req["lead_time_days"]*24:03d}.{req["var"]}.grib2')

    if os.path.isfile(output_file):
        return output_file

    else:
        for nbm_set in nbm_sets:

            bucket_dir = f'blend.{yyyymmdd}/{req["hh"]:02d}/{nbm_set}/'

            grib_file = f'{bucket_dir}blend.t{req["hh"]:02d}z.'+\
                        f'{nbm_set}.f{req["lead_time_days"]*24:03d}.{req["nbm_area"]}.grib2'

            index_file = f'{grib_file}.idx'

            client = boto3.client('s3', config=Config(signature_version=UNSIGNED))

            print(index_file)

            index_data_raw = client.get_object(
                Bucket=aws_bucket, Key=index_file)['Body'].read().decode().split('\n')

            cols = ['num', 'byte', 'date', 'var', 'level',
                'forecast', 'fthresh', 'ftype', '']

            n_data_cols = len(index_data_raw[0].split(':'))

            while len(cols) > n_data_cols:
                cols = cols[:-1]

            index_data = pd.DataFrame(
                [item.split(':') for item in index_data_raw],
                            columns=cols)

            # Clean up any ghost indicies, set the index
            index_data = index_data[index_data['num'] != '']
            index_data['num'] = index_data['num'].astype(int)
            index_data = index_data.set_index('num')

            # Allow byte ranging to '' (EOF)
            index_data.loc[index_data.shape[0]+1] = ['']*index_data.shape[1]

            index_subset = index_data[
                ((index_data['var'] == req['var']) &
                (index_data['level'] == req['level']))]

            # byte start >> byte range
            for i in index_subset.index:
                index_subset.loc[i]['byte'] = (
                    index_data.loc[i, 'byte'],
                    index_data.loc[int(i)+1, 'byte'])

            # Filter out excluded vars
            for ex in excludes:
                mask = np.column_stack([index_subset[col].str.contains(ex, na=False)
                                        for col in index_subset])

                index_subset = index_subset.loc[~mask.any(axis=1)]

            # Fetch the data by byte range, write from stream
            for index, item in index_subset.iterrows():
                byte_range = f"bytes={item['byte'][0]}-{item['byte'][1]}"

                output_bytes = client.get_object(
                    Bucket=aws_bucket, Key=grib_file, Range=byte_range)

                with open(output_file, 'ab') as wfp:
                    for chunk in output_bytes['Body'].iter_chunks(chunk_size=4096):
                        wfp.write(chunk)

    return output_file
    client.close()

In [None]:
def get_region_bounds(nws_region):
    import geopandas as gpd

    cwa_shapefile = download_unzip(
        'https://www.weather.gov/source/gis/Shapefiles/WSOM/w_08mr23.zip')

    cwas = gpd.read_file(cwa_shapefile)

    nws_regions = ['WR', 'CR', 'ER', 'SR']

    if nws_region in nws_regions:
        bounds = cwas.query(f"REGION == '{nws_region}'").total_bounds
    else:
        bounds = cwas.query(f"CWA == '{nws_region}'").total_bounds

    return bounds

In [None]:
def grib2nc(grib_file_path, subset_bounds=None,
                interval=False, save_dir='./netcdf/'):

    mkdir_p(save_dir)

    netcdf_file_path = (save_dir +
    grib_file_path.split('/')[-1].replace('.grib2', '.nc'))

    if not os.path.isfile(netcdf_file_path):

        # Create a list to store the data arrays for each variable
        data_arrays = []

        # Open the GRIB2 file using pygrib
        with pygrib.open(grib_file_path) as grib_file:
            print('\nreading: ', grib_file_path)

            # Iterate over each message in the GRIB2 file
            for msg in grib_file:

                smsg = str(msg).lower()

                grib_interval = msg['endStep'] - msg['startStep']

                if (('probability' in smsg) &
                    ((grib_interval == interval) or (not interval))):

                    threshold_in = (round(msg['upperLimit']*0.0393701, 2)
                                    if interval else 0)

                    if (('temperature' in smsg) or (threshold_in <= 4.0)):

                        valid_time = datetime.strptime(
                            f"{msg['validityDate']}{msg['validityTime']}",
                            '%Y%m%d%H%M')

                        nlon, nlat, xlon, xlat = subset_bounds

                        # Extract data and metadata from the GRIB2 message
                        data = msg.values
                        lats, lons = msg.latlons()

                        # Less memory intensive method to subset on read but
                        # returns 2 1D arrays (LCC projection??) and need 2D
                        # data, lats, lons = msg.data(lat1=nlat, lat2=xlat,
                        #                             lon1=nlon, lon2=xlon)

                        # Create an xarray DataArray for the variable
                        da = xr.DataArray(data,
                                        coords={'lat': lats[:, 0],
                                                'lon': lons[0, :]},
                                        dims=['lat', 'lon'])

                        da = da.sel(lat=slice(nlat, xlat),
                                    lon=slice(nlon, xlon))
                        gc.collect()

                        # Add variable metadata as attributes (slow, not needed)
                        # for key in msg.keys():
                        #     if key not in ['values', 'latlons']:
                        #         try:
                        #             da.attrs[key] = msg[key]
                        #         except:
                        #             pass

                        if 'precipitation' in smsg:
                            da.name = f"tp_ge_{str(threshold_in).replace('.','p')}"

                        elif 'temperature' in smsg:
                            gtlt = 'le' if 'below' in smsg else 'ge'
                            tk = (msg['lowerLimit'] if 'below'
                                  in smsg else msg['upperLimit'])
                            tk = tk_fix[tk]
                            tc = tk-273
                            tf = (((tc)*(9/5))+32)
                            da.name = f"temp_{gtlt}_{tf:.0f}".replace('-', 'm')

                        da['valid_time'] = valid_time

                        # Add the DataArray to the list
                        data_arrays.append(da)

        # Combine the list of DataArrays into a single xarray dataset
        ds = xr.merge(data_arrays, compat='override')
        gc.collect()

        ds.to_netcdf(netcdf_file_path)

    return netcdf_file_path

# User Input/Multiprocessing Inputs

In [None]:
region_selection = 'WR'
element = 'qpf' #input('Desired element? (QPF/MaxT/MinT)').lower()

start_date = '20230901'
end_date = '20231130'

# QPF 0/6/12/18, valid 0/6/12/18
# MaxT 6/18 valid 6
# MinT 6/18 valid 18

# Build arg dict
nbm_request_args = {
    #'yyyymmdd':yyyymmdd, #input('Desired init date (YYYYMMDD)? '),
    'interval':24, #6/12/24/48/72, if element==temp then False
    'hh':12, #int(input('Desired init hour int(HH)? ')),
    'lead_time_days':1, #int(input('Desired forecast hour/lead time int(HHH)?')),
    'nbm_area':'co',
    'var':element_var[element],
    'level':element_lev[element]}

if ((element == 'maxt') or (element == 'mint')):
    nbm_request_args['interval'] = False
    nbm_request_args['hh'] = 6 if element == 'maxt' else 18

# Main/Multiprocessing Call

In [None]:
# Convert user input to datetime objects
start_date, end_date = [datetime.strptime(date+'0000', '%Y%m%d%H%M')
    for date in [start_date, end_date]]

# Fix offset of init time vs valid time to verify between chosen dates
valid_hours_advance = (
    nbm_request_args['hh'] + (nbm_request_args['lead_time_days']*24))

if (valid_hours_advance) >= 24:
    start_date -= timedelta(days=int(valid_hours_advance/24))
    end_date -= timedelta(days=int(valid_hours_advance/24))

print(start_date, end_date)

In [None]:
# Build an iterable date list from range
iter_date = start_date
date_selection_iterable = []
while iter_date <= end_date:
    date_selection_iterable.append(iter_date.strftime('%Y%m%d'))
    iter_date += timedelta(days=1)

# Assign the fixed kwargs to the function
multiprocess_function = partial(fetch_grib_from_AWS, **nbm_request_args)

# Set up this way for later additions (e.g. a 2D iterable)
# multiprocess_iterable = [item for item in itertools.product(
#     other_iterable, date_selection_iterable)]
multiprocess_iterable = date_selection_iterable

with get_context('fork').Pool(process_pool_size) as pool:
    print(f'Spooling up process pool for {len(multiprocess_iterable)} tasks '
          f'across {process_pool_size} workers')
    grib_output_files = pool.map(multiprocess_function, multiprocess_iterable)
    pool.terminate()
    print('Multiprocessing Complete')

In [None]:
mp_grib2nc = partial(grib2nc,
                         subset_bounds=get_region_bounds(region_selection),
                         interval=nbm_request_args['interval'])

# netcdf_output_files = []
# for grib_output_file in grib_output_files:
#     netcdf_output_files.append(mp_grib2nc(grib_output_file))

# Seems to behave OK with 4 procs, unstable higher than ?
with get_context('fork').Pool(8) as pool:
    netcdf_output_files = pool.map(mp_grib2nc, grib_output_files,
                                   chunksize=1)
    pool.terminate()

# Compile along time axis
nbm = xr.open_mfdataset(netcdf_output_files, combine='nested', concat_dim='time')
nbm