In [None]:
import xarray as xr
import rioxarray
import rasterio as rio
import numpy as np
import os
import re
import pandas as pd
import geopandas as gpd
import hvplot.pandas
from datetime import datetime, timedelta
import dask
from hpc_setup import launch_dask
from hlsstack.hls_funcs import fetch
from hlsstack.hls_funcs.masks import mask_hls, shp2mask, bolton_mask_xr, bolton_mask_np, atsa_mask

In [None]:
fetch.setup_env(aws=False)

In [None]:
#from importlib import reload
#import sys
#reload(sys.modules["ltar_jer_params_extract"])

In [None]:
import sys
sys.path.insert(1, '/project/cper_neon_aop/hls_nrt/params/co_wss_params/')

from co_wss_params_extract import *

In [None]:
if inPATH.split('.')[-1] == 'csv':
    # load csv of ground data
    df_aoi_txt = pd.read_csv(inPATH, parse_dates=[date_col])
    if preprocess is not None:
        df_aoi_txt = preprocess(df_aoi_txt)
    # drop any IDs with missing coordinates or IDs
    df_aoi_txt = df_aoi_txt.dropna(subset=[id_col, x_coord_col, y_coord_col])
    df_aoi_txt = df_aoi_txt[df_aoi_txt['NAME'] == 'Weld']

    # convert to GeoDataFrame using coordinates
    gdf_aoi = gpd.GeoDataFrame(
        df_aoi_txt, geometry=gpd.points_from_xy(df_aoi_txt[x_coord_col], df_aoi_txt[y_coord_col]))

elif inPATH.split('.')[-1] == 'shp':
    gdf_aoi = gpd.read_file(inPATH)
    if preprocess is not None:
        gdf_aoi = preprocess(gdf_aoi)
# set the coordinate system
if gdf_aoi.crs is None:
    gdf_aoi = gdf_aoi.set_crs(epsg=input_epsg)
# reproject to output coordinate system if different
if input_epsg != output_epsg:
    gdf_aoi = gdf_aoi.to_crs(epsg=output_epsg)
# buffer the points to extract surrounding pixels later
gdf_aoi.geometry = gdf_aoi.buffer(buffer)

In [None]:
gdf_aoi

In [None]:
gdf_aoi = gdf_aoi.reset_index()

In [None]:
# check if any polygons are empty
any(gdf_aoi.area == 0)

In [None]:
# if multiple geometries exists for single ID, get average centroid and rebuffer
if any(gdf_aoi.groupby(id_col).apply(lambda x: len(np.unique(x.geometry.astype('str'))) > 1)):
    print('Averaging plots to centroid for multi-polygon plots:')
    print()
    # get average centroid and rebuffer
    mean_polys = gdf_aoi.dissolve(by=id_col).centroid.buffer(buffer)
    # rename geometry
    mean_polys.name = 'geometry'
    # convert to GeoDataFrame by resetting index
    mean_polys = mean_polys.reset_index()
    # overwrite geometry of original gdf with new geometry by matching ID
    gdf_aoi.geometry = pd.merge(gdf_aoi,
                            mean_polys,
                            on=id_col,
                            how='left')['geometry_y']
else:
    print('No multiple geometries found for unique ids')

In [None]:
# check if any polygon ID's overlap with other ID's within the same year
gdf_overlapping = []
gdf_non_overlapping = []
for yr in gdf_aoi[date_col].dt.year.unique():
    overlapping = []
    non_overlapping = []
    gdf_aoi_yr = gdf_aoi[gdf_aoi[date_col].dt.year == yr]
    gdf_aoi_nodup_yr = gdf_aoi_yr.drop_duplicates(subset=[id_col])
    geom_list = list(gdf_aoi_nodup_yr.geometry)
    for n, p in enumerate(geom_list, 0):
        if any(p.overlaps(g) or p.intersects(g) for g in [x for i,x in enumerate(geom_list) if i!=n]):
            # Store the index from the original dataframe
            overlapping.append(n)
        else:
            non_overlapping.append(n)
    # Create a new dataframes and reset their indexes

    if len(overlapping) > 0:
        gdf_overlapping.append(gdf_aoi_nodup_yr.iloc[overlapping])
    if len(non_overlapping) > 0:
        gdf_non_overlapping.append(gdf_aoi_nodup_yr.iloc[non_overlapping])
if len(gdf_overlapping) > 0:
    gdf_overlapping = pd.concat(gdf_overlapping)
    print(len(gdf_overlapping))
else:
    gdf_overlapping = None
    print('No overlapping polygons detected.')

In [None]:
# preview the plots
# p1 = gdf_aoi.drop_duplicates(id_col).hvplot(tiles='EsriImagery', crs=gdf_aoi.crs.to_epsg(), 
#                color='red', alpha=0.5, hover=True, hover_cols=[id_col],
#                frame_height=540, frame_width=800)
# p1

In [None]:
#client.cluster.close()
#client.close()

In [None]:
gdf_aoi.geometry.plot()

In [None]:
import psutil
interface_prefs = [
        'ibp175s0',
        'ibp59s0',
        'enp24s0f0',
        'ens7f0']
interface = [x for x in interface_prefs if x in list(psutil.net_if_addrs().keys())][0]
if len(interface) == 0:
    print('ERROR: Preferred interfaces not found on node!')
else:
    print(interface)
num_jobs=30
client = launch_dask(cluster_loc=cluster_loc,
                     num_jobs=num_jobs,
                     mem_gb_per=4.0,
                     partition='medium', 
                     duration='06:00:00',
                     slurm_opts={'interface': interface},
                    wait_timeout=400,
                    debug=False)
display(client)

In [None]:
try_atsa = True
mask_bolton_by_pixel = False

In [None]:
for yr in gdf_aoi[date_col].dt.year.unique():
    print(yr)
    gdf_aoi_sub_yr = gdf_aoi[gdf_aoi[date_col].dt.year == yr]
    bbox_full = np.array(gdf_aoi_sub_yr.buffer(150).total_bounds)
    if bbox_full[2] - bbox_full[0] > (30*400):
        x_coords = np.arange(bbox_full[0], bbox_full[2], 30*400)
        x_coords = np.append(x_coords, bbox_full[2])
    else:
        x_coords = [bbox_full[0], bbox_full[2]]
    if bbox_full[3] - bbox_full[1] > (30*400):
        y_coords = np.arange(bbox_full[1], bbox_full[3], 30*400)
        y_coords = np.append(y_coords, bbox_full[3])
    else:
        y_coords = [bbox_full[1], bbox_full[3]]
    
    for xi, x in enumerate(x_coords[:-1]):
        for yi, y in enumerate(y_coords[:-1]):
            minx = x_coords[xi] - 45
            maxx = x_coords[xi+1] + 45
            miny = y_coords[yi] - 45
            maxy = y_coords[yi+1] + 45
            # expand window if small and running ATSA mask
            if try_atsa:
                if (maxx - minx) / 30 < 100:
                    maxx = minx + (30*100)
                if (maxy - miny) / 30 < 100:
                    maxy = miny + (30*100)
            print('Lower left: ', minx, ',', miny)
            outPATH_tmp = os.path.join(outDIR,
                                       'tmp', 
                                       re.sub('.csv', 
                                              '_' + '_'.join([str(yr), str(int(minx)), str(int(miny))]) + '.csv',
                                              basename))
            if os.path.exists(outPATH_tmp):
                print('Extraction already complete for coords. Moving on.')
                continue
            elif len(gdf_aoi_sub_yr.cx[minx:maxx, miny:maxy]) == 0:
                print('No plots in grid block. Moving on.')
                continue
            else:
                idx_load_try = 0
                max_load_try = 5
                while idx_load_try < max_load_try:
                    try: 
                        # make sure there are at least some workers before fetching data
                        client.wait_for_workers(n_workers=int(num_jobs*0.1), timeout=200)
                        
                        start_date = gdf_aoi_sub_yr.cx[minx:maxx, miny:maxy][date_col].min() - timedelta(days=184)
                        end_date = gdf_aoi_sub_yr.cx[minx:maxx, miny:maxy][date_col].max() + timedelta(days=184)
                        # save the date range as a dictionary for fetching
                        data_dict = {'date_range': [str(start_date.date()), 
                                                    str(end_date.date())]}
                        hls_ds = fetch.get_hls(hls_data=data_dict,
                                               bbox=np.array([minx, miny, maxx, maxy]), 
                                               stack_chunks=(1, -1, 450, 450),
                                               debug=True,
                                               proj_epsg=gdf_aoi.crs.to_epsg(),
                                               lim=1000,
                                               aws=False)
                        hls_ds = hls_ds.assign_coords(tile_id = ('time', [x.split('.')[2] for x in hls_ds['id'].values]))
                        hls_ds.sortby('time')
                        
                        #hls_ds
                        
                        hls_ds = hls_ds.load()
                        idx_load_try = max_load_try
                    except rio.errors.RasterioIOError:
                        if idx_load_try == max_load_try:
                            print('Fetching HLS failed for the max number of tries. Ending.')
                            break
                        else:
                            print('Warning: error loading data. Retrying ' + str(idx_load_try+1) + ' of ' + str(max_load_try))
                            # client.restart(wait_for_workers=False)
                            client.wait_for_workers(n_workers=int(num_jobs*0.1), timeout=200)
                            idx_load_try += 1   
                    except RuntimeError:
                        if idx_load_try == max_load_try:
                            print('Fetching HLS failed for the max number of tries. Ending.')
                            break
                        else:
                            print('Warning: error loading data. Retrying ' + str(idx_load_try+1) + ' of ' + str(max_load_try))
                            # client.restart(wait_for_workers=False)
                            client.wait_for_workers(n_workers=int(num_jobs*0.1), timeout=200)
                            idx_load_try += 1   
                        
                # create a tile ID coordinate
                hls_ds = hls_ds.assign_coords(tile_id = ('time', [x.split('.')[2] for x in hls_ds['id'].values]))
                
                # pick best image for any dates with duplicate images for the same tile
                if len(hls_ds['time'].groupby('tile_id').apply(lambda x: x.drop_duplicates('time', False))) < len(hls_ds['time']):
                    print('Dropping duplicate images for same tile.')
                    hls_mask = mask_hls(hls_ds['FMASK'], mask_types=['all'])
                    hls_ds['maskcov_pct'] = ((hls_mask != 0).sum(['y', 'x']) / hls_ds['FMASK'].isel(time=0).size * 100)#
                    hls_ds = hls_ds.groupby('tile_id').apply(lambda x: x.sortby('maskcov_pct').drop_duplicates('time', keep='first')).sortby('time').compute()
                
                # merge and drop tile_id if multiple tiles exist, but don't overlap
                if 'tile_id' in hls_ds.coords and \
                len(np.unique(hls_ds.tile_id.values)) > 1 and \
                len(np.unique(hls_ds.drop_duplicates(dim=['time', 'y', 'x'])['time'])) < len(np.unique(hls_ds['time'])):
                    print('Multiple, overlapping tiles ids still exist for the same date. Need to figure out how to deal with this and keep mask intact.')
                elif 'tile_id' in hls_ds.coords and len(np.unique(hls_ds.tile_id.values)) > 1:
                    print('Dropping tile_id by taking mean across time dimension.')
                    hls_ds = hls_ds.groupby('time').mean()
                
                display(hls_ds)
                
                # compute ATSA mask if possible
                if try_atsa:
                    print('masking out clouds and shadows detected by ATSA')
                    if len(np.unique(hls_ds.drop_duplicates(dim=['time', 'y', 'x'])['time'])) < len(np.unique(hls_ds['time'])):
                        print('Overlapping tiles found. Computing masks separately by tile id.')
                        hls_atsa = hls_ds.groupby('tile_id').apply(lambda x: atsa_mask(x.where(
                            x['BLUE'].notnull(), drop=True))).compute()
                        hls_atsa = hls_atsa.transpose('time', 'y', 'x')
                        mask_atsa = True
                    else:
                        hls_ds = hls_ds.reset_coords(drop=True)
                        try:
                            hls_atsa = atsa_mask(hls_ds).compute()
                            mask_atsa = True
                        except (ValueError, IndexError):
                            print('WARNING: Could not compute ATSA cloud/shadow mask')
                            mask_atsa = False
                            pass
                else:
                    mask_atsa = False
                
                if mask_bolton_by_pixel:
                    # compute the bolton mask
                    hls_bolton_mask = bolton_mask(hls_ds).compute()
                    hls_ds = xr.merge([hls_ds, hls_atsa], join='inner')
                    hls_ds.where(hls_ds['BOLTON'] == 0, drop=True)
                
                # compute native HLS mask, including all aerosol flags
                hls_mask = mask_hls(hls_ds['FMASK'], mask_types=['cirrus',
                                                                 'cloud',
                                                                 'cloud_adj',
                                                                 'shadow', 
                                                                 'snow',
                                                                 'water',
                                                                 'high_aerosol'])
                # mask using native HLS mask
                hls_ds = hls_ds.where(hls_mask == 0)
                # mask using ATSA mask, if available
                
                if mask_atsa:
                    print('Applying ATSA mask to dataset.')
                    # merge ATSA mask with HLS data
                    hls_ds = xr.merge([hls_ds, hls_atsa], join='inner')
                    hls_ds = hls_ds.where(hls_ds['ATSA'] == 1)
                
                # in case multiple tile_id's still exist, take the mean by pixel
                if 'tile_id' in hls_ds.coords and \
                len(np.unique(hls_ds.tile_id.values)) > 1 and \
                len(np.unique(hls_ds.drop_duplicates(dim=['time', 'y', 'x'])['time'])) < len(np.unique(hls_ds['time'])):
                    print('Multiple, overlapping tiles ids still exist, taking mean by pixel for each date')
                    hls_ds = hls_ds.groupby('time').mean()

                # lazy compute all vegetation indices
                for vegidx in veg_dict:
                    hls_ds[vegidx] = veg_dict[vegidx](hls_ds)

                # subset overlapping polygons if they exist
                if gdf_overlapping is not None:
                    gdf_overlapping_yr = gdf_overlapping[gdf_overlapping[date_col].dt.year==yr].drop_duplicates(subset=[id_col])
                    if len(gdf_overlapping_yr) > 0:
                        print('subsetting overlapping polygons')
                        gdf_aoi_sub_yr = gdf_aoi_sub_yr[~gdf_aoi_sub_yr[id_col].isin(gdf_overlapping_yr[id_col].unique())]
                        df_list_overlapping = []
                        for _, row in gdf_overlapping_yr.iterrows():
                            # create an xarray mask from the plot
                            mask_info = row.to_frame().transpose()[[id_col, 'geometry']].reset_index(drop=True).reset_index().rename(columns={'index': 'id'})
                            mask_shp = [(row.geometry, row.id+1) for _, row in mask_info.iterrows()]
                            plot_mask = shp2mask(shp=mask_shp, 
                                                 transform=hls_ds.rio.transform(), 
                                                 outshape=hls_ds['BLUE'].shape[1:], 
                                                 xr_object=hls_ds['BLUE'])
                            mask_dict = {row.id+1: row[id_col] for _, row in mask_info.iterrows()}
                            mask_dict[0] = 'UNK'
                            plot_mask.values = np.array([mask_dict[i] for i in plot_mask.values.flatten()]).reshape(plot_mask.shape)
                            
                            # convert to dataframe for plot, if any pixels are NA, result is NA with skipna=False
                            df_yr_sub_tmp = hls_ds.where(plot_mask == row[id_col], drop=True).mean(['y', 'x'], skipna=False).to_dataframe().reset_index()
                            # add ID to dataframe
                            df_yr_sub_tmp['Plot'] = row[id_col]
                            # append to list
                            df_list_overlapping.append(df_yr_sub_tmp)
                        df_overlapping = pd.concat(df_list_overlapping)
                    else:
                        df_overlapping = None
                else:
                    df_overlapping = None
                    
                # create an xarray mask from the ground data
                mask_info = gdf_aoi_sub_yr.drop_duplicates(
                    subset=[id_col])[[id_col, 'geometry']].reset_index(drop=True).reset_index().rename(columns={'index': '_id'})
                mask_shp = [(row.geometry, row._id+1) for _, row in mask_info.iterrows()]
                plot_mask = shp2mask(shp=mask_shp, 
                                     transform=hls_ds.rio.transform(), 
                                     outshape=hls_ds['BLUE'].shape[1:], 
                                     xr_object=hls_ds['BLUE'])
                mask_dict = {row._id+1: row[id_col] for _, row in mask_info.iterrows()}
                mask_dict[0] = 'UNK'
                plot_mask = np.array([mask_dict[i] for i in plot_mask.values.flatten()]).reshape(plot_mask.shape)
                
                # assign the plot id's to the xarray dataset
                hls_ds = hls_ds.assign(Plot=(['y', 'x'], plot_mask)).chunk({'y': 50, 'x': 50})
                hls_ds = hls_ds.set_coords('Plot')
                
                # mask out areas outside ground plots
                hls_ds = hls_ds.where(hls_ds['Plot'] != 'UNK')
                
                # convert to dataframe at plot scale, if any pixels are NA, result is NA with skipna=False
                df_yr_sub = hls_ds[list(veg_dict.keys()) + band_list].groupby(
                    'Plot').mean('stacked_y_x', skipna=False).to_dataframe().reset_index()
                # drop outside plots
                df_yr_sub = df_yr_sub[df_yr_sub['Plot'] != 'UNK']

                # add in overlapping polygon data if it exists
                if df_overlapping is not None:
                    df_yr_sub = pd.concat([df_yr_sub, df_overlapping])

                # write to disk
                df_yr_sub.to_csv(outPATH_tmp, index=False)

                # delete datasets to free memory
                del hls_ds, df_yr_sub, plot_mask, df_overlapping
                
                if mask_atsa:
                    del hls_atsa
        
                # client.restart()

In [None]:
test