<a href="https://colab.research.google.com/github/m-wessler/nbm-verification/blob/main/get_nbm_aws_streamline_multi.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports

In [None]:
!pip install boto3
!pip install pygrib

import os, gc
import boto3
import pygrib

import numpy as np
import pandas as pd
import xarray as xr

from functools import partial
from botocore import UNSIGNED
from botocore.client import Config
from datetime import datetime, timedelta
from multiprocessing import cpu_count, Pool

# Globals

In [None]:
# Multiprocess settings
process_pool_size = cpu_count()*8
print(f'Process Pool Size: {process_pool_size}')

# Define Globals
aws_bucket = 'noaa-nbm-grib2-pds'

# Where to place the grib file (subdirs can be added in local) (not used)
# output_dir = './'

# Which grib variables do each element correlate with
element_var = {'qpf':'APCP',
                  'maxt':'TMP',
                  'mint':'TMP'}

# Which grib levels do each element correlate with
element_lev = {'qpf':'surface',
               'maxt':'2 m above ground',
               'mint':'2 m above ground'}

# If a grib message contains any of these, exclude
excludes = ['ens std dev', '% lev']

# Methods

In [None]:
def fetch_grib_from_AWS(iter_item, **req):

    yyyymmdd = iter_item

    nbm_sets = ['core', 'qmd'] if ((element == 'qpf') &
                (req['hh'] % 6 == 0)) else ['core']

    output_file = f'{yyyymmdd}.t{req["hh"]:02d}z.fhr{req["fhr"]:03d}.{req["var"]}.grib2'

    if os.path.isfile(output_file):
        return output_file

    else:
        for nbm_set in nbm_sets:

            bucket_dir = f'blend.{yyyymmdd}/{req["hh"]:02d}/{nbm_set}/'

            grib_file = f'{bucket_dir}blend.t{req["hh"]:02d}z.'+\
                        f'{nbm_set}.f{req["fhr"]:03d}.{req["nbm_area"]}.grib2'

            index_file = f'{grib_file}.idx'

            client = boto3.client('s3', config=Config(signature_version=UNSIGNED))

            index_data_raw = client.get_object(
                Bucket=aws_bucket, Key=index_file)['Body'].read().decode().split('\n')

            cols = ['num', 'byte', 'date', 'var', 'level',
                'forecast', 'fthresh', 'ftype', '']

            index_data = pd.DataFrame([item.split(':') for item in index_data_raw],
                            columns=cols if nbm_set == 'core' else cols[:-1])

            # Clean up any ghost indicies, set the index
            index_data = index_data[index_data['num'] != '']
            index_data['num'] = index_data['num'].astype(int)
            index_data = index_data.set_index('num')

            # Allow byte ranging to '' (EOF)
            index_data.loc[index_data.shape[0]+1] = ['']*index_data.shape[1]

            index_subset = index_data[
                ((index_data['var'] == req['var']) &
                (index_data['level'] == req['level']))]

            # byte start >> byte range
            for i in index_subset.index:
                index_subset.loc[i]['byte'] = (
                    index_data.loc[i, 'byte'],
                    index_data.loc[int(i)+1, 'byte'])

            # Filter out excluded vars
            for ex in excludes:
                mask = np.column_stack([index_subset[col].str.contains(ex, na=False)
                                        for col in index_subset])

                index_subset = index_subset.loc[~mask.any(axis=1)]

            # Fetch the data by byte range, write from stream
            for index, item in index_subset.iterrows():
                byte_range = f"bytes={item['byte'][0]}-{item['byte'][1]}"

                output_bytes = client.get_object(
                    Bucket=aws_bucket, Key=grib_file, Range=byte_range)

                with open(output_file, 'ab') as wfp:
                    for chunk in output_bytes['Body'].iter_chunks(chunk_size=4096):
                        wfp.write(chunk)

    client.close()
    return output_file

In [None]:
def grib2xarray(grib_file_path, precip_interval=None):

    # Create a list to store the data arrays for each variable
    data_arrays = []

    # Open the GRIB2 file using pygrib
    with pygrib.open(grib_file_path) as grib_file:
        print('\nreading: ', grib_file_path)

        # Iterate over each message in the GRIB2 file
        for msg in grib_file:

            if (('Probability' in str(msg))
                & (msg['lengthOfTimeRange'] == precip_interval)):

                threshold_in = round(msg['upperLimit']*0.0393701, 2)

                if threshold_in <= 4.0:

                    valid_time = datetime.strptime(
                        f"{msg['validityDate']}{msg['validityTime']:02d}",
                        '%Y%m%d%H')

                    # Extract data and metadata from the GRIB2 message
                    data = msg.values
                    lats, lons = msg.latlons()

                    # Create an xarray DataArray for the variable
                    da = xr.DataArray(data,
                                    coords={'lat': lats[:, 0],
                                            'lon': lons[0, :]},
                                    dims=['lat', 'lon'])

                    # Add variable metadata as attributes (slow, not needed?)
                    # for key in msg.keys():
                    #     if key not in ['values', 'latlons']:
                    #         try:
                    #             da.attrs[key] = msg[key]
                    #         except:
                    #             pass

                    da.name = f"tp_gt_{str(threshold_in).replace('.','p')}"
                    da['valid_time'] = valid_time

                    # Add the DataArray to the list
                    data_arrays.append(da)

    # Combine the list of DataArrays into a single xarray dataset
    ds = xr.merge(data_arrays, compat='override')
    gc.collect()

    return ds

# User Input/Multiprocessing Inputs

In [None]:
element = 'qpf' #input('Desired element? (QPF/MaxT/MinT)').lower()

start_date = '20231001'
end_date = '20231031'

# Immediately convert user input to datetime objects
start_date, end_date = [datetime.strptime(date+'0000', '%Y%m%d%H%M')
    for date in [start_date, end_date]]

# Main/Multiprocessing Call

In [None]:
# Build arg dict
nbm_request_args = {
    #'yyyymmdd':yyyymmdd, #input('Desired init date (YYYYMMDD)? '),
    'hh':00, #int(input('Desired init hour int(HH)? ')),
    'fhr':24, #int(input('Desired forecast hour/lead time int(HHH)?')),
    'nbm_area':'co',
    'var':element_var[element],
    'level':element_lev[element]}

In [None]:
# Build an iterable date list from range
iter_date = start_date
date_selection_iterable = []
while iter_date <= end_date:
    date_selection_iterable.append(iter_date.strftime('%Y%m%d'))
    iter_date += timedelta(days=1)

# Assign the fixed kwargs to the function
multiprocess_function = partial(fetch_grib_from_AWS, **nbm_request_args)

# Set up this way for later additions (e.g. a 2D iterable)
# multiprocess_iterable = [item for item in itertools.product(
#     other_iterable, date_selection_iterable)]
multiprocess_iterable = date_selection_iterable

with Pool(process_pool_size) as pool:
    print(f'Spooling up process pool for {len(multiprocess_iterable)} tasks '
          f'across {process_pool_size} workers')
    output_files = pool.map(multiprocess_function, multiprocess_iterable)
    pool.terminate()
    print('Multiprocessing Complete')

In [None]:
mp_grib2xarray = partial(grib2xarray, precip_interval=24)

with Pool(8) as pool:
    ds_list = pool.map(mp_grib2xarray, output_files[:4])
    pool.terminate()

# Compile along time axis
ds = xr.concat(ds_list, dim='valid_time')