In [1]:
import os
import time

# pip/conda installed
import dask.array as da
import fsspec
import pandas as pd
import xarray as xr
from dask.distributed import Client

In [2]:
from utils.hls import HLSBand
from utils.hls import HLSCatalog
from utils.hls import scene_to_urls

## Setup necessary utility functions/classes

In [3]:
def create_multiband_dataset(row, bands, chunks):
    '''A function to load multiple bands into an xarray dataset adapted from https://github.com/scottyhq/cog-benchmarking/blob/master/notebooks/landsat8-cog-ndvi-mod.ipynb'''
    datasets = []
    for band, url in zip(bands, scene_to_urls(row['scene'], row['sensor'], bands)):
        da = xr.open_rasterio(url, chunks=chunks)
        da = da.squeeze().drop(labels='band')
        datasets.append(da.to_dataset(name=band))
    return xr.merge(datasets)

def create_timeseries_multiband_dataset(df, bands, chunks):
    '''For a single HLS tile create a multi-date, multi-band xarray dataset'''
    datasets = []
    for i,row in df.iterrows():
        try:
            ds = create_multiband_dataset(row, bands, chunks)
            datasets.append(ds)
        except Exception as e:
            print('ERROR loading, skipping acquistion!')
            print(e)
    DS = xr.concat(datasets, dim=pd.DatetimeIndex(df['dt'].tolist(), name='time'))
    print('Dataset size (Gb): ', DS.nbytes/1e9)
    return DS

In [4]:
def get_mask(qa_band):
    """Takes a data array HLS qa band and returns a mask of True where quality is good, False elsewhere
    Mask usage:
        ds.where(mask)
        
    Example:
        qa_mask = get_mask(dataset[HLSBand.QA])
        ds = dataset.drop_vars(HLSBand.QA)
        masked = ds.where(qa_mask)
    """
    def is_bad_quality(qa):
        cirrus = 0b1
        cloud = 0b10
        adjacent_cloud = 0b100
        cloud_shadow = 0b1000
        high_aerosol = 0b11000000

        return (qa & cirrus > 0) | (qa & cloud > 0) | (qa & adjacent_cloud > 0) | \
            (qa & cloud_shadow > 0) | (qa & high_aerosol == high_aerosol)
    return xr.where(is_bad_quality(qa_band), False, True)  # True where is_bad_quality is False, False where is_bad_quality is True

In [5]:
def process_catalog(catalog, catalog_groupby, job_fn, job_groupby, account_name, storage_container, account_key):
    """Process a catalog.
    
    Args:
        catalog (HLSCatalog): catalog to process
        catalog_groupby (str): column to group the catalog in to jobs by (e.g. 'INDEX', 'year')
        job_fn: a function to apply to each job from the grouped catalog (e.g. `calculate_tile_median`)
        job_groupby (str): how to group data built within each job (e.g. 'time.month', 'time.year')
    """
    grps = list(catalog.xr_ds.groupby(catalog_groupby))
    
    for idx, ds in grps:
        df = ds.to_dataframe()
        job_id = idx
        write_store = fsspec.get_mapper(
            f"az://{storage_container}/{job_id}.zarr",
            account_name=account_name,
            account_key=account_key
        )
        print(f"Starting {job_id}")
        start = time.perf_counter()
        # compute job and write to Azure blob storage
        job_fn(job_id, df, catalog.xr_ds.attrs['bands'], job_groupby).to_zarr(write_store)
        print(f"{job_id} finished in {time.perf_counter()-start}")

In [6]:
def calculate_tile_median(job_id, dataframe, bands, groupby):
    tile_ds = create_timeseries_multiband_dataset(dataframe, bands, chunks)
    # apply QA mask
    if HLSBand.QA in tile_ds.data_vars:
        qa_mask = get_mask(tile_ds[HLSBand.QA])
        tile_ds = (tile_ds
            .drop_vars(HLSBand.QA)  # drop QA band
            .where(qa_mask)  # Apply mask
        )
    return (tile_ds
        .where(tile_ds != -1000)  # -1000 means no data - set those entries to nan
        .groupby(groupby)
        .median()
        .chunk({'month': 1, 'y': 3660, 'x': 3660})  # groupby + median changes chunk size...lets change it back
        .rename({var: var.name for var in tile_ds.data_vars})  # Rename vars from Enum to string for saving to zarr
    )

In [7]:
# HLS data on Azure isn't tiled so we want to read the entire data once (each tile is 3660x3660)...
x_chunk = 3660
y_chunk = 3660
chunks = {'band': 1, 'x': x_chunk, 'y': y_chunk}

In [8]:
from dask_gateway import GatewayCluster

cluster = GatewayCluster(worker_cores=2, worker_memory=8)
cluster.adapt(minimum=1, maximum=50)

cluster

VBox(children=(HTML(value='<h2>GatewayCluster</h2>'), HBox(children=(HTML(value='\n<div>\n<style scoped>\n    …

In [10]:
client = cluster.get_client()
client

0,1
Client  Scheduler: gateway://traefik-dhub-dask-gateway.default:80/default.06546f4ba2084c50871c09b1606c0038  Dashboard: /services/dask-gateway/clusters/default.06546f4ba2084c50871c09b1606c0038/status,Cluster  Workers: 0  Cores: 0  Memory: 0 B


In [11]:
os.environ['AZURE_ACCOUNT_KEY'] = "{YOUR_STORAGE_ACCOUNT_KEY}"

In [12]:
catalog_url = fsspec.get_mapper(
    f"az://fia/catalogs/fia10.zarr",
    account_name="usfs",
    account_key=os.environ['AZURE_ACCOUNT_KEY']
)
point_catalog = HLSCatalog.from_zarr(catalog_url)

In [None]:
process_catalog(
    catalog=point_catalog,
    catalog_groupby='INDEX',
    job_fn=calculate_tile_median,
    job_groupby='time.month',
    account_name="usfs",
    storage_container="fia/hls",
    account_key=os.environ["AZURE_ACCOUNT_KEY"],
)