In [10]:
import dask, concurrent.futures, time, warnings, os, re, pickle
from osgeo import gdal
import requests as r
import panel as pn
pn.extension()
import param as pm
import pandas as pd
from collections import OrderedDict as odict
import numpy as np
from dask.distributed import LocalCluster, Client
import xarray as xr
import hvplot.pandas
import hvplot.xarray
import warnings
warnings.filterwarnings('ignore')
from urllib.request import urlopen
from xml.etree.ElementTree import parse,fromstring
from affine import Affine
from pandas import to_datetime
import jinja2 as jj2
from rasterio.crs import CRS
from tempfile import NamedTemporaryFile
from datetime import datetime
from netrc import netrc
from subprocess import Popen
from pyproj import Proj
from src.hls_funcs.masks import mask_hls
from src.hls_funcs.predict import pred_cov, pred_bm, pred_bm_se, pred_bm_thresh
import cartopy.crs as ccrs
from bokeh.models.formatters import PrintfTickFormatter
import stackstac
from subprocess import Popen, DEVNULL, STDOUT
from getpass import getpass
from sys import platform

In [11]:
# Create a LUT dict including the HLS product bands mapped to names
lut = {'HLSS30':
       {'B01': 'COASTAL-AEROSOL',
        'B02': 'BLUE', 
        'B03': 'GREEN', 
        'B04': 'RED', 
        'B05': 'RED-EDGE1',
        'B06': 'RED-EDGE2', 
        'B07': 'RED-EDGE3',
        'B08': 'NIR-Broad',
        'B8A': 'NIR1', 
        'B09': 'WATER-VAPOR',
        'B10': 'CIRRUS',
        'B11': 'SWIR1', 
        'B12': 'SWIR2', 
        'Fmask': 'FMASK'},
       'HLSL30': 
       {'B01': 'COASTAL-AEROSOL',
        'B02': 'BLUE', 
        'B03': 'GREEN', 
        'B04': 'RED', 
        'B05': 'NIR1',
        'B06': 'SWIR1',
        'B07': 'SWIR2', 
        'B09': 'CIRRUS', 
        'B10': 'TIR1', 
        'B11': 'TIR2', 
        'Fmask': 'FMASK'}}

# List of all available/acceptable band names
all_bands = ['ALL', 'COASTAL-AEROSOL', 'BLUE', 'GREEN', 'RED', 'RED-EDGE1', 'RED-EDGE2', 'RED-EDGE3', 
             'NIR1', 'SWIR1', 'SWIR2', 'CIRRUS', 'TIR1', 'TIR2', 'WATER-VAPOR', 'FMASK']

needed_bands = ['BLUE', 'GREEN', 'RED', 'NIR1', 'SWIR1', 'SWIR2', 'FMASK']

In [12]:
def NASA_CMR_STAC(hls_data, aws):
    stac = 'https://cmr.earthdata.nasa.gov/stac/' # CMR-STAC API Endpoint
    stac_response = r.get(stac).json()            # Call the STAC API endpoint
    stac_lp = [s for s in stac_response['links'] if 'LP' in s['title']]  # Search for only LP-specific catalogs

    # LPCLOUD is the STAC catalog we will be using and exploring today
    lp_cloud = r.get([s for s in stac_lp if s['title'] == 'LPCLOUD'][0]['href']).json()
    lp_links = lp_cloud['links']
    lp_search = [l['href'] for l in lp_links if l['rel'] == 'search'][0]  # Define the search endpoint
    lim = 100
    search_query = f"{lp_search}?&limit={lim}"    # Add in a limit parameter to retrieve 100 items at a time.
    bbox_num=[-104.79107047,   40.78311181, -104.67687336,   40.87008987]
    bbox = f'{bbox_num[0]},{bbox_num[1]},{bbox_num[2]},{bbox_num[3]}'  # Defined from ROI bounds
    search_query2 = f"{search_query}&bbox={bbox}"                                                  # Add bbox to query
    date_time = hls_data['date_range'][0]+'/'+hls_data['date_range'][1]  # Define start time period / end time period
    search_query3 = f"{search_query2}&datetime={date_time}"  # Add to query that already includes bbox
    collections = r.get(search_query3).json()['features']    
    hls_collections = [c for c in collections if 'HLS' in c['collection']]
    s30_items = [h for h in hls_collections if h['collection'] == 'HLSS30.v1.5']  # Grab HLSS30 collection
    l30_items = [h for h in hls_collections if h['collection'] == 'HLSL30.v1.5']  # Grab HLSL30 collection
    
    if aws:
        for stac in s30_items:
            for band in stac['assets']:
                stac['assets'][band]['href'] = stac['assets'][band]['href'].replace('https://lpdaac.earthdata.nasa.gov/lp-prod-protected', 
                                                                                    '/vsis3/lp-prod-protected')
        for stac in l30_items:
            for band in stac['assets']:
                stac['assets'][band]['href'] = stac['assets'][band]['href'].replace('https://lpdaac.earthdata.nasa.gov/lp-prod-protected', 
                                                                                    '/vsis3/lp-prod-protected')
    return {'S30': s30_items,
            'L30': l30_items}

def setup_netrc(creds,aws):
    urs = 'urs.earthdata.nasa.gov' 
    try:
        netrcDir = os.path.expanduser("~/.netrc")
        netrc(netrcDir).authenticators(urs)[0]
        del netrcDir

    # Below, create a netrc file and prompt user for NASA Earthdata Login Username and Password
    except FileNotFoundError:
        homeDir = os.path.expanduser("~")
        Popen('touch {0}.netrc | chmod og-rw {0}.netrc | echo machine {1} >> {0}.netrc'.format(homeDir + os.sep, urs), shell=True)
        Popen('echo login {} >> {}.netrc'.format(creds[0], homeDir + os.sep), shell=True)
        Popen('echo password {} >> {}.netrc'.format(creds[1], homeDir + os.sep), shell=True)
        del homeDir

    # Determine OS and edit netrc file if it exists but is not set up for NASA Earthdata Login
    except TypeError:
        homeDir = os.path.expanduser("~")
        Popen('echo machine {1} >> {0}.netrc'.format(homeDir + os.sep, urs), shell=True)
        Popen('echo login {} >> {}.netrc'.format(creds[0], homeDir + os.sep), shell=True)
        Popen('echo password {} >> {}.netrc'.format(creds[1], homeDir + os.sep), shell=True)
        del homeDir
    del urs
    if aws:
        return(r.get('https://lpdaac.earthdata.nasa.gov/s3credentials').json())
    else:
        return('')

def build_xr(stac_dict):
    try:
        s30_stack = stackstac.stack(stac_dict['S30'], epsg=32613, resolution=30, assets=[i for i in lut['HLSS30'] if lut['HLSS30'][i] in needed_bands],
                                   chunksize=(4000, 4000))
        s30_stack['band'] = [lut['HLSS30'][b] for b in s30_stack['band'].values]
        s30_stack['time'] = [datetime.fromtimestamp(t) for t in s30_stack.time.astype('int').values//1000000000]
        s30_stack = s30_stack.to_dataset(dim='band').reset_coords(['end_datetime', 'start_datetime'], drop=True)
    except ValueError:
        s30_stack = None
    try:
        l30_stack = stackstac.stack(stac_dict['L30'], epsg=32613, resolution=30, assets=[i for i in lut['HLSL30'] if lut['HLSL30'][i] in needed_bands],
                                   chunksize=(4000, 4000))
        l30_stack['band'] = [lut['HLSL30'][b] for b in l30_stack['band'].values]
        l30_stack['time'] = [datetime.fromtimestamp(t) for t in l30_stack.time.astype('int').values//1000000000]
        l30_stack = l30_stack.to_dataset(dim='band').reset_coords(['name', 'end_datetime', 'start_datetime'], drop=True)
    except ValueError:
        l30_stack = None
    if s30_stack is not None and l30_stack is not None:
        hls_stack = xr.concat([s30_stack, l30_stack], dim='time')
    elif s30_stack is not None:
        hls_stack = s30_stack
    elif l30_stack is not None:
        hls_stack = l30_stack
    else:
        print('No data found for date range')
    return hls_stack.chunk({'time': 1, 'y': -1, 'x': -1})
    
def get_hls(creds, hls_data={}, aws=False):
    #Seteup creds
    
    s3_cred = setup_netrc(creds,aws=aws)
    #define gdalenv
    if aws:
        
        env = dict(GDAL_DISABLE_READDIR_ON_OPEN='FALSE', 
                   #AWS_NO_SIGN_REQUEST='YES',
                   GDAL_MAX_RAW_BLOCK_CACHE_SIZE='200000000',
                   GDAL_SWATH_SIZE='200000000',
                   VSI_CURL_CACHE_SIZE='200000000',
                   CPL_VSIL_CURL_ALLOWED_EXTENSIONS='TIF',
                   GDAL_HTTP_UNSAFESSL='YES',
                   GDAL_HTTP_COOKIEFILE=os.path.expanduser('~/cookies.txt'),
                   GDAL_HTTP_COOKIEJAR=os.path.expanduser('~/cookies.txt'),
                   AWS_REGION='us-west-2',
                   AWS_SECRET_ACCESS_KEY=s3_cred['secretAccessKey'],
                   AWS_ACCESS_KEY_ID=s3_cred['accessKeyId'],
                   AWS_SESSION_TOKEN=s3_cred['sessionToken'])
    else:
        env = dict(GDAL_DISABLE_READDIR_ON_OPEN='EMPTY_DIR', 
                   AWS_NO_SIGN_REQUEST='YES',
                   GDAL_MAX_RAW_BLOCK_CACHE_SIZE='200000000',
                   GDAL_SWATH_SIZE='200000000',
                   VSI_CURL_CACHE_SIZE='200000000',
                   GDAL_HTTP_COOKIEFILE=os.path.expanduser('~/cookies.txt'),
                   GDAL_HTTP_COOKIEJAR=os.path.expanduser('~/cookies.txt'))


    os.environ.update(env)
    
    catalog = NASA_CMR_STAC(hls_data, aws)
    da  = build_xr(catalog)
    return da

In [None]:
data_dict = {'date_range': [str(datetime(2021, 1, 15).date()), str(datetime.now().date())]}
tmp_data = get_hls(['spkearney', '1mrChamu'], hls_data=data_dict, aws=True)
bm_mod = pickle.load(open('src/models/CPER_HLS_to_VOR_biomass_model_lr_simp.pk', 'rb'))
da = tmp_data.loc[dict(x=slice(517587.0, 527283.0), y=slice(4524402.0, 4514699.0))]
da

In [None]:
da_mask = mask_hls(da['FMASK'])
da = da.where(da_mask == 0)
da_bm = pred_bm(da, bm_mod)
da_bm = da_bm.where(da_bm > 0)
da_se = pred_bm_se(da, bm_mod)
da_se = da_se.where(da_bm > 0)
t0 = time.time()
print(da_bm.isel(time=1).values)
t1 = time.time()
print(t1-t0)

In [None]:
cluster = LocalCluster(threads_per_worker=1)
cl = Client(cluster)
cl

In [None]:
t0 = time.time()
da_bm.isel(time=1).values
t1 = time.time()
print(t1-t0)


In [None]:
from cartopy import crs
import geoviews as gv
import holoviews as hv
from copy import deepcopy
from src.objects.charts import gauge_obj
from holoviews import streams
import affine
from src.hls_funcs.masks import shp2mask

dat = da
dat['time'] = dat.time.dt.floor("D")
dat = dat.rename(dict(time='date'))
bm_mod = pickle.load(open('src/models/CPER_HLS_to_VOR_biomass_model_lr_simp.pk', 'rb'))
cov_mod = pickle.load(open('src/models/CPER_HLS_to_LPI_cover_pls_binned_model.pk', 'rb'))

In [None]:


class App(pm.Parameterized):
    date_init = dat['date'][0].values
    thresh_init = 500
    bm_mean1 = 0

    da_bm = pred_bm(dat, bm_mod)
    da_bm.name = 'Biomass'
    da_bm = da_bm.where(da_bm > 0)
    da_bm_sel = da_bm.sel(date=date_init).persist()
    da_cov = pred_cov(dat, cov_mod)
    da_cov = da_cov[['SD', 'GREEN', 'BARE']].to_array(dim='type')
    da_cov = da_cov.where((da_cov < 1.0) | (da_cov.isnull()), 1.0)
    da_cov_sel = da_cov.sel(date=date_init).persist()
    da_se = pred_bm_se(dat, bm_mod)
    da_se = da_se.where(da_bm.notnull())
    da_thresh = pred_bm_thresh(da_bm, da_se, thresh_init)
    da_thresh.name = 'Threshold'
    da_thresh_sel = da_thresh.sel(date=date_init).persist()

    datCRS = crs.UTM(13)
    mapCRS = crs.GOOGLE_MERCATOR
    datProj = Proj(datCRS.proj4_init)
    mapProj = Proj(mapCRS.proj4_init)
    map_args = dict(crs=datCRS, rasterize=True, project=False, dynamic=True)
    map_opts = dict(projection=mapCRS, responsive=False, xticks=None, yticks=None, width=900, height=700,
                         padding=0, tools=['pan', 'wheel_zoom', 'box_zoom'],
                         active_tools=['pan', 'wheel_zoom'], toolbar='left')
    poly_opts = dict(fill_color=['', ''], fill_alpha=[0.0, 0.0], line_color=['#1b9e77', '#d95f02'],
                     line_width=[3, 3])
    gauge_opts = dict(height=200, width=300)

    bg_col='#ffffff'

    css = '''
    .bk.box1 {
      background: #ffffff;
      border-radius: 5px;
      border: 1px black solid;
    }
    '''
    pn.extension(raw_css=[css])

    basemap = pm.ObjectSelector(default="Satellite", objects=["Satellite", "Map"])
    alpha = pm.Number(default=1.0)
    date = pm.CalendarDate(default=pd.Timestamp(date_init).to_pydatetime())
    thresh = pm.Integer(default=thresh_init)
    action = pm.Action(lambda self: self.param.trigger('action'), 'Compute')

    #action = pn.widgets.Button(name='Save regions and \ncompute stats',
    #                           width=200)
    #action.on_click(lambda self: self.param.trigger('action'))

    #action_val = action


    cov_map = da_cov_sel.hvplot.rgb(x='x', y='y', bands='type',
                                              **map_args).opts(**map_opts)
    bm_map = da_bm_sel.hvplot(x='x', y='y',
                                        cmap='Viridis', clim=(100, 1000), colorbar=False,
                                        **map_args).opts(**map_opts)
    thresh_map = da_thresh_sel.hvplot(x='x', y='y',
                                                cmap='YlOrRd', clim=(0.05, 0.95), colorbar=False,
                                                **map_args).opts(**map_opts)

    tiles = gv.tile_sources.EsriImagery.opts(projection=mapCRS, backend='bokeh')

    polys = hv.Polygons([])

    def __init__(self, **params):
        super(App, self).__init__(**params)

        self.gauge_obj = deepcopy(gauge_obj)

        self.poly_stream = streams.PolyDraw(source=self.polys, drag=True, num_objects=2,
                                       show_vertices=True, styles=self.poly_opts)
        self.edit_stream = streams.PolyEdit(source=self.polys, shared=True)
        self.select_stream = streams.Selection1D(source=self.polys)

        self.startX, self.endX = (float(self.da_bm['x'].min().values), float(self.da_bm['x'].max().values))
        self.startY, self.endY = (float(self.da_bm['y'].min().values), float(self.da_bm['y'].max().values))
        self.cov_stats = ''
        self.bm_stats = ''
        self.thresh_stats = ''
        #self.stats_dict = {0: self.cov_stats,
        #         1: self.bm_stats,
        #         2: self.thresh_stats}
        #self.stats = self.stats_dict[self.active_tab]
        #self.text2.jscallback(args={'gauge1': self.gauge_pane1}, code="""
        #gauge1.data.series[0].data[0].value = cb_obj.value
        #gauge1.properties.data.change.emit()
        #""")
        self.all_maps = pn.Tabs(('Cover', pn.Row(self.tiles * self.cov_map * self.polys, self.cov_stats)),
                           ('Biomass', pn.Row(self.tiles * self.bm_map * self.polys, self.bm_stats)),
                           ('Threshold', pn.Row(self.tiles * self.thresh_map * self.polys, self.thresh_stats)))
        self.active_tab = self.all_maps.active

    def keep_zoom(self, x_range, y_range):
        map_x_range, map_y_range = self.mapProj(x_range, y_range, inverse=True)
        (self.startX, self.endX), (self.startY, self.endY) = self.datProj(map_x_range, map_y_range, inverse=False)

    @pm.depends('basemap')
    def map_base(self):
        if self.basemap == "Satellite":
            self.tiles = gv.tile_sources.EsriImagery(projection=self.mapCRS, backend='bokeh')
        elif self.basemap == "Map":
            self.tiles = gv.tile_sources.Wikipedia(projection=self.mapCRS, backend='bokeh')
        return self.tiles

    @pm.depends('date', watch=True)
    def bm_date(self):
        self.date_init = np.datetime64(self.date)
        self.da_bm_sel = self.da_bm.sel(date=self.date_init).persist()
        self.bm_map = self.da_bm_sel.hvplot(x='x', y='y',
                                                               #xlim=(self.startX, self.endX),
                                                               #ylim=(self.startY, self.endY),
                                                               cmap='Viridis', clim=(100, 1000),
                                                               colorbar=False,
                                                               **self.map_args).opts(alpha=self.alpha,
                                                                                     **self.map_opts)
        self.bm_map.streams[-1].add_subscriber(self.keep_zoom)
        return self.bm_map

    @pm.depends('date', watch=True)
    def cov_date(self):
        self.date_init = np.datetime64(self.date)
        self.da_cov_sel = self.da_cov.sel(date=self.date_init).persist()
        self.cov_map = self.da_cov_sel.hvplot.rgb(x='x', y='y',
                                                    #xlim=(self.startX, self.endX),
                                                    #ylim=(self.startY, self.endY),
                                                    bands='type',
                                                    **self.map_args).opts(alpha=self.alpha,
                                                                          **self.map_opts)
        return self.cov_map

    @pm.depends('date', 'thresh', watch=True)
    def thresh_date(self):
        self.date_init = np.datetime64(self.date)
        self.thresh_init = self.thresh
        self.da_thresh = pred_bm_thresh(self.da_bm, self.da_se, self.thresh_init)
        self.da_thresh_sel = self.da_thresh.sel(date=self.date_init).persist()
        self.thresh_map = self.da_thresh_sel.hvplot(x='x', y='y',
                                           #xlim=(self.startX, self.endX),
                                           #ylim=(self.startY, self.endY),
                                           cmap='YlOrRd', clim=(0.05, 0.95), colorbar=False,
                                           **self.map_args).opts(alpha=self.alpha,
                                                                 **self.map_opts)
        self.bm_map.streams[-1].add_subscriber(self.keep_zoom)
        return self.thresh_map

    @pm.depends('alpha', watch=True)
    def cov_alpha(self):
        return self.cov_map.opts(alpha=self.alpha,
                                 xlim=self.bm_map.streams[-1].x_range,
                                 ylim=self.bm_map.streams[-1].y_range,
                               **self.map_opts)
    @pm.depends('alpha', watch=True)
    def bm_alpha(self):
        self.bm_map = self.bm_map.opts(alpha=self.alpha,
                                 xlim=self.bm_map.streams[-1].x_range,
                                 ylim=self.bm_map.streams[-1].y_range,
                               **self.map_opts)
        self.bm_map.streams[-1].add_subscriber(self.keep_zoom)
        return self.bm_map

    @pm.depends('alpha', watch=True)
    def thresh_alpha(self):
        return self.thresh_map.opts(alpha=self.alpha,
                                 xlim=self.bm_map.streams[-1].x_range,
                                 ylim=self.bm_map.streams[-1].y_range,
                               **self.map_opts)

    @pm.depends('action', watch=True)
    def show_hist(self):
        if self.poly_stream.data is None:
            self.cov_stats = ''
            self.bm_stats = ''
            self.thresh_stats = ''
        else:
            self.bm_stats = 'Yes'
            thresh_list = []
            bm_list = []
            cov_list = []
            ts_yr_list = []
            ts_avg_list = []
            for idx, ps_c in enumerate(self.poly_stream.data['line_color']):
                xs_map, ys_map = self.mapProj(self.poly_stream.data['xs'][idx],
                                              self.poly_stream.data['ys'][idx], inverse=True)
                xs_dat, ys_dat = self.datProj(xs_map, ys_map, inverse=False)
                geometries = {
                    "type": "Polygon",
                    "coordinates": [
                        list(map(list, zip(xs_dat, ys_dat)))
                    ]
                }
                ta = affine.Affine(30.0, 0.0, float(self.da_bm_sel['x'].min()),
                                   0.0, -30.0, float(self.da_bm_sel['y'].max()))
                poly_mask = shp2mask([geometries], self.da_bm_sel,
                                     transform=ta, outshape=self.da_bm_sel.shape, default_value=1)
                da_bm_tmp = self.da_bm_sel.where(poly_mask == 1)
                bm_hist_tmp = da_bm_tmp.hvplot.hist('Biomass', xlim=(0, 2000),
                                                    bins=np.arange(0, 10000, 20))\
                    .opts(height=200, width=300, fill_color=ps_c, fill_alpha=0.6,
                          line_color='black', line_width=0.5, line_alpha=0.6,
                          bgcolor=self.bg_col).options(toolbar=None)
                markdown = pn.pane.Markdown('## Region stats', height=50,
                                            style={'font-family': "serif",
                                                   'color': ps_c})
                thresh_pct = round(float(da_bm_tmp.where(da_bm_tmp < self.thresh_init).count())/
                                   float(da_bm_tmp.count()) * 100, 0)
                thresh_text = pn.pane.Markdown(f'**{thresh_pct}%** of the region is estimated to have biomass ' +
                                               f'less than {self.thresh_init} kg/ha.',
                                               style={'font-family': "Helvetica"})
                thresh_list.append(pn.Column(pn.Row(pn.layout.HSpacer(), markdown, pn.layout.HSpacer()),
                                             bm_hist_tmp * hv.VLine(x=self.thresh_init).opts(line_color='black'),
                                             thresh_text,
                                             css_classes=['box1'], margin=5))
                bm_gauge_obj = deepcopy(self.gauge_obj)
                bm_gauge_obj['series'][0]['data'][0]['value'] = int(da_bm_tmp.mean().values)
                bm_gauge_pane = pn.pane.ECharts(bm_gauge_obj, **self.gauge_opts)
                bm_list.append(pn.Column(pn.Row(pn.layout.HSpacer(),markdown, pn.layout.HSpacer()),
                                         bm_gauge_pane,
                                         css_classes=['box1'], margin=5))
                #yr = int(self.da_bm_sel.YEAR.values)
                #ts_bm_yr_tmp = self.da_bm.where(poly_mask == 1).sel(date=slice(datetime(yr, 5, 1),
                #                                                                datetime(yr, 10, 31)))
                #ts_yr_list.append(ts_bm_yr_tmp.mean(dim=['x', 'y']).hvplot.line(x='date',
                #                                                               y='Biomass').opts(line_color=ps_c))
                #ts_bm_avg_tmp = self.da_bm.where(poly_mask == 1).groupby(da_bm.date.dt.dayofyear).mean(dim=['x', 'y'])
                #ts_avg_list.append(ts_bm_avg_tmp.hvplot.line(x='date', y='Biomass').opts(line_color=ps_c))
                da_cov_tmp = self.da_cov_sel.where(poly_mask == 1)
                cov_factors = list(da_cov_tmp.type.values)
                cov_vals = [round(float(da_cov_tmp.sel(type=f).mean().values), 2) for f in cov_factors]
                from bokeh.models import NumeralTickFormatter
                pct_fmt = NumeralTickFormatter(format="0%")
                cov_colors = hv.Cycle(['red', 'green', 'blue'])
                cov_scatter_tmp = hv.Overlay([hv.Scatter(f) for f in list(zip(cov_factors, cov_vals))]) \
                    .options({'Scatter': dict(xformatter=pct_fmt,
                                              size=15,
                                              fill_color=cov_colors,
                                              line_color=cov_colors,
                                              ylim=(0, 1))})
                cov_spike_tmp = hv.Overlay([hv.Spikes(f) for f in cov_scatter_tmp])\
                    .options({'Spikes': dict(color=cov_colors, line_width=4,
                                             labelled=[], invert_axes=True, color_index=None,
                                             ylim=(0, 1))})
                cov_list.append(pn.Column(pn.Row(pn.layout.HSpacer(), markdown, pn.layout.HSpacer()),
                                          (cov_spike_tmp * cov_scatter_tmp).options(height=200,
                                                                                    width=300,
                                                                                    bgcolor=self.bg_col,
                                                                                    toolbar=None),
                                          css_classes=['box1'], margin=5))


            self.polys=self.poly_stream.element.opts(xlim=(self.startX, self.endX),
                                                     ylim=(self.startY, self.endY))
            self.poly_stream = streams.PolyDraw(source=self.polys, drag=True, num_objects=2,
                                           show_vertices=True, styles=self.poly_opts)
            self.edit_stream = streams.PolyEdit(source=self.polys, shared=True)
            #self.gauge_pane1.object['series'][0]['data'][0]['value'] = int(da_bm_tmp.mean().values)
            #self.gauge_pane2.object['series'][0]['data'][0]['value'] = int(da_bm_tmp.mean().values) + 100
            self.thresh_stats = pn.Column(*thresh_list)
            self.bm_stats = pn.Column(*bm_list)
            self.cov_stats = pn.Column(*cov_list)


    def view_all(self):
        #self.da_bm_sel.name = 'Biomass'
        self.active_tab = self.all_maps.active
        self.bm_map.streams[-1].add_subscriber(self.keep_zoom)
        #self.stats = self.stats_dict[self.active_tab]
        base = hv.DynamicMap(self.map_base)
        cov = self.cov_map
        bm = self.bm_map
        thresh = self.thresh_map
        self.all_maps = pn.Tabs(('Cover', pn.Row(base * cov * self.polys, self.cov_stats)),
                                ('Biomass', pn.Row(base * bm * self.polys, self.bm_stats)),
                                ('Threshold', pn.Row(base * thresh * self.polys, self.thresh_stats)),
                                active=self.active_tab)
        return pn.Column(self.all_maps)



viewer = App()
layout = pn.Row(pn.Column(pn.Param(viewer.param,
                                      widgets={'date': pn.widgets.DatePicker(name='Calendar',
                                                                             #value=datetime(2021, 1, 1).date(),
                                                                             enabled_dates = [pd.Timestamp(x).to_pydatetime().date() for x in dat['date'].values],
                                                                             width=200),
                                               'alpha': pn.widgets.FloatSlider(name='Map transparency',
                                                                               value=1.0,
                                                                               start=0.0, end=1.0,
                                                                               step=0.1,
                                                                               width=200),
                                               'thresh': pn.widgets.IntSlider(name='Threshold',
                                                                              start=200, end=2000,
                                                                              step=25, value=500,
                                                                              format=PrintfTickFormatter(
                                                                                  format='%d kg/ha'),
                                                                              width=200),
                                               'basemap': pn.widgets.Select(name="Change basemap",
                                                                            options=["Satellite", "Map"],
                                                                            width=200),
                                               'action': pn.widgets.Button(name='Save regions and \ncompute stats',
                                                                           width=200)
                                               })),
                viewer.view_all)
layout.servable()