In [None]:
import os

import fiona.transform
import fsspec
import xarray as xr
import pandas as pd
from affine import Affine

from utils.dask import upload_source
from utils.dask import create_cluster

from utils.hls.compute import process_catalog
from utils import get_logger

from utils.hls import catalog
from utils.hls import compute

In [None]:
logger = get_logger('hls-az')

In [None]:
az_zr = fsspec.get_mapper(
    f"az://fia/catalogs/fia_az_2015-2019.zarr",
    account_name="usfs",
    account_key=os.environ['AZURE_ACCOUNT_KEY']
)
ds_az = xr.open_zarr(az_zr)

In [None]:
df_az = ds_az.to_dataframe()

In [None]:
df_az.drop_duplicates(subset=['INDEX'], inplace=True)

In [None]:
df_az_500 = df_az[500:1000]

In [None]:
def chip_tiles(job_id, job_df, job_groupby, chunks, write_store, client):
    """A job compatible with `process_catalog` which .
    Args:
        job_id (str): Id of the job, used for tracking purposes
        job_df (pandas.Dataframe): Dataframe of  to include in the computation
        job_groupby (str): How to group the dataset produced from the dataframe (e.g. "time.month")
        bands (List[HLSBand]): List of HLSBand objects to compute median reflectance on
        band_names (List[str]): List of band name strings
        qa_band_name (str): Name of the QA band to use for masking
        chunks (Dict[str, int]): How to chunk HLS input data
        write_store (fsspec.FSMap): The location to write any results
        client (dask.distributed.Client): Dask cluster client to submit tasks to
    Returns:
        dask.distributed.Future: Future for the computation that is being done, can be waited on.
    """
    def chip(ds, lat, lon, chip_size):
        CRS = "EPSG:4326"
        tfm = Affine(*ds.attrs['transform'])
        ([x], [y]) = fiona.transform.transform(
            CRS, ds.attrs['crs'], [lon], [lat]
        )
        print(x, y)
        x_idx, y_idx = [round(coord) for coord in ~tfm * (x, y)]
        print(x_idx, y_idx)
        half_chip = int(chip_size/2)
        return ds[dict(x=range(x_idx-half_chip, x_idx+half_chip), y=range(y_idx-half_chip, y_idx+half_chip))].drop_vars(['COASTAL_AEROSOL', 'CIRRUS'])

    chip_size = 32
    year_points = list(job_df.groupby("year"))
    futures = []
    for year, points in year_points:
        # read the zarr for the job_id (tile) + year
        zr = fsspec.get_mapper(
            f"az://fia/hls/{float(year)}/{job_id}.zarr",
            account_name="usfs",
            account_key=os.environ['AZURE_ACCOUNT_KEY']
        )
        print(year)
        ds = client.submit(xr.open_zarr, zr, chunks=chunks)
        # calculate chip for each point in points
        for _, point_row in points.iterrows():
            lat = point_row['lat']
            lon = point_row['lon']
            sample = client.submit(chip, ds, lat, lon, chip_size)
            write_store = fsspec.get_mapper(
                f"az://fia/chips-test6/hls/az/HLSAZ{int(year)}{int(point_row['INDEX'])}.zarr",
                account_name="usfs",
                account_key=os.environ['AZURE_ACCOUNT_KEY']
            )
            futures.append(client.submit(compute.save_to_zarr, sample, write_store, 'w', point_row['INDEX']))
    return client.submit(lambda x: job_id, futures)

In [None]:
num_workers = 32
cluster = create_cluster(
    workers=num_workers,
    worker_threads=1,
    worker_memory=4,
    scheduler_threads=1,
    scheduler_memory=8
)
client = cluster.get_client()
cluster

In [None]:
logger.info("Waiting for cluster workers to start")
client.wait_for_workers(num_workers)
logger.info("Uploading code to workers")
upload_source('./utils', client)

In [None]:
catalog_groupby = "tile"
job_groupby = "year"
chunks = {'month': 1, 'x': 3660, 'y': 3660}
process_catalog(
    catalog=df_az_500,
    catalog_groupby=catalog_groupby,
    job_fn=chip_tiles,
    job_groupby=job_groupby,
    chunks=chunks,
    account_name="usfs",
    storage_container="fia/chips-test6/hls/az",
    account_key=os.environ['AZURE_ACCOUNT_KEY'],
    client=client,
    concurrency=4,
    checkpoint_path="./checkpoint_file.txt",
    logger=logger
)

In [None]:
cluster.shutdown()