# Head

In [1]:
import os
import json
import warnings
from math import ceil
from io import StringIO
from datetime import datetime

import requests
import numpy as np
import pandas as pd
import xarray as xr
from shapely.geometry import shape

import matplotlib.pyplot as plt
from matplotlib import cm, colors
from IPython.display import display

import ipyleaflet as mwg
from ipyleaflet import Map, LayerGroup, GeoJSON, CircleMarker
from ipywidgets import Layout, Button, IntProgress, Output, HBox, VBox, HTML, interactive


# ----------------------------------------------------------------------------
# app settings

#font = {"family": "normal", "weight": "normal", "size": 16}  # matplotlib ->
#plt.rcParams['figure.figsize'] = [14, 5]
#plt.rc("font", **font) # <- some matplotlib settings

warnings.filterwarnings('ignore')
auth = dict(ORNL_DAAC_USER_NUM=str(32863))

smvdownload = "https://daac.ornl.gov/cgi-bin/viz/download.pl?"
smvdatasets = pd.read_csv(
    "docs/smvdatasets.csv", 
    index_col="dataset", 
    header=0)

ignore_variables = [
    "sample","time","stat","lat","lon","FLUXNET_surface","FLUXNET_rootzone"]

# usfs shapefile fields; only applies to Yaxing's workshop -->>
fields = [
    "FORESTNUMB","DISTRICTNU","REGION","GIS_ACRES","MIN","MEDIAN","MAX",
    "RANGE","SUM","VARIETY","MINORITY","MAJORITY","COUNT"]

site_details = """
REGION:   {REGION}
ACRES:    {GIS_ACRES}
MIN:      {MIN}
MEDIAN:   {MEDIAN}
MAX:      {MAX}
RANGE:    {RANGE}
SUM:      {SUM}
VARIETY:  {VARIETY}
MINORITY: {MINORITY}
MAJORITY: {MAJORITY}
COUNT:    {COUNT}
"""
#{FORESTNAME} ({FORESTNUMB})
#{DISTRICTNA} ({DISTRICTNU})

# ----------------------------------------------------------------------------
# widget settings

bmap = mwg.basemap_to_tiles(mwg.basemaps.Esri.WorldImagery)    # map widget 
map_args = dict(
    center=(32.75, -109), 
    zoom=7, 
    scroll_wheel_zoom=True)

submit_args = dict(                                            # submit button
    description="Submit", 
    disabled=True, 
    button_style="success")

progress_args = dict(                                          # progress bar
    description="Progress: ", 
    layout=Layout(width="95%"))

# ----------------------------------------------------------------------------
# functions that don't support the all-in-one application


def poly_mapper(lyr, mw="25%", ow="75%", zoom=8):
    """Generates the map/plot side-by-side widget container."""

    l = (lyr["layer"].layer, lyr.points, bmap,)
    m = Map(layers=l, center=(lyr["lat"], lyr["lon"]), zoom=zoom, width=mw)
    o = Output(layout={"width": ow})
    
    return((m,o))


# Other helpers

In [2]:
# ----------------------------------------------------------------------------
# other helpers

figure_args = dict(ncols=1, sharex=True, figsize=(10,5))
bar_args = dict(stacked=True, colormap="tab20c", legend=False)
legend_ncols = lambda df: ceil(len(df.index)/10)   
download_msg = """<p style="text-align:center;">
Click <b>Submit</b> to download Soil Moisture Visualizer data for this site.
<br></p>"""


def get_ancillary_data(geojson):
    """ """
    features, cols = from_geojson(geojson)            # get features and cols

    layers = []                                       # a temporary list 
    for i, feat in enumerate(features):               # loop over features
        lyrd = get_layer_data(i, feat, opac=0, samp=False)
        lyrm = GeoJSON(
            data=lyrd[0], 
            hover_style={"color": cols[i], "fillOpacity": 0.1})
        layers.append((i, lyrd[2], lyrd[3], lyrm, lyrd[5]))
    layers = pd.DataFrame(layers, columns=["id","lat","lon","layer","attrs"])

    return(layers)


def get_nan_summary(xrdataset): 
    """ """
    nandict = {"in situ": {}, "airborne": {}, "spaceborne": {}}

    for pt in nandict.keys():

        # for percentage: ----------------------------------------------------
        """# get the datasets for the current platform
        ds = xrds.filter_by_attrs(type=pt).sel(stat="Mean", drop=True)

        # get fraction of null dataset by timestep for all samples
        dsnull = ds.isnull()

        # get variables with nodata; variables with data; valid counts
        nodata, yesdata, data = [], [], {}
        for name, dataset in dsnull.items():
            if dataset.data.all(): #dataset.isnull().mean()==1
                nodata.append(name)
            else:
                yesdata.append(name)
                dataset.data = np.logical_not(dataset.data)
                data[name] = dataset.mean("sample").data"""
        # --------------------------------------------------------------------

        # get the datasets for the current platform
        pds = xrdataset.filter_by_attrs(type=pt).sel(stat="Mean", drop=True)
        timelen, samplelen = pds.time.size, pds.sample.size
        potential_obs_count = timelen*samplelen

        # get variables with nodata; variables with data; valid counts
        nodata, yesdata, obscount = [], [], {}
        for name, dataset in pds.items():
            if allnan(dataset):
                nodata.append(name)
            else:
                yesdata.append(name)
                obscount[name], obstotal = [], 0
                for i in range(samplelen):
                    samp = dataset.sel(sample=i)
                    count = numvalid(samp)
                    obscount[name].append(count) #/potential_obs_count*100
                    obstotal += count
                obscount[name].append(potential_obs_count-obstotal)
        
        # update summary dictionary
        ix = list(range(samplelen))+["nan"]   
        nandict[pt].update({                      
            "nodata": nodata, 
            "yesdata": yesdata, 
            "summary": pd.DataFrame(obscount, index=ix)})

    return(nandict)


def get_nan_plot(nandict):
    """ """

    stypes = []
    for stype, nandata in nandict.items():
        cnt = len(nandata["yesdata"])
        if cnt!=0:
            stypes.append((stype, nandata["summary"], cnt))

    if len(stypes)==1:
        df = stypes[0][1]
        df.T.plot.barh(stacked=True, colormap="tab20c", figsize=(10,5))
        plt.legend(ncol=legend_ncols(df), title="sample")
    else:
        st = sorted(stypes, key=lambda x: x[2], reverse=True)
        fig, axs = plt.subplots(
            nrows=len(st), 
            gridspec_kw={'height_ratios': [i[2] for i in st]}, 
            **figure_args)
        for i, d in enumerate(st):
            d[1].T.plot.barh(ax=axs[i],**bar_args)
        fig.tight_layout()
        axs[0].legend(ncol=legend_ncols(st[0][1]), title="sample")
        axs[0].set_title("n observations by dataset")
        for i, ax in enumerate(axs):
            xlim = ax.get_xlim()
            ax.set_xlim(xlim[0], xlim[1]/2)
            label = ax.set_xlabel(str(ceil(xlim[1]))+" total")
            ax.xaxis.set_label_coords(1.05, -0.05)

    plt.show()

# Sample

In [25]:
# ----------------------------------------------------------------------------
# SMV sample dataset formatters

numvalid = lambda v: np.count_nonzero(~np.isnan(v.data))
allnan = lambda v: numvalid(v)==0

latatts = dict(
    standard_name="latitude",
    long_name="sample latitude",
    units="degrees_north")

lonatts = dict(
    standard_name="latitude",
    long_name="sample latitude",
    units="degrees_north")

pt_style = dict(
    radius=6, 
    stroke=False,
    fill_opacity=0.6, 
    fill_color="black")


def txt_to_pd(response_text):
    """Parses response.text to data frame with date index."""
    
    f = StringIO(response_text)                      # get file from string

    df = pd.read_csv(f, header=4, index_col="time")  # read to df
    df.index = pd.to_datetime(df.index)              # convert index to dates
    
    return(df)


def split_pd(col):
    """Splits pd column by ; and set all values to float, nan."""
    
    df = col.str.split(";",n=2,expand=True)           # split col by ;
    df = df.replace('', np.nan)                       # set '' to nan
    df = df.astype(float)                             # set all to float
    df.columns = ["Max","Mean","Min"]                 # add column names
    
    return(df)


def pd_to_xr(dataset, df):
    """Makes an xr.Dataset from a pandas column (series) and coords."""

    a = smvdatasets.loc[dataset].to_dict()
    x = xr.DataArray(df, name=dataset, attrs=a)
    x = x.rename(dict(dim_1="stat"))
    
    return(x)


def get_sample_xr(samp):
    """ """

    d = ["sample"]                          
    s = xr.DataArray(data=[samp.id], dims=d) # get sample, lat, lon xr arrays
    y = xr.DataArray(data=[samp.lat], coords=[s], dims=d, attrs=latatts)
    x = xr.DataArray(data=[samp.lon], coords=[s], dims=d, attrs=lonatts)

    df = samp.df                                         # get the sample df
    ds = {}
    for dataset in df.columns:
        if ("FLUXNET" not in dataset) & ("PBOH2O" not in dataset): # disabled
            split_column = split_pd(df[dataset])
            ds[dataset] = pd_to_xr(dataset, split_column)

    xds = xr.merge(ds.values())                          # merge to one xr
    xds = xds.assign_coords(lat=y, lon=x)                # add coord arrays
    
    return(xds)


def get_null_summary(xrdataset):
    """ """
    percent = {}
    typebool = {"in situ": False, 
                "airborne": False, 
                "spaceborne": False}                             # summarize
    for name, dataset in xrdataset.items():
        t = dataset.attrs["type"]
        dsnull = dataset.isnull().sel(stat="Mean", drop=True)    # get null dataset
        if (not dsnull.data.all()) & (name not in ignore_variables):
            typebool[t] = True
            percent[name] = np.logical_not(dsnull.data).mean() 
    
    return((typebool, percent))


def get_symbology(typebool):
    """ """
    return({
        "fill_color": "#fdc086" if typebool["in situ"] else "gray",
        "color": "#beaed4" if typebool["airborne"] else None,
        "stroke": True if typebool["airborne"] else None})


class Sample(object):


    def __init__(self, i, lat, lon):
        """Inits with id,lat,lon; makes request string, map point."""
        self.id = i
        self.lat = lat 
        self.lon = lon

        self.on = False
        self.pt = CircleMarker(location=(lat, lon), **pt_style)  
        self.dl = smvdownload+"lt={0}&ln={1}&d=smap".format(lat, lon)

    def update(self, **kwargs):
        for arg, val in kwargs.items():
            setattr(self.pt, arg, val)


    def toggle(self, event, type, coordinates):
        opac = 0.1 if self.on else 0.6                    # determine opacity
        self.update(opacity=opac)                         # set opacity
        self.on = False if self.on else True              # toggle status


    def submit(self):
        """Called by parent. Downloads url. Updates status."""
        
        self.response = requests.get(self.dl, cookies=auth)# download
        self.df = txt_to_pd(self.response.text)            # read to df
        self.xr = get_sample_xr(self)                      # get xr dataset
        self.pt.on_click(self.toggle)                      # callback switch style
        self.summary = get_null_summary(self.xr)           # get null summary
        self.symbology = get_symbology(self.summary[0])
        self.on = True                                     # toggle status on

In [26]:
test = Sample(1, lat=30, lon=-90)
test.submit()

In [27]:
dir(test)
test.symbology

{'fillColor': 'gray', 'color': None, 'stroke': None}

# Layer

In [None]:
# ----------------------------------------------------------------------------
# input polygon data

latf = "docs/EASE2_M09km.lats.3856x1624x1.double"
lonf = "docs/EASE2_M09km.lons.3856x1624x1.double"

lats = np.fromfile(latf, dtype=np.float64).flatten() 
lons = np.fromfile(lonf, dtype=np.float64).flatten()
crds = np.dstack((lats,lons))[0]


def get_colors(n, cmap=cm.Set3):
    """ """

    cspace = np.linspace(0.0, 1.0, n)           # 1
    rgb = cmap(cspace)                          # 2
    cols = [colors.to_hex(c[0:3]) for c in rgb] # 3

    return(cols)

def from_geojson(input_geojson):
    """ """
    with open(input_geojson, "r") as f:
        shapes = json.load(f)
    features = shapes["features"]
    cols = get_colors(len(features))
    return((features, cols))


def get_ease(shapely_geom):
    """ """

    bnds = shapely_geom.bounds 
    ease = crds[
        (bnds[1]<lats) & (lats<bnds[3]) &     # ybnds < lat < ybnds
        (bnds[0]<lons) & (lons<bnds[2])]      # xbnds < lon < xbnds

    ease_reduced = []
    for p in ease:
        shapely_pt = shape({                  # input to shapely.shape is a
            "type": "Point",                  # python dict equivalent of
            "coordinates": (p[1], p[0])})     # geojson point geometry
        
        if shapely_geom.contains(shapely_pt): # if point inside poly
            ease_reduced.append([p[0], p[1]]) # return lat, lon tuple

    return(ease_reduced)


def get_properties(prop):
    """ """

    details, stats = {}, {"MEAN": [], "STD": []}
    for key, value in prop.items():
        if key in fields:
            details[key] = value
        elif "MEAN_" in key:
            stats["MEAN"].append(value)
        elif "STD_" in key:
            stats["STD"].append(value)
        else:
            pass
    yr = pd.date_range(
        start="1985",
        freq="1Y",
        periods=len(stats["MEAN"]))
    stats = pd.DataFrame(stats, index=yr)

    return((details, stats))


def get_layer_data(i, feat, col="#FFFFFF", opac=0.4, samp=True):
    """ """

    shapely_geom = shape(feat["geometry"])              # shapely geom
    ease = get_ease(shapely_geom) if samp else None     # ease grid points
    cent = shapely_geom.centroid                        # centroid
    lat, lon = cent.y, cent.x                           # lat, lon
    details, stats = get_properties(feat["properties"])
    feat["properties"].update({
        "id": i, 
        "style": {
            "weight": 0.75,
            "color": col,
            "fillColor": col,
            "fillOpacity": opac}})

    return((feat, ease, lat, lon, stats, details))


class Layer(object):


    def __init__(self, i, feat, col=None):
        """Inits with id,lat,lon; makes request string, map point."""
        
        layer_data = get_layer_data(i, feat, col)
        
        self.id = i
        self.feat = layer_data[0]
        self.ease = layer_data[1]
        self.lat = layer_data[2]
        self.lon = layer_data[3]
        self.stats = layer_data[4]
        self.details = layer_data[5]

        self.layer = GeoJSON(
            data=self.feat,
            hover_style={
                "color": "white", 
                "fillOpacity": 0.8})
        self.layer.on_click(self.toggle)

        self.dl = False    # downloaded or nah?
        self.on = False    # toggle on or nah?


    def toggle(self, **kwargs):
        """Routine for when a new USFS polygon is selected."""

        if list(kwargs.keys()) != ['event', 'properties']: # check event
            return(None)                                   # skip basemap
        
        self.on = False if self.on else True               # update status
    

    def update(self, **kwargs):
        for arg, val in kwargs.items():
            setattr(self.layer, arg, val)
            
    
    def getui(self):
        """ """
        self.groupbyui = radio_checkbox(plot_options)
        return(self.groupbyui)

# JupyterSMV

In [None]:
# ----------------------------------------------------------------------------
# app

pt_status_on = dict(
    opacity=0.5,
    stroke=True, 
    color="white")

pt_status_off = dict(
    opacity=0.6,
    stroke=False, 
    color="black")

sample_header = [
    "id",
    "lat",
    "lon",
    "samp"
]

layer_header = [
    "id",
    "lat",
    "lon",
    "layer",
    "samples",
    "points",
    "xr"
]


def get_output_layout(w="95%", b="1px solid lightgray"):
    """ """
    return({"width": w, "border": b})


class JupyterSMV(object):
    """App."""


    def __init__(self, primary=None, ancillary=None, freedom=False):

        self.polys = LayerGroup()
        self.points = LayerGroup()
        self.apolys = LayerGroup()
        self.mapw = Map(
            layers=(bmap, self.apolys, self.polys, self.points,), **map_args)

        self.submit = Button(**submit_args)
        self.submit.on_click(self.submit_handler)
        self.progress = IntProgress(**progress_args)
        
        if primary:                                   # if given, 
            self.load_features(primary)               # load input features
        if ancillary:
            self.load_ancillary(ancillary)

        layout = [self.mapw, HBox([self.submit, self.progress])]
        self.out1 = Output(layout=get_output_layout(w="80%"))
        self.out2 = Output(layout=get_output_layout(w="20%"))
        self.ui = VBox(layout + [HBox([self.out1, self.out2])])


    def load_features(self, infeats):
        """ """
        features, cols = from_geojson(infeats)        # get features and cols

        layers = []                                   # a temporary list 
        for i, feat in enumerate(features):           # loop over features
            
            poly = Layer(i, feat, cols[i])            # get Layer class
            poly.layer.on_click(self.layer_click_handler)  # global callback
            self.polys.add_layer(poly.layer)          # add to poly grp

            pts, samps = LayerGroup(), []             # points group; Samples
            for j, p in enumerate(poly.ease):         # loop EASE grid pts
                s = Sample(j, p[0], p[1])             # make Sample instance
                pts.add_layer(s.pt)                   # add to points group
                samps.append((j, p[0], p[1], s))      # append tuple to list  

            samps = pd.DataFrame(samps, columns=sample_header)       # samples
            layers.append((i,poly.lat,poly.lon,poly,samps,pts,None)) # append
        
        self.layers = pd.DataFrame(layers, columns=layer_header)     # layers
        self.selected = None


    def load_ancillary(self, infeats):
        """ """          
        self.alayers = get_ancillary_data(infeats)
        for layer in self.alayers.layer:
            self.apolys.add_layer(layer)


    def submit_handler(self, b):
        """Resets UI and sends requests to SMV when new submit."""
        
        layer_row = self.layers.iloc[self.selected]
        sample = layer_row.samples["samp"].tolist()

        self.progress.min = 0                      # reset progress bar
        self.progress.max = len(sample)
        self.progress.value = 0
        
        for s in sample:                           # loop over sample pts
            self.progress.value += 1               # update progress bar
            s.update(**s.symbology)                # update style
            s.submit()                             # download the data
        layer_row.layer.dl = True                  # set dl status to True
        
        xrds = xr.concat([s.xr for s in sample], "sample")
        lnan = get_nan_summary(xrds)
        self.layers.at[self.selected,"xr"] = xrds  # make xr dataset
        self.layers.iloc[self.selected].layer.nan = lnan

        self.out1.clear_output(); self.out2.clear_output()
        with self.out1:                            # display a summary of nan
            get_nan_plot(lnan)
        with self.out2:
            print(site_details.format(**layer_row.layer.details))


    def layer_click_handler(self, **kwargs): 
        """ Evaluates when new polygon is selected. Layer.toggle first!"""
        
        if list(kwargs.keys()) != ['event', 'properties']: # check event
            return(None)                                   # skip basemap

        i = int(kwargs["properties"]["id"])             # get selected poly id
        layer_row = self.layers.iloc[i]                 # get row for selected
        layer_inst = layer_row.layer                    # get Layer class inst
        self.selected = i

        self.points.clear_layers()
        self.points.add_layer(layer_row["points"]) 
        self.mapw.center = (layer_inst.lat, layer_inst.lon)
        self.mapw.zoom = 9
        self.submit.disabled = True if layer_inst.dl else False

        self.out1.clear_output(); self.out2.clear_output()
        with self.out1:                             # display nan summary
            if layer_row.layer.dl:
                get_nan_plot(layer_row.layer.nan)
            else:
                display(HTML(download_msg))
        with self.out2:
            print(site_details.format(**layer_inst.details))


# Run

In [None]:
%matplotlib inline
from smvjupyter import *                                    # import the UI

usfs_sites = "docs/usfs_sites/Sites_lf_geo.json"               # USFS sites
usfs_regions = "docs/usfs_admin/USFS_Regional_Boundaries.json" # admin regions

app = JupyterSMV(usfs_sites)#, ancillary=usfs_regions          # init UI
app.ui     

# Testing -->>



## Introduction

Some text ...



## Workshop

In [28]:
%matplotlib inline
from smvjupyter import *                                    # import the UI

usfs_sites = "docs/usfs_sites/Sites_lf_geo.json"               # USFS sites
usfs_regions = "docs/usfs_admin/USFS_Regional_Boundaries.json" # admin regions

app = JupyterSMV(usfs_sites)#, ancillary=usfs_regions          # init UI
app.ui     

VBox(children=(Map(basemap={'url': 'https://{s}.tile.openstreetmap.org/{z}/{x}/{y}.png', 'max_zoom': 19, 'attr…

In [29]:
#xr_complete = xr.open_dataset("usfs_sites_smv.nc")  # open netCDF
#xr_complete
app.layers

Unnamed: 0,id,lat,lon,layer,samples,points,xr
0,0,31.774734,-109.327589,<smvjupyter.Layer object at 0x000001FCDFDF4908>,id lat lon \ 0 0 32.08...,LayerGroup(layers=(CircleMarker(fill_color='bl...,
1,1,31.505073,-110.536967,<smvjupyter.Layer object at 0x000001FCE0F017B8>,id lat lon \ 0 0 31.83...,LayerGroup(layers=(CircleMarker(color='#377eb8...,"[SoilSCAPE_surface, SoilSCAPE_rootzone, AirMOS..."
2,2,33.653808,-108.56681,<smvjupyter.Layer object at 0x000001FCE24943C8>,id lat lon \ 0 0 33.84...,LayerGroup(layers=(CircleMarker(fill_color='bl...,
3,3,35.298089,-111.641646,<smvjupyter.Layer object at 0x000001FCE1453160>,id lat lon \ 0 0 35.55...,LayerGroup(layers=(CircleMarker(fill_color='bl...,
4,4,33.996685,-108.673179,<smvjupyter.Layer object at 0x000001FCE2515860>,id lat lon \ 0 0 34.17...,LayerGroup(layers=(CircleMarker(fill_color='bl...,
5,5,34.020165,-109.401942,<smvjupyter.Layer object at 0x000001FCE2515A58>,id lat lon \ 0 0 34.17...,LayerGroup(layers=(CircleMarker(fill_color='bl...,
6,6,47.606681,-103.517068,<smvjupyter.Layer object at 0x000001FCE2877908>,id lat lon \ 0 0 48.10...,LayerGroup(layers=(CircleMarker(fill_color='bl...,
7,7,46.752166,-103.55641,<smvjupyter.Layer object at 0x000001FCE27C12B0>,id lat lon \ 0 0 47.27...,LayerGroup(layers=(CircleMarker(fill_color='bl...,
8,8,37.714688,-106.863892,<smvjupyter.Layer object at 0x000001FCE28E0E80>,id lat lon \ 0 0 37.91...,LayerGroup(layers=(CircleMarker(fill_color='bl...,
9,9,37.337869,-103.076888,<smvjupyter.Layer object at 0x000001FCE2DDB908>,id lat lon \ 0 0 37.82...,LayerGroup(layers=(CircleMarker(fill_color='bl...,


In [30]:
app.layers.iloc[1].layer.nan

{'in situ': {'nodata': ['SoilSCAPE_surface',
   'SoilSCAPE_rootzone',
   'AirMOSS_in-ground_surface',
   'AirMOSS_in-ground_rootzone',
   'COSMOS_surface',
   'COSMOS_rootzone',
   'SCAN_surface',
   'SCAN_rootzone',
   'SNOTEL_surface',
   'SNOTEL_rootzone'],
  'yesdata': ['CRN_surface', 'CRN_rootzone'],
  'summary':      CRN_surface  CRN_rootzone
  0              0             0
  1              0             0
  2              0             0
  3              0             0
  4           2197          2197
  5              0             0
  6              0             0
  7              0             0
  8              0             0
  9              0             0
  10             0             0
  11             0             0
  12             0             0
  13             0             0
  14             0             0
  15             0             0
  16             0             0
  17             0             0
  nan       109025        109025},
 'airborne': {'nodat

In [31]:
from ipywidgets import RadioButtons, Checkbox, ToggleButtons

#dir(ToggleButtons)

In [53]:
def calc_xrarrays(xrds):
    """ """
    
    stack = np.stack([xrds[v].data for v in xrds])    # collapse dataset into a stacked array
    data = dict(
        mean = np.nanmean(stack, axis=0),             # calculate mean over time axis (0)
        std = np.nanstd(stack, axis=0),               # calculate mean over time axis (0)
        time = xrds.time.data)
    
    return(pd.DataFrame(
        data, 
        columns=["Time", "Mean", "Std"], 
        index=data["time"]))


class Plotter:
    """Generates the map/plot side-by-side widget container."""
    
    figure_args = dict(sharex=True, figsize=(15, 8))
    legend_args = dict(loc=0, framealpha=1)
    error_msg = """
        <p style="text-align:center;">
        <b>Error:</b> unable to plot this combination for some reason.
        </p>
        """
    
    def __init__(self, layer, nrow=2, ncol=1, mw="25%", ow="75%", zoom=8):

        self.layerr = layer
        self.nrow, self.ncol = nrow, ncol
        
        self.layero = layer.layer
        self.xr = layer.xr                                      # xarray.Dataset
        self.nan = self.layero.nan                              # nan summary
        
        variables = [i for L in  [n["yesdata"] for n in self.nan.values()] for i in L]
        self.xrs = self.xr[variables]

        # --------------------------------------------------------------------
        # selection ui
        
        Type = list(set([self.xrs[d].attrs["type"] for d in self.xrs]))
        self.ChkbxType = [Checkbox(description=a, value=True) for a in Type]
        self.ChkbxZone = [Checkbox(description=a, value=True) for a in ['surface', 'rootzone']]
        
        self.TypeUI = VBox([HTML("<b>Platform type: </b>"), interactive(
            self.type_handler,
            **{c.description: c.value for c in self.ChkbxType})])
        
        self.ZoneUI = VBox([HTML("<b>Soil zone:</b>"), interactive(
            self.zone_handler,
            **{c.description: c.value for c in self.ChkbxZone})])
  
        self.bybutton = ToggleButtons(
            options=["day", "week", "month", "year"],
            value="day",
            description='Interval: ',
            button_style="info")
        self.bybutton.observe(self.bttn_handler)
        
        self.selectionui = VBox([self.bybutton, HBox([
            self.TypeUI, 
            self.ZoneUI])])
        
        # --------------------------------------------------------------------
        # map/plot ui
        
        self.mapw = Map(                                             # map widget
            layers=(self.layero.layer, layer.points, bmap,), 
            center=(layer.lat, layer.lon), 
            zoom=zoom, width=mw)
        self.output = Output(layout={"width": ow})                   # Output widget 
        self.mapplotui = HBox([self.mapw, self.output])              # Map/plot ui
        
        # --------------------------------------------------------------------
        # master ui and init
        
        self.selected = {"type":Type, "soil_zone":['surface', 'rootzone']}
        self.ui = VBox([self.mapplotui, self.selectionui])
        self.draw()

    # ------------------------------------------------------------------------

    def type_handler(self, **kwargs):
        """ """
        self.selected["type"] = [k for k,v in kwargs.items() if v]
    
    def zone_handler(self, **kwargs):
        """ """
        self.selected["soil_zone"] = [k for k,v in kwargs.items() if v]
    
    def bttn_handler(self, b):
        """ """
        self.selected["By"] = self.bybutton.value

    
    def draw(self):
        """ """
        self.output.clear_output()                               
        data = self.xrs
        
        By = self.bybutton.value
        if By is not "day":                                     # maybe aggregate
            data = data.groupby("time."+str(By)).mean() 
        
        with self.output:
            fig, axs = plt.subplots(nrows=self.nrow, ncols=self.ncol, **self.figure_args)

            # ax1: soil moisture
            for T in self.selected["type"]:
                D1 = data.filter_by_attrs(type=T)
                for Z in self.selected["soil_zone"]:
                    D2 = D1.filter_by_attrs(soil_zone=Z).sel(stat="Mean").mean(dim="sample")
                    self.tmpdf = calc_xrarrays(D2)
                    self.tmpdf.plot(x="Time", y="Mean", ax=axs[0])            
            axs[0].set_title("Mean soil moisture by platform type")                  # axis title
            axs[0].set_ylabel("m3/m3")
            
            # ax2: productivity
            data2 = data.filter_by_attrs(units="g m-2 d-1")
            for name, ds in data2.items():
                axs[1].set_title(ds.attrs["description"])                        # axis title
                tdata = ds.sel(stat="Mean")
                tmean, tstd = tdata.mean("sample"), tdata.std("sample")# get mean,std of sample dimension
                tmean.plot.line(label=name, ax=axs[1], add_legend=False, marker=None)         # plot a line
            axs[1].set_ylabel("g m-2 d-1")                                       # axis labels
            
            plt.show()

In [54]:
selected_layer = app.layers.iloc[app.selected]
p = Plotter(selected_layer)
p.ui

VBox(children=(HBox(children=(Map(basemap={'url': 'https://{s}.tile.openstreetmap.org/{z}/{x}/{y}.png', 'max_z…

In [58]:
p.tmpdf.max(skipna=True)

Time   NaN
Mean   NaN
Std    NaN
dtype: float64

In [None]:
fds = p.xr.filter_by_attrs(type='airborne').sel(stat="Mean")
xr.concat([fds[d] for d in fds]).expand_dims()

In [None]:
selected_layer = app.layers.iloc[app.selected]
xrdataset = selected_layer.xr

mapw1, output1 = poly_mapper(selected_layer)                                 # get some widgets
display(HBox([mapw1, output1]))

plotters = [(xrdataset["SMAP_surface"], 0, None, None),                      # a list of lines to plot and
            (xrdataset["SMAP_rootzone"], 0, None, "m3/m3"),                  #   their respective properties:
            (xrdataset["GPP_mean"], 1, "green", None),                       #   dataset, axis, color
            (xrdataset["NEE_mean"], 1, "purple", "g m-2 d-1")]



with output1:
    fig, axs = plt.subplots(nrows=2, ncols=1, sharex=True, figsize=(15, 8))  # initialize the figure
    for p in plotters:                                                       # loop over list above
        data, axis_num, line_color, y_label = p                              # dataset, axis, color
        line_args = {"x": "time", "add_legend": False,                       # a dictionary of line properties
                    "ax": axs[axis_num], 
                    "color": line_color}                                     
        data = data.sel(stat="Mean")
        tmpmean, tmpstd = data.mean("sample"), data.std("sample")            # get mean,std of sample dimension
        tmpmean.plot.line(label=data.attrs["description"], **line_args)      # plot a line
        axs[axis_num].set_title(None)                                        # axis title
        axs[axis_num].set_ylabel(y_label)                                    # plot axis labels
    
    axs[0].legend(loc=0, framealpha=1); axs[1].legend(loc=0, framealpha=1)   # set legend properties
    axs[0].set_title("SMAP Datasets: USFS site 1", loc='left')               # plot title
    axs[0].set_xlabel(None)                                                  # remove axis 1 xlabel
    plt.show()

In [None]:
get_nan_plot(layer_row.layer.nan)

In [None]:
lyr = app.layers.iloc[1]
lyrlyr = lyr.layer
testnan = lyrlyr.nan

a = testnan["in situ"]
b = testnan["airborne"]
c = testnan["spaceborne"]

a["summary"].head(5)

In [None]:
ds = lyr.xr
dsnull = ds.sel(stat="Mean").isnull()
NEE_mean = dsnull["NEE_mean"]
np.count_nonzero(NEE_mean.data)

In [None]:
dsnull.apply(lambda x: x is True, dims)

In [None]:
from datetime import datetime
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as dt


def get_nan_summary(xrds): 
    """ """
    nandict = {"in situ": dict(), "airborne": dict(), "spaceborne": dict()}

    for pt in nandict.keys():

        # get the datasets for the current platform
        ds = xrds.filter_by_attrs(type=pt).sel(stat="Mean", drop=True)

        # get fraction of null dataset by timestep for all samples
        dsnull = ds.isnull()

        # get variables with nodata; variables with data; valid counts
        nodata, yesdata, data = [], [], {}
        for name, dataset in dsnull.items():
            if dataset.data.all(): #dataset.isnull().mean()==1
                nodata.append(name)
            else:
                yesdata.append(name)
                dataset.data = np.logical_not(dataset.data)
                data[name] = dataset.mean("sample").data
        
        nandict[pt].update({
            "nodata": nodata, 
            "yesdata": yesdata, 
            "summary": pd.DataFrame(data, index=ds.time.data)})
        
    return(nandict)

In [None]:
nands = get_nan_summary(ds)
spaceborne = nands['spaceborne']["summary"]
spaceborne.head(5)

In [None]:
spaceborne.()

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111)
ax.imshow(
    nstack, 
    aspect='auto', 
    cmap="tab20c", 
    interpolation='nearest')

In [None]:
tnancount = np.count_nonzero(data)
                yesdata[name].append(potential_obs_count-obstotal)

In [None]:
        # update summary dictionary
        ix = list(range(samplelen))+["nan"]   
        nandict[pt].update({                      
            "nodata": nodata, 
            "yesdata": yesdata, 
            "summary": pd.DataFrame(obscount, index=ix)})

In [None]:
        if len(data)>0:


In [None]:
from datetime import datetime
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as dt

df = pd.read_csv('data.csv')
df.amin = pd.to_datetime(df.amin).astype(datetime)
df.amax = pd.to_datetime(df.amax).astype(datetime)

fig = plt.figure()
ax = fig.add_subplot(111)
ax = ax.xaxis_date()
ax = plt.hlines(df.index, dt.date2num(df.amin), dt.date2num(df.amax))

In [None]:
def poly_mapper(lyr, mw="25%", ow="75%", zoom=8):
    """Generates the map/plot side-by-side widget container."""

    l = (lyr["layer"].layer, lyr.points, bmap,)
    m = Map(layers=l, center=(lyr["lat"], lyr["lon"]), zoom=zoom, width=mw)
    o = Output(layout={"width": ow})
    
    return((m,o))

In [None]:
units = [
    'g m-2 d-1', 
    'm3/m3', 
    'degrees C', 
    'mm/day']

In [None]:
lyr = app.layers.iloc[5]
lyrxr = lyr.xr

In [None]:
def get_plottable(xds):
    """ """
    plotvars = []
    for v in list(xds.variables):
        if v not in ignorevars:
            if not allnan(xds[v]):
                plotvars.append(v)
    return(plotvars)

In [None]:
mapw2, output2 = poly_mapper(lyr)
display(HBox([output2, mapw2]))

insitu = lyrxr.filter_by_attrs(type="in situ")            # select in situ, and
airborne = lyrxr.filter_by_attrs(type="airborne")         # airborne datasets

widgets = dict(
    Source=,
    Type=["in situ", "airborne"],
    Depth=['surface', 'rootzone'],
    By=["None", "year", "month", "week", "day"])


def update(Source, Type, Depth):
    
    

with output2:
    fig, axs = plt.subplots(nrows=2, ncols=1, sharex=True, figsize=(15, 8))
    plotters = [(insitu, axs[0], "In situ soil moisture"), 
                (airborne, axs[1], "Airborne soil moisture")]
    for p in plotters:
        data, ax, title = p
        data = data.sel(stat="Mean").mean(dim="sample")         # select stat "Mean" and site N. avg sample dim
        tmpstack = np.stack([data[v].data for v in data])       # collapse dataset into a stacked array
        tmpmean = np.nanmean(tmpstack, axis=0)                  # calculate mean over time axis (0)
        tmpstd = np.nanstd(tmpstack, axis=0)                    # calculate mean over time axis (0)
        tmptime = data.time.data
        
        ax.grid('on', alpha=0.25)
        ax.plot_date(x=tmptime, y=tmpmean, color="black", linestyle="solid", marker=None)
        ax.fill_between(tmptime, (tmpmean-tmpstd), (tmpmean+tmpstd), color="gray", alpha=0.2)
        ax.set_title(title+" [avg of available SMV datasets]")
        ax.set_ylabel("m3/m3")
        if not np.count_nonzero(~np.isnan(tmpmean))==0:
            ax.set_ylim(np.nanmin(tmpmean), np.nanmax(tmpmean))
        
    
    widgets = dict(Dataset=plotvars, Statistic=["Mean","Min","Max"])
    p = interactive(update, **widgets)
    display(p)

In [None]:
lyr = app.layers.iloc[0]
xds = lyr.xr
xdsm = xds.sel(stat="Mean", drop=True)
xdsm

In [None]:
xdsm["SMAP_surface"]

In [None]:
samp = lyr.samples.iloc[0].samp
sampdf = samp.df.AirMOSS_L4_rootzone

In [None]:
platform_types = {"in situ": {}, "airborne": {}, "spaceborne": {}}

for pt in platform_types.keys():
    
    # get the datasets for the current platform
    pds = xds.filter_by_attrs(type=pt).sel(stat="Mean", drop=True)
    timelen, samplelen = pds.time.size, pds.sample.size
    potential_obs_count = timelen*samplelen
    
    # get variables with nodata; variables with data; valid counts
    nodata, yesdata, obscount = [], [], {}
    for name, dataset in pds.items():
        if allnan(dataset):
            nodata.append(name)
        else:
            yesdata.append(name)
            obscount[name], obstotal = [], 0
            for i in range(samplelen):
                samp = dataset.sel(sample=i)
                count = numvalid(samp)
                obscount[name].append(count)
                obstotal += count
            obscount[name].append(potential_obs_count-obstotal)

    # update summary dictionary
    ix = list(range(samplelen))+["nan"]
    platform_types[pt].update({
        "nodata": nodata, 
        "yesdata": yesdata, 
        "summary": pd.DataFrame(obscount, index=ix)})

In [None]:
spaceborne = platform_types["spaceborne"]
spacebornedf = spaceborne["summary"]
spacebornedf

In [None]:
def get_nan_summary(nandict): 
    
    n, stypes, sratio = 0, [], {'height_ratios': []}
    for stype, nandata in nandict.items():
        cnt = len(nandata["yesdata"])
        if cnt!=0:
            n += 1
            stypes.append((stype, nandata["summary"]))
            sratio['height_ratios'].append(cnt)

    fig, axs = plt.subplots(n, 1, sharex=True, figsize=(7,4), gridspec_kw=sratio)
    for i, d in enumerate(stypes):
        d[1].T.plot.barh(stacked=True, ax=axs[i], colormap="tab20c", legend=False)
    
    fig.tight_layout()
    leg = dict(title="samples", loc="top right", bbox_to_anchor=(1, 1), framealpha=1) 
    axs[0].legend(ncol=ceil(len(stypes[0][1].index)/10), **leg)
    axs[0].set_title("n observations by dataset")
    
    return(fig)

In [None]:
testfig = get_nan_summary(platform_types)
plt.show()

In [None]:
get_legend_ncols = lambda df: ceil(len(df.index)/10)
legends_args = dict(
    title="samples", 
    loc=0, #"top left", 
    #bbox_to_anchor=(0.75, 1), 
    framealpha=1) 

def get_nan_summary(xrdataset): 
    """ """
    nandict = {"in situ": {}, "airborne": {}, "spaceborne": {}}

    for pt in nandict.keys():

        # get the datasets for the current platform
        pds = xrdataset.filter_by_attrs(type=pt).sel(stat="Mean", drop=True)
        timelen, samplelen = pds.time.size, pds.sample.size
        potential_obs_count = timelen*samplelen

        # get variables with nodata; variables with data; valid counts
        nodata, yesdata, obscount = [], [], {}
        for name, dataset in pds.items():
            if allnan(dataset):
                nodata.append(name)
            else:
                yesdata.append(name)
                obscount[name], obstotal = [], 0
                for i in range(samplelen):
                    samp = dataset.sel(sample=i)
                    count = numvalid(samp)
                    obscount[name].append(count)
                    obstotal += count
                obscount[name].append(potential_obs_count-obstotal)
        
        # update summary dictionary
        ix = list(range(samplelen))+["nan"]   
        nandict[pt].update({                      
            "nodata": nodata, 
            "yesdata": yesdata, 
            "summary": pd.DataFrame(obscount, index=ix)})

    n, stypes, sratio = 0, [], {'height_ratios': []}
    for stype, nandata in nandict.items():
        cnt = len(nandata["yesdata"])
        if cnt!=0:
            n += 1
            stypes.append((stype, nandata["summary"]))
            sratio['height_ratios'].append(cnt)

    if n>1:
        fig, axs = plt.subplots(nrows=n, gridspec_kw=sratio, **figure_args)
        for i, d in enumerate(stypes):
            d[1].T.plot.barh(ax=axs[i],**bar_args)
    else:
        fig, ax = plt.plot(figsize=(7,4))
        stypes[0][1].T.plot.barh(ax=ax,**bar_args)

    fig.tight_layout()
    axs[0].legend(ncol=get_legend_ncols(stypes[0][1]), **legends_args)
    axs[0].set_title("n observations by dataset")

    plt.show()

In [None]:
get_nan_summary(xds)

In [None]:
%matplotlib inline
font = {"family": "normal", "weight": "normal", "size": 14}       # some matplotlib settings ->
plt.rc("font", **font)      
plt.rcParams['figure.figsize'] = [14, 5]                          #