In [None]:
!pip uninstall deafrica-tools -y
!pip install ../Tools

In [None]:
# Load modules
import datacube
import seaborn as sns
import matplotlib.pyplot as plt
from odc.ui import select_on_a_map
from datacube.utils.geometry import CRS
from ipyleaflet import WMSLayer, basemaps, basemap_to_tiles, Map, DrawControl, WidgetControl, LayerGroup
from traitlets import Unicode
from ipywidgets import GridspecLayout, Button, Layout, jslink, IntText, IntSlider, DatePicker, HBox, VBox, Text, BoundedFloatText, HTML, Dropdown
import json
import geopandas as gpd
from io import BytesIO

# from deafrica_tools.spatial import reverse_geocode
from deafrica_tools.dask import create_local_dask_cluster
from deafrica_tools.wetlands import WIT_drill

In [None]:
def create_map():
    
    basemap_osm = basemap_to_tiles(basemaps.OpenStreetMap.Mapnik)
    basemap_esri = basemap_to_tiles(basemaps.Esri.WorldImagery)
    basemap_cartodb = basemap_to_tiles(basemaps.CartoDB.Positron)
    
    m = Map(center=(4, 20), zoom=3, basemap=basemap_cartodb)
    
    return m

In [None]:
def create_deafrica_layer(product, date):
    
    # Load DEA WMS
    class TimeWMSLayer(WMSLayer):
        time = Unicode("").tag(sync=True, o=True)

    time_wms = TimeWMSLayer(
        url="https://ows.digitalearth.africa/",
        layers=product,
        time=date,
        format="image/png",
        transparent=True,
        attribution="Digital Earth Africa",
    )
    
    return time_wms

In [None]:
def create_datepicker(description):
    
    date_picker = DatePicker(
        description=description,
        disabled=False
    )
    
    return date_picker

In [None]:
def create_drawcontrol():
    
    draw_control = DrawControl()
    

    draw_control.rectangle = {
        "shapeOptions": {
            "fillColor": "#fca45d",
            "color": "#fca45d",
            "fillOpacity": 1.0
        }
    }
    
    draw_control.polygon = {
        "shapeOptions": {
            "fillColor": "#6be5c3",
            "color": "#6be5c3",
            "fillOpacity": 1.0
        },
        "drawError": {
            "color": "#dd253b",
            "message": "Error!"
        },
        "allowIntersection": False
    }
    
    # Disable other forms
    draw_control.marker={}
    draw_control.circle={}
    draw_control.circlemarker={}
    draw_control.polyline={}
    
    
    return draw_control

In [None]:
def create_inputtext(value, placeholder, description):
    
    input_text = Text(
        value=value,
        placeholder=placeholder,
        description=description,
        disabled=False
    )
    
    return input_text
    

In [None]:
def create_boundedfloattext(value, description, min_val, max_val, step_val):
    
    float_text = BoundedFloatText(
        value=value,
        min=min_val,
        max=max_val,
        step=step_val,
        description=description,
        disabled=False
    )
    
    return float_text

In [None]:
def create_html(value):
    
    html = HTML(
        value=value,
    )
    
    return html

In [None]:
def create_dropdown(options, value, description):
    
    dropdown = Dropdown(
        options=options,
        value=value,
        description=description,
    )
    
    return dropdown

In [None]:
def create_expanded_button(description, button_style):
    return Button(description=description, button_style=button_style, layout=Layout(height='auto', width='auto'))
    


class wit_app(HBox):
    
    def __init__(self):
        super().__init__()
        
        ##########################################################
        
        # set any initial attributes here 
        self.startdate = '2019-01-01'
        self.enddate = '2019-03-01'
        self.mingooddata = 0.8
        self.out_csv = 'example_WIT.csv'
        self.out_plot = 'example_WIT.png'
        self.product_list = [('None', 'none'), ('Sentinel-2 Geomedian', 'gm_s2_annual'), ('Water Observations from Space', 'wofs_ls_summary_annual')]
        self.product = self.product_list[0][1]
        self.product_year = '2019-01-01'
        self.target = None
        self.action = None
        self.gdf_drawn = None
        
        def update_geojson(target, action, geo_json):
        
            self.action = action
            
            json_data = json.dumps(geo_json)
            binary_data = json_data.encode()
            io = BytesIO(binary_data)
            io.seek(0)
            
            gdf = gpd.read_file(io)
            gdf.crs = "EPSG:4326"
            
            self.gdf_drawn = gdf
        
        ##########################################################
        
        
        
        ##########################################################
        
        # Create DE Africa layers
        #self.wofs_layer = create_deafrica_layer('wofs_ls_summary_annual', self.product_year)
        #self.s2_geomad_layer = create_deafrica_layer('gm_s2_annual', self.product_year)
        
        self.deafrica_layers = LayerGroup(layers=())
        
        # Create map widget
        self.m = create_map()
        draw_control = create_drawcontrol()
        self.m.add_control(draw_control)
        self.basemap = self.m.basemap
        self.m.add_layer(self.deafrica_layers)
        
        ##########################################################
        
        # Create parameter widgets
        startdate_picker = create_datepicker('Start Date')
        enddate_picker = create_datepicker('End Date')
        min_good_data = create_boundedfloattext(self.mingooddata, 'Min Good Data', 0.0, 1.0, 0.05)
        output_csv = create_inputtext(self.out_csv, self.out_csv, 'Output CSV')
        output_plot = create_inputtext(self.out_plot, self.out_plot, 'Output Plot')
        basemap_dropdown = create_dropdown(self.product_list, self.product_list[0][1], 'Basemap')
        
        parameter_selection = VBox(
            [
                startdate_picker, 
                enddate_picker,
                min_good_data,
                output_csv,
                output_plot,
                basemap_dropdown,
            ]
        )
        
        ##########################################################
        
        # Run button
        
        run_button = create_expanded_button('Run', 'info')
        
        self.paramlog = create_html('Log')
        
        ##########################################################
        
        # Create the layout #[rowspan, colspan]
        grid = GridspecLayout(5, 3, height='800px', width='1000px')
        grid[0, :1] = parameter_selection
        grid[0:2, 1:] = self.m
        grid[1, :1] = run_button
        grid[2:, :] = self.paramlog #create_expanded_button('Plot', 'danger')
        
        # Display using HBox children attribute
        self.children = [grid]
        
        ##########################################################
        
        # Run update functions whenever various widgets are changed.
        startdate_picker.observe(self.update_startdate, 'value')
        enddate_picker.observe(self.update_enddate, 'value')
        min_good_data.observe(self.update_mingooddata, 'value')
        output_csv.observe(self.update_outputcsv, 'value')
        output_plot.observe(self.update_outputplot, 'value')
        basemap_dropdown.observe(self.update_basemap, 'value')
        run_button.on_click(self.run_app)
        draw_control.on_draw(update_geojson)
        
        ##########################################################
        
    # set the start date to the new edited date
    def update_startdate(self, change):
        self.startdate = change.new
        self.product_year  = f'{self.startdate.year}-01-01'
        
    # set the end date to the new edited date    
    def update_enddate(self, change):
        self.enddate = change.new 
    
    # set the min good data
    def update_mingooddata(self, change):
        self.mingooddata = change.new 
        
    # set the output csv
    def update_outputcsv(self, change):
        self.out_csv = change.new 
        
    # set the output plot
    def update_outputplot(self, change):
        self.out_plot = change.new
        
    # Update product
    def update_basemap(self, change):
        
        self.product = change.new
        
        if self.product == 'none':
            self.deafrica_layers.clear_layers()
        else:
            self.deafrica_layers.clear_layers()
            layer = create_deafrica_layer(self.product, self.product_year)
            self.deafrica_layers.add_layer(layer)
        
    def run_app(self, change):
        
        text = ' '.join([f'{self.target}', f'{self.action}', f'{self.gdf_drawn.crs}', f'{self.product_year}', f'{self.startdate}', f'{self.enddate}', f'{self.mingooddata}', f'{self.out_csv}', f'{self.out_plot}', f'{self.product}'])
        self.paramlog.value = text
        
        # Connect to datacube database
        dc = datacube.Datacube(app="wetland_app")

        # Configure local dask cluster
        create_local_dask_cluster()

        # Set any defaults
        resample_frequency = '1M'
        TCW_threshold = -0.035
        dask_chunks = dict(x=1000, y=1000, time=1)
        
        self.paramlog.value = 'Running WIT'
        # run wetlands polygon drill
        df = WIT_drill(
            gdf=self.gdf_drawn,
            time=(self.startdate, self.enddate),
            min_gooddata=self.mingooddata,
            resample_frequency=resample_frequency,
            TCW_threshold=TCW_threshold,
            export_csv=self.out_csv,
            dask_chunks=dask_chunks,
            verbose=False,
        )
        self.paramlog.value = 'WIT Complete'

#         # save the csv
#         if export_csv:
#             if verbose:
#                 print("exporting csv: " + export_csv)
#             df.to_csv(export_csv, index_label="Datetime")

#         # ---Plotting------------------------------

#         fontsize = 17
#         # set up color palette
#         pal = [
#             sns.xkcd_rgb["cobalt blue"],
#             sns.xkcd_rgb["neon blue"],
#             sns.xkcd_rgb["grass"],
#             sns.xkcd_rgb["beige"],
#             sns.xkcd_rgb["brown"],
#         ]

#         # make a stacked area plot
#         plt.clf()
#         fig = plt.figure(figsize=(22, 6))
#         plt.stackplot(
#             df.index,
#             df.wofs_area_percent,
#             df.wet_percent,
#             df.green_veg_percent,
#             df.dry_veg_percent,
#             df.bare_soil_percent,
#             labels=[
#                 "open water",
#                 "wet",
#                 "green veg",
#                 "dry veg",
#                 "bare soil",
#             ],
#             colors=pal,
#             alpha=0.6,
#         )

#         # set axis limits to the min and max
#         plt.axis(xmin=df.index[0], xmax=df.index[-1], ymin=0, ymax=100, fontsize=fontsize)
#         plt.tick_params(labelsize=fontsize)
#         # add a legend and a tight plot box
#         plt.legend(loc="lower left", framealpha=0.6, fontsize=fontsize)
#         plt.title("Fractional Cover, Wetness, and Water", fontsize=fontsize)
#         plt.tight_layout()
#         plt.show()
#         if export_plot:
#             if verbose:
#                 print("exporting plot: " + export_plot)
#             # save the figure
#             plt.savefig(f"{export_plot}")
        
        
    

In [None]:
wit_app()