In [None]:
import tempfile
import warnings

import dask
import icechunk
import matplotlib.style as mplstyle
import numpy as np
import xarray as xr
from distributed import Client
from icechunk.distributed import merge_sessions

from ocr.chunking_config import ChunkingConfig

mplstyle.use('fast')

In [None]:
client = Client(n_workers=12)
client

In [None]:
config = ChunkingConfig()
config

In [None]:
storage = icechunk.s3_storage(
    bucket='carbonplan-ocr',
    prefix='input/fire-risk/tensor/USFS/RDS-2022-0016-02_all_vars_merge_icechunk',
    from_env=True,
)


repo = icechunk.Repository.open(storage)
session = repo.readonly_session('main')
ds = xr.open_zarr(session.store, consolidated=False, chunks={})[['BP']]
ds['BP'] = ds['BP'].astype('float32')
ds['BP'].encoding = {}

In [None]:
zarr_chunks = config.chunks
template = xr.Dataset(config.ds.coords).drop_vars('spatial_ref')
template['BP'] = xr.DataArray(
    dask.array.zeros(
        (config.ds.sizes['y'], config.ds.sizes['x']),
        dtype='float32',
        chunks=(config.chunks['y'], config.chunks['x']),
    ),
    dims=('y', 'x'),
)

In [None]:
storage = icechunk.local_filesystem_storage(tempfile.TemporaryDirectory().name)
repo = icechunk.Repository.create(storage)
session = repo.writable_session('main')

template.to_zarr(
    session.store,
    compute=False,
    mode='w',
    encoding={
        'BP': {'chunks': ((config.chunks['y'], config.chunks['x'])), 'fill_value': np.nan}
    },  # IMPORTANT
    consolidated=False,
)

session.commit('template')

In [None]:
def get_commit_messages_ancestry(repo: icechunk.repository) -> list:
    return [commit.message for commit in list(repo.ancestry(branch='main'))]


@dask.delayed
def insert_region(session: icechunk.Session, subset_ds: xr.Dataset):
    subset_ds.to_zarr(
        session.store,
        region='auto',
        consolidated=False,
    )
    return session


def write_regions(ds: xr.Dataset, session: icechunk.Session, region_dict: dict):
    commit_messages = get_commit_messages_ancestry(repo)
    already_commited_messages = [
        msg
        for message in commit_messages
        for msg in (message.split(',') if ',' in message else [message])
    ]
    uncommited_dict = {
        key: subset for key, subset in region_dict.items() if key not in already_commited_messages
    }
    if not uncommited_dict:
        # maybe add logging
        warnings.warn(f'No new chunks to commit!: {uncommited_dict}')
    else:
        # IMPORTANT! we need to pass in subsets, not the entire dataset to get pickled.
        ds_subsets_uncommited = [
            ds.isel(x=x_slice, y=y_slice) for x_slice, y_slice in uncommited_dict.values()
        ]

        with session.allow_pickling():
            tasks = [
                insert_region(session=session, subset_ds=subset_ds)
                for subset_ds in ds_subsets_uncommited
            ]
            # we could persist or w/e here
            sessions = dask.compute(*tasks, scheduler=client)

        # grabs only the dict keys / region_ids
        region_ids = [key for key in uncommited_dict.keys()]
        commit_region_ids = ','.join(region_ids)

        session = merge_sessions(session, *sessions)
        session.commit(f'{commit_region_ids}')

# Write CA BBOX chunks

In [None]:
ca_bbox = config.bbox_from_wgs84(-125.277100, 32.374502, -113.961182, 41.951126)
ca_chunks = config.get_chunks_for_bbox(ca_bbox)
chunk_slices_ca = config.chunks_to_slices(ca_chunks)

repo = icechunk.Repository.open(storage)
session = repo.writable_session('main')
write_regions(ds=ds, session=session, region_dict=chunk_slices_ca)

In [None]:
repo = icechunk.Repository.open(storage)
session = repo.readonly_session('main')
rtds = xr.open_zarr(session.store, consolidated=False, chunks={})[['BP']]
rtds.isel(y=slice(0, 90000), x=slice(0, 120000)).coarsen(x=10, y=10, boundary='trim').mean()[
    'BP'
].plot()

## Write OR BBOX chunks 


In [None]:
or_bbox = config.bbox_from_wgs84(-124.958496, 41.963324, -116.477051, 46.208322)

or_chunks = config.get_chunks_for_bbox(or_bbox)
chunk_slices_or = config.chunks_to_slices(or_chunks)


repo = icechunk.Repository.open(storage)
session = repo.writable_session('main')
write_regions(ds=ds, session=session, region_dict=chunk_slices_or)

In [None]:
repo = icechunk.Repository.open(storage)
session = repo.readonly_session('main')
rtds = xr.open_zarr(session.store, consolidated=False, chunks={})[['BP']]
rtds.isel(y=slice(0, 90000), x=slice(0, 120000)).coarsen(x=10, y=10, boundary='trim').mean()[
    'BP'
].plot()

### RT

In [None]:
# TODO: can we get config to return all slices / return all chunks?