In [None]:
import dask, concurrent.futures, time, warnings, os, re, pickle
from osgeo import gdal
import requests as r
import panel as pn
pn.extension()
import param as pm
import pandas as pd
from collections import OrderedDict as odict
import numpy as np
from dask.distributed import LocalCluster, Client
import xarray as xr
import hvplot.pandas
import hvplot.xarray
import warnings
warnings.filterwarnings('ignore')
from urllib.request import urlopen
from xml.etree.ElementTree import parse,fromstring
from affine import Affine
from pandas import to_datetime
import jinja2 as jj2
from rasterio.crs import CRS
from tempfile import NamedTemporaryFile
from datetime import datetime
from netrc import netrc
from subprocess import Popen
from pyproj import Proj
from src.hls_funcs.masks import mask_hls
from src.hls_funcs.predict import pred_cov, pred_bm, pred_bm_se, pred_bm_thresh
import cartopy.crs as ccrs
from bokeh.models.formatters import PrintfTickFormatter


In [None]:
#!/opt/conda/envs/py_geo/bin/pip3 install pysptools

In [None]:
# Create a LUT dict including the HLS product bands mapped to names
lut = {'HLSS30': 
       {'COASTAL-AEROSOL':'B01', 'BLUE':'B02', 'GREEN':'B03', 'RED':'B04', 
        'RED-EDGE1':'B05', 'RED-EDGE2':'B06', 'RED-EDGE3':'B07', 'NIR-Broad':'B08', 'NIR1':'B8A', 
        'WATER-VAPOR':'B09', 'CIRRUS':'B10', 'SWIR1':'B11', 'SWIR2':'B12', 'FMASK':'Fmask'},
       'HLSL30': 
       {'COASTAL-AEROSOL':'B01', 'BLUE':'B02', 'GREEN':'B03', 'RED':'B04', 
        'NIR1':'B05', 'SWIR1':'B06','SWIR2':'B07', 
        'CIRRUS':'B09', 'TIR1':'B10', 'TIR2':'B11', 'FMASK':'Fmask'}}

# List of all available/acceptable band names
all_bands = ['ALL', 'COASTAL-AEROSOL', 'BLUE', 'GREEN', 'RED', 'RED-EDGE1', 'RED-EDGE2', 'RED-EDGE3', 
             'NIR1', 'SWIR1', 'SWIR2', 'CIRRUS', 'TIR1', 'TIR2', 'WATER-VAPOR', 'FMASK']

d_bounds = (datetime(2019, 1, 1), datetime.now())

In [None]:
def open_hls_meta(stac_id):
    var_url = urlopen('https://cmr.earthdata.nasa.gov/search/concepts/'+stac_id)
    xmldoc = parse(var_url)
    res={stac_id:{}}
    for child in xmldoc.findall('.//AdditionalAttributes/AdditionalAttribute'):
        if child.find('Name').text in ['ULX',
                                       'ULY',
                                       'NROWS',
                                       'NCOLS',
                                       'SPATIAL_RESOLUTION',
                                       'HORIZONTAL_CS_CODE',
                                       'SENSING_TIME',
                                       'MGRS_TILE_ID',
                                       'CLOUD_COVERAGE',
                                       'FILLVALUE',
                                       'QA_FILLVALUE',
                                       'MEAN_SUN_AZIMUTH_ANGLE',
                                       'MEAN_SUN_ZENITH_ANGLE',
                                       'MEAN_VIEW_AZIMUTH_ANGLE',
                                       'MEAN_VIEW_ZENITH_ANGLE',
                                       'NBAR_SOLAR_ZENITH',
                                       'IDENTIFIER_PRODUCT_DOI',
                                       'IDENTIFIER_PRODUCT_DOI_AUTHORITY',
                                       'REF_SCALE_FACTOR',
                                       'ADD_OFFSET']:
            res[stac_id][child.find('Name').text]=child.find('.//Value').text
    return(res)

def build_xr(catalog,bands,is_aws):
    #Retreive Metadata using threads - not cpu bound, so works well
    l_meta={}
    with concurrent.futures.ThreadPoolExecutor(5) as executor:
        futures = []
        for stac in catalog:
            stac_id = stac['id']
            futures.append(executor.submit(open_hls_meta, stac_id=stac_id))
        for future in concurrent.futures.as_completed(futures):
            l_meta.update(future.result())
        for stac in catalog:
            stac_id = stac['id']
            for band in bands:
                if bool(re.search('S30', l_meta[stac_id]['IDENTIFIER_PRODUCT_DOI'])):
                        b = lut['HLSS30'][band]
                elif bool(re.search('L30', l_meta[stac_id]['IDENTIFIER_PRODUCT_DOI'])):
                        b = lut['HLSL30'][band]
                if is_aws:
                    l_meta[stac_id][b+'_url'] = stac['assets'][b]['href'].replace('https://lpdaac.earthdata.nasa.gov/lp-prod-protected', '/vsis3/lp-prod-protected')
                else:
                    l_meta[stac_id][b+'_url'] = '/vsicurl/'+stac['assets'][b]['href']#.replace('https://lpdaac.earthdata.nasa.gov/lp-prod-protected', '/vsis3/lp-prod-protected')
    
    #Setup template file for each raster. Use Jinja to quickly fill in metadata.
    vrt_template = jj2.Template('''
    <VRTDataset rasterXSize="{{rasterXSize}}" rasterYSize="{{rasterYSize}}">
      <SRS>{{SRS}}</SRS>
      <GeoTransform>{{GeoTransform}}</GeoTransform>
      <VRTRasterBand dataType="{{dtype}}" band="1">
        <NoDataValue>{{nodata}}</NoDataValue>
        <Scale>{{scale}}</Scale>
        <Metadata>
          <MDI key="obs_date">{{obs_date}}</MDI>
        </Metadata>
        <SimpleSource>
          <SourceFilename relativeToVRT="1">{{SourceFilename}}</SourceFilename>
          <SourceBand>1</SourceBand>
          <SourceProperties RasterXSize="{{rasterXSize}}" RasterYSize="{{rasterYSize}}" DataType="{{dtype}}" BlockXSize="1024" BlockYSize="1024" />
          <SrcRect xOff="0" yOff="0" xSize="{{rasterXSize}}" ySize="{{rasterYSize}}" />
          <DstRect xOff="0" yOff="0" xSize="{{rasterXSize}}" ySize="{{rasterYSize}}" />
          <NODATA>{{nodata}}</NODATA>
        </SimpleSource>
      </VRTRasterBand>
    </VRTDataset>
    ''')
    
    #Enumerate the stac catalog (by band) and create vrt objects in a dictionary (l_vrt)
    l_vrt={}
    for i, band in enumerate(bands):             
        l_tmp = []
        for k in l_meta.keys():
            item = l_meta[k]
            if bool(re.search('S30', item['IDENTIFIER_PRODUCT_DOI'])):
                b = lut['HLSS30'][band]
            elif bool(re.search('L30', item['IDENTIFIER_PRODUCT_DOI'])):
                b = lut['HLSL30'][band]  
            vrt = vrt_template.render(rasterXSize = int(item['NCOLS']),
                                      rasterYSize = int(item['NROWS']),
                                      SourceFilename = item[b+'_url'],
                                      SRS=CRS.from_epsg('32613').wkt,#CRS.from_epsg(item['HORIZONTAL_CS_CODE'].split(':')[1]).wkt,
                                      GeoTransform=item['ULX']+', '+item['SPATIAL_RESOLUTION']+', 0, '+item['ULY']+', 0, -'+item['SPATIAL_RESOLUTION'],
                                      band=band,
                                      obs_date=item['SENSING_TIME'],
                                      dtype='int16',
                                      nodata=item['FILLVALUE'],
                                      scale=item['REF_SCALE_FACTOR'])
            l_tmp.append(vrt)
        l_vrt[band] = l_tmp
    
    #Use GDAL to merge vrts from the same bands
    vrt_files={}
    for b in bands:
        with NamedTemporaryFile() as tmpfile:
            vrt_options = gdal.BuildVRTOptions(separate=True, bandList=[1])
            my_vrt = gdal.BuildVRT(tmpfile.name, l_vrt[b], options=vrt_options)
            my_vrt = None
            f = tmpfile.read().decode("utf-8")
            vrt_files[b]=f
    
    #Extract metadata and lazy-load a nd merge into single xarray object
    if type(bands) == str:
        bands = [bands]
    b_l = []
    for b in bands:
        v = vrt_files[b]
        xmldoc = fromstring(v)
        dates = []
        for child in xmldoc.findall('''.//VRTRasterBand/ComplexSource/SourceFilename'''):
            dates.append(to_datetime(child.text[child.text.find('obs_date')+10:
                                                child.text.find('obs_date')+36]).date())
        ds_tmp = xr.open_rasterio(v,chunks={'band':1,
                                            'x':'auto',
                                            'y':'auto'}).to_dataset(name=b)
        ds_tmp = ds_tmp.rename({'band':'time'})
        ds_tmp['time'] = to_datetime(dates)
        b_l.append(ds_tmp)
    
    ds = xr.merge(b_l)
    return(ds,vrt_files)


def NASA_CMR_STAC(hls_data):
    stac = 'https://cmr.earthdata.nasa.gov/stac/' # CMR-STAC API Endpoint
    stac_response = r.get(stac).json()            # Call the STAC API endpoint
    stac_lp = [s for s in stac_response['links'] if 'LP' in s['title']]  # Search for only LP-specific catalogs

    # LPCLOUD is the STAC catalog we will be using and exploring today
    lp_cloud = r.get([s for s in stac_lp if s['title'] == 'LPCLOUD'][0]['href']).json()
    lp_links = lp_cloud['links']
    lp_collections = [l['href'] for l in lp_links if l['rel'] == 'collections'][0]  # Set collections endpoint to variable
    lp_search = [l['href'] for l in lp_links if l['rel'] == 'search'][0]  # Define the search endpoint
    lim = 100
    search_query = f"{lp_search}?&limit={lim}"    # Add in a limit parameter to retrieve 100 items at a time.
    bbox_num=[-104.79107047,   40.78311181, -104.67687336,   40.87008987]
    bbox = f'{bbox_num[0]},{bbox_num[1]},{bbox_num[2]},{bbox_num[3]}'  # Defined from ROI bounds
    search_query2 = f"{search_query}&bbox={bbox}"                                                  # Add bbox to query
    date_time = hls_data['date_range'][0]+'/'+hls_data['date_range'][1]  # Define start time period / end time period
    search_query3 = f"{search_query2}&datetime={date_time}"  # Add to query that already includes bbox
    collections = r.get(search_query3).json()['features']    
    hls_collections = [c for c in collections if 'HLS' in c['collection']]
    s30_items = [h for h in hls_collections if h['collection'] == 'HLSS30.v1.5']  # Grab HLSS30 collection
    l30_items = [h for h in hls_collections if h['collection'] == 'HLSL30.v1.5']  # Grab HLSL30 collection
    
    # Combine the S30 ad L30 items:
    return(s30_items + l30_items)

def setup_netrc(creds,aws):
    urs = 'urs.earthdata.nasa.gov' 
    try:
        netrcDir = os.path.expanduser("~/.netrc")
        netrc(netrcDir).authenticators(urs)[0]
        del netrcDir

    # Below, create a netrc file and prompt user for NASA Earthdata Login Username and Password
    except FileNotFoundError:
        homeDir = os.path.expanduser("~")
        Popen('touch {0}.netrc | chmod og-rw {0}.netrc | echo machine {1} >> {0}.netrc'.format(homeDir + os.sep, urs), shell=True)
        Popen('echo login {} >> {}.netrc'.format(creds[0], homeDir + os.sep), shell=True)
        Popen('echo password {} >> {}.netrc'.format(creds[1], homeDir + os.sep), shell=True)
        del homeDir

    # Determine OS and edit netrc file if it exists but is not set up for NASA Earthdata Login
    except TypeError:
        homeDir = os.path.expanduser("~")
        Popen('echo machine {1} >> {0}.netrc'.format(homeDir + os.sep, urs), shell=True)
        Popen('echo login {} >> {}.netrc'.format(creds[0], homeDir + os.sep), shell=True)
        Popen('echo password {} >> {}.netrc'.format(creds[1], homeDir + os.sep), shell=True)
        del homeDir
    del urs
    if aws:
        return(r.get('https://lpdaac.earthdata.nasa.gov/s3credentials').json())
    else:
        return('')

def get_hls(creds,hls_data={},aws=True):
    #Seteup creds
    
    s3_cred = setup_netrc(creds,aws=aws)
    #define gdalenv
    if aws:
        
        env = dict(GDAL_DISABLE_READDIR_ON_OPEN='YES', 
                   #AWS_NO_SIGN_REQUEST='YES',
                   GDAL_MAX_RAW_BLOCK_CACHE_SIZE='200000000',
                   GDAL_SWATH_SIZE='200000000',
                   VSI_CURL_CACHE_SIZE='200000000',
                   CPL_VSIL_CURL_ALLOWED_EXTENSIONS='TIF',
                   GDAL_HTTP_UNSAFESSL='YES',
                   GDAL_HTTP_COOKIEFILE=os.path.expanduser('~/cookies.txt'),
                   GDAL_HTTP_COOKIEJAR=os.path.expanduser('~/cookies.txt'),
                   AWS_REGION='us-west-2',
                   AWS_SECRET_ACCESS_KEY=s3_cred['secretAccessKey'],
                   AWS_ACCESS_KEY_ID=s3_cred['accessKeyId'],
                   AWS_SESSION_TOKEN=s3_cred['sessionToken'])
    else:
        env = dict(GDAL_DISABLE_READDIR_ON_OPEN='EMPTY_DIR', 
                   AWS_NO_SIGN_REQUEST='YES',
                   GDAL_MAX_RAW_BLOCK_CACHE_SIZE='200000000',
                   GDAL_SWATH_SIZE='200000000',
                   VSI_CURL_CACHE_SIZE='200000000',
                   GDAL_HTTP_COOKIEFILE=os.path.expanduser('~/cookies.txt'),
                   GDAL_HTTP_COOKIEJAR=os.path.expanduser('~/cookies.txt'))


    os.environ.update(env)
    
    catalog = NASA_CMR_STAC(hls_data)
    da,vrt  = build_xr(catalog,  ['BLUE', 'GREEN', 'RED', 'NIR1', 'SWIR1', 'SWIR2', 'FMASK'],aws)
    return(da,env)

In [None]:
class HLS_BM_Explorer(pm.Parameterized):
    action = pm.Action(lambda x: x.param.trigger('action'), label='Load Data and Run Analysis')
    username_input = pn.widgets.PasswordInput(name='NASA Earthdata Login', placeholder='Enter Username...')
    password_input = pn.widgets.PasswordInput(name='', placeholder='Enter Password...')
    date_picker = pn.widgets.DatePicker(name='Calendar',
                                        value=datetime(2000,1,1).date(),
                                        enabled_dates = [datetime(2000,1,1).date(),datetime(2000,1,2).date()])
    d_range = pn.widgets.DateRangeSlider(name='',end=d_bounds[-1],start=d_bounds[0])
    thresh_picker = pn.widgets.IntSlider(name='Threshold', start=200, end=2000, step=25, value=500,
                                     format=PrintfTickFormatter(format='%d kg/ha'))

    def __init__(self, **params):
        super(HLS_BM_Explorer, self).__init__(**params)
        self.data = ''
        self.da = ''
        self.da_cov = ''
        self.da_bm = ''
        self.da_se = ''
        self.bm_mod = ''
    
    @pm.depends('action', watch=True)
    def access_data(self):
        if self.username_input.value != '':
            try:
                d_from = str(self.d_range.value[0].date())
                d_to = str(self.d_range.value[1].date())
                tmp_data, env = get_hls([self.username_input.value,self.password_input.value],
                                       hls_data={'date_range':[d_from,d_to]},aws=True)
                os.environ.update(env)
                with LocalCluster(threads_per_worker=1) as cluster, Client(cluster) as cl:
                    bbox_num=[-104.79107047,   40.78311181, -104.67687336,   40.87008987]
                    utmProj = Proj("+proj=utm +zone=13U, +north +ellps=WGS84 +datum=WGS84 +units=m +no_defs")
                    bbox_utm = utmProj([bbox_num[i] for i in [0, 2]], [bbox_num[i] for i in [3, 1]])
                    self.data = tmp_data.loc[dict(x=slice(*tuple(bbox_utm[0])), y=slice(*tuple(bbox_utm[1])))]

                self.date_picker.enabled_dates = [datetime.utcfromtimestamp(x).date() for
                                                  x in self.data.time.data.astype('int') * 1e-9]
                self.date_picker.value = datetime.utcfromtimestamp(self.data.time[-1].data.astype('int') * 1e-9).date()

                da = self.data
                da_mask = mask_hls(da['FMASK'])
                da = da.where(da_mask == 0)
                da = da.groupby('time').mean()
                self.da = da

                bm_mod = pickle.load(open('src/models/CPER_HLS_to_VOR_biomass_model_lr_simp.pk', 'rb'))
                da_bm = pred_bm(da, bm_mod)
                da_bm = da_bm.where(da_bm > 0)
                da_se = pred_bm_se(da, bm_mod)
                da_se = da_se.where(da_bm > 0)
                self.da_bm = da_bm
                self.da_se = da_se
                self.bm_mod = bm_mod
                
                ends_dict = {
                    'SD': {
                        'ndvi': 0.30,
                        'dfi': 16,
                        'bai_126': 155},
                    'GREEN': {
                        'ndvi': 0.55,
                        'dfi': 10,
                        'bai_126': 160},
                    'BARE': {
                        'ndvi': 0.10,
                        'dfi': 8,
                        'bai_126': 140}}
                da_cov = pred_cov(da, ends_dict)
                da_cov = da_cov.to_array(dim='type')
                da_cov = da_cov.where((da_cov < 1.0) | (da_cov.isnull()), 1.0)
                da_cov = da_cov.where(~(da_cov.any(dim='time').isnull()))
                self.da_cov = da_cov
                return('Success!')
            except:
                return('App Failure')
        else:
            return('Not Yet Launched')
        
    
    @pm.depends('date_picker.param')
    def load_cov(self):
        if self.da_cov is not '':
            cov_map = self.da_cov.sel(time=np.datetime64(self.date_picker.value)).hvplot.rgb(x='x',y='y', 
                                                                                            bands='type', 
                                                                                            tiles='EsriImagery', 
                                                                                            crs=ccrs.UTM(13),
                                                                                            sizing_mode='stretch_both',
                                                                                            width=350,
                                                                                            title='CV: '+str(self.date_picker.value))#.opts(height=int(500*1.5), width=int(750*1.5))
            return cov_map
        else:
            return('')
    
    @pm.depends('date_picker.param')
    def load_bm(self):
        if self.da_bm is not '':
            bm_map = self.da_bm.sel(time=np.datetime64(self.date_picker.value)).hvplot(x='x',y='y',
                                                                                       tiles='EsriImagery',
                                                                                       crs=ccrs.UTM(13),
                                                                                       cmap='inferno', 
                                                                                       clim=(100, 1000), 
                                                                                       colorbar=False,
                                                                                       sizing_mode='stretch_both',
                                                                                       width=350,
                                                                                       title='BM: '+str(self.date_picker.value))#.opts(height=int(500*1.5), width=int(750*1.5))
            return bm_map
        else:
            return('')
    
    @pm.depends('date_picker.param', 'thresh_picker.param')
    def load_thresh(self):
        if self.da_bm is not '':
            da_thresh = pred_bm_thresh(self.da_bm, self.da_se, self.thresh_picker.value)      
            thresh_map = da_thresh.sel(time=np.datetime64(self.date_picker.value)).hvplot(x='x', y='y', 
                                                                                          tiles='EsriImagery', 
                                                                                          crs=ccrs.UTM(13),
                                                                                          cmap='YlOrRd', 
                                                                                          clim=(0.05, 0.95),
                                                                                          colorbar=False,
                                                                                          data_aspect=0.6).opts(responsive=True,
                                                                                                                xticks=None,
                                                                                                                yticks=None)
            return thresh_map
        else:
            return('')
        
    @pm.depends('access_data')
    def showdata(self):
        return(pn.pane.HTML(self.data,sizing_mode='stretch_both',max_width=250))
app = HLS_BM_Explorer(name='Central Plains Experimental Range: HLS Biomass')

template_theme = pn.template.MaterialTemplate(title='Central Plains Experimental Range: HLS Biomass',
                                              logo='https://ltar.ars.usda.gov/wp-content/uploads/2018/10/usda_ltar_logo_header_v3.png',
                                              header_color='grey',header_background='#80b1ed')
template_theme.sidebar.append(pn.Column(app.username_input,
                                        app.password_input,
                                        app.d_range,
                                        pn.panel(app.param.action),
                                        sizing_mode='stretch_both',
                                        name="Download Options",
                                        css_classes=['panel-widget-box']))
template_theme.sidebar.append(pn.Column(app.date_picker, 
                                        app.thresh_picker,
                                        sizing_mode='stretch_both'))
template_theme.main.append(pn.Card(pn.Tabs(('Cover', app.load_cov),
                                           ('Biomass', app.load_bm),
                                           ('Biomass threshold', app.load_thresh)),
                                   title='Maps',
                                   sizing_mode='stretch_both',
                                   header_background='#80b1ed'))
template_theme.main.append(pn.Card(app.showdata,
                                   title='Xarray Object',
                                   sizing_mode='stretch_both',
                                   header_background='#80b1ed',
                                   collapsed=True))
#template_theme.show(port=9000)
template_theme.servable()