In [1]:
import os
import time

# pip/conda installed
import dask.array as da
import fsspec
import pandas as pd
import xarray as xr
from dask.distributed import as_completed
from dask.distributed import Client
from dask_gateway import GatewayCluster

In [2]:
from utils.hls import HLSBand
from utils.hls import HLSCatalog
from utils.hls import scene_to_urls

## Setup necessary utility functions/classes

In [3]:
def get_mask(qa_band):
    """Takes a data array HLS qa band and returns a mask of True where quality is good, False elsewhere
    Mask usage:
        ds.where(mask)
        
    Example:
        qa_mask = get_mask(dataset[HLSBand.QA])
        ds = dataset.drop_vars(HLSBand.QA)
        masked = ds.where(qa_mask)
    """
    def is_bad_quality(qa):
        cirrus = 0b1
        cloud = 0b10
        adjacent_cloud = 0b100
        cloud_shadow = 0b1000
        high_aerosol = 0b11000000

        return (qa & cirrus > 0) | (qa & cloud > 0) | (qa & adjacent_cloud > 0) | \
            (qa & cloud_shadow > 0) | (qa & high_aerosol == high_aerosol)
    return xr.where(is_bad_quality(qa_band), False, True)  # True where is_bad_quality is False, False where is_bad_quality is True

In [4]:
def fetch_band_url(tpl):
    """Fetch a given url with xarray, creating a dataset with a single data variable of the band name for the url.
    
    Args:
        tpl (Tuple[str, str]): tuple of the form (band, url) - the url to fetch and the band name for the data variable
        
    Returns:
        xarray.Dataset: Dataset for the given HLS scene url with the data variable being named the given band
        
    """
    band, url = tpl
    da = xr.open_rasterio(url, chunks=chunks)
    da = da.squeeze().drop(labels='band')
    return da.to_dataset(name=band)

def compute_tile_median(job_id, ds, groupby, qa_name, write_store):
    """Compute QA-band-masked {groupby} median reflectance for the given dataset and save the result as zarr to `write_store`.
    
    Args:
        job_id (str): The job_id of the tile being computed
        ds (xarray.Dataset): Dataset to compute on
        groupby (str): How to group the dataset (e.g. "time.month")
        qa_name (str): Name of the QA band to use for masking
        write_store (fsspec.FSMap): The location to write the zarr
    
    Returns:
        str: The job_id that was computed and written
        
    """
    # apply QA mask
    if qa_name in ds.data_vars:
        qa_mask = get_mask(ds[qa_name])
        ds = (ds
            .drop_vars(qa_name)  # drop QA band
            .where(qa_mask)  # Apply mask
        )
    zarr = (ds
        .where(ds != -1000)  # -1000 means no data - set those entries to nan
        .groupby(groupby)
        .median()
        .chunk({'month': 1, 'y': 3660, 'x': 3660})  # groupby + median changes chunk size...lets change it back
        .to_zarr(write_store, mode='w')
    )
    return job_id

def calculate_job_median(job_id, job_df, job_groupby, bands, band_names, qa_band_name, write_store, client):
    """A job compatible with `process_catalog` which computes per-band median reflectance for the input job_df.
    
    Args:
        job_id (str): Id of the job, used for tracking purposes
        job_df (pandas.Dataframe): Dataframe of scenes to include in the computation
        job_groupby (str): How to group the dataset produced from the dataframe (e.g. "time.month")
        bands (List[HLSBand]): List of HLSBand objects to compute median reflectance on
        band_names (List[str]): List of band name strings
        qa_band_name (str): Name of the QA band to use for masking
        write_store (fsspec.FSMap): The location to write any results
        client (dask.distributed.Client): Dask cluster client to submit tasks to
        
    Returns:
        dask.distributed.Future: Future for the computation that is being done, can be waited on.
        
    """
    scene_ds_futures = []
    for _, row in job_df.iterrows():
        scenes = scene_to_urls(row['scene'], row['sensor'], bands)
        # list of datasets that need to be xr.merge'd (future)
        band_ds_futures = client.map(fetch_band_url, list(zip(band_names, scenes)), priority=0)
        # single dataset with every band (future)
        scene_ds_futures.append(client.submit(xr.merge, band_ds_futures, priority=1))
    # dataset of a single index/tile with a data var for every band and dimensions: x, y, time
    job_ds_future = client.submit(lambda scene_futures: xr.concat(scene_futures, dim=pd.DatetimeIndex(job_df['dt'].tolist(), name='time')), scene_ds_futures, priority=2)
    # compute masked, monthly, median per band per pixel
    return client.submit(compute_tile_median, job_id, job_ds_future, job_groupby, qa_band_name, write_store, priority=3)

def process_catalog(
    catalog,
    catalog_groupby,
    job_fn,
    job_groupby,
    account_name,
    storage_container,
    account_key,
    client,
):
    """Process a catalog.
    
    Args:
        catalog (HLSCatalog): catalog to process
        catalog_groupby (str): column to group the catalog in to jobs by (e.g. 'INDEX', 'tile')
        job_fn: a function to apply to each job from the grouped catalog (e.g. `calculate_job_median`)
        job_groupby (str): how to group data built within each job (e.g. 'time.month', 'time.year')
        account_name (str): Azure storage account to write results to
        storage_container (str): Azure storage container within the `account_name` to write results to
        account_key (str): Azure account key for the `account_name` which results are written to
        client (dask.distributed.Client): Dask cluster client to submit tasks to
        
    """
    bands = point_catalog.xr_ds.attrs['bands']
    band_names = [band.name for band in bands]
    qa_band_name = HLSBand.QA.name

    df = catalog.xr_ds.to_dataframe()
    job_futures = []
    start_time = time.perf_counter()
    
    for job_id, job_df in df.reset_index().groupby(catalog_groupby):
        write_store = fsspec.get_mapper(
            f"az://{storage_container}/{job_id}.zarr",
            account_name=account_name,
            account_key=account_key
        )
        job_futures.append(
            job_fn(job_id, job_df, job_groupby, bands, band_names, qa_band_name, write_store, client)
        )
    for future in as_completed(job_futures):
        print(future.result())
    print(f"{len(job_futures)} completed in {time.perf_counter()-start_time} seconds")

In [5]:
# HLS data on Azure isn't tiled so we want to read the entire data once (each tile is 3660x3660)...
x_chunk = 3660
y_chunk = 3660
chunks = {'band': 1, 'x': x_chunk, 'y': y_chunk}

In [6]:
# fill with your account key
os.environ['AZURE_ACCOUNT_KEY'] = ""

In [7]:
catalog_url = fsspec.get_mapper(
    f"az://fia/catalogs/fia10.zarr",
    account_name="usfs",
    account_key=os.environ['AZURE_ACCOUNT_KEY']
)
point_catalog = HLSCatalog.from_zarr(catalog_url)

In [8]:
account_name="usfs"
storage_container="fia/hls-testing"
account_key=os.environ["AZURE_ACCOUNT_KEY"]
catalog_groupby = "INDEX"
job_groupby = "time.month"

with GatewayCluster(worker_cores=2, worker_memory=8) as cluster:
    print(f"Cluster dashboard visible at: {cluster.dashboard_link}")
    cluster.scale(16)
    client = cluster.get_client()
    process_catalog(point_catalog, catalog_groupby, calculate_job_median, job_groupby, account_name, storage_container, account_key, client)

Cluster dashboard visible at: /services/dask-gateway/clusters/default.8f095b12e15e4aeb969e05609e1506db/status
8
5
2
3 completed in 349.21651702599775 seconds


## TODO

1. Do QA on results
1. Is COG data tiled now?