In [10]:
import dask, concurrent.futures, time, warnings, os, re, pickle
from osgeo import gdal
import requests as r
import panel as pn
pn.extension()
import param as pm
import pandas as pd
from collections import OrderedDict as odict
import numpy as np
from dask.distributed import LocalCluster, Client
import xarray as xr
import hvplot.pandas
import hvplot.xarray
import warnings
warnings.filterwarnings('ignore')
from urllib.request import urlopen
from xml.etree.ElementTree import parse,fromstring
from affine import Affine
from pandas import to_datetime
import jinja2 as jj2
from rasterio.crs import CRS
from tempfile import NamedTemporaryFile
from datetime import datetime
from netrc import netrc
from subprocess import Popen
from pyproj import Proj
from src.hls_funcs.masks import mask_hls
from src.hls_funcs.predict import pred_cov, pred_bm, pred_bm_se, pred_bm_thresh
import cartopy.crs as ccrs
from bokeh.models.formatters import PrintfTickFormatter
import stackstac
from subprocess import Popen, DEVNULL, STDOUT
from getpass import getpass
from sys import platform

In [11]:
# Create a LUT dict including the HLS product bands mapped to names
lut = {'HLSS30':
       {'B01': 'COASTAL-AEROSOL',
        'B02': 'BLUE', 
        'B03': 'GREEN', 
        'B04': 'RED', 
        'B05': 'RED-EDGE1',
        'B06': 'RED-EDGE2', 
        'B07': 'RED-EDGE3',
        'B08': 'NIR-Broad',
        'B8A': 'NIR1', 
        'B09': 'WATER-VAPOR',
        'B10': 'CIRRUS',
        'B11': 'SWIR1', 
        'B12': 'SWIR2', 
        'Fmask': 'FMASK'},
       'HLSL30': 
       {'B01': 'COASTAL-AEROSOL',
        'B02': 'BLUE', 
        'B03': 'GREEN', 
        'B04': 'RED', 
        'B05': 'NIR1',
        'B06': 'SWIR1',
        'B07': 'SWIR2', 
        'B09': 'CIRRUS', 
        'B10': 'TIR1', 
        'B11': 'TIR2', 
        'Fmask': 'FMASK'}}

# List of all available/acceptable band names
all_bands = ['ALL', 'COASTAL-AEROSOL', 'BLUE', 'GREEN', 'RED', 'RED-EDGE1', 'RED-EDGE2', 'RED-EDGE3', 
             'NIR1', 'SWIR1', 'SWIR2', 'CIRRUS', 'TIR1', 'TIR2', 'WATER-VAPOR', 'FMASK']

needed_bands = ['BLUE', 'GREEN', 'RED', 'NIR1', 'SWIR1', 'SWIR2', 'FMASK']

d_bounds = (datetime(2019, 1, 1), datetime.now())

In [12]:
def NASA_CMR_STAC(hls_data, aws):
    stac = 'https://cmr.earthdata.nasa.gov/stac/' # CMR-STAC API Endpoint
    stac_response = r.get(stac).json()            # Call the STAC API endpoint
    stac_lp = [s for s in stac_response['links'] if 'LP' in s['title']]  # Search for only LP-specific catalogs

    # LPCLOUD is the STAC catalog we will be using and exploring today
    lp_cloud = r.get([s for s in stac_lp if s['title'] == 'LPCLOUD'][0]['href']).json()
    lp_links = lp_cloud['links']
    lp_search = [l['href'] for l in lp_links if l['rel'] == 'search'][0]  # Define the search endpoint
    lim = 100
    search_query = f"{lp_search}?&limit={lim}"    # Add in a limit parameter to retrieve 100 items at a time.
    bbox_num=[-104.79107047,   40.78311181, -104.67687336,   40.87008987]
    bbox = f'{bbox_num[0]},{bbox_num[1]},{bbox_num[2]},{bbox_num[3]}'  # Defined from ROI bounds
    search_query2 = f"{search_query}&bbox={bbox}"                                                  # Add bbox to query
    date_time = hls_data['date_range'][0]+'/'+hls_data['date_range'][1]  # Define start time period / end time period
    search_query3 = f"{search_query2}&datetime={date_time}"  # Add to query that already includes bbox
    collections = r.get(search_query3).json()['features']    
    hls_collections = [c for c in collections if 'HLS' in c['collection']]
    s30_items = [h for h in hls_collections if h['collection'] == 'HLSS30.v1.5']  # Grab HLSS30 collection
    l30_items = [h for h in hls_collections if h['collection'] == 'HLSL30.v1.5']  # Grab HLSL30 collection
    
    if aws:
        for stac in s30_items:
            for band in stac['assets']:
                stac['assets'][band]['href'] = stac['assets'][band]['href'].replace('https://lpdaac.earthdata.nasa.gov/lp-prod-protected', 
                                                                                    '/vsis3/lp-prod-protected')
        for stac in l30_items:
            for band in stac['assets']:
                stac['assets'][band]['href'] = stac['assets'][band]['href'].replace('https://lpdaac.earthdata.nasa.gov/lp-prod-protected', 
                                                                                    '/vsis3/lp-prod-protected')
    return {'S30': s30_items,
            'L30': l30_items}

def setup_netrc(creds,aws):
    urs = 'urs.earthdata.nasa.gov' 
    try:
        netrcDir = os.path.expanduser("~/.netrc")
        netrc(netrcDir).authenticators(urs)[0]
        del netrcDir

    # Below, create a netrc file and prompt user for NASA Earthdata Login Username and Password
    except FileNotFoundError:
        homeDir = os.path.expanduser("~")
        Popen('touch {0}.netrc | chmod og-rw {0}.netrc | echo machine {1} >> {0}.netrc'.format(homeDir + os.sep, urs), shell=True)
        Popen('echo login {} >> {}.netrc'.format(creds[0], homeDir + os.sep), shell=True)
        Popen('echo password {} >> {}.netrc'.format(creds[1], homeDir + os.sep), shell=True)
        del homeDir

    # Determine OS and edit netrc file if it exists but is not set up for NASA Earthdata Login
    except TypeError:
        homeDir = os.path.expanduser("~")
        Popen('echo machine {1} >> {0}.netrc'.format(homeDir + os.sep, urs), shell=True)
        Popen('echo login {} >> {}.netrc'.format(creds[0], homeDir + os.sep), shell=True)
        Popen('echo password {} >> {}.netrc'.format(creds[1], homeDir + os.sep), shell=True)
        del homeDir
    del urs
    if aws:
        return(r.get('https://lpdaac.earthdata.nasa.gov/s3credentials').json())
    else:
        return('')

def build_xr(stac_dict):
    try:
        s30_stack = stackstac.stack(stac_dict['S30'], epsg=32613, resolution=30, assets=[i for i in lut['HLSS30'] if lut['HLSS30'][i] in needed_bands])
        s30_stack['band'] = [lut['HLSS30'][b] for b in s30_stack['band'].values]
        s30_stack['time'] = [datetime.fromtimestamp(t) for t in s30_stack.time.astype('int').values//1000000000]
        s30_stack = s30_stack.to_dataset(dim='band').reset_coords(['end_datetime', 'start_datetime'], drop=True)
    except ValueError:
        s30_stack = None
    try:
        l30_stack = stackstac.stack(stac_dict['L30'], epsg=32613, resolution=30, assets=[i for i in lut['HLSL30'] if lut['HLSL30'][i] in needed_bands])
        l30_stack['band'] = [lut['HLSL30'][b] for b in l30_stack['band'].values]
        l30_stack['time'] = [datetime.fromtimestamp(t) for t in l30_stack.time.astype('int').values//1000000000]
        l30_stack = l30_stack.to_dataset(dim='band').reset_coords(['name', 'end_datetime', 'start_datetime'], drop=True)
    except ValueError:
        l30_stack = None
    if s30_stack is not None and l30_stack is not None:
        hls_stack = xr.concat([s30_stack, l30_stack], dim='time')
    elif s30_stack is not None:
        hls_stack = s30_stack
    elif l30_stack is not None:
        hls_stack = l30_stack
    else:
        print('No data found for date range')
    return hls_stack.chunk({'time': 1, 'y': -1, 'x': -1})
    
def get_hls(creds, hls_data={}, aws=False):
    #Seteup creds
    
    s3_cred = setup_netrc(creds,aws=aws)
    #define gdalenv
    if aws:
        
        env = dict(GDAL_DISABLE_READDIR_ON_OPEN='FALSE', 
                   #AWS_NO_SIGN_REQUEST='YES',
                   GDAL_MAX_RAW_BLOCK_CACHE_SIZE='200000000',
                   GDAL_SWATH_SIZE='200000000',
                   VSI_CURL_CACHE_SIZE='200000000',
                   CPL_VSIL_CURL_ALLOWED_EXTENSIONS='TIF',
                   GDAL_HTTP_UNSAFESSL='YES',
                   GDAL_HTTP_COOKIEFILE=os.path.expanduser('~/cookies.txt'),
                   GDAL_HTTP_COOKIEJAR=os.path.expanduser('~/cookies.txt'),
                   AWS_REGION='us-west-2',
                   AWS_SECRET_ACCESS_KEY=s3_cred['secretAccessKey'],
                   AWS_ACCESS_KEY_ID=s3_cred['accessKeyId'],
                   AWS_SESSION_TOKEN=s3_cred['sessionToken'])
    else:
        env = dict(GDAL_DISABLE_READDIR_ON_OPEN='EMPTY_DIR', 
                   AWS_NO_SIGN_REQUEST='YES',
                   GDAL_MAX_RAW_BLOCK_CACHE_SIZE='200000000',
                   GDAL_SWATH_SIZE='200000000',
                   VSI_CURL_CACHE_SIZE='200000000',
                   GDAL_HTTP_COOKIEFILE=os.path.expanduser('~/cookies.txt'),
                   GDAL_HTTP_COOKIEJAR=os.path.expanduser('~/cookies.txt'))


    os.environ.update(env)
    
    catalog = NASA_CMR_STAC(hls_data, aws)
    da  = build_xr(catalog)
    return da

In [13]:
class HLS_BM_Explorer(pm.Parameterized):
    action = pm.Action(lambda x: x.param.trigger('action'), label='Load Data and Run Analysis')
    username_input = pn.widgets.PasswordInput(name='NASA Earthdata Login', placeholder='Enter Username...')
    password_input = pn.widgets.PasswordInput(name='', placeholder='Enter Password...')
    date_picker = pn.widgets.DatePicker(name='Calendar',
                                        value=datetime(2000,1,1).date(),
                                        enabled_dates = [datetime(2000,1,1).date(),datetime(2000,1,2).date()])
    d_range = pn.widgets.DateRangeSlider(name='',end=d_bounds[-1],start=d_bounds[0])
    thresh_picker = pn.widgets.IntSlider(name='Threshold', start=200, end=2000, step=25, value=500,
                                     format=PrintfTickFormatter(format='%d kg/ha'))

    def __init__(self, **params):
        super(HLS_BM_Explorer, self).__init__(**params)
        self.data = ''
        self.da = ''
        self.da_cov = ''
        self.da_cov_sel = ''
        self.da_bm = ''
        self.da_bm_sel = ''
        self.da_se = ''
        self.da_se_sel = ''
    
    @pm.depends('action')
    def access_data(self):
        message = 'Not yet launched'
        if self.username_input.value != '':
            try:
                message = 'button clicked'
                d_from = str(self.d_range.value[0].date())
                d_to = str(self.d_range.value[1].date())
                tmp_data = get_hls([self.username_input.value,self.password_input.value],
                                       hls_data={'date_range':[d_from,d_to]},aws=True)
                message = 'data querried'
                #os.environ.update(env)
                with LocalCluster(threads_per_worker=1) as cluster, Client(cluster) as cl:
                    bbox_num=[-104.79107047,   40.78311181, -104.67687336,   40.87008987]
                    utmProj = Proj("+proj=utm +zone=13U, +north +ellps=WGS84 +datum=WGS84 +units=m +no_defs")
                    bbox_utm = utmProj([bbox_num[i] for i in [0, 2]], [bbox_num[i] for i in [3, 1]])
                    self.data = tmp_data.loc[dict(x=slice(*tuple(bbox_utm[0])), y=slice(*tuple(bbox_utm[1])))]

                message = 'data loaded'
                self.date_picker.enabled_dates = [datetime.utcfromtimestamp(x).date() for
                                                  x in self.data.time.data.astype('int') * 1e-9]
                self.date_picker.value = datetime.utcfromtimestamp(self.data.time[-1].data.astype('int') * 1e-9).date()
                message = 'date picker set'
                da = self.data
                da['time'] = da.time.dt.floor("D")
                da = da.rename(dict(time='date'))
                da_mask = mask_hls(da['FMASK'])
                da = da.where(da_mask == 0)
                message = 'data masked'
                #da = da.groupby('date').mean()
                message = 'data averaged'
                self.da = da
                message = 'data reset'
                bm_mod = pickle.load(open('src/models/CPER_HLS_to_VOR_biomass_model_lr_simp.pk', 'rb'))
                da_bm = pred_bm(da, bm_mod)
                da_bm = da_bm.where(da_bm > 0)
                da_se = pred_bm_se(da, bm_mod)
                da_se = da_se.where(da_bm > 0)
                self.da_bm = da_bm
                self.da_se = da_se
                self.bm_mod = bm_mod
                message = 'bm calculated'

                cov_mod = pickle.load(open('src/models/CPER_HLS_to_LPI_cover_pls_binned_model.pk', 'rb'))
                da_cov = pred_cov(da, cov_mod)
                message = 'cover calculated'
                da_cov = da_cov[['SD', 'GREEN', 'BARE']].to_array(dim='type')
                message = 'cover converted'
                da_cov = da_cov.where((da_cov < 1.0) | (da_cov.isnull()), 1.0)
                da_cov = da_cov.where(~(da_cov.any(dim='date').isnull()))
                message = 'cover masked'
                da_cov = da_cov * 255 
                da_cov = da_cov.astype('uint8')
                message = 'conver integered'
                self.da_cov = da_cov
                message = 'cover loaded'
                message = 'Success!'
                return message
            except:
                return message + ': App Failure'
        else:
            return message
        
    
    @pm.depends('date_picker.param')
    def load_cov(self):
        if self.da_cov is not '':
            self.da_cov_sel = self.da_cov.sel(date=np.datetime64(self.date_picker.value))
            cov_map = self.da_cov_sel.hvplot.rgb(x='x',y='y', 
                                                                                            bands='type', 
                                                                                            tiles='EsriImagery', 
                                                                                            crs=ccrs.UTM(13),
                                                                                             data_aspect=0.6).opts(responsive=True,
                                                                                                                xticks=None,
                                                                                                                yticks=None)
            return cov_map
        else:
            return('')
    
    @pm.depends('date_picker.param')
    def load_bm(self):
        if self.da_bm is not '':
            bm_map = self.da_bm.sel(date=np.datetime64(self.date_picker.value)).hvplot(x='x',y='y',
                                                                                       tiles='EsriImagery',
                                                                                       crs=ccrs.UTM(13),
                                                                                       cmap='inferno', 
                                                                                       clim=(100, 1000), 
                                                                                       colorbar=False,
                                                                                       data_aspect=0.6).opts(responsive=True,
                                                                                                                xticks=None,
                                                                                                                yticks=None)
            return bm_map
        else:
            return('')
    
    @pm.depends('date_picker.param', 'thresh_picker.param')
    def load_thresh(self):
        if self.da_bm is not '':
            da_thresh = pred_bm_thresh(self.da_bm, self.da_se, self.thresh_picker.value)      
            thresh_map = da_thresh.sel(date=np.datetime64(self.date_picker.value)).hvplot(x='x', y='y', 
                                                                                          tiles='EsriImagery', 
                                                                                          crs=ccrs.UTM(13),
                                                                                          cmap='YlOrRd', 
                                                                                          clim=(0.05, 0.95),
                                                                                          colorbar=False,
                                                                                          data_aspect=0.6).opts(responsive=True,
                                                                                                                xticks=None,
                                                                                                                yticks=None)
            return thresh_map
        else:
            return('')
        
    @pm.depends('access_data')
    def showdata(self):
        return(pn.pane.HTML(self.data,sizing_mode='stretch_both',max_width=250))
    
    @pm.depends('thresh_picker.param')
    def showthresh(self):
        return(pn.pane.HTML(str(self.thresh_picker.value),sizing_mode='stretch_both',max_width=250))
    
app = HLS_BM_Explorer(name='Central Plains Experimental Range: HLS Biomass')

app_layout = pn.Column(pn.Column(app.username_input,
                                        app.password_input,
                                        app.d_range,
                                        pn.panel(app.param.action),
                                        sizing_mode='stretch_both',
                                        name="Download Options",
                                        css_classes=['panel-widget-box']),
                       pn.Column(app.date_picker, 
                                 app.thresh_picker,
                                 app.access_data,
                                 sizing_mode='stretch_both'),
                       pn.Column(pn.Tabs(('Cover', app.load_cov),
                                         ('Biomass', app.load_bm),
                                         ('Biomass threshold', app.load_thresh), sizing_mode='stretch_both'),
                                 sizing_mode='stretch_both'), sizing_mode='stretch_both')
#template_theme.show(port=9000)
app_layout.servable()
#pn.serve(app_layout)