In [1]:
%load_ext jupyter_black

In [2]:
import dataclasses
from abc import abstractmethod
from datetime import datetime, timedelta
from typing import Literal, Callable, TypeVar, ParamSpec, Iterator

import s3fs
import pandas as pd
import xarray as xr
from IPython.display import HTML

from griblib.hrrr._zarr import ZArrTable

In [3]:
from datetime import datetime
from griblib.hrrr import load_hrrr

In [4]:
hrrr = load_hrrr(datetime(2022, 6, 16, 6), 5)

In [5]:
ZArrTable

Unnamed: 0_level_0,Unnamed: 1_level_0,vertical_level,parameter_short_name,units,1st_version_available
analysis_or_forecast,parameter_long_name,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
both,u-component_vertical_wind_shear,0_1000m_above_ground,VUCSH,1/s,V3
both,v-component_vertical_wind_shear,0_1000m_above_ground,VVCSH,1/s,V3
both,convective_available_potential_energy,0_3000m_above_ground,CAPE,J/kg,V4
both,u-component_storm_motion,0_6000m_above_ground,USTM,m/s,V3
both,v-component_storm_motion,0_6000m_above_ground,VSTM,m/s,V3
...,...,...,...,...,...
anl,hail,entire_atmosphere,HAIL_max_fcst,m,V4
anl,total_column_integrated_graupel,entire_atmosphere_single_layer,TCOLG_max_fcst,kg/m2,V3
anl,hail,surface,HAIL_max_fcst,m,V4
anl,storm_surface_runoff,surface,SSRUN_acc_fcst,kg/m2,V3


In [32]:
T = TypeVar("T")
P = ParamSpec("P")
idx: slice = pd.IndexSlice


class SharedTable:
    __table = ZArrTable[["vertical_level", "parameter_short_name"]]

    @property
    def __model_type(self):
        if isinstance(self, Forecast):
            return "fcst", "both"
        elif isinstance(self, Analysis):
            return "anl", "both"
        else:
            return "fcst", "anl", "both"

    def __repr__(self) -> str:
        return self.table.__repr__()

    def _repr_html_(self) -> HTML:
        return self.table._repr_html_()

    @property
    def table(self) -> pd.DataFrame:
        return self.__table.loc[idx[self.__model_type, :], :].droplevel(0).copy()


class Base(SharedTable):
    def __init__(self, hrrr: "HRRR2", level_type: Literal["sfc", "prs"]):
        model_type = "fcst" if isinstance(self, Forecast) else "anl"
        base_url = f"s3://hrrrzarr/{level_type}/" + hrrr.date_range.strftime(f"%Y%m%d/%Y%m%d_%Hz_{model_type}.zarr")

        def generate_urls():
            for long_name, (vlevel, short_name) in self.table.iterrows():
                base = base_url + f"/{vlevel}/{short_name}"
                yield (long_name, vlevel), tuple(zip(base, base + f"/{vlevel}"))

        self.hrrr = hrrr
        self._urldf = pd.DataFrame(dict(generate_urls()), index=hrrr.date_range)

    def iterload(self, long_name: str, vertical_level: str) -> Iterator[xr.Dataset]:
        for urls in self._urldf[long_name, vertical_level]:
            yield xr.open_mfdataset(
                (s3fs.S3Map(url, s3=self.hrrr.fs) for url in urls),
                engine="zarr",
            )


def loadermethod(func: Callable[P, T]) -> Callable[P, T]:
    # possible_levels = get_args(func.__annotations__["vertical_level"])
    anno = func.__annotations__.copy()
    anno.pop("return", None)
    # ds_callback = xr.Dataset in func.__annotations__.values()
    default: tuple[str, ...] = func.__defaults__
    if default:
        (default_value,) = default

    long_name = func.__name__

    def inner(self: "Base", vertical_level: str = default_value):

        return xr.concat(
            self.iterload(long_name, vertical_level),
            dim="valid_time",
            combine_attrs="override",
        )

    return inner


class Both(Base):
    @loadermethod
    @abstractmethod
    def temperature(
        self,
        vertical_level: Literal[
            "1000mb", "2m_above_ground", "500mb", "700mb", "850mb", "925mb", "surface"
        ] = "surface",
    ) -> xr.Dataset:
        ...

    @loadermethod
    @abstractmethod
    def hail(
        self, vertical_level: Literal["0.1_sigma_layer", "entire_atmosphere", "surface"] = "surface"
    ) -> xr.Dataset:
        ...

    # @withrepr(lambda x: "<Func: %s>" % x.__name__)
    @loadermethod
    @abstractmethod
    def vertical_velocity(self, vertical_level: Literal["0.5_0.8_sigma_layer", "700mb"] = "700mb") -> xr.Dataset:
        ...


class Forecast(Both):
    ...


class Analysis(Both):
    ...


# LevelType enherits the SharedTable the so the entire table can be viewed
class LevelType(SharedTable):
    def __init__(self, hrrr: "HRRR2", level_type: Literal["sfc", "prs"]):
        self.__forecast = Forecast(hrrr, level_type)
        self.__analysis = Analysis(hrrr, level_type)

    @property
    def forecast(self):
        return self.__forecast

    @property
    def analysis(self):
        return self.__analysis


@dataclasses.dataclass
class HRRR2:
    start_date: datetime
    hours: int
    date_range: pd.DatetimeIndex
    fs: s3fs.S3FileSystem

    @property
    def surface(self):
        return LevelType(self, level_type="sfc")

    @property
    def pressure(self):
        return LevelType(self, level_type="prs")


def load_hrrr(start_date: datetime, hour_delta: int) -> HRRR2:
    date_range = pd.date_range(start_date, start_date + timedelta(hours=hour_delta))
    return HRRR2(start_date, hour_delta, date_range, s3fs.S3FileSystem(anon=True))


hrrr = load_hrrr(datetime(2022, 6, 16, 6), 5)
hrrr.surface.forecast.vertical_velocity()

Unnamed: 0,Array,Chunk
Bytes,174.42 MiB,2.06 MiB
Shape,"(1, 48, 1059, 1799)","(1, 48, 150, 150)"
Count,193 Tasks,96 Chunks
Type,float16,numpy.ndarray
"Array Chunk Bytes 174.42 MiB 2.06 MiB Shape (1, 48, 1059, 1799) (1, 48, 150, 150) Count 193 Tasks 96 Chunks Type float16 numpy.ndarray",1  1  1799  1059  48,

Unnamed: 0,Array,Chunk
Bytes,174.42 MiB,2.06 MiB
Shape,"(1, 48, 1059, 1799)","(1, 48, 150, 150)"
Count,193 Tasks,96 Chunks
Type,float16,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,384 B,384 B
Shape,"(1, 48)","(1, 48)"
Count,3 Tasks,1 Chunks
Type,timedelta64[ns],numpy.ndarray
"Array Chunk Bytes 384 B 384 B Shape (1, 48) (1, 48) Count 3 Tasks 1 Chunks Type timedelta64[ns] numpy.ndarray",48  1,

Unnamed: 0,Array,Chunk
Bytes,384 B,384 B
Shape,"(1, 48)","(1, 48)"
Count,3 Tasks,1 Chunks
Type,timedelta64[ns],numpy.ndarray


In [16]:
hrrr.surface.forecast.vertical_velocity()

<bound method reprwrapper._loadermethod of self._repr(self._func)>

In [8]:
# uses surface by default
hrrr.surface.forecast.temperature()

Unnamed: 0,Array,Chunk
Bytes,174.42 MiB,2.06 MiB
Shape,"(1, 48, 1059, 1799)","(1, 48, 150, 150)"
Count,193 Tasks,96 Chunks
Type,float16,numpy.ndarray
"Array Chunk Bytes 174.42 MiB 2.06 MiB Shape (1, 48, 1059, 1799) (1, 48, 150, 150) Count 193 Tasks 96 Chunks Type float16 numpy.ndarray",1  1  1799  1059  48,

Unnamed: 0,Array,Chunk
Bytes,174.42 MiB,2.06 MiB
Shape,"(1, 48, 1059, 1799)","(1, 48, 150, 150)"
Count,193 Tasks,96 Chunks
Type,float16,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,384 B,384 B
Shape,"(1, 48)","(1, 48)"
Count,3 Tasks,1 Chunks
Type,timedelta64[ns],numpy.ndarray
"Array Chunk Bytes 384 B 384 B Shape (1, 48) (1, 48) Count 3 Tasks 1 Chunks Type timedelta64[ns] numpy.ndarray",48  1,

Unnamed: 0,Array,Chunk
Bytes,384 B,384 B
Shape,"(1, 48)","(1, 48)"
Count,3 Tasks,1 Chunks
Type,timedelta64[ns],numpy.ndarray


In [9]:
hrrr.surface.forecast.temperature(vertical_level="1000mb")

Unnamed: 0,Array,Chunk
Bytes,174.42 MiB,2.06 MiB
Shape,"(1, 48, 1059, 1799)","(1, 48, 150, 150)"
Count,193 Tasks,96 Chunks
Type,float16,numpy.ndarray
"Array Chunk Bytes 174.42 MiB 2.06 MiB Shape (1, 48, 1059, 1799) (1, 48, 150, 150) Count 193 Tasks 96 Chunks Type float16 numpy.ndarray",1  1  1799  1059  48,

Unnamed: 0,Array,Chunk
Bytes,174.42 MiB,2.06 MiB
Shape,"(1, 48, 1059, 1799)","(1, 48, 150, 150)"
Count,193 Tasks,96 Chunks
Type,float16,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,384 B,384 B
Shape,"(1, 48)","(1, 48)"
Count,3 Tasks,1 Chunks
Type,timedelta64[ns],numpy.ndarray
"Array Chunk Bytes 384 B 384 B Shape (1, 48) (1, 48) Count 3 Tasks 1 Chunks Type timedelta64[ns] numpy.ndarray",48  1,

Unnamed: 0,Array,Chunk
Bytes,384 B,384 B
Shape,"(1, 48)","(1, 48)"
Count,3 Tasks,1 Chunks
Type,timedelta64[ns],numpy.ndarray


In [10]:
hrrr.surface.forecast.hail(vertical_level="entire_atmosphere")

Unnamed: 0,Array,Chunk
Bytes,174.42 MiB,2.06 MiB
Shape,"(1, 48, 1059, 1799)","(1, 48, 150, 150)"
Count,193 Tasks,96 Chunks
Type,float16,numpy.ndarray
"Array Chunk Bytes 174.42 MiB 2.06 MiB Shape (1, 48, 1059, 1799) (1, 48, 150, 150) Count 193 Tasks 96 Chunks Type float16 numpy.ndarray",1  1  1799  1059  48,

Unnamed: 0,Array,Chunk
Bytes,174.42 MiB,2.06 MiB
Shape,"(1, 48, 1059, 1799)","(1, 48, 150, 150)"
Count,193 Tasks,96 Chunks
Type,float16,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,384 B,384 B
Shape,"(1, 48)","(1, 48)"
Count,3 Tasks,1 Chunks
Type,timedelta64[ns],numpy.ndarray
"Array Chunk Bytes 384 B 384 B Shape (1, 48) (1, 48) Count 3 Tasks 1 Chunks Type timedelta64[ns] numpy.ndarray",48  1,

Unnamed: 0,Array,Chunk
Bytes,384 B,384 B
Shape,"(1, 48)","(1, 48)"
Count,3 Tasks,1 Chunks
Type,timedelta64[ns],numpy.ndarray


In [11]:
hrrr.surface.forecast.hail(vertical_level="surface")

Unnamed: 0,Array,Chunk
Bytes,174.42 MiB,2.06 MiB
Shape,"(1, 48, 1059, 1799)","(1, 48, 150, 150)"
Count,193 Tasks,96 Chunks
Type,float16,numpy.ndarray
"Array Chunk Bytes 174.42 MiB 2.06 MiB Shape (1, 48, 1059, 1799) (1, 48, 150, 150) Count 193 Tasks 96 Chunks Type float16 numpy.ndarray",1  1  1799  1059  48,

Unnamed: 0,Array,Chunk
Bytes,174.42 MiB,2.06 MiB
Shape,"(1, 48, 1059, 1799)","(1, 48, 150, 150)"
Count,193 Tasks,96 Chunks
Type,float16,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,384 B,384 B
Shape,"(1, 48)","(1, 48)"
Count,3 Tasks,1 Chunks
Type,timedelta64[ns],numpy.ndarray
"Array Chunk Bytes 384 B 384 B Shape (1, 48) (1, 48) Count 3 Tasks 1 Chunks Type timedelta64[ns] numpy.ndarray",48  1,

Unnamed: 0,Array,Chunk
Bytes,384 B,384 B
Shape,"(1, 48)","(1, 48)"
Count,3 Tasks,1 Chunks
Type,timedelta64[ns],numpy.ndarray
