In [None]:
from dask.distributed import LocalCluster, Client
import xarray as xr
import dask
import intake
import os
#import fsspec, os, netrc, aiohttp,dask
from satsearch import Search
import hvplot.pandas
import hvplot.xarray
import warnings
warnings.filterwarnings('ignore')
import gdal
import requests
import concurrent.futures
from urllib.request import urlopen
from xml.etree.ElementTree import parse,fromstring
from affine import Affine
from pandas import to_datetime
import time
import jinja2 as jj2
from rasterio.crs import CRS
from tempfile import NamedTemporaryFile

In [None]:
# AUTHENTICATION CONFIGURATION
from netrc import netrc
from subprocess import Popen
from getpass import getpass

urs = 'urs.earthdata.nasa.gov'    # Earthdata URL to call for authentication
prompts = ['Enter NASA Earthdata Login Username \n(or create an account at urs.earthdata.nasa.gov): ',
           'Enter NASA Earthdata Login Password: ']

# Determine if netrc file exists, and if so, if it includes NASA Earthdata Login Credentials
try:
    netrcDir = os.path.expanduser("~/.netrc")
    netrc(netrcDir).authenticators(urs)[0]
    del netrcDir

# Below, create a netrc file and prompt user for NASA Earthdata Login Username and Password
except FileNotFoundError:
    homeDir = os.path.expanduser("~")
    Popen('touch {0}.netrc | chmod og-rw {0}.netrc | echo machine {1} >> {0}.netrc'.format(homeDir + os.sep, urs), shell=True)
    Popen('echo login {} >> {}.netrc'.format(getpass(prompt=prompts[0]), homeDir + os.sep), shell=True)
    Popen('echo password {} >> {}.netrc'.format(getpass(prompt=prompts[1]), homeDir + os.sep), shell=True)
    del homeDir

# Determine OS and edit netrc file if it exists but is not set up for NASA Earthdata Login
except TypeError:
    homeDir = os.path.expanduser("~")
    Popen('echo machine {1} >> {0}.netrc'.format(homeDir + os.sep, urs), shell=True)
    Popen('echo login {} >> {}.netrc'.format(getpass(prompt=prompts[0]), homeDir + os.sep), shell=True)
    Popen('echo password {} >> {}.netrc'.format(getpass(prompt=prompts[1]), homeDir + os.sep), shell=True)
    del homeDir
del urs, prompts

In [None]:
s3_cred = requests.get('https://lpdaac.earthdata.nasa.gov/s3credentials').json()
s3_cred

In [None]:
#Setup GDAL Env for optimum performance
env = dict(GDAL_DISABLE_READDIR_ON_OPEN='EMPTY_DIR', 
           AWS_NO_SIGN_REQUEST='YES',
           GDAL_MAX_RAW_BLOCK_CACHE_SIZE='200000000',
           GDAL_SWATH_SIZE='200000000',
           VSI_CURL_CACHE_SIZE='200000000',
           CPL_VSIL_CURL_ALLOWED_EXTENSIONS='TIF',
           GDAL_HTTP_UNSAFESSL='YES',
           GDAL_HTTP_COOKIEFILE=os.path.expanduser('~/cookies.txt'),
           GDAL_HTTP_COOKIEJAR=os.path.expanduser('~/cookies.txt'),
           AWS_REGION='us-west-2',
           AWS_SECRET_ACCESS_KEY=s3_cred['secretAccessKey'],
           AWS_ACCESS_KEY_ID=s3_cred['accessKeyId'],
           AWS_SESSION_TOKEN=s3_cred['sessionToken'])


os.environ.update(env)

In [None]:
dask.config.set({'distributed.dashboard.link':'http://localhost:8888/proxy/8787/status'})#'https://localhost:8787/status'})
cluster = LocalCluster(threads_per_worker=2)
client = Client(cluster)
client

In [None]:
url = 'https://cmr.earthdata.nasa.gov/stac/LPCLOUD' 
collection = 'C1711924822-LPCLOUD' #HLS
bbox=[-104.79107047,   40.78311181, -104.67687336,   40.87008987]
dates = '2020-01-01/2021-01-10'

cat = get_STAC_items(url,collection,dates,','.join(map(str, bbox)))

In [None]:
#Function to build metadata for Stac Catalog
def open_hls_meta(stac_id):
    var_url = urlopen('https://cmr.earthdata.nasa.gov/search/concepts/'+stac_id)
    xmldoc = parse(var_url)
    res={stac_id:{}}
    for child in xmldoc.findall('.//AdditionalAttributes/AdditionalAttribute'):
        if child.find('Name').text in ['ULX',
                                       'ULY',
                                       'NROWS',
                                       'NCOLS',
                                       'SPATIAL_RESOLUTION',
                                       'HORIZONTAL_CS_CODE',
                                       'SENSING_TIME',
                                       'MGRS_TILE_ID',
                                       'CLOUD_COVERAGE',
                                       'FILLVALUE',
                                       'QA_FILLVALUE',
                                       'MEAN_SUN_AZIMUTH_ANGLE',
                                       'MEAN_SUN_ZENITH_ANGLE',
                                       'MEAN_VIEW_AZIMUTH_ANGLE',
                                       'MEAN_VIEW_ZENITH_ANGLE',
                                       'NBAR_SOLAR_ZENITH',
                                       'IDENTIFIER_PRODUCT_DOI',
                                       'IDENTIFIER_PRODUCT_DOI_AUTHORITY',
                                       'REF_SCALE_FACTOR',
                                       'ADD_OFFSET']:
            res[stac_id][child.find('Name').text]=child.find('.//Value').text
    return(res)

#Convert STAC catalog to VRT file(s) (1 vrt file per band)
#Outputs as single xarray dataset object (dims = x,y,t ; variables = bands)
def build_xr(catalog,bands,):
    #Retreive Metadata using threads - not cpu bound, so works well
    l_meta={}
    t1 = time.time()
    with concurrent.futures.ThreadPoolExecutor(5) as executor:
        futures = []
        for stac_id in list(catalog):
            futures.append(executor.submit(open_hls_meta, stac_id=stac_id))
        for future in concurrent.futures.as_completed(futures):
            l_meta.update(future.result())
        for stac_id in list(catalog):
            for b in bands:
                l_meta[stac_id][b+'_url'] = catalog[stac_id][b].urlpath
    
    t2 = time.time()
    print('Get File Meta',t2-t1)
    
    #Setup template file for each raster. Use Jinja to quickly fill in metadata.
    vrt_template = jj2.Template('''
    <VRTDataset rasterXSize="{{rasterXSize}}" rasterYSize="{{rasterYSize}}">
      <SRS>{{SRS}}</SRS>
      <GeoTransform>{{GeoTransform}}</GeoTransform>
      <VRTRasterBand dataType="{{dtype}}" band="1">
        <NoDataValue>{{nodata}}</NoDataValue>
        <Scale>{{scale}}</Scale>
        <Metadata>
          <MDI key="obs_date">{{obs_date}}</MDI>
        </Metadata>
        <SimpleSource>
          <SourceFilename relativeToVRT="1">/vsicurl/{{SourceFilename}}</SourceFilename>
          <SourceBand>1</SourceBand>
          <SourceProperties RasterXSize="{{rasterXSize}}" RasterYSize="{{rasterYSize}}" DataType="{{dtype}}" BlockXSize="1024" BlockYSize="1024" />
          <SrcRect xOff="0" yOff="0" xSize="{{rasterXSize}}" ySize="{{rasterYSize}}" />
          <DstRect xOff="0" yOff="0" xSize="{{rasterXSize}}" ySize="{{rasterYSize}}" />
          <NODATA>{{nodata}}</NODATA>
        </SimpleSource>
      </VRTRasterBand>
    </VRTDataset>
    ''')
    
    #Enumerate the stac catalog (by band) and create vrt objects in a dictionary (l_vrt)
    l_vrt={}
    for i, b in enumerate(bands):
        l_tmp = []
        for k in l_meta.keys():
            item = l_meta[k]
            vrt = vrt_template.render(rasterXSize = int(item['NCOLS']),
                                      rasterYSize = int(item['NROWS']),
                                      SourceFilename = catalog[k][b].urlpath,
                                      SRS=CRS.from_epsg(item['HORIZONTAL_CS_CODE'].split(':')[1]).wkt,
                                      GeoTransform=item['ULX']+', '+item['SPATIAL_RESOLUTION']+', 0, '+item['ULY']+', 0, -'+item['SPATIAL_RESOLUTION'],
                                      band=b,
                                      obs_date=item['SENSING_TIME'],
                                      dtype='int16',
                                      nodata=item['FILLVALUE'],
                                      scale=item['REF_SCALE_FACTOR'])
            l_tmp.append(vrt)
        l_vrt[b] = l_tmp
    
    t3 = time.time()
    print('Build VRT',t3-t2)
    
    #Use GDAL to merge vrts from the same bands
    vrt_files={}
    for b in bands:
        with NamedTemporaryFile() as tmpfile:
            vrt_options = gdal.BuildVRTOptions(separate=True,bandList=[1])
            my_vrt = gdal.BuildVRT(tmpfile.name, l_vrt[b], options=vrt_options)
            my_vrt = None
            f = tmpfile.read().decode("utf-8")
            vrt_files[b]=f
    
    t4 = time.time()
    print('Build VRT',t4-t3)
    
    #Extract metadata and lazy-load a nd merge into single xarray object
    if type(bands) == str:
        bands = [bands]
    b_l = []
    for b in bands:
        v = vrt_files[b]
        xmldoc = fromstring(v)
        dates = []
        for child in xmldoc.findall('''.//VRTRasterBand/ComplexSource/SourceFilename'''):
            dates.append(to_datetime(child.text[child.text.find('obs_date')+10:
                                                child.text.find('obs_date')+36]).date())
        ds_tmp = xr.open_rasterio(v,chunks={'band':1,
                                            'x':'auto',
                                            'y':'auto'}).to_dataset(name=b)
        ds_tmp = ds_tmp.rename({'band':'t'})
        ds_tmp['t'] = dates
        b_l.append(ds_tmp)
    
    t5 = time.time()
    print('Create/Merge Xarray',t5-t4)
    
    ds = xr.merge(b_l)
    return(ds,vrt_files)

In [None]:
t1 = time.time()
da,vrts = build_xr(cat,['B8A', 'B03', 'B04', 'B12', 'Fmask'])
print('Completed IN:',time.time()-t1,'seconds')
da

In [None]:
da = da.rename(t='time')
da['time'] = da['time'].astype('datetime64[ns]')
da

In [None]:
from src.hls_funcs.masks import mask_hls
da_mask = mask_hls(da['Fmask'])
da_mask

In [None]:
from pyproj import Proj
utmProj = Proj("+proj=utm +zone=13U, +north +ellps=WGS84 +datum=WGS84 +units=m +no_defs")
bbox_utm = utmProj([bbox[i] for i in [0, 2]], [bbox[i] for i in [3, 1]]) 
tuple(bbox_utm[1])

In [None]:
da_sub = da.loc[dict(x=slice(*tuple(bbox_utm[0])), y=slice(*tuple(bbox_utm[1])))].where(da_mask == 0)
da_sub

In [None]:
da_stacked = da_sub.stack(z=('y', 'x')).chunk(dict(time=50, z=-1))
da_stacked

In [None]:
from src.hls_funcs.predict import pred_bm
import pickle
bm_mod = pickle.load(open('src/models/CPER_HLS_to_VOR_biomass_model_lr_simp.pk', 'rb'))
da_bm = pred_bm(da_stacked, bm_mod, dim='z')
da_bm

In [None]:
da_bm = da_bm.unstack('z').persist()
da_bm

In [None]:
import param
import panel as pn
import datetime as dt
def load_map(date):
    return da_bm.isel(time=date).hvplot(x='x',y='y',rasterize=True,tiles='EsriImagery', crs=ccrs.UTM(13),
                         cmap='inferno', clim=(100, 1000))

date_slider = pn.widgets.IntSlider(name='Date Slider',
                                    start=0, end=len(da_bm.time), value=0)

@pn.depends(date=date_slider.param.value)
def load_map_cb(date):
    return load_map(date)

pn.Row(pn.WidgetBox('Select date', date_slider), load_map_cb)