### EXAMPLE: Extract the HLS indices for individual plots
This example is for the WTGN study only for simplicity. It is a good workflow for point data, and will take a weighted mean of pixels surrounding
the plot point location (you can specify how many pixels to use).
Other studies in the TBNG dataset will require more preprocessing (see /extract_hls_veg_idx.ipynb). More efficient methods may exists 
for plots that represent areas (e.g., polygons), for example using a plot mask (see /examples/extract_smooth_hls_by_plot.ipynb).

#### Load required packages

In [1]:
import xarray as xr
import rioxarray
import rasterio as rio
import numpy as np
import os
import re
import pandas as pd
import geopandas as gpd
from datetime import datetime, timedelta
from tqdm.notebook import tqdm
import time
import json
from src.hls_funcs import fetch
from src.hls_funcs.masks import mask_hls, shp2mask, bolton_mask
from src.hls_funcs.indices import ndvi_func, dfi_func, ndti_func, satvi_func, ndii7_func
from src.hls_funcs.indices import bai_126_func, bai_136_func, bai_146_func, bai_236_func, bai_246_func, bai_346_func
from src.hls_funcs.smooth import smooth_xr, despike_ts_xr

#### Specify the input/output paths and other parameters

In [2]:
###############
### Outputs ###
###############
# output directory
outDIR = '../examples/outputs/'
# output file name
outFILE = 'example_TB_wtgn_vor_idxs.csv'

##############
### Inputs ###
##############
# input path
df_aoi = pd.read_csv('../data/vor/TB_wtgn_vor.csv', parse_dates=[0])

# unique ID column name
id_col = 'Join_ID'

# date column name
date_col = 'Date'

# coordinate column names
x_coord_col = 'Mean.GPS_E.Biomass'
y_coord_col = 'Mean.GPS_N.Biomass'

################
### HLS info ###
################
# dictionary of vegetation indices to be extracted and functions to create them
veg_dict = {
    'ndvi': ndvi_func,
    'dfi': dfi_func,
    'ndti': ndti_func,
    'satvi': satvi_func,
    'ndii7': ndii7_func,
    'bai_126': bai_126_func,
    'bai_136': bai_136_func,
    'bai_146': bai_146_func,
    'bai_236': bai_236_func,
    'bai_246': bai_246_func,
    'bai_346': bai_346_func
}

# list of individual bands to be extracted
band_list = ['NIR1', 'SWIR1', 'SWIR2']

# specify number of pixels surrounding the point
pixels = 4

#### Setup the dask client

In [3]:
from dask.distributed import LocalCluster, Client
import dask
fetch.setup_env(aws=False)
cluster = LocalCluster(n_workers=8, threads_per_worker=2)
client = Client(cluster)
display(client)

Perhaps you already have a cluster running?
Hosting the HTTP server on port 40601 instead
2023-04-05 14:14:47,998 - distributed.diskutils - INFO - Found stale lock file and directory '/tmp/dask-worker-space/worker-4ynx5yya', purging
2023-04-05 14:14:47,999 - distributed.diskutils - INFO - Found stale lock file and directory '/tmp/dask-worker-space/worker-f8b9a8on', purging
2023-04-05 14:14:47,999 - distributed.diskutils - INFO - Found stale lock file and directory '/tmp/dask-worker-space/worker-y3x612am', purging
2023-04-05 14:14:47,999 - distributed.diskutils - INFO - Found stale lock file and directory '/tmp/dask-worker-space/worker-m_4c3k3n', purging
2023-04-05 14:14:47,999 - distributed.diskutils - INFO - Found stale lock file and directory '/tmp/dask-worker-space/worker-ueyrm9dy', purging
2023-04-05 14:14:47,999 - distributed.diskutils - INFO - Found stale lock file and directory '/tmp/dask-worker-space/worker-a7vd4dk2', purging
2023-04-05 14:14:47,999 - distributed.diskutils - IN

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:40601/status,

0,1
Dashboard: http://127.0.0.1:40601/status,Workers: 8
Total threads: 16,Total memory: 11.85 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:44149,Workers: 8
Dashboard: http://127.0.0.1:40601/status,Total threads: 16
Started: Just now,Total memory: 11.85 GiB

0,1
Comm: tcp://127.0.0.1:34049,Total threads: 2
Dashboard: http://127.0.0.1:36267/status,Memory: 1.48 GiB
Nanny: tcp://127.0.0.1:40213,
Local directory: /tmp/dask-worker-space/worker-dtgtvjts,Local directory: /tmp/dask-worker-space/worker-dtgtvjts

0,1
Comm: tcp://127.0.0.1:43105,Total threads: 2
Dashboard: http://127.0.0.1:46813/status,Memory: 1.48 GiB
Nanny: tcp://127.0.0.1:40359,
Local directory: /tmp/dask-worker-space/worker-sgmdqw2s,Local directory: /tmp/dask-worker-space/worker-sgmdqw2s

0,1
Comm: tcp://127.0.0.1:45905,Total threads: 2
Dashboard: http://127.0.0.1:46769/status,Memory: 1.48 GiB
Nanny: tcp://127.0.0.1:36283,
Local directory: /tmp/dask-worker-space/worker-vlf97thv,Local directory: /tmp/dask-worker-space/worker-vlf97thv

0,1
Comm: tcp://127.0.0.1:39829,Total threads: 2
Dashboard: http://127.0.0.1:33949/status,Memory: 1.48 GiB
Nanny: tcp://127.0.0.1:34161,
Local directory: /tmp/dask-worker-space/worker-y38eqgf7,Local directory: /tmp/dask-worker-space/worker-y38eqgf7

0,1
Comm: tcp://127.0.0.1:40051,Total threads: 2
Dashboard: http://127.0.0.1:35891/status,Memory: 1.48 GiB
Nanny: tcp://127.0.0.1:46053,
Local directory: /tmp/dask-worker-space/worker-kz9ooqv7,Local directory: /tmp/dask-worker-space/worker-kz9ooqv7

0,1
Comm: tcp://127.0.0.1:43607,Total threads: 2
Dashboard: http://127.0.0.1:46463/status,Memory: 1.48 GiB
Nanny: tcp://127.0.0.1:42741,
Local directory: /tmp/dask-worker-space/worker-e340lvsk,Local directory: /tmp/dask-worker-space/worker-e340lvsk

0,1
Comm: tcp://127.0.0.1:46761,Total threads: 2
Dashboard: http://127.0.0.1:41909/status,Memory: 1.48 GiB
Nanny: tcp://127.0.0.1:41253,
Local directory: /tmp/dask-worker-space/worker-1te4cwlu,Local directory: /tmp/dask-worker-space/worker-1te4cwlu

0,1
Comm: tcp://127.0.0.1:38081,Total threads: 2
Dashboard: http://127.0.0.1:33563/status,Memory: 1.48 GiB
Nanny: tcp://127.0.0.1:41239,
Local directory: /tmp/dask-worker-space/worker-rtgs1aq2,Local directory: /tmp/dask-worker-space/worker-rtgs1aq2


#### Prepare data and directories

##### Plot data prep

In [4]:
### Prepare the plot data ###
# convert to GeoDataFrame using coordinates.
gdf_aoi = gpd.GeoDataFrame(
    df_aoi, geometry=gpd.points_from_xy(df_aoi[x_coord_col], df_aoi[y_coord_col]))
# set the coordinate system
gdf_aoi = gdf_aoi.set_crs(epsg=32613)
# buffer the points to extract surrounding pixels later
gdf_aoi.geometry = gdf_aoi.buffer(150)
# convert the 'Date' column to date instead of timestamp
gdf_aoi[date_col] = gdf_aoi[date_col].dt.date

##### Directory prep

In [6]:
# create output path if it doesn't already exists
if not os.path.exists(outDIR):
    os.mkdir(outDIR)

# load any previously saved data so that if you encounter an error during extraction, you don't have to start over!
if os.path.exists(os.path.join(outDIR, outFILE)):
    # load saved data
    df_out = pd.read_csv(os.path.join(outDIR, outFILE))
    # subset plot data to only those that have not alerady been processed
    gdf_aoi_sub = gdf_aoi[~gdf_aoi[id_col].isin(df_out[id_col])]
else:
    # set the output dataframe to none
    df_out = None
    # set the subset to the entire plot dataset
    gdf_aoi_sub = gdf_aoi

#### Extract HLS indices for each plot

In [7]:
%%time
# loop through all the plots that have not already been processed. 
# idx: the index of the plot
# gdf_plot: the individual row, corresponding to one plot observation
for idx, gdf_plot in tqdm(gdf_aoi_sub.iterrows(), total=gdf_aoi_sub.shape[0]):
    # print the plot ID and date
    print(gdf_plot[id_col], ': ', gdf_plot[date_col])

    ### fetch the COG data
    # get the date range for the fetch based on the year. 
    # we pull a longer date range for smoothing prior to 2017 since Sentinel 2B was not yet launched, 
    # so satellite observations are more sparse
    if gdf_plot.Date.year < 2017:
        start_date = gdf_plot.Date - timedelta(days=60)
        end_date = gdf_plot.Date + timedelta(days=60)
    elif gdf_plot.Date.year == 2017:
        start_date = gdf_plot.Date - timedelta(days=50)
        end_date = gdf_plot.Date + timedelta(days=50)
    elif gdf_plot.Date.year > 2017:
        start_date = gdf_plot.Date - timedelta(days=45)
        end_date = gdf_plot.Date + timedelta(days=45)
    
    # save the date range as a dictionary for fetching
    data_dict = {'date_range': [str(start_date), 
                                str(end_date)]}
    
    # skip plot if any coordinates are NaN
    if np.any(np.isnan(gdf_plot[[x_coord_col, y_coord_col]].astype(float).values)):
        print('    SKIPPED! Missing coordinates.')
        continue
    else:
        # set values for retrying up to 5 times
        idx_load_try = 0
        max_load_try = 5
        while idx_load_try < max_load_try:
            try:
                # fetch the data
                hls_ds = fetch.get_hls(hls_data=data_dict,
                                       bbox=np.array(gdf_plot.geometry.bounds), 
                                       stack_chunks=(20, 20),
                                       proj_epsg=gdf_aoi.crs.to_epsg(),
                                       lim=1000,
                                       aws=False).drop_vars(['SZA', 'SAA', 'VZA', 'VAA']).chunk({'time': -1}).persist()

                # mask the dataset using the native HLS mask
                hls_mask = mask_hls(hls_ds['FMASK'], mask_types=['all'])
                hls_ds = hls_ds.where(hls_mask == 0)

                # mask the dataset using the Bolton mask. Note this will slow things considerably
                #hls_bolton_mask = bolton_mask(hls_ds)
                #hls_ds = hls_ds.where(hls_bolton_mask == 0, drop=True)

                # pick best image (based on the mask) for any dates with duplicate images
                if len(np.unique(hls_ds.time.values)) < len(hls_ds.time.values):
                    print('    reducing along id dimension to single observation for each date, keeping least-masked image')
                    hls_mask = mask_hls(hls_ds['FMASK'], mask_types=['all'])
                    hls_ds['maskcov_pct'] = ((hls_mask != 0).sum(['y', 'x']) / hls_ds['FMASK'].isel(time=0).size * 100)#
                    hls_ds = hls_ds.groupby('maskcov_pct').apply(
                        lambda x: x.sortby('maskcov_pct')).drop_duplicates(
                        'time', keep='first').sortby('time').drop_vars('maskcov_pct')

                # convert the timestamp to date
                hls_ds['time'] = pd.to_datetime(hls_ds['time'])
                hls_ds['time'] = hls_ds['time'].dt.date
                # drop any dates with all NA's, and drop the mask
                hls_ds = hls_ds.reset_coords(drop=True).drop_vars('FMASK')

                # calculate the distance of each pixel center to the plot center
                hls_ds['dist'] = np.sqrt((hls_ds['y'] - gdf_plot[y_coord_col])**2 + ((hls_ds['x'] - gdf_plot[x_coord_col])**2))
                # set the distance as a coordinate
                hls_ds = hls_ds.set_coords('dist')
                # get the maximum distance of the N pixels surround the plot, with N specified earlier
                all_dists = hls_ds.stack(z=['y', 'x']).sortby('dist')['dist'].values
                max_dist = all_dists[~np.isnan(all_dists)][pixels-3]
                # mask out HLS data that is more than approximately N/2 pixels from the plot center
                hls_plot = hls_ds.where(hls_ds['dist'] <= max_dist)
                # load just the window of HLS data surrounding the plot center
                hls_plot = hls_plot.where(hls_plot['BLUE'].notnull()).compute()
                # if successful, set the 'try' iteration to mask to stop this while loop 
                idx_load_try = max_load_try

            ### deal with common errors ###
            except RuntimeError as e:            
                print('Warning: error connecting to lpdaac. Retrying ' + str(idx_load_try+1) + ' of ' + str(max_load_try))
                client.restart()
                idx_load_try += 1
                time.sleep(10)
            except rio.errors.RasterioIOError as e:
                print('Warning: error loading data. Retrying ' + str(idx_load_try+1) + ' of ' + str(max_load_try))
                client.restart()
                idx_load_try += 1
                time.sleep(10)
            except json.decoder.JSONDecodeError:
                print('Warning: JSON decoding error (usually related to 502 Bad Gateway error).',
                      'Retrying ' + str(idx_load_try+1) + ' of ' + str(max_load_try))
                client.restart()
                idx_load_try += 1
                time.sleep(10)

        ##############################
        ### smooth the time series ###
        ##############################
        print('   creating daily template for output...')
        # set the date range for analysis
        date_rng = pd.date_range(start=start_date, end=end_date)

        # create empty numpy array matching plot xarray dims
        dat_out_nans = np.zeros((len([x.date() for x in date_rng if x.date() not in hls_plot['time'].values]), 
                                 hls_plot.dims['y'], 
                                 hls_plot.dims['x'])) * np.nan

        # create dictionary to map variables to date and coords
        xr_empty_dict = {}
        for veg in veg_dict:
            xr_empty_dict[veg] = (['time', 'y', 'x'],
                                       dat_out_nans)
        for band in band_list:
            xr_empty_dict[band] = (['time', 'y', 'x'],
                                       dat_out_nans)

        # create an empty list for combining all the vegetation indices and bands
        xr_veg_list = []

        # loop through the specified vegetation indices, compute them and append 
        for veg in veg_dict:
            xr_veg_tmp = veg_dict[veg](hls_plot)
            xr_veg_tmp.name = veg
            xr_veg_list.append(xr_veg_tmp)
        # loop through the specified bands, pull them and append
        for band in band_list:
            xr_veg_tmp = hls_plot[band]
            xr_veg_tmp.name = band
            xr_veg_list.append(xr_veg_tmp)

        # combine all the indices and bands into one Dataset
        ds_plot = xr.merge(xr_veg_list)

        # create the empty xarray Dataset for populating dates without imagery
        ds_empty = xr.Dataset(data_vars=xr_empty_dict,
                              coords={'time': [x.date() for x in date_rng if x.date() not in hls_plot['time'].values],
                                      'x': hls_plot.x,
                                      'y': hls_plot.y})

        # combine the empty Dataset with the original, now all dates are present
        ds_plot = xr.concat([ds_empty, ds_plot], dim='time').sortby('time').chunk({'time': -1, 'y': 1, 'x': 1})

        # convert all timestamps to dates, just to be sure
        ds_plot['time'] = pd.to_datetime(ds_plot['time'])
        ds_plot['time'] = ds_plot['time'].dt.date

        # first despuje the NDVI index and then remove from the Dataset any dates/pixels removed by NDVI despiking
        # this provides a faster and more consistent method compared to trying to despike all bands/indices
        ds_plot['ndvi'] = despike_ts_xr(ds_plot['ndvi'], dims=['time'], dat_thresh=0.075, days_thresh=60).persist()
        ds_plot = ds_plot.where(ds_plot['ndvi'].notnull())

        # smooth all the vegetation indices and bands
        ds_plot_smooth = ds_plot.map(smooth_xr, dims=['time'])

        # flatten the dataset
        ds_plot_smooth = ds_plot_smooth.stack(z=['y', 'x'])
        # remove any pixels that are too far from the plot (this should have already been done...not sure why it is here again)
        ds_plot_smooth = ds_plot_smooth.where(ds_plot_smooth['dist'].notnull())

        # get the weighted mean of pixels around the plot point
        df_plot_tmp = ds_plot_smooth.sel(
            time=[gdf_plot[date_col]]).weighted(
            ds_plot_smooth.dist).mean('z').rename({'time':
                                                          date_col}).to_dataframe().reset_index()
        # combine the extracted indices/bands for the plot with the rest of the plot data
        df_plot_out = pd.merge(gdf_plot.to_frame().transpose(), df_plot_tmp, left_on=date_col, right_on=date_col)
        # rename the band columns
        df_plot_out = df_plot_out.rename(columns={'NIR1': 'nir',
                                                  'SWIR1': 'swir1',
                                                  'SWIR2': 'swir2'})
        if df_out is None:
            print('creating initial dataframe')
            # create the output dataframe from the plot
            df_out = df_plot_out
        else:
            # combine the plot data with the already processed data
            df_out = pd.concat([df_out, df_plot_out])
            # write to disk
            df_out.to_csv(os.path.join(outDIR, outFILE), index=False)

        # restart the dask client every 50 plots to avoid memory issues (shouldn't be necessary, but apparently is)
        if idx%50 == 0:
            client.restart()

  0%|          | 0/136 [00:00<?, ?it/s]

Section_4_Pt-1_No :  2020-06-04
Center_Owens_Pt1_No :  2020-06-22
    reducing along id dimension to single observation for each date, keeping least-masked image
   creating daily template for output...
Center_Owens_Pt3_No :  2020-06-22
    reducing along id dimension to single observation for each date, keeping least-masked image
   creating daily template for output...
Center_Owens_Pt9_No :  2020-06-22
    reducing along id dimension to single observation for each date, keeping least-masked image
   creating daily template for output...
Center_Owens_Pt10_No :  2020-06-22
    reducing along id dimension to single observation for each date, keeping least-masked image
   creating daily template for output...
North_Owens_Pt2_Yes :  2020-07-02
    reducing along id dimension to single observation for each date, keeping least-masked image
   creating daily template for output...
West_Weiss_Pt21_No :  2020-07-21
    reducing along id dimension to single observation for each date, keeping le



Hogsback_Pt4_No :  2021-07-19
    reducing along id dimension to single observation for each date, keeping least-masked image
   creating daily template for output...
Hogsback_Pt5_No :  2021-07-19
    reducing along id dimension to single observation for each date, keeping least-masked image
   creating daily template for output...
Hogsback_Pt6_No :  2021-07-19
    reducing along id dimension to single observation for each date, keeping least-masked image
   creating daily template for output...
Hogsback_Pt8_No :  2021-07-19
    reducing along id dimension to single observation for each date, keeping least-masked image
   creating daily template for output...
Hogsback_Pt10_No :  2021-07-19
    reducing along id dimension to single observation for each date, keeping least-masked image
   creating daily template for output...
Hogsback_Pt14_No :  2021-07-19
    reducing along id dimension to single observation for each date, keeping least-masked image
   creating daily template for output