In [2]:
# pip/conda installed
import dask.array as da
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import rtree
import xarray as xr
from dask.distributed import Client

In [3]:
from utils.hls import HLSBand
from utils.hls import HLSCatalog
from utils.hls import HLSTileLookup
from utils.hls import fia_csv_to_data_catalog_input
from utils.hls import scene_to_urls

## Setup necessary utility functions/classes

In [4]:
lookup = HLSTileLookup()

Reading tile extents...
Read tile extents for 56686 tiles


In [5]:
point_catalog = HLSCatalog.from_point_pandas(df=fia_csv_to_data_catalog_input('./fia_10.csv'), bands=[HLSBand.NIR_NARROW, HLSBand.QA], tile_lookup=lookup)

In [6]:
point_catalog.xr_ds

In [7]:
%%time
bbox = [-124.98046874999999, 24.367113562651262, -66.70898437499999, 49.49667452747045]
years = [2019]
bands=[HLSBand.NIR_NARROW, HLSBand.QA]
# bbox_catalog = HLSCatalog.from_bbox(bbox, years, landsat_bands, sentinel_bands, lookup)

CPU times: user 0 ns, sys: 6 µs, total: 6 µs
Wall time: 8.58 µs


In [8]:
def create_multiband_dataset(row, bands, chunks):
    '''A function to load multiple bands into an xarray dataset '''
    
    # Each image is a dataset containing both band4 and band5
    datasets = []
    for band, url in zip(bands, scene_to_urls(row['scene'], row['sensor'], bands)):
        da = xr.open_rasterio(url, chunks=chunks)
        da = da.squeeze().drop(labels='band')
        ds = da.to_dataset(name=band)
        datasets.append(ds)
    return xr.merge(datasets)

def create_timeseries_multiband_dataset(df, bands, chunks):
    '''For a single HLS tile create a multi-date, multi-band xarray dataset'''
    datasets = []
    for i,row in df.iterrows():
        try:
            # print('loading...', row['dt'])
            ds = create_multiband_dataset(row, bands, chunks)
            datasets.append(ds)
        except Exception as e:
            print('ERROR loading, skipping acquistion!')
            print(e)
    DS = xr.concat(datasets, dim=pd.DatetimeIndex(df['dt'].tolist(), name='time'))
    print('Dataset size (Gb): ', DS.nbytes/1e9)
    return DS

In [9]:
def get_mask(qa_band):
    """Takes a data array HLS qa band and returns a mask of True where quality is good, False elsewhere
    Mask usage:
        ds.where(mask)
        
    Example:
        qa_mask = get_mask(dataset[HLSBand.QA])
        ds = dataset.drop_vars(HLSBand.QA)
        masked = ds.where(qa_mask)
    """
    def is_bad_quality(qa):
        cirrus = 0b1
        cloud = 0b10
        adjacent_cloud = 0b100
        cloud_shadow = 0b1000
        high_aerosol = 0b11000000

        return (qa & cirrus > 0) | (qa & cloud > 0) | (qa & adjacent_cloud > 0) | \
            (qa & cloud_shadow > 0) | (qa & high_aerosol == high_aerosol)
    return xr.where(is_bad_quality(qa_band), False, True)  # True where is_bad_quality is False, False where is_bad_quality is True

In [10]:
client = Client("tcp://127.0.0.1:33727")

In [11]:
x_chunk = 366*2
y_chunk = 366*2
chunks = {'band': 1, 'x': x_chunk, 'y': y_chunk}
grps = list(point_catalog.xr_ds.groupby('INDEX'))
jobs = []
for idx, ds in grps:
    df = ds.to_dataframe()
    jobs.append((idx, df))

In [12]:
training_ds = create_timeseries_multiband_dataset(jobs[0][1], point_catalog.xr_ds.attrs['bands'], chunks)

Dataset size (Gb):  6.791629112


In [13]:
qa_mask = None
if HLSBand.QA in training_ds.data_vars:
    qa_mask = get_mask(training_ds[HLSBand.QA])
    training_ds = training_ds.drop_vars(HLSBand.QA)

In [14]:
masked = training_ds.where(qa_mask)

In [20]:
median = masked.groupby('time.month').median()

In [21]:
median

Unnamed: 0,Array,Chunk
Bytes,1.29 GB,4.29 MB
Shape,"(12, 3660, 3660)","(1, 732, 732)"
Count,111388 Tasks,300 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 1.29 GB 4.29 MB Shape (12, 3660, 3660) (1, 732, 732) Count 111388 Tasks 300 Chunks Type float64 numpy.ndarray",3660  3660  12,

Unnamed: 0,Array,Chunk
Bytes,1.29 GB,4.29 MB
Shape,"(12, 3660, 3660)","(1, 732, 732)"
Count,111388 Tasks,300 Chunks
Type,float64,numpy.ndarray


In [None]:
median.compute()

In [None]:
%%time

def job_to_median(job):
    xr
    x_chunk = 366*2
    y_chunk = 366*2
    da_lst = [
        xr.open_rasterio(url, chunks={'band': 1, 'x': x_chunk, 'y': y_chunk}).data # get underlying dask array because xarray doesn't support median w/ dask
        for url in urls
    ]
    year_array = da.concatenate(da_lst, axis=0)
    median_array = da.median(year_array, axis=0)
    median_array.compute()
    return median_array

for idx, month, urls in jobs[:12]:
    med = urls_to_median(urls)
    print(f"Completed median for {idx}, {month}")

Completed median for 2, 1
Completed median for 2, 2
Completed median for 2, 3
Completed median for 2, 4
Completed median for 2, 5
Completed median for 2, 6
Completed median for 2, 7
Completed median for 2, 8
Completed median for 2, 9
Completed median for 2, 10
