In [11]:
import time

# pip/conda installed
import dask.array as da
import pandas as pd
import xarray as xr
from dask.distributed import Client

In [4]:
from utils.hls import HLSBand
from utils.hls import HLSCatalog
from utils.hls import scene_to_urls

## Setup necessary utility functions/classes

In [5]:
def create_multiband_dataset(row, bands, chunks):
    '''A function to load multiple bands into an xarray dataset adapted from https://github.com/scottyhq/cog-benchmarking/blob/master/notebooks/landsat8-cog-ndvi-mod.ipynb'''
    datasets = []
    for band, url in zip(bands, scene_to_urls(row['scene'], row['sensor'], bands)):
        da = xr.open_rasterio(url, chunks=chunks)
        da = da.squeeze().drop(labels='band')
        datasets.append(da.to_dataset(name=band))
    return xr.merge(datasets)

def create_timeseries_multiband_dataset(df, bands, chunks):
    '''For a single HLS tile create a multi-date, multi-band xarray dataset'''
    datasets = []
    for i,row in df.iterrows():
        try:
            ds = create_multiband_dataset(row, bands, chunks)
            datasets.append(ds)
        except Exception as e:
            print('ERROR loading, skipping acquistion!')
            print(e)
    DS = xr.concat(datasets, dim=pd.DatetimeIndex(df['dt'].tolist(), name='time'))
    print('Dataset size (Gb): ', DS.nbytes/1e9)
    return DS

In [6]:
def get_mask(qa_band):
    """Takes a data array HLS qa band and returns a mask of True where quality is good, False elsewhere
    Mask usage:
        ds.where(mask)
        
    Example:
        qa_mask = get_mask(dataset[HLSBand.QA])
        ds = dataset.drop_vars(HLSBand.QA)
        masked = ds.where(qa_mask)
    """
    def is_bad_quality(qa):
        cirrus = 0b1
        cloud = 0b10
        adjacent_cloud = 0b100
        cloud_shadow = 0b1000
        high_aerosol = 0b11000000

        return (qa & cirrus > 0) | (qa & cloud > 0) | (qa & adjacent_cloud > 0) | \
            (qa & cloud_shadow > 0) | (qa & high_aerosol == high_aerosol)
    return xr.where(is_bad_quality(qa_band), False, True)  # True where is_bad_quality is False, False where is_bad_quality is True

In [16]:
def calculate_tile_median(job_id, dataframe, bands, groupby):
    tile_ds = create_timeseries_multiband_dataset(dataframe, bands, chunks)
    # apply QA mask
    if HLSBand.QA in tile_ds.data_vars:
        qa_mask = get_mask(tile_ds[HLSBand.QA])
        tile_ds = (tile_ds
            .drop_vars(HLSBand.QA)  # drop QA band
            .where(qa_mask)  # Apply mask
        )
    (tile_ds
        .where(tile_ds != -1000)  # -1000 means no data - set those entries to nan
        .groupby(groupby)
        .median()
        .chunk({'month': 1, 'y': 3660, 'x': 3660})  # groupby + median changes chunk size...lets change it back
        .rename({var: var.name for var in tile_ds.data_vars})  # Rename vars from Enum to string for saving to zarr
        .to_zarr(f"{job_id}.zarr")
    )

In [8]:
catalog = HLSCatalog.from_zarr('fia10.zarr')
catalog.xr_ds

Unnamed: 0,Array,Chunk
Bytes,3.54 kB,3.54 kB
Shape,"(443,)","(443,)"
Count,2 Tasks,1 Chunks
Type,int64,numpy.ndarray
"Array Chunk Bytes 3.54 kB 3.54 kB Shape (443,) (443,) Count 2 Tasks 1 Chunks Type int64 numpy.ndarray",443  1,

Unnamed: 0,Array,Chunk
Bytes,3.54 kB,3.54 kB
Shape,"(443,)","(443,)"
Count,2 Tasks,1 Chunks
Type,int64,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,3.54 kB,3.54 kB
Shape,"(443,)","(443,)"
Count,2 Tasks,1 Chunks
Type,int64,numpy.ndarray
"Array Chunk Bytes 3.54 kB 3.54 kB Shape (443,) (443,) Count 2 Tasks 1 Chunks Type int64 numpy.ndarray",443  1,

Unnamed: 0,Array,Chunk
Bytes,3.54 kB,3.54 kB
Shape,"(443,)","(443,)"
Count,2 Tasks,1 Chunks
Type,int64,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,3.54 kB,3.54 kB
Shape,"(443,)","(443,)"
Count,2 Tasks,1 Chunks
Type,int64,numpy.ndarray
"Array Chunk Bytes 3.54 kB 3.54 kB Shape (443,) (443,) Count 2 Tasks 1 Chunks Type int64 numpy.ndarray",443  1,

Unnamed: 0,Array,Chunk
Bytes,3.54 kB,3.54 kB
Shape,"(443,)","(443,)"
Count,2 Tasks,1 Chunks
Type,int64,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,3.54 kB,3.54 kB
Shape,"(443,)","(443,)"
Count,2 Tasks,1 Chunks
Type,datetime64[ns],numpy.ndarray
"Array Chunk Bytes 3.54 kB 3.54 kB Shape (443,) (443,) Count 2 Tasks 1 Chunks Type datetime64[ns] numpy.ndarray",443  1,

Unnamed: 0,Array,Chunk
Bytes,3.54 kB,3.54 kB
Shape,"(443,)","(443,)"
Count,2 Tasks,1 Chunks
Type,datetime64[ns],numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,3.54 kB,3.54 kB
Shape,"(443,)","(443,)"
Count,2 Tasks,1 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 3.54 kB 3.54 kB Shape (443,) (443,) Count 2 Tasks 1 Chunks Type float64 numpy.ndarray",443  1,

Unnamed: 0,Array,Chunk
Bytes,3.54 kB,3.54 kB
Shape,"(443,)","(443,)"
Count,2 Tasks,1 Chunks
Type,float64,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,3.54 kB,3.54 kB
Shape,"(443,)","(443,)"
Count,2 Tasks,1 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 3.54 kB 3.54 kB Shape (443,) (443,) Count 2 Tasks 1 Chunks Type float64 numpy.ndarray",443  1,

Unnamed: 0,Array,Chunk
Bytes,3.54 kB,3.54 kB
Shape,"(443,)","(443,)"
Count,2 Tasks,1 Chunks
Type,float64,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,3.54 kB,3.54 kB
Shape,"(443,)","(443,)"
Count,2 Tasks,1 Chunks
Type,object,numpy.ndarray
"Array Chunk Bytes 3.54 kB 3.54 kB Shape (443,) (443,) Count 2 Tasks 1 Chunks Type object numpy.ndarray",443  1,

Unnamed: 0,Array,Chunk
Bytes,3.54 kB,3.54 kB
Shape,"(443,)","(443,)"
Count,2 Tasks,1 Chunks
Type,object,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,3.54 kB,3.54 kB
Shape,"(443,)","(443,)"
Count,2 Tasks,1 Chunks
Type,object,numpy.ndarray
"Array Chunk Bytes 3.54 kB 3.54 kB Shape (443,) (443,) Count 2 Tasks 1 Chunks Type object numpy.ndarray",443  1,

Unnamed: 0,Array,Chunk
Bytes,3.54 kB,3.54 kB
Shape,"(443,)","(443,)"
Count,2 Tasks,1 Chunks
Type,object,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,3.54 kB,3.54 kB
Shape,"(443,)","(443,)"
Count,2 Tasks,1 Chunks
Type,object,numpy.ndarray
"Array Chunk Bytes 3.54 kB 3.54 kB Shape (443,) (443,) Count 2 Tasks 1 Chunks Type object numpy.ndarray",443  1,

Unnamed: 0,Array,Chunk
Bytes,3.54 kB,3.54 kB
Shape,"(443,)","(443,)"
Count,2 Tasks,1 Chunks
Type,object,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,3.54 kB,3.54 kB
Shape,"(443,)","(443,)"
Count,2 Tasks,1 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 3.54 kB 3.54 kB Shape (443,) (443,) Count 2 Tasks 1 Chunks Type float64 numpy.ndarray",443  1,

Unnamed: 0,Array,Chunk
Bytes,3.54 kB,3.54 kB
Shape,"(443,)","(443,)"
Count,2 Tasks,1 Chunks
Type,float64,numpy.ndarray


In [9]:
# HLS data on Azure isn't tiled so we want to read the entire data once (each tile is 3660x3660)...
x_chunk = 3660
y_chunk = 3660
chunks = {'band': 1, 'x': x_chunk, 'y': y_chunk}

# get dask client
client = Client("tcp://127.0.0.1:45183")
client

0,1
Client  Scheduler: tcp://127.0.0.1:45183  Dashboard: http://127.0.0.1:8787/status,Cluster  Workers: 4  Cores: 8  Memory: 33.68 GB


In [17]:
# abstract to a function - pass in job function (calculate_tile_median), job id function (act on dataframe), groupby
grps = list(catalog.xr_ds.groupby('INDEX'))
for idx, ds in grps:
    df = ds.to_dataframe()
    job_id = idx
    print(f"Starting {job_id}")
    start = time.perf_counter()
    calculate_tile_median(job_id, df, catalog.xr_ds.attrs['bands'], 'time.month')
    print(f"{job_id} finished in {time.perf_counter()-start}")

Starting 2
Dataset size (Gb):  38.485618712
2 finished in 567.722782775003
