In [1]:
import xarray as xr
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
import sys
import datetime as dt
import ujson
import kerchunk 
import fsspec

import dask
from dask.distributed import Client

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

from settings import DATA_DIR, JSON_DIR

import importlib

# https://github.com/lsterzinger/kerchunk-medium-tutorial/blob/main/tutorial.ipynb


In [8]:
def parse_date_arg(date, default=None): 
    if date is None and default is None: 
        return pd.to_datetime(pd.Timestamp.today().date())
    elif date is None and default is not None: 
        return pd.to_datetime(default)
    else: 
        return pd.to_datetime(date) 


class HRRRInitKerchunk: 

    def __init__(self, var_class: str = 'sfc', type_of_level: str = 'heightAboveGround', 
                 level: str = 2, local_json: bool = True): 
        """Class for loading collections of GEFS GRIB files as a single zarr-like dataset. 

        Common combintions of type_of_level and level are: 
            * 'heightAboveGround' (levels 2, 10, 80). Temperature, wind, moisture. 
            * 'isobaricInhPa' (levels 1000, 925, ... 10). Pressure levels. Geopotential height,
                temperature, wind, moisture. 
            * 'surface' (level 0). Surface variables. 
            * 'atmosphere' (level 0). Total atmosphere, eg. cloud cover. 

        Args:
            var_class (str, optional): Which group of variables to draw from. Typical values are 
                'sfc' (surface + key pressure levels) and 'prs' (all pressure levels). Will probably use
                'sfc' for everything.
            type_of_level (str, optional): Type of level at which to pull data. 
                Defaults to heightAboveGround. 
            level (int, optional): Atmospheric data level, operates in concert with type_of_level. 
                Defaults to 2.
            local_json (bool, optional): Write json locally if True, on AWS if False. Defaults to True.
        """
        self.var_class = var_class
        self.local_json = local_json
        self.type_of_level = type_of_level
        self.level = level

        # File system information. 
        self.fs_remote = None
        self.fs_local = None
        self.json_dir = None
        self.json_combined = None
        self._set_file_system_info()

    def _set_file_system_info(self):
        """Set parameters needed for interacting with remote (NOAA s3) 
        and local file systems. 

        Raises:
            NotImplementedError: Local system as s3 not implemented yet. 
        """

        # S3 remote filesystem: where gefs data is coming from
        self.fs_remote = fsspec.filesystem('s3', anon=True, skip_instance_cache=True)

        # File system for managing json files. Can be local disk or s3. 
        if self.local_json: 
            self.fs_local = fsspec.filesystem('file')

            # Set addresses for json files. 
            type_of_level, level = self.type_of_level, str(self.level)
            self.json_dir = os.path.join(
                JSON_DIR, 'hrrr_init', 
                self.var_class, type_of_level, level
            )
            self.json_combined = 'combined.json'

            self.fs_local.makedirs(self.json_dir, exist_ok=True)

        else: 
            raise NotImplementedError

    def _get_dask_client(self, n_workers: int = 5) -> dask.distributed.Client: 
        """Initialize dask client for use in writing JSONs. 

        Args:
            n_workers (int, optional): Number of workers in dask client. Defaults to 5.

        Raises:
            NotImplementedError: Not implemented for AWS client. 

        Returns:
            dask.distributed.Client: Dask client with n_workers workers. 
        """

        if self.local_json: 
            cluster = dask.distributed.LocalCluster(n_workers=n_workers)
        else: 
            raise NotImplementedError
        
        return dask.distributed.Client(cluster)

    def _get_remote_files(self, start_date: pd.Timestamp, end_date: pd.Timestamp) -> list: 
        """Get list of s3 grib files in model run. 

        Args:
            test (bool, optional): Only download first 6 hours if True. Defaults to True. 
            start_date (pd.Timestamp): Start of date range for update. 
            end_date (pd.Timestamp): End of date range for update. 

        Returns:
            list: list of s3 addresses. 
        """

        def date_from_fname(s): 
            return pd.to_datetime(s.split('/')[1].replace('hrrr.', ''), format='%Y%m%d')
        
        # Get list of all possible files. 
        files = [
            f"s3://{f}" 
            for f in self.fs_remote.glob(f's3://noaa-hrrr-bdp-pds/hrrr.*/conus/hrrr.t*z.wrf{self.var_class}f00.grib2') 
            if f[-4:] != '.idx'
            and start_date <= date_from_fname(f) <= end_date
        ]
        
        return files

    def _generate_json(self, remote_files: list, n_workers: int = 10): 
        """Generate JSON files that kerchunk will use to access the data. 

        Args:
            remote_files (list): list of grib file s3 addresses 
            n_workers (int, optional): number of workers to use for writing files. Defaults to 10.
        """

        def write_json(grib_file: str) -> None: 
            """Write a json file for a single grib file.

            Args:
                grib_file (str): address of grib file on s3.  
            """
        
            # Set options
            storage_options = dict(
                anon=True, 
                default_cache_type="readahead"
            )

            # Filter for grib parameters. 
            afilter = {
                'typeOfLevel': self.type_of_level, 
                'level': self.level,
            }

            common_coords = ['time', 'step', 'latitude', 'longitude', 'valid_time']

            # Scan GRIB file
            out = kerchunk.grib2.scan_grib(grib_file, common_coords, storage_options, inline_threashold=100, filter=afilter)
            output_file = os.path.join(self.json_dir, f"{grib_file.split('/')[-1]}.json")

            # Save to self.json_dir. 
            try: 
                self.fs_local.rm_file(output_file)    
            except OSError: 
                pass

            with open(output_file, 'w') as outf:
                outf.write(ujson.dumps(out))

        # Write grib files in parallel. 
        client = self._get_dask_client(n_workers=n_workers)
        print('JSON generation running on dask server:', client.dashboard_link)
        _ = dask.compute(*[dask.delayed(write_json)(u) for u in remote_files], retries=10)
        client.close()

    def _combine_json(self):
        """Combine kerchunk json files for each grib into a single, combined JSON.
        """

        try: 
            self.fs_local.rm_file(os.path.join(self.json_dir, self.json_combined))    
        except OSError: 
            pass

        json_files = [f for f in self.fs_local.ls(self.json_dir) if f != self.json_combined]
        mzz = kerchunk.combine.MultiZarrToZarr(
            json_files,
            remote_protocol="s3",
            remote_options=dict(anon=True, skip_instance_cache=True),
            concat_dims=['valid_time', 'number'],
            identical_dims=['latitude', 'longitude', self.type_of_level, 'time'],
            inline_threshold=500, 
            # preprocess=kerchunk.combine.drop("step"),
        )
        mzz.translate(os.path.join(self.json_dir, self.json_combined))

    def update_json_catalog(self, start_date, end_date, n_workers: int = 10):
        """Function that creates the combined datset with kerchunk. 

        Args:
            n_workers (int, optional): Number of dask workers to use for 
                creating kerchunk json files. Defaults to 10.
            start_date (pd.Timestamp): Start of date range for update. 
            end_date (pd.Timestamp): End of date range for update. 
        """
        start_date = parse_date_arg(start_date)
        end_date = parse_date_arg(end_date)
        files = self._get_remote_files(start_date=start_date, end_date=end_date)
        self._generate_json(remote_files=files, n_workers=n_workers) 
        self._combine_json()

    def open_dataset(self) -> xr.Dataset: 
        """Open combined dataset in xarray. 

        Returns:
            xr.Dataset: Dataset with the data in question. 
        """

        fs = fsspec.filesystem(
            "reference", 
            fo=os.path.join(self.json_dir, self.json_combined), 
            remote_protocol="s3", 
            remote_options={"anon": True}, 
            skip_instance_cache=True
        )
        m = fs.get_mapper("")
        ds = xr.open_dataset(m, engine='zarr', consolidated=False)

        # Zero's get mapped to NaN in dimensions for some reason. 
        ds['number'] = ds['number'].fillna(0.)
        ds['longitude'] = ds['longitude'].fillna(0.)
        ds['latitude'] = ds['latitude'].fillna(0.)

        # Don't know how to deal with this properly. Not important. 
        ds = ds.drop('step')

        return ds

In [7]:
loader = HRRRInitKerchunk(
    var_class='sfc', 
    type_of_level='heightAboveGround', 
    level=2, 
    local_json=True
)
loader.update_json_catalog(
    start_date=pd.to_datetime('2022-07-20'), 
    end_date=pd.to_datetime('2022-07-20'), 
    n_workers=10
)
ds = loader.open_dataset()
ds

KeyboardInterrupt: 

In [2]:
import cfgrib
cds = cfgrib.open_datasets(os.path.join(DATA_DIR, 'hrrr.t00z.wrfsfcf00.grib2'))
cds

[<xarray.Dataset>
 Dimensions:                (y: 1059, x: 1799)
 Coordinates:
     time                   datetime64[ns] 2022-07-12
     step                   timedelta64[ns] 00:00:00
     adiabaticCondensation  float64 0.0
     latitude               (y, x) float64 ...
     longitude              (y, x) float64 ...
     valid_time             datetime64[ns] ...
 Dimensions without coordinates: y, x
 Data variables:
     gh                     (y, x) float32 ...
 Attributes:
     GRIB_edition:            2
     GRIB_centre:             kwbc
     GRIB_centreDescription:  US National Weather Service - NCEP
     GRIB_subCentre:          0
     Conventions:             CF-1.7
     institution:             US National Weather Service - NCEP,
 <xarray.Dataset>
 Dimensions:     (y: 1059, x: 1799)
 Coordinates:
     time        datetime64[ns] 2022-07-12
     step        timedelta64[ns] 00:00:00
     atmosphere  float64 0.0
     latitude    (y, x) float64 21.14 21.15 21.15 21.16 ... 47.86 47.

In [None]:
cds