In [1]:
import xarray as xr
import rioxarray
import numpy as np
import os
import re
import pandas as pd
import geopandas as gpd
from datetime import datetime, timedelta
from tqdm.notebook import tqdm
import time
import json
from shapely import wkt
from itertools import chain
from hlsstack.hls_funcs import fetch
from hlsstack.hls_funcs.masks import mask_hls, shp2mask, bolton_mask, atsa_mask
from hlsstack.hls_funcs.indices import ndvi_func, dfi_func, ndti_func, satvi_func, ndii7_func
from hlsstack.hls_funcs.indices import bai_126_func, bai_136_func, bai_146_func, bai_236_func, bai_246_func, bai_346_func
from hlsstack.hls_funcs.smooth import despike_ts, double_savgol

In [2]:
inDIR = '../data/ground_cln/'
inFILE = 'vor_2013_2022_cln_2023_04_26.csv'

inPATH = os.path.join(inDIR, inFILE)

outDIR = '../data/training/'
outPATH = os.path.join(outDIR, re.sub('.csv', '_hls_idxs.csv', inFILE))

try_atsa = False
mask_bolton = False

veg_dict = {
    'NDVI': ndvi_func,
    'DFI': dfi_func,
    'NDTI': ndti_func,
    'SATVI': satvi_func,
    'NDII7': ndii7_func,
    'BAI_126': bai_126_func,
    'BAI_136': bai_136_func,
    'BAI_146': bai_146_func,
    'BAI_236': bai_236_func,
    'BAI_246': bai_246_func,
    'BAI_346': bai_346_func
}

band_list = ['NIR1', 'SWIR1', 'SWIR2']

In [3]:
print('   setting up Local cluster...')
from dask.distributed import LocalCluster, Client
import dask
aws=False
fetch.setup_env(aws=aws)
cluster = LocalCluster(n_workers=8, threads_per_worker=2)
client = Client(cluster)
display(client)

   setting up Local cluster...


2023-04-26 14:54:39,605 - distributed.diskutils - INFO - Found stale lock file and directory '/tmp/dask-worker-space/worker-51fl0s20', purging
2023-04-26 14:54:39,605 - distributed.diskutils - INFO - Found stale lock file and directory '/tmp/dask-worker-space/worker-fsgaphyj', purging
2023-04-26 14:54:39,605 - distributed.diskutils - INFO - Found stale lock file and directory '/tmp/dask-worker-space/worker-gyf46mlk', purging
2023-04-26 14:54:39,612 - distributed.diskutils - INFO - Found stale lock file and directory '/tmp/dask-worker-space/worker-iqqxnc9_', purging


0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 8
Total threads: 16,Total memory: 11.85 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:44579,Workers: 8
Dashboard: http://127.0.0.1:8787/status,Total threads: 16
Started: Just now,Total memory: 11.85 GiB

0,1
Comm: tcp://127.0.0.1:36963,Total threads: 2
Dashboard: http://127.0.0.1:44003/status,Memory: 1.48 GiB
Nanny: tcp://127.0.0.1:36159,
Local directory: /tmp/dask-worker-space/worker-05ouv9sj,Local directory: /tmp/dask-worker-space/worker-05ouv9sj

0,1
Comm: tcp://127.0.0.1:42009,Total threads: 2
Dashboard: http://127.0.0.1:36549/status,Memory: 1.48 GiB
Nanny: tcp://127.0.0.1:44253,
Local directory: /tmp/dask-worker-space/worker-uur0jysf,Local directory: /tmp/dask-worker-space/worker-uur0jysf

0,1
Comm: tcp://127.0.0.1:36759,Total threads: 2
Dashboard: http://127.0.0.1:34173/status,Memory: 1.48 GiB
Nanny: tcp://127.0.0.1:40615,
Local directory: /tmp/dask-worker-space/worker-_6tr8klt,Local directory: /tmp/dask-worker-space/worker-_6tr8klt

0,1
Comm: tcp://127.0.0.1:45457,Total threads: 2
Dashboard: http://127.0.0.1:38025/status,Memory: 1.48 GiB
Nanny: tcp://127.0.0.1:45223,
Local directory: /tmp/dask-worker-space/worker-kibyxmmk,Local directory: /tmp/dask-worker-space/worker-kibyxmmk

0,1
Comm: tcp://127.0.0.1:40519,Total threads: 2
Dashboard: http://127.0.0.1:38935/status,Memory: 1.48 GiB
Nanny: tcp://127.0.0.1:33885,
Local directory: /tmp/dask-worker-space/worker-_tt1nac8,Local directory: /tmp/dask-worker-space/worker-_tt1nac8

0,1
Comm: tcp://127.0.0.1:37955,Total threads: 2
Dashboard: http://127.0.0.1:42007/status,Memory: 1.48 GiB
Nanny: tcp://127.0.0.1:38937,
Local directory: /tmp/dask-worker-space/worker-uw2l592w,Local directory: /tmp/dask-worker-space/worker-uw2l592w

0,1
Comm: tcp://127.0.0.1:44541,Total threads: 2
Dashboard: http://127.0.0.1:34345/status,Memory: 1.48 GiB
Nanny: tcp://127.0.0.1:34603,
Local directory: /tmp/dask-worker-space/worker-opsw2rj3,Local directory: /tmp/dask-worker-space/worker-opsw2rj3

0,1
Comm: tcp://127.0.0.1:38251,Total threads: 2
Dashboard: http://127.0.0.1:43481/status,Memory: 1.48 GiB
Nanny: tcp://127.0.0.1:44993,
Local directory: /tmp/dask-worker-space/worker-klujqr9j,Local directory: /tmp/dask-worker-space/worker-klujqr9j


In [4]:
if not os.path.exists(outDIR):
    os.mkdir(outDIR)

In [5]:
# load csv of ground data as GeoDataFrame
df_vor = pd.read_csv(inPATH, parse_dates=[2, 3])
df_vor = df_vor[df_vor['geometry'].notnull()].copy()
df_vor['geometry'] = df_vor['geometry'].astype('str').apply(wkt.loads)
gdf_vor = gpd.GeoDataFrame(df_vor, geometry='geometry', crs=32613)
gdf_vor = gdf_vor.sort_values('Date')

In [6]:
# load any existing output data
if os.path.exists(outPATH):
    df_out = pd.read_csv(outPATH, parse_dates=[2, 3])
else:
    df_out = None

In [7]:
for yr in tqdm(gdf_vor['Year'].unique()):
    print(yr)
    # skip if year already in output data
    if df_out is not None and yr in df_out['Year'].unique():
        print('Skipping year - already in output dataset.')
        continue
    else:
        # get subset of vor data for year
        gdf_yr = gdf_vor[gdf_vor['Year'] == yr]

        # get the date range for the fetch as a dictionary
        start_date = gdf_yr['Date'].min().date() - timedelta(days=30)
        end_date = gdf_yr['Date'].max().date() + timedelta(days=30)
        data_dict = {'date_range': [str(start_date), 
                                    str(end_date)]}

        # set the bounding box for fetching data
        bbox_yr = np.array(gdf_yr.total_bounds) # the entire subset dataset 

        # fetch the data for the entire year's dataset
        hls_ds = fetch.get_hls(hls_data=data_dict,
                               bbox=bbox_yr, 
                               stack_chunks=(400, 400),
                               proj_epsg=gdf_yr.crs.to_epsg(),
                               lim=1000).load()

        # create a tile ID coordinate
        hls_ds = hls_ds.assign_coords(tile_id = ('time', [x.split('.')[2] for x in hls_ds['id'].values]))

        # pick best image for any dates with duplicate images for the same tile
        if len(hls_ds['time'].groupby('tile_id').apply(lambda x: x.drop_duplicates('time', False))) < len(hls_ds['time']):
            hls_mask = mask_hls(hls_ds['FMASK'], mask_types=['all'])
            hls_ds['maskcov_pct'] = ((hls_mask != 0).sum(['y', 'x']) / hls_ds['FMASK'].isel(time=0).size * 100)#
            hls_ds = hls_ds.groupby('tile_id').apply(lambda x: x.sortby('maskcov_pct').drop_duplicates('time', keep='first')).sortby('time').compute()

        # compute ATSA mask if possible
        if try_atsa:
            print('masking out clouds and shadows detected by ATSA')
            if len(np.unique(hls_ds.tile_id)) > 1:
                hls_atsa = hls_ds.groupby('tile_id').apply(lambda x: atsa_mask(x.where(
                    x['BLUE'].notnull(), drop=True))).compute()
                hls_atsa = hls_atsa.transpose('time', 'y', 'x')
            else:
                hls_ds = hls_ds.reset_coords(drop=True)
                try:
                    hls_atsa = atsa_mask(hls_ds).compute()
                    mask_atsa = True
                except ValueError or IndexError:
                    print('WARNING: Could not compute ATSA cloud/shadow mask')
                    mask_atsa = False
                    pass
        else:
            mask_atsa = False

        if mask_bolton:
            # compute the bolton mask
            hls_bolton_mask = bolton_mask(hls_ds).compute()
            hls_ds = xr.merge([hls_ds, hls_atsa], join='inner')
            hls_ds.where(hls_ds['BOLTON'] == 0, drop=True)

        # compute native HLS mask
        hls_mask = mask_hls(hls_ds['FMASK'], mask_types=['cirrus',
                                                        'cloud',
                                                        'cloud_adj',
                                                        'shadow', 
                                                        'snow',
                                                        'water',
                                                        'high_aerosol'])
        # mask using native HLS mask
        hls_ds = hls_ds.where(hls_mask == 0)
        # mask using ATSA mask, if available
        if mask_atsa:
            # merge ATSA mask with HLS data
            hls_ds = xr.merge([hls_ds, hls_atsa], join='inner')
            hls_ds = hls_ds.where(hls_ds['ATSA'] == 1)

        # in case multiple tile_id's still exist, take the mean by pixel
        if 'tile_id' in hls_ds.coords and len(np.unique(hls_ds.tile_id.values)) > 1:
            hls_ds = hls_ds.groupby('time').mean()

        # create an xarray mask from the ground data
        mask_info = gdf_vor.drop_duplicates(subset=['Id', 'Date'])[['Id', 'geometry']].reset_index(drop=True).reset_index().rename(columns={'index': 'id'})
        mask_shp = [(row.geometry, row.id+1) for _, row in mask_info.iterrows()]
        plot_mask = shp2mask(shp=mask_shp, 
                             transform=hls_ds.rio.transform(), 
                             outshape=hls_ds['BLUE'].shape[1:], 
                             xr_object=hls_ds['BLUE'])
        mask_dict = {row.id+1: row.Id for _, row in mask_info.iterrows()}
        mask_dict[0] = 'UNK'
        plot_mask = np.array([mask_dict[i] for i in plot_mask.values.flatten()]).reshape(plot_mask.shape)

        # assign the plot id's to the xarray dataset
        hls_ds = hls_ds.assign(Plot=(['y', 'x'], plot_mask)).chunk({'y': 50, 'x': 50})
        hls_ds = hls_ds.set_coords('Plot')

        # mask out areas outside ground plots
        hls_ds = hls_ds.where(hls_ds['Plot'] != 'UNK')

        # lazy compute all vegetation indices
        for vegidx in veg_dict:
            hls_ds[vegidx] = veg_dict[vegidx](hls_ds)

        # convert to dataframe at plot scale
        df_yr = hls_ds[list(veg_dict.keys()) + band_list].groupby('Plot').mean('stacked_y_x').to_dataframe().reset_index()

        # remove all non-plot data
        df_yr = df_yr[df_yr['Plot'] != 'UNK']

        # rename columns to match VOR data
        df_yr = df_yr.rename(columns={'time': 'Date',
                                        'Plot': 'Id'})

        # get missing dates for gap-filling
        dates_missing = [x for x in pd.date_range(start_date, end_date).date if x not in df_yr['Date'].unique()]

        # convert missing dates to a dataframe for combining with veg index data
        df_missing = pd.DataFrame({
            'Id': list(chain.from_iterable([list(np.repeat(x, len(dates_missing))) for x in df_yr['Id'].unique()])),
            'Date': list(chain.from_iterable(dates_missing for x in df_yr['Id'].unique()))})

        # combine into one dataframe for gapfilling
        df_yr_ts = pd.concat([df_yr, df_missing]).sort_values(['Id', 'Date'])

        # smooth all vegetation indices to gapfill
        for vegidx in veg_dict:
            df_yr_ts[vegidx + '_smooth'] = df_yr_ts.groupby('Id')[vegidx].transform(lambda x: double_savgol(x.values))
        for band in band_list:
            df_yr_ts[band + '_smooth'] = df_yr_ts.groupby('Id')[band].transform(lambda x: double_savgol(x.values))

        # convert date to datetime
        df_yr_ts['Date'] = pd.to_datetime(df_yr_ts['Date'])

        # rename smoothed columns and drop originals
        df_yr_ts = df_yr_ts.drop(columns=list(veg_dict.keys()) + band_list)
        col_rename_dict = {c: re.sub('_smooth', '', c) for c in df_yr_ts.columns if '_smooth' in c}
        df_yr_ts = df_yr_ts.rename(columns=col_rename_dict)

        df_out_yr = pd.merge(gdf_yr, 
                             df_yr_ts[['Id', 'Date'] + list(veg_dict.keys()) + band_list], 
                             on=['Id', 'Date'],
                             how='left')

        if df_out is not None:
            # merge with existing ouptput dataset
            df_out = pd.concat([df_out, df_out_yr])
        else:
            # create output dataset
            df_out = df_out_yr.copy()
        df_out.to_csv(outPATH, index=False)
        client.restart()

  0%|          | 0/10 [00:00<?, ?it/s]

2013
Skipping year - already in output dataset.
2014
Skipping year - already in output dataset.
2015
Skipping year - already in output dataset.
2016
Skipping year - already in output dataset.
2017




2018




2019




2020




2021




2022


