In [51]:
import xarray as xr
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
import sys
import datetime as dt
import ujson
import kerchunk 
import fsspec

import dask
from dask.distributed import Client

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

from settings import DATA_DIR, JSON_DIR

import importlib

# https://github.com/lsterzinger/kerchunk-medium-tutorial/blob/main/tutorial.ipynb


In [54]:
def parse_date_arg(date, default=None): 
    if date is None and default is None: 
        return pd.to_datetime(pd.Timestamp.today().date())
    elif date is None and default is not None: 
        return pd.to_datetime(default)
    else: 
        return pd.to_datetime(date) 


class GEFSKerchunk: 

    def __init__(self, run_date=None, cycle: str = '00', var_class: str = 'pgrb2ap5', 
                 type_of_level: str = 'heightAboveGround', level: str = 2, local_json: bool = True): 
        """Class for loading collections of GEFS GRIB files as a single zarr-like dataset. 

        Common combintions of type_of_level and level are: 
            * 'heightAboveGround' (levels 2, 10, 80, 100). Temperature, wind, moisture. 
            * 'isobaricInhPa' (levels 1000, 925, ... 10). Pressure levels. Geopotential height,
                temperature, wind, moisture. 
            * 'nominalTop' (level 0). Top of atmosphere. OLR. 
            * 'surface' (level 0). Surface variables. 
            * 'atmosphere' (level 0). Total atmosphere, eg. cloud cover. 
            * 'depthBelowLandLayer' (level 0). Soil properties, temperature/moisture. 

        Args:
            run_date (datetime-like, optional): GEFS run date. Defaults to None.
            cycle (str, optional): GEFS run cycle. Defaults to '00'.
            var_class (str, optional): Which group of variables to draw from. Typical values are 
                'pgrb2ap5' (primary parameters) and 'pgrb2bp5' (secondary parameters). Will probably use
                'pgrb2ap5' for everything except 100 m winds. Defaults to 'pgrb2ap5'.
            type_of_level (str, optional): Type of level at which to pull data. 
                Defaults to heightAboveGround. 
            level (int, optional): Atmospheric data level, operates in concert with type_of_level. 
                Defaults to 2.
            local_json (bool, optional): Write json locally if True, on AWS if False. Defaults to True.
        """
        self.run_date = parse_date_arg(run_date)
        self.cycle = '%02d' % int(cycle)
        self.var_class = var_class
        self.local_json = local_json
        self.type_of_level = type_of_level
        self.level = level

        # File system information. 
        self.fs_remote = None
        self.fs_local = None
        self.json_dir = None
        self.json_combined = None
        self._set_file_system_info()

    def _set_file_system_info(self):
        """Set parameters needed for interacting with remote (NOAA s3) 
        and local file systems. 

        Raises:
            NotImplementedError: Local system as s3 not implemented yet. 
        """

        # S3 remote filesystem: where gefs data is coming from
        self.fs_remote = fsspec.filesystem('s3', anon=True, skip_instance_cache=True)

        # File system for managing json files. Can be local disk or s3. 
        if self.local_json: 
            self.fs_local = fsspec.filesystem('file')

            # Set addresses for json files. 
            type_of_level, level = self.type_of_level, str(self.level)
            self.json_dir = os.path.join(
                JSON_DIR, 'gefs', 
                self.run_date.strftime("%Y%m%d"), self.cycle, 
                self.var_class, type_of_level, level
            )
            self.json_combined = 'combined.json'

            self.fs_local.makedirs(self.json_dir, exist_ok=True)

        else: 
            raise NotImplementedError

    def _get_dask_client(self, n_workers: int = 5) -> dask.distributed.Client: 
        """Initialize dask client for use in writing JSONs. 

        Args:
            n_workers (int, optional): Number of workers in dask client. Defaults to 5.

        Raises:
            NotImplementedError: Not implemented for AWS client. 

        Returns:
            dask.distributed.Client: Dask client with n_workers workers. 
        """
        if self.local_json: 
            cluster = dask.distributed.LocalCluster(n_workers=n_workers)
        else: 
            raise NotImplementedError
        return dask.distributed.Client(cluster)

    def _get_remote_files(self, test: bool = True) -> list: 
        """Get list of s3 grib files in model run. 

        Args:
            test (bool, optional): Only download first 6 hours if True. Defaults to True. 

        Returns:
            list: list of s3 addresses. 
        """
        
        # Get list of all possible files. 
        datestr = self.run_date.strftime("%Y%m%d")
        files = [
            f"s3://{f}" 
            for f in self.fs_remote.glob(f's3://noaa-gefs-pds/gefs.{datestr}/{self.cycle}/atmos/{self.var_class}/*') 
            if f[-4:] != '.idx'
            and '/geavg.' not in f
            and '/gespr.' not in f
        ]

        if test: 
            # Filter on hours
            fhr_filter = ['000', '003']
            files = [f for f in files if f[-3:] in fhr_filter]
        
        return files

    def _generate_json(self, remote_files: list, n_workers: int = 10): 
        """Generate JSON files that kerchunk will use to access the data. 

        Args:
            remote_files (list): list of grib file s3 addresses 
            n_workers (int, optional): number of workers to use for writing files. Defaults to 10.
        """

        def write_json(grib_file: str) -> None: 
            """Write a json file for a single grib file.

            Args:
                grib_file (str): address of grib file on s3.  
            """
        
            # Set options
            storage_options = dict(
                anon=True, 
                default_cache_type="readahead"
            )

            # Filter for grib parameters. 
            afilter = {
                'typeOfLevel': self.type_of_level, 
                'level': self.level
            }

            common_coords = ['time', 'step', 'latitude', 'longitude', 'valid_time', 'number']

            # Scan GRIB file
            out = kerchunk.grib2.scan_grib(grib_file, common_coords, storage_options, inline_threashold=100, filter=afilter)
            output_file = os.path.join(self.json_dir, f"{grib_file.split('/')[-1]}.json")

            # Save to self.json_dir. 
            try: 
                self.fs_local.rm_file(output_file)    
            except OSError: 
                pass

            with open(output_file, 'w') as outf:
                outf.write(ujson.dumps(out))

        # Write grib files in parallel. 
        client = self._get_dask_client(n_workers=n_workers)
        print('JSON generation running on dask server:', client.dashboard_link)
        _ = dask.compute(*[dask.delayed(write_json)(u) for u in remote_files], retries=10)
        client.close()

    def _combine_json(self):
        """Combine kerchunk json files for each grib into a single, combined JSON.
        """

        try: 
            self.fs_local.rm_file(os.path.join(self.json_dir, self.json_combined))    
        except OSError: 
            pass

        json_files = [f for f in self.fs_local.ls(self.json_dir) if f != self.json_combined]
        mzz = kerchunk.combine.MultiZarrToZarr(
            json_files,
            remote_protocol="s3",
            remote_options=dict(anon=True, skip_instance_cache=True),
            concat_dims=['valid_time', 'number'],
            identical_dims=['latitude', 'longitude', self.type_of_level, 'time'],
            inline_threshold=500, 
            # preprocess=kerchunk.combine.drop("step"),
        )
        mzz.translate(os.path.join(self.json_dir, self.json_combined))

    def build_json_catalog(self, n_workers: int = 10):
        """Function that creates the combined datset with kerchunk. 

        Args:
            n_workers (int, optional): Number of dask workers to use for 
                creating kerchunk json files. Defaults to 10.
        """
        files = self._get_remote_files()
        self._generate_json(remote_files=files, n_workers=n_workers) 
        self._combine_json()

    def open_dataset(self) -> xr.Dataset: 
        """Open combined dataset in xarray. 

        Returns:
            xr.Dataset: Dataset with the data in question. 
        """

        fs = fsspec.filesystem(
            "reference", 
            fo=os.path.join(self.json_dir, self.json_combined), 
            remote_protocol="s3", 
            remote_options={"anon": True}, 
            skip_instance_cache=True
        )
        m = fs.get_mapper("")
        ds = xr.open_dataset(m, engine='zarr', consolidated=False)

        # Zero's get mapped to NaN in dimensions for some reason. 
        ds['number'] = ds['number'].fillna(0.)
        ds['longitude'] = ds['longitude'].fillna(0.)
        ds['latitude'] = ds['latitude'].fillna(0.)

        # Don't know how to deal with this properly. Not important. 
        ds = ds.drop('step')

        return ds

In [65]:
loader = GEFSKerchunk(
    run_date=pd.Timestamp.today(tz='UTC'), 
    cycle='00', 
    var_class='pgrb2bp5', 
    type_of_level='heightAboveGround', 
    level=100, 
    local_json=True
)
loader._combine_json()
ds = loader.open_dataset().chunk(chunks={'number': 1})

ds

Unnamed: 0,Array,Chunk
Bytes,744 B,24 B
Shape,"(3, 31)","(3, 1)"
Count,32 Tasks,31 Chunks
Type,timedelta64[ns],numpy.ndarray
"Array Chunk Bytes 744 B 24 B Shape (3, 31) (3, 1) Count 32 Tasks 31 Chunks Type timedelta64[ns] numpy.ndarray",31  3,

Unnamed: 0,Array,Chunk
Bytes,744 B,24 B
Shape,"(3, 31)","(3, 1)"
Count,32 Tasks,31 Chunks
Type,timedelta64[ns],numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,92.21 MiB,2.97 MiB
Shape,"(3, 31, 361, 720)","(3, 1, 361, 720)"
Count,32 Tasks,31 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 92.21 MiB 2.97 MiB Shape (3, 31, 361, 720) (3, 1, 361, 720) Count 32 Tasks 31 Chunks Type float32 numpy.ndarray",3  1  720  361  31,

Unnamed: 0,Array,Chunk
Bytes,92.21 MiB,2.97 MiB
Shape,"(3, 31, 361, 720)","(3, 1, 361, 720)"
Count,32 Tasks,31 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,92.21 MiB,2.97 MiB
Shape,"(3, 31, 361, 720)","(3, 1, 361, 720)"
Count,32 Tasks,31 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 92.21 MiB 2.97 MiB Shape (3, 31, 361, 720) (3, 1, 361, 720) Count 32 Tasks 31 Chunks Type float32 numpy.ndarray",3  1  720  361  31,

Unnamed: 0,Array,Chunk
Bytes,92.21 MiB,2.97 MiB
Shape,"(3, 31, 361, 720)","(3, 1, 361, 720)"
Count,32 Tasks,31 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,92.21 MiB,2.97 MiB
Shape,"(3, 31, 361, 720)","(3, 1, 361, 720)"
Count,32 Tasks,31 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 92.21 MiB 2.97 MiB Shape (3, 31, 361, 720) (3, 1, 361, 720) Count 32 Tasks 31 Chunks Type float32 numpy.ndarray",3  1  720  361  31,

Unnamed: 0,Array,Chunk
Bytes,92.21 MiB,2.97 MiB
Shape,"(3, 31, 361, 720)","(3, 1, 361, 720)"
Count,32 Tasks,31 Chunks
Type,float32,numpy.ndarray


In [48]:
ds = xr.open_dataset(
    os.path.join(DATA_DIR, 'gec00.t00z.pgrb2a.0p50.f003'),
    engine='cfgrib', 
    filter_by_keys={'typeOfLevel': 'isobaricInhPa', 'level': 500}, 
    indexpath='',
)

In [50]:
ds['step'].value

In [134]:
import cfgrib
cds = cfgrib.open_datasets(os.path.join(DATA_DIR, 'gec00.t00z.pgrb2a.0p50.f003'))
cds

[<xarray.Dataset>
 Dimensions:     (latitude: 361, longitude: 720)
 Coordinates:
     number      int64 0
     time        datetime64[ns] 2022-07-15
     step        timedelta64[ns] 03:00:00
     atmosphere  float64 0.0
   * latitude    (latitude) float64 90.0 89.5 89.0 88.5 ... -89.0 -89.5 -90.0
   * longitude   (longitude) float64 0.0 0.5 1.0 1.5 ... 358.0 358.5 359.0 359.5
     valid_time  datetime64[ns] ...
 Data variables:
     tcc         (latitude, longitude) float32 ...
 Attributes:
     GRIB_edition:            2
     GRIB_centre:             kwbc
     GRIB_centreDescription:  US National Weather Service - NCEP
     GRIB_subCentre:          2
     Conventions:             CF-1.7
     institution:             US National Weather Service - NCEP,
 <xarray.Dataset>
 Dimensions:                (latitude: 361, longitude: 720)
 Coordinates:
     number                 int64 0
     time                   datetime64[ns] 2022-07-15
     step                   timedelta64[ns] 03:00:00
  

In [135]:
cds

[<xarray.Dataset>
 Dimensions:     (latitude: 361, longitude: 720)
 Coordinates:
     number      int64 0
     time        datetime64[ns] 2022-07-15
     step        timedelta64[ns] 03:00:00
     atmosphere  float64 0.0
   * latitude    (latitude) float64 90.0 89.5 89.0 88.5 ... -89.0 -89.5 -90.0
   * longitude   (longitude) float64 0.0 0.5 1.0 1.5 ... 358.0 358.5 359.0 359.5
     valid_time  datetime64[ns] ...
 Data variables:
     tcc         (latitude, longitude) float32 ...
 Attributes:
     GRIB_edition:            2
     GRIB_centre:             kwbc
     GRIB_centreDescription:  US National Weather Service - NCEP
     GRIB_subCentre:          2
     Conventions:             CF-1.7
     institution:             US National Weather Service - NCEP,
 <xarray.Dataset>
 Dimensions:                (latitude: 361, longitude: 720)
 Coordinates:
     number                 int64 0
     time                   datetime64[ns] 2022-07-15
     step                   timedelta64[ns] 03:00:00
  

In [125]:
ds = xr.open_dataset(
    os.path.join(DATA_DIR, 'gec00.t00z.pgrb2b.0p50.f003'), 
    engine='cfgrib', 
    # filter_by_keys={'typeOfLevel':'atmosphereSingleLayer'}
    filter_by_keys={'paramId': 'UGRD'}
)

In [126]:
ds

In [111]:
ds2 = xr.open_dataset(os.path.join(DATA_DIR, 'gec00.t00z.pgrb2b.0p50.f003'), engine='cfgrib')

DatasetBuildError: multiple values for unique key, try re-open the file with one of:
    filter_by_keys={'typeOfLevel': 'meanSea'}
    filter_by_keys={'typeOfLevel': 'hybrid'}
    filter_by_keys={'typeOfLevel': 'surface'}
    filter_by_keys={'typeOfLevel': 'planetaryBoundaryLayer'}
    filter_by_keys={'typeOfLevel': 'isobaricInhPa'}
    filter_by_keys={'typeOfLevel': 'depthBelowLandLayer'}
    filter_by_keys={'typeOfLevel': 'heightAboveGround'}
    filter_by_keys={'typeOfLevel': 'atmosphereSingleLayer'}
    filter_by_keys={'typeOfLevel': 'lowCloudLayer'}
    filter_by_keys={'typeOfLevel': 'middleCloudLayer'}
    filter_by_keys={'typeOfLevel': 'highCloudLayer'}
    filter_by_keys={'typeOfLevel': 'cloudCeiling'}
    filter_by_keys={'typeOfLevel': 'convectiveCloudBottom'}
    filter_by_keys={'typeOfLevel': 'lowCloudBottom'}
    filter_by_keys={'typeOfLevel': 'middleCloudBottom'}
    filter_by_keys={'typeOfLevel': 'highCloudBottom'}
    filter_by_keys={'typeOfLevel': 'nominalTop'}
    filter_by_keys={'typeOfLevel': 'convectiveCloudTop'}
    filter_by_keys={'typeOfLevel': 'lowCloudTop'}
    filter_by_keys={'typeOfLevel': 'middleCloudTop'}
    filter_by_keys={'typeOfLevel': 'highCloudTop'}
    filter_by_keys={'typeOfLevel': 'convectiveCloudLayer'}
    filter_by_keys={'typeOfLevel': 'boundaryLayerCloudLayer'}
    filter_by_keys={'typeOfLevel': 'heightAboveGroundLayer'}
    filter_by_keys={'typeOfLevel': 'tropopause'}
    filter_by_keys={'typeOfLevel': 'maxWind'}
    filter_by_keys={'typeOfLevel': 'heightAboveSea'}
    filter_by_keys={'typeOfLevel': 'isothermZero'}
    filter_by_keys={'typeOfLevel': 'highestTroposphericFreezing'}
    filter_by_keys={'typeOfLevel': 'pressureFromGroundLayer'}
    filter_by_keys={'typeOfLevel': 'sigmaLayer'}
    filter_by_keys={'typeOfLevel': 'sigma'}
    filter_by_keys={'typeOfLevel': 'theta'}
    filter_by_keys={'typeOfLevel': 'potentialVorticity'}

In [103]:
ds2['pressureFromGroundLayer'].values