# Building a large GPM-IMERG Virtual Dataset with Dask

## Define Functions

In [None]:
from datetime import datetime, timedelta
import pandas as pd
import coiled
from dask import compute
import dask.bag as db
import itertools

In [None]:
base_url = "s3://gesdisc-cumulus-prod-protected/GPM_L3/GPM_3IMERGHH.07"

def make_url(date: datetime) -> str:
    """Create an S3 URL for a specific datateime"""
    
    end_date = date + timedelta(minutes=29, seconds=59)
    base_date = datetime(year=date.year, month=date.month, day=date.day, hour=0, minute=0, second=0)
    delta_minutes = (date - base_date) // timedelta(minutes=1)
    components = [
        base_url,
        "{:04d}".format(date.year),
        date.strftime('%j'),  # day of year
        (
            "3B-HHR.MS.MRG.3IMERG." +
            date.strftime("%Y%m%d") +
            "-S" + date.strftime("%H%M%S") +
            "-E" + end_date.strftime("%H%M%S") +
            ".{:04d}".format(delta_minutes) +
            ".V07B.HDF5"
        )
    ]
    return '/'.join(components)


In [None]:
def hours_for_day(day):
    assert day.hour == day.minute == day.second == 0
    return pd.date_range(start=day, periods=48, freq="30min")

def get_info_for_day(day):
    return [get_info(make_url(full_datetime)) for full_datetime in hours_for_day(day)]

In [None]:
def open_virtual(url, keep_coords=True):
    from virtualizarr.readers.hdf import HDFVirtualBackend
    from virtualizarr import open_virtual_dataset

    drop_variables = ["Intermediate", "nv", "lonv", "latv"]
    all_coords = ["time", "lon", "lat", "time_bnds", "lon_bnds", "lat_bnds"]
    min_coords = ["time", "time_bnds"]

    if keep_coords:
        my_drop_variables = drop_variables
        loadable_variables = all_coords
        my_coords = all_coords
    else:
        my_drop_variables = drop_variables + list(set(all_coords) - set(min_coords))
        loadable_variables = min_coords
        my_coords = min_coords
        
    ds = open_virtual_dataset(
        url, indexes={}, group="Grid", backend=HDFVirtualBackend,
        drop_variables=my_drop_variables,
        loadable_variables=loadable_variables
    ).set_coords(my_coords)
    return ds

In [None]:
def reduce_via_concat(dsets):
    import xarray as xr
    return xr.concat(dsets, dim="time", coords="minimal", join="override")

In [None]:
from xarray.backends.zarr import FillValueCoder

def fix_ds(ds):
    """Fix fill-value encoding of GPM IMERG data variables"""
    
    ds = ds.copy()
    coder = FillValueCoder()
    # promote fill value to attr for zarr V3
    for dvar in ds.data_vars:
        dtype = ds[dvar].dtype
        # this is wrong due to bug in Sean's reader
        #fill_value = dtype.type(ds_concat[dvar].data.zarray.fill_value)
        fill_value = dtype.type(ds[dvar].attrs['CodeMissingValue'])
        encoded_fill_value = coder.encode(fill_value, dtype)
        ds[dvar].attrs['_FillValue'] = encoded_fill_value
    
    return ds

In [None]:
def dset_for_year(year):
    all_days = pd.date_range(start=f"{year}-01-01", end=f"{year}-12-31", freq="1D")
    all_times = list(itertools.chain(*[hours_for_day(day) for day in all_days]))

    b = db.from_sequence(all_times, partition_size=48)
    all_urls = db.map(make_url, b)
    vdsets = db.map(open_virtual, all_urls)
    concatted = vdsets.reduction(reduce_via_concat, reduce_via_concat)
    ds = concatted.compute()
    return fix_ds(ds)

## Do Computations

In [None]:
cluster = coiled.Cluster(
    software="icechunk-virtualizarr",
    region="us-west-2",
    n_workers=100,
)
cluster.send_private_envs({"ARRAYLAKE_TOKEN": "***"})  # fill in appropriately
client = cluster.get_client()

In [None]:
from arraylake import Client
aclient = Client()

### Create the repo and write the first year

In [None]:
ic_repo = aclient.get_or_create_repo("nasa-impact/GPM_3IMERGHH.07-virtual-full", kind="icechunk")
ic_repo

In [None]:
ds_1998 = dset_for_year(1998)
ds_1998.virtualize.to_icechunk(ic_repo)

In [None]:
ic_repo.commit("Wrote 1998")

### Compute and Append Subsequent Years

This starts okay, but each subsequent append takes more and more memory.
Stops working around 2009.

In [None]:
for year in range(1999, 2024):
    print(year)
    ds_year = dset_for_year(year)
    ds_year.virtualize.to_icechunk(ic_repo, append_dim="time")
    cid = ic_repo.commit(f"Appended {year}")
    print(cid)