In [None]:
import datetime
import copy
import xarray as xr
import numpy as np
import pandas as pd
import fsspec
import kerchunk
from kerchunk.grib2 import scan_grib, grib_tree
import gcsfs
import datatree

# This could be generalized to any gridded FMRC dataset but right now it works with NOAA's Grib2 files
import dynamic_zarr_store


## Extract the zarr store metadata


Pick a file, any file... Must be on GCS so that coords use the same file store as the data vars

In [None]:
%%time
# Pick two files to build a grib_tree with the correct dimensions
gfs_files = [
    "s3://ecmwf-forecasts/20240229/00z/ifs/0p25/enfo/20240229000000-0h-enfo-ef.grib2",
    "s3://ecmwf-forecasts/20240229/00z/ifs/0p25/enfo/20240229000000-3h-enfo-ef.grib2"
]

# This operation reads two of the large Grib2 files from GCS
# scan_grib extracts the zarr kerchunk metadata for each individual grib message
# grib_tree builds a zarr/xarray compatible hierarchical view of the dataset
gfs_grib_tree_store = grib_tree([group for f in gfs_files for group in scan_grib(f)])
# it is slow even in parallel because it requires a huge amount of IO
#CPU times: user 3min 37s, sys: 17.1 s, total: 3min 55s
#Wall time: 50min 58s

In [None]:
%%time
# The grib_tree can be opened directly using either zarr or xarray datatree
# But this is too slow to build big aggregations
gfs_dt = datatree.open_datatree(fsspec.filesystem("reference", fo=gfs_grib_tree_store).get_mapper(""), engine="zarr", consolidated=False)
gfs_dt
#CPU times: user 130 ms, sys: 28.2 ms, total: 158 ms
#Wall time: 154 ms

## Separating static metadata from the chunk indexes

In [None]:
%%time
# The key metadata associated with each grib message can be extracted into a table
gfs_kind = dynamic_zarr_store.extract_datatree_chunk_index(gfs_dt, gfs_grib_tree_store, grib=True)
gfs_kind

In [None]:
%%time
# While the static zarr metadata associated with the dataset can be seperated - created once.
deflated_gfs_grib_tree_store = copy.deepcopy(gfs_grib_tree_store)
dynamic_zarr_store.strip_datavar_chunks(deflated_gfs_grib_tree_store)


print("Original references: ", len(gfs_grib_tree_store["refs"]))
print("Stripped references: ", len(deflated_gfs_grib_tree_store["refs"]))


#Original references:  1006
#Stripped references:  453
#CPU times: user 1.83 ms, sys: 0 ns, total: 1.83 ms
#Wall time: 1.83 ms

In [None]:
%%time
# We can pull this out into a dataframe, that starts to look a bit like what we got above extracted from the actual grib files
# But this method runs in under a second reading a file that is less than 100k
idxdf = dynamic_zarr_store.parse_grib_idx(
    fs=fsspec.filesystem("s3"),
    basename="s3://ecmwf-forecasts/20240229/00z/ifs/0p25/enfo/20240229000000-3h-enfo-ef",
    suffix="index",
)
idxdf

In [None]:
%%time
# Unfortunately, some accumulation variables have duplicate attributes making them
# indesinguishable from the IDX file
idxdf.loc[idxdf['attrs'].duplicated(keep=False), :]


In [None]:
%%time
# What we need is a mapping from our grib/zarr metadata to the attributes in the idx files
# They are unique for each time horizon e.g. you need to build a unique mapping for the 1 hour
# forecast, the 2 hour forecast... the 48 hour forecast.

# let's make one for the 6 hour horizon. This requires reading both the grib and the idx file,
# mapping the data for each grib message in order
#took 2 minutes for one 

mapping = dynamic_zarr_store.build_idx_grib_mapping(
    fs=fsspec.filesystem("s3"),
    basename="s3://ecmwf-forecasts/20240229/00z/ifs/0p25/enfo/20240229000000-3h-enfo-ef.grib2",
)
mapping

In [None]:
%%time
# Now if we parse the RunTime from the idx file name `gfs.20230901/00/`
# We can build a fully compatible k_index
mapped_index = dynamic_zarr_store.map_from_index(
    pd.Timestamp("2023-09-01T00"),
    mapping.loc[~mapping["attrs"].duplicated(keep="first"), :],
    idxdf.loc[~idxdf["attrs"].duplicated(keep="first"), :]
)
mapped_index

In [None]:
%%time
mapped_index_list = []

deduped_mapping = mapping.loc[~mapping["attrs"].duplicated(keep="first"), :]
for date in pd.date_range("2023-09-01", "2023-09-30"):
  for runtime in range(0,24,6):
    horizon=6
    fname=f"gs://global-forecast-system/gfs.{date.strftime('%Y%m%d')}/{runtime:02}/atmos/gfs.t{runtime:02}z.pgrb2.0p25.f{horizon:03}"

    idxdf = dynamic_zarr_store.parse_grib_idx(
        fs=fsspec.filesystem("gcs"),
        basename=fname
    )

    mapped_index = dynamic_zarr_store.map_from_index(
        pd.Timestamp( date + datetime.timedelta(hours=runtime)),
        deduped_mapping,
        idxdf.loc[~idxdf["attrs"].duplicated(keep="first"), :],
    )
    mapped_index_list.append(mapped_index)

gfs_kind = pd.concat(mapped_index_list)
gfs_kind


## We just aggregated a 120 GFS grib files in 18 seconds!

Lets build it back into a data tree!

The reinflate_grib_store interface is pretty opaque but allows building any slice of an FMRC. A good area for future improvement, but for now, since
we have just a single 6 hour horizon slice let's build that...

In [None]:
axes = [
  pd.Index(
    [
      pd.timedelta_range(start="0 hours", end="6 hours", freq="6h", closed="right", name="6 hour"),
    ],
    name="step"
  ),
  pd.date_range("2023-09-01T06:00", "2023-10T00:00", freq="360min", name="valid_time")
]
axes

In [None]:
%%time
# It is fast to rebuild the datatree - but lets pull out two varables to look at...
gfs_store = dynamic_zarr_store.reinflate_grib_store(
    axes=axes,
    aggregation_type=dynamic_zarr_store.AggregationType.HORIZON,
    chunk_index=gfs_kind.loc[gfs_kind.varname.isin(["dswrf", "t2m"])],
    zarr_ref_store=deflated_gfs_grib_tree_store
)

In [None]:
gfs_dt = datatree.open_datatree(fsspec.filesystem("reference", fo=gfs_store).get_mapper(""), engine="zarr", consolidated=False)
gfs_dt

In [None]:
%%time
gfs_dt.dswrf.avg.surface.dswrf[0,0:10,300,400].compute()

In [None]:
%%time
# Reading the data - especially extracting point time series isn't any faster once you have
# the xarray datatree. This is just a much faster way of building the aggregations than
# directly running scan_grib over all the data first.
gfs_dt.dswrf.avg.surface.dswrf[0,0:10,300,400].compute()

In [None]:
gfs_dt.dswrf.avg.surface.dswrf[0,1,:,:].plot(figsize=(12,8))

In [None]:
gfs_dt.dswrf.avg.surface.dswrf[0,2,:,:].plot(figsize=(12,8))

In [None]:
gfs_dt.dswrf.avg.surface.dswrf[0,3,:,:].plot(figsize=(12,8))

# Timeseries

In [None]:
from joblib import parallel_config

In [None]:
%%time
with parallel_config(n_jobs=8):
    #res = gfs_dt.dswrf.avg.surface.dswrf.interp(longitude=[320.5, 300.2], latitude=[20.6, 45.7], method="linear")
    res = gfs_dt.dswrf.avg.surface.dswrf.interp(longitude=[320.5], latitude=[20.6], method="linear")


In [None]:
res.plot()


In [None]:
%%time
with parallel_config(n_jobs=8):
    res = gfs_dt.dswrf.avg.surface.dswrf.interp(longitude=[320.5, 300.2], latitude=[20.6, 45.7], method="linear")

In [None]:
res

In [None]:
%%time
with parallel_config(n_jobs=8):
    res = gfs_dt.dswrf.avg.surface.compute()

In [None]:
res