# Notebook version of 2.3_dps.py, which runs over a single tile
#### cant get 2.3_dps.py to work, so alternate approach is to loop over tile ids and return indiv tile CSVs of ATL08

In [1]:
#import pdal
import json
import os
import glob

import geopandas as gpd
from pyproj import CRS, Transformer

import argparse

from maap.maap import MAAP
maap = MAAP()

import sys
sys.path.append("/projects/icesat2_boreal/notebooks/3.Gridded_product_development/")
sys.path.append('/projects/code/icesat2_boreal/notebooks/3.Gridded_product_development')

sys.path.append('/projects/code/icesat2_boreal/notebooks/2.ICESat-2_processing')

#TODO: how to get this import right if its in a different dir
from CovariateUtils import get_index_tile
from FilterUtils import *
from ExtractUtils import *

  shapely_geos_version, geos_capi_version_string


In [2]:
def get_h5_list(tile_num, tile_fn="/projects/maap-users/alexdevseed/boreal_tiles.gpkg", layer="boreal_tiles_albers",DATE_START='06-01', DATE_END='09-30', YEARS=[2019, 2020, 2021]):
    '''
    Return a list of ATL08 h5 names that intersect a tile for a give date range across a set of years
    '''
    tile_id = get_index_tile(tile_fn, tile_num, buffer=0, layer = layer)

    in_bbox = ",".join(str(coord) for coord in tile_id['bbox_4326'])
    
    print("\tTILE_NUM: {} ({})".format(tile_num, in_bbox) )
    
    out_crs = tile_id['tile_crs']
    
    DATE_START = DATE_START + 'T00:00:00Z' # SUMMER start
    DATE_END = DATE_END + 'T23:59:59Z' # SUMMER end
    
    date_filters = [f'{year}-{DATE_START},{year}-{DATE_END}' for year in YEARS]
    
    base_query = {
    'short_name':"ATL08",
    'version':"003",
    'bounding_box':in_bbox
    }

    #q3 = [build_query(copy.copy(base_query), date_filter) for date_filter in date_filters]
    queries = [dict(base_query, temporal=date_filter) for date_filter in date_filters]
    
    # query CMR as many seasons as necessary
    result_chain = itertools.chain.from_iterable([maap.searchGranule(**query) for query in queries])
    
    # This is the list of ATL08 that intersect the tile bounds
    # Use this list of ATL08 to identify the ATL08 h5/CSV files you have already DPS'd.
    # get the s3 urls for granules we want to process
    granules = [item.getDownloadUrl() for item in result_chain]
    
    # Convert to just the h5 basenames (removing the s3 url)
    out_h5_list = [os.path.basename(x) for x in granules]
    
    print("\t\t# ATL08 for tile {}: {}".format(tile_num, len(out_h5_list)) )
    
    return(out_h5_list)

def reorder_4326_bounds(boreal_tile_index_path, test_tile_id, buffer, layer):
    
    tile_parts = get_index_tile(boreal_tile_index_path, test_tile_id, buffer=buffer, layer=layer)
    bounds_order = [0, 2, 1, 3]
    out_4326_bounds = [tile_parts['bbox_4326'][i] for i in bounds_order]
    
    return(out_4326_bounds)

def get_granules_list(granules):
    '''
    Function to get list of granules returned from maap.searchGranule()
    '''
    url_list = []
    output_list = []
    for res in granules:
        url_list.append(res.getDownloadUrl())

    for url in url_list:
        if url[0:5] == 's3://':
            url = url[5:].split('/')
            url[0] += '.s3.amazonaws.com'
            url = 'https://' + '/'.join(url)
        output_list.append(url)
    return output_list

def prep_filter_atl08_qual(atl08):
    '''
    Run this data prep on a df built from all CSVs from a DPS of extract_atl08.py for v003 of ATL08
    '''
    
    print("\nPre-filter data cleaning...")
    print("\nGet beam type from orbit orientation and ground track...") 
    atl08.loc[( (atl08.orb_orient == 1 ) & (atl08['gt'].str.contains('r')) ), "beam_type"] = 'Strong' 
    atl08.loc[( (atl08.orb_orient == 1 ) & (atl08['gt'].str.contains('l')) ), "beam_type"] = 'Weak'
    atl08.loc[( (atl08.orb_orient == 0 ) & (atl08['gt'].str.contains('r')) ), "beam_type"] = 'Weak'
    atl08.loc[( (atl08.orb_orient == 0 ) & (atl08['gt'].str.contains('l')) ), "beam_type"] = 'Strong'
    print(atl08.beam_type.unique())

    cols_float = ['lat', 'lon', 'h_can', 'h_te_best', 'ter_slp'] 
    print(f"Cast some columns to type float: {cols_float}")
    atl08[cols_float] = atl08[cols_float].apply(pd.to_numeric, errors='coerce')

    cols_int = ['n_ca_ph', 'n_seg_ph', 'n_toc_ph']
    print(f"Cast some columns to type integer: {cols_int}")
    atl08[cols_int] = atl08[cols_int].apply(pd.to_numeric, downcast='signed', errors='coerce')
    
    if False:
        cols_date = ['yr', 'm', 'd']
        clist = [c for c in atl08.columns[atl08.dtypes == object] if c in cols_date ]

        for c in clist:
            print('DEBUG prep')
            #Get rid of b strings and convert to int, then datetime
            atl08[c] = atl08[c].str.strip("b\'\"").astype(int)
    print('DEBUG prep')    
    #if set(cols_date).issubset(atl08.columns):
    #    print('DEBUG prep date')
    #    atl08["date"] = pd.to_datetime(atl08["yr"]*1000 + atl08["d"], format = "%Y%j")
    
    print("Returning a prepared dataframe.")   
    return(atl08)

def filter_atl08_bounds_tile_ept(in_ept_fn, in_tile_fn, in_tile_num, in_tile_layer, output_dir):
        '''Get bounds from a tile_id and apply to an EPT database
            Return a path to a GEOJSON that is a subset of the ATL08 db
        '''
        
        # Return the 4326 representation of the input <tile_id> geometry 
        tile_parts = get_index_tile(in_tile_fn, in_tile_num, buffer=0, layer = in_tile_layer)
        geom_4326 = tile_parts["geom_4326"]

        xmin, xmax = geom_4326[0:2]
        ymin, ymax = geom_4326[2:]
        transformer = Transformer.from_crs("EPSG:4326", "EPSG:3857", always_xy=True)
        xmin, ymax = transformer.transform(xmin, ymax)
        xmax, ymin = transformer.transform(xmax, ymin)
        pdal_tile_bounds = f"([{xmin}, {xmax}], [{ymin}, {ymax}])"

        # Spatial subset
        pipeline_def = [
            {
                "type": "readers.ept",
                "filename": in_ept_fn
            },
            {
                "type":"filters.crop",
                "bounds": pdal_tile_bounds
            },
            {
                "type" : "writers.text",
                "format": "geojson",
                "write_header": True
            }
        ]

        # Output the spatial subset as a geojson
        out_fn = os.path.join(output_dir, os.path.split(os.path.splitext(in_ept_fn)[0])[1] + "_" + in_tile_num + ".geojson")
        run_pipeline(pipeline_def, out_fn)
        
        return(out_fn)

def filter_atl08_bounds(atl08_df=None, in_bounds=None, in_ept_fn=None, in_tile_fn=None, in_tile_num=None, in_tile_layer=None, output_dir=None, return_pdf=False):
    '''
    Filter an ATL08 database using bounds.
    Bounds can come from an input vector tile or a list: [xmin,xmax,ymin,ymax]
    '''
    out_fn = None
    
    if all(v is not None for v in [in_ept_fn, in_tile_fn, in_tile_num, in_tile_layer, output_dir]):
        #
        out_fn = filter_atl08_bounds_tile_ept(in_ept_fn, in_tile_fn, in_tile_num, in_tile_layer, output_dir)
    elif in_bounds is not None and atl08_df is not None:
        
        print("Filtering by bounds: {}".format(in_bounds) )
        xmin = in_bounds[0]
        xmax = in_bounds[1]
        ymin = in_bounds[2]
        ymax = in_bounds[3]
        
        print("Returning a data frame")
        return_pdf = True
        
        atl08_df_prepd = prep_filter_atl08_qual(atl08_df)
        atl08_df = None
        
        if return_pdf :
            print('DEBUG atl08_df_prepd')
            atl08_df = atl08_df_prepd[(atl08_df_prepd.lat > float(ymin)) &
                                        (atl08_df_prepd.lat < float(ymax)) &
                                        (atl08_df_prepd.lon > float(xmin)) &
                                        (atl08_df_prepd.lon < float(xmax))
                               ]
            
    else:
        print("Missing input args; can't filter. Check call.")
        os._exit(1)
    
    if return_pdf:
        print("Filtered bounds, returning a pandas dataframe.")
        if out_fn is not None:
            atl08_df = gpd.read(out_fn)
        return(atl08_df)
    else:
        print(out_fn)
        return(out_fn)

def filter_atl08_qual(input_fn=None, subset_cols_list=['rh25','rh50','rh60','rh70','rh75','rh80','rh85','rh90','rh95','h_can','h_max_can'], filt_cols = ['h_can','h_dif_ref','m','msw_flg','beam_type','seg_snow'], thresh_h_can=None, thresh_h_dif=None, month_min=None, month_max=None, SUBSET_COLS=True):
    '''
    Quality filtering Function
    Returns a data frame
    Note: beams 1 & 5 strong (better radiometric perf, sensitive), then beam 3 [NOT IMPLEMENTED]
    '''
    # TODO: filt col names: make sure you have these in the EPT db
    
    if not subset_cols_list:
        print("filter_atl08: Must supply a list of strings matching ATL08 column names returned from the input EPT")
        os._exit(1) 
    elif thresh_h_can is None:
        print("filter_atl08: Must supply a threshold for h_can")
        os._exit(1)    
    elif thresh_h_dif is None:
        print("filter_atl08: Must supply a threshold for h_dif_ref")
        os._exit(1)
    elif month_min is None or month_max is None:
        print("filter_atl08: Must supply a month_min and month_max")
        os._exit(1)  
        
    if input_fn is not None:
        if not isinstance(input_fn, pd.DataFrame):
            if input_fn.endswith('geojson'):
                atl08_df = gpd.read(input_fn)
            elif input_fn.endswith('csv'):
                atl08_df = pd.read_csv(input_fn)
            else:
                print("Input filename must be a CSV, GEOJSON, or pd.DataFrame")
                os._exit(1)
        else:
            atl08_df = input_fn
            
    # Run the prep to get fields needed (v003)
    atl08_df_prepd = prep_filter_atl08_qual(atl08_df)
    atl08_df = None
    
    # Check that you have the cols that are required for the filter
    filt_cols_not_in_df = [col for col in filt_cols if col not in atl08_df_prepd.columns] 
    if len(filt_cols_not_in_df) > 0:
        print("These filter columns not found in input df: {}".format(filt_cols_not_in_df))
        os._exit(1)
    
    # Filtering
    #
    
    # Filter list (keep):
    #   h_ref_diff < thresh_h_dif
    #   h_can < thresh_h_can
    #   no LC forest masking: only forest LC classes no good b/c trees outside of forest aer of interest (woodlands, etc)
    #   msw = 0
    #   night better (but might exclude too much good summer data in the high northern lats)
    #   strong beam
    #   summer (june - mid sept)
    #   seg_snow == 'snow free land'
        
    print("\nFiltering for quality:\n\tfor clear skies + strong beam + snow free land,\n\th_can < {},\n\televation diff from ref < {},\n\tmonths {}-{}".format(thresh_h_can, thresh_h_dif, month_min, month_max))
    atl08_df_filt =  atl08_df_prepd[
                                (atl08_df_prepd.h_can < thresh_h_can) &
                                (atl08_df_prepd.h_dif_ref < thresh_h_dif) &
                                (atl08_df_prepd.m >= month_min ) & 
                                (atl08_df_prepd.m <= month_max) &
                                # Hard coded quality flags for ABoVE AGB
                                (atl08_df_prepd.msw_flg == 0) &
                                #(atl08_df.night_flg == 'night') & # might exclude too much good summer data in the high northern lats
                                (atl08_df_prepd.beam_type == 'Strong') & 
                                (atl08_df_prepd.seg_snow == 'snow free land')
                    ]
        
    print(f"Before quaity filtering: {atl08_df_prepd.shape[0]} observations in the input dataframe.")
    print(f"After quality filtering: {atl08_df_filt.shape[0]} observations in the output dataframe.")
    
    atl08_df_prepd = None
    
    if SUBSET_COLS:
        subset_cols_list = ['lon','lat'] + subset_cols_list
        print("Returning a pandas data frame of filtered observations for columns: {}".format(subset_cols_list))
        print(f"Shape: {atl08_df_filt[subset_cols_list].shape} ")
        return(atl08_df_filt[subset_cols_list])
    else:
        print("Returning a pandas data frame of filtered observations for all columns")
        return(atl08_df_filt)

In [3]:
maap_query = True
dps_dir = '/projects/jabba/dps_output/2.3_output'#'/projects/r2d2/dps_output/run_extract_atl08_orig_ubuntu'
output_dir = '/projects/jabba/data/out_tiles'
TEST = True
do_30m = True
extract_covars = False

in_tile_fn = '/projects/maap-users/alexdevseed/boreal_tiles.gpkg'
in_tile_layer = 'boreal_tiles_albers'

thresh_h_can = 100
thresh_h_dif = 100
month_min = 6
month_max = 9

date_start = '06-01'
date_end = '09-30'

# NA tiles
# Read the boreal tile index file

boreal_tile_index = gpd.read_file(in_tile_fn)
boreal_tile_index_subset = boreal_tile_index.to_crs(4326).cx[-170:-50, 50:75]

# Boreal NA tiles: need just a list of tile_ids
INPUT_TILE_NUM_LIST = boreal_tile_index_subset['layer'].astype(int).tolist()

INPUT_TILE_NUM_LIST[0:5]

if False:
    # TODO: Should not do glob.glob by tile
    # Get a list of all ATL08 CSV files from (from extract_atl08) (this will be a large boreal list)
    print("\tDPS dir to find ATL08 CSVs: {}".format(dps_dir))
    all_atl08_csvs = glob.glob(dps_dir + "/**/ATL08*.csv", recursive=True)


# Loop over input tiles to output ATL08 tile CSVs

In [None]:
%%time

#for in_tile_num in INPUT_TILE_NUM_LIST[100:105]:
for in_tile_num in [30542]:#, 30543, 30821, 30822, 30823]:
    
    # TODO: make this an arg
    years_list = [2019, 2020, 2021]
    
    seg_str = '_100m'
    if do_30m:
        seg_str = '_30m'
    if TEST:
        seg_str = ''
    
    if maap_query and dps_dir is not None:
        
        print("\nDoing MAAP query by tile bounds to find all intersecting ATL08 ")
        # Get a list of all ATL08 H5 granule names intersecting the tile (this will be a small list)
        # all_atl08_for_tile = ExtractUtils.get_h5_list() #<- when you get import to work, change back to this
        all_atl08_for_tile = get_h5_list(tile_num=in_tile_num, tile_fn=in_tile_fn, layer=in_tile_layer, DATE_START=date_start, DATE_END=date_end, YEARS=years_list)
        
        # Change the small ATL08 H5 granule names to match the output filenames from extract_atl08.py (eg, ATL08_*_30m.csv)
        all_atl08_csvs_for_tile_BASENAME = [os.path.basename(f).replace('.h5', seg_str+'.csv') for f in all_atl08_for_tile]
        ##print(all_atl08_csvs_for_tile_BASENAME)
        # Get a list of all ATL08 CSV files from (from extract_atl08) (this will be a large boreal list)
        print("\tDPS dir to find ATL08 CSVs: {}".format(dps_dir))
        all_atl08_csvs = glob.glob(dps_dir + "/**/ATL08*" + seg_str + ".csv", recursive=True)
        all_atl08_csvs_BASENAME = [os.path.basename(f) for f in all_atl08_csvs]
        print('Length of all_atl08_csvs: {}'.format(len(all_atl08_csvs)))
        # Get index of ATL08 in tile bounds from the large list of all ATL08 CSVs
        ###names = [name for i, name in enumerate(all_atl08_csvs_for_tile_BASENAME) if name in set(all_atl08_csvs_BASENAME)]
        ###print(names)
        idx = [i for i, name in enumerate(all_atl08_csvs_for_tile_BASENAME) if name in set(all_atl08_csvs_BASENAME)]
        # Get the subset of all ATL08 CSVs that just correspond to the ATL08 H5 intersecting the current tile
        all_atl08_h5_with_csvs_for_tile = [all_atl08_for_tile[x] for x in idx]       
        
        # Check to make sure these are in fact files (necessary?)
        all_atl08_csvs_NOT_FOUND = []
        all_atl08_csvs_FOUND = []
        for file in all_atl08_h5_with_csvs_for_tile:
            file = os.path.join(dps_dir, os.path.basename(file).replace('.h5',seg_str+'.csv'))       
            if not os.path.isfile(file):
                all_atl08_csvs_NOT_FOUND.append(file)
            else:
                all_atl08_csvs_FOUND.append(file)

        #all_atl08_csvs_FOUND = [x for x in all_atl08_h5_with_csvs_for_tile if x not in all_atl08_csvs_NOT_FOUND]
        print("\t# of ATL08 CSV found for tile {}: {}".format(in_tile_num, len(all_atl08_csvs_FOUND)))
        if len(all_atl08_csvs_FOUND) == 0:
            print('\tNo ATL08 extracted for this tile.')
            continue
        
        # Merge all ATL08 CSV files for the current tile into a pandas df
        print("Creating pandas data frame...")
        atl08 = pd.concat([pd.read_csv(f) for f in all_atl08_csvs_FOUND ], sort=False)
        
        print("\nFiltering by tile: {}".format(in_tile_num))
    
        # Get tile bounds as xmin,xmax,ymin,ymax
        in_bounds = reorder_4326_bounds(in_tile_fn, in_tile_num, buffer=0, layer=in_tile_layer)
        
        # Now filter ATL08 obs by tile bounds
        atl08 = filter_atl08_bounds(atl08_df=atl08, in_bounds=in_bounds)
        
    elif maap_query and dps_dir is None:
        print("\nNo DPS dir specified: cant get ATL08 CSV list to match with tile bound results from MAAP query.\n")
        os._exit(1)
    else:
        # Filter by bounds: EPT with a the bounds from an input tile
        atl08 = filter_atl08_bounds_tile_ept(in_ept_fn, in_tile_fn, in_tile_num, in_tile_layer, output_dir, return_pdf=True)
    
    ## Filter by quality: based on a standard filter_atl08_qual() function that we use across all notebooks, scripts, etc
    #atl08_pdf_filt = FilterUtils.filter_atl08_qual(atl08, out_cols_list)
    # Filter by quality
    print('DEBUG qual filter.')
    atl08_pdf_filt = filter_atl08_qual(atl08, SUBSET_COLS=True, 
                                                       subset_cols_list=['rh25','rh50','rh60','rh70','rh75','rh80','rh85','rh90','rh95','h_can','h_max_can'], 
                                                       filt_cols=['h_can','h_dif_ref','m','msw_flg','beam_type','seg_snow', 'seg_landcov'], 
                                                       thresh_h_can=100, thresh_h_dif=100, month_min=6, month_max=9)
    atl08=None
    
    # Convert to geopandas data frame in lat/lon
    atl08_gdf = GeoDataFrame(atl08_pdf_filt, geometry=gpd.points_from_xy(atl08_pdf_filt.lon, atl08_pdf_filt.lat), crs='epsg:4326')
    out_name_stem = "atl08_filt"
    atl08_pdf_filt=None
    
    if extract_covars:
        ### Below here should be re-worked to follow final chunk of nb 2.3 (6/15/2021)
        #
        
        # Extract topo covar values to ATL08 obs (doing a reproject to tile crs)
        # TODO: consider just running 3.1.5_dpy.py here to produce this topo stack right before extracting its values
        topo_covar_fn = do_3_1_5_dp.main(in_tile_fn=in_tile_fn, in_tile_num=in_tile_num, tile_buffer_m=120, in_tile_layer=in_tile_layer, topo_tile_fn='https://maap-ops-dataset.s3.amazonaws.com/maap-users/alexdevseed/dem30m_tiles.geojson')
        atl08_gdf_out = ExtractUtils.extract_value_gdf(topo_covar_fn, atl08_gdf, ["elevation","slope","tsri","tpi", "slopemask"], reproject=True)
        out_name_stem = out_name_stem + "_topo"

        # Extract landsat covar values to ATL08 obs
        # TODO: consider just running 3.1.2_dpy.py here
        landsat_covar_fn = do_3_1_2_dps.main(in_tile_fn=in_tile_fn, in_tile_num=in_tile_num, in_tile_layer=in_tile_layer, sat_api='https://landsatlook.usgs.gov/sat-api', local=args.local)
        atl08_gdf_out = ExtractUtils.extract_value_gdf(landsat_covar_fn, atl08_gdf_out, ['Blue', 'Green', 'Red', 'NIR', 'SWIR', 'NDVI', 'SAVI', 'MSAVI', 'NDMI', 'EVI', 'NBR', 'NBR2', 'TCB', 'TCG', 'TCW', 'ValidMask', 'Xgeo', 'Ygeo'], reproject=False)
        out_name_stem = out_name_stem + "_landsat"
        
    # CSV the file
    cur_time = time.strftime("%Y%m%d%H%M%S")
    out_csv_fn = os.path.join(output_dir, out_name_stem + "_" + cur_time + ".csv")
    atl08_gdf_out.to_csv(out_csv_fn,index=False, encoding="utf-8-sig")
    
    print("Wrote output csv of filtered ATL08 obs with topo and Landsat covariates for tile {}: {}".format(in_tile_num, out_csv_fn) )


Doing MAAP query by tile bounds to find all intersecting ATL08 
	TILE_NUM: 30542 (-117.10749852280769,50.78795362739066,-116.50936927974429,51.16389512140189)
		# ATL08 for tile 30542: 23
	DPS dir to find ATL08 CSVs: /projects/jabba/dps_output/2.3_output
Length of all_atl08_csvs: 6718
	# of ATL08 CSV found for tile 30542: 15
Creating pandas data frame...

Filtering by tile: 30542
Filtering by bounds: [-117.10749852280769, -116.50936927974429, 50.78795362739066, 51.16389512140189]
Returning a data frame

Pre-filter data cleaning...

Get beam type from orbit orientation and ground track...
['Weak' 'Strong']
Cast some columns to type float: ['lat', 'lon', 'h_can', 'h_te_best', 'ter_slp']
