In [68]:
import os
import io
import functools

import rioxarray
import xarray as xr
import numpy as np
from pyproj import Transformer
from shapely.geometry import Point, Polygon, shape, mapping
from shapely import ops
from shapely.affinity import affine_transform
import matplotlib.pyplot as plt
from PIL import Image as PILImage
from PIL import ImageDraw
import bqplot
from bqplot import Scatter, Lines
from ipyevents import Event
from affine import Affine
from rasterio import warp
from pyproj import CRS


import ipywidgets as ipw
from IPython.display import display, clear_output
from ipyleaflet import Map, TileLayer, GeoJSON, basemaps, WidgetControl

In [11]:
cube = xr.open_dataset('/home/loic/Downloads/czechia_nrt_test.nc')

In [64]:
date_timeId_mapping = [('2016-01-13', 3515),
                       ('2017-01-11', 577),
                       ('2018-01-08', 13161),
                       ('2019-01-09', 6036),
                       ('2020-01-08', 23001),
                       ('2021-01-13', 1049),
                       ('2022-01-12', 42663),
                       ('2023-01-11', 11475)]

fc = [{'type': 'Feature',
       'geometry': {'type': 'Point', 'coordinates': [4813210, 2935950]},
       'properties': {'idx': 1}},
      {'type': 'Feature',
       'geometry': {'type': 'Point', 'coordinates': [4815170, 2936650]},
       'properties': {'idx': 2}}]


######
## Google webmap
leaflet_google = Map(basemap=basemaps.OpenStreetMap.Mapnik,
                  center=(0,0),
                  scroll_wheel_zoom=True,
                  layout=ipw.Layout(height='400px', width='100%'))
google = TileLayer(url="https://mt1.google.com/vt/lyrs=s&x={x}&y={y}&z={z}")
leaflet_google.add(google)

#################################
# ESRI with wayback functionality
# Date slider widget
wayback_slider = ipw.SelectionSlider(
    options=date_timeId_mapping,
    value=3515,
    description='Date',
    disabled=False,
    continuous_update=True,
    orientation='horizontal',
    readout=True
)
time_control = WidgetControl(widget=wayback_slider, position='topright')

# Prepare map, tilelayer and draw it
leaflet_wayback = Map(center=(0, 0),
                      scroll_wheel_zoom=True)
wayback = TileLayer(url='https://wayback.maptiles.arcgis.com/arcgis/rest/services/world_imagery/wmts/1.0.0/default028mm/mapserver/tile/3515/{z}/{y}/{x}',
                    max_zoom=25)
leaflet_wayback.add_layer(wayback)
leaflet_wayback.add_control(time_control)

def on_date_change(*args):
    wayback.url = 'https://wayback.maptiles.arcgis.com/arcgis/rest/services/world_imagery/wmts/1.0.0/default028mm/mapserver/tile/%d/{z}/{y}/{x}' % wayback_slider.value
    wayback.redraw()

wayback_slider.observe(on_date_change, 'value')

In [65]:
class ColorComposite(object):
    """Transform to create a stretched image in numpy format from a multivariate xarray Dataset

    Works only for a Dataset with a single temporal slice (e.g. only x and y coordinate valiables)
    """
    def __init__(self, b='B02_20', g='B03_20', r='B04_20',
                 blim=[20,2000], glim=[50,2000], rlim=[20,2000]):
        self.b = b
        self.g = g
        self.r = r
        self.blim = blim
        self.glim = glim
        self.rlim = rlim

    @staticmethod
    def stretch(arr, blim=[20,2000], glim=[50,2000], rlim=[20,2000]):
        """Apply color stretching and [0,1] clipping to a 3 bands image
    
        Args:
            arr (np.ndarray): 3D array; bands as last dimension in RGB order.
                (if read in 'wavelength order' satellite images are usually BGR)
            blim,glim,rlim (list): min and max values between which to stretch the individual bands
        """
        bottom = np.array([[[rlim[0], glim[0], blim[0]]]])
        top = np.array([[[rlim[1], glim[1], blim[1]]]])
        arr_stretched = (arr - bottom)/(top-bottom)
        return np.clip(arr_stretched, 0.0, 1.0)

    def __call__(self, ds):
        rgb = np.stack([ds[self.r].values, ds[self.g].values, ds[self.b].values], axis=-1)
        rgb = ColorComposite.stretch(rgb, blim=self.blim, glim=self.glim, rlim=self.rlim)
        return rgb


class NDVI(object):
    def __init__(self, red='B04_20', nir='B8A'):
        self.red = red
        self.nir = nir

    def __call__(self, ds):
        ds = (ds[self.nir].astype(np.float32) - ds[self.red].astype(np.float32)) / (ds[self.nir] + ds[self.red] + 0.0000001)
        return ds


def np2ipw(arr, geom=None, transform=None, res=20, scale=4, outline_color='magenta'):
    img = PILImage.fromarray((arr * 255).astype(np.uint8))
    img = img.resize(size=(arr.shape[0] * scale, arr.shape[1] * scale), resample=PILImage.NEAREST)

    # Modify transform to new scale
    scaled_transform = transform * Affine.scale(1 / scale)
    scaled_transform = ~scaled_transform

    # Get polygon coordinates in PIL format
    shape_ = shape(geom)
    if isinstance(shape_, Point):
        shape_ = shape_.buffer(res/2, cap_style=3)
    shape_ = affine_transform(shape_, scaled_transform.to_shapely())
    x,y = shape_.exterior.coords.xy
    polygon_coordinates = list(zip(x,y))

    # Draw polygon on image
    draw = ImageDraw.Draw(img)
    draw.polygon(polygon_coordinates, fill=None, outline=outline_color)


    # Save image to fileobject and reload as ipw.Image
    with io.BytesIO() as fileobj:
        img.save(fileobj, 'PNG')
        img_b = fileobj.getvalue()
    img_w = ipw.Image(value=img_b)
    return img_w

In [70]:
class Interface(object):
    def __init__(self, cube, features,
                 compositor=ColorComposite(),
                 vi_calculator=NDVI(),
                 window_size=500,
                 outline_color='magenta',
                 webmap=leaflet_wayback):
        """
        Args:
            point (shapely.geometry.Point): A point in the crs of the cube
        """
        self.cube = cube
        self.features = features
        self.compositor = compositor
        self.vi_calculator = vi_calculator
        self.window_size = window_size
        self.current_idx = 0
        self.outline_color = outline_color
        self.webmap = webmap
        self.webmap_geom = None
        self.breakpoints = []

        # Widgets
        self.next_button = ipw.Button(description='Next')
        self.previous_button = ipw.Button(description='Previous')
        self.label_input = ipw.Text(description='Label')
        self.output_area = ipw.Output()

        # Bind events
        self.next_button.on_click(self.next_sample)
        self.previous_button.on_click(self.previous_sample)
        #self.label_input.observe(self.label_changed, names='value')

        # Initialize display
        self.draw_webmap()
        self.update_display()

    def draw_webmap(self):
        # Simply creates a geometry, add it to the map and center the map on it
        current_shape = shape(fc[self.current_idx]['geometry'])
        # TODO: that part need to be changed if a list of cube is accepted in class instance constructor
        res = self.cube.rio.resolution()[0]
        if isinstance(current_shape, Point):
            current_shape = current_shape.buffer(res/2, cap_style=3)
        current_geom = warp.transform_geom(src_crs = self.cube.rio.crs,
                                           dst_crs = CRS.from_epsg(4326),
                                           geom=mapping(current_shape))
        centroid = shape(current_geom).centroid
        self.webmap_geom = GeoJSON(data=current_geom,
                                   style = {'opacity': 1, 'fillOpacity': 0,
                                            'weight': 1, 'color': 'magenta'})
        self.webmap.add(self.webmap_geom)
        self.webmap.center = [centroid.y, centroid.x]
        self.webmap.zoom = 17
        
    def update_webmap(self):
        # Update webmap_geom attribute and recenter the map on it
        current_shape = shape(fc[self.current_idx]['geometry'])
        # TODO: that part need to be changed if a list of cube is accepted in class instance constructor
        res = self.cube.rio.resolution()[0]
        if isinstance(current_shape, Point):
            current_shape = current_shape.buffer(res/2, cap_style=3)
        current_geom = warp.transform_geom(src_crs = self.cube.rio.crs,
                                           dst_crs = CRS.from_epsg(4326),
                                           geom=mapping(current_shape))
        centroid = shape(current_geom).centroid
        # Only way I found to update the GeoJSON layer without layers stacking up in self.webmap.layers
        for layer in list(self.webmap.layers):
            if isinstance(layer, GeoJSON):
                layer.data = current_geom
        # self.webmap_geom.data = current_geom    
        # Try to force update the layer
        self.webmap_geom.style = {'opacity': 1, 'fillOpacity': 0,
                                  'weight': 1, 'color': self.outline_color}
        self.webmap.center = [centroid.y, centroid.x]
        self.webmap.zoom = 17

    def next_sample(self, b):
        # Go to next sample
        if self.current_idx < len(self.features) - 1:
            self.current_idx += 1
            self.update_webmap()
            self.update_display()

    def previous_sample(self, b):
        # Go to previous sample
        if self.current_idx > 0:
            self.current_idx -= 1
            self.update_webmap()
            self.update_display()

    def add_breakpoint(self, date):
        pass
        # 

    def remove_breakpoint(self, date):
        pass

    def update_display(self):
        with self.output_area:
            clear_output(wait=True)
            print(f"Sample {self.current_idx + 1}/{len(self.features)}")
            interface = self.link_components(self.current_idx)
            display(interface)

    def display_interface(self):
        """Method called by the user to display the full interface"""
        # Display the full interface
        display(ipw.HBox([self.previous_button, self.next_button, self.webmap]))
        display(self.output_area)

    def get_chips(self, idx):
        geom = self.features[idx]['geometry']
        point = shape(geom)
        bbox = point.buffer(self.window_size).bounds
        cube_sub = self.cube.rio.clip_box(*bbox)
        transform = cube_sub.rio.transform()
        res = cube_sub.rio.resolution()[0]
        imgs = []
        for date in cube_sub.time.values:
            slice = cube_sub.sel(time=date)
            bgr = self.compositor(slice)
            imgs.append(np2ipw(bgr, geom=geom, transform=transform,
                               res=res, outline_color=self.outline_color))
        return imgs
        

    def get_ts(self, idx):
        # TODO: handle that differently if it's a point or a polygon
        cube_sub = self.cube.sel(x=self.features[idx]['geometry']['coordinates'][0],
                                 y=self.features[idx]['geometry']['coordinates'][1],
                                 method='nearest')
        ds = self.vi_calculator(cube_sub)
        return ds.time.values, ds.values

    def link_components(self, idx):
        chips = self.get_chips(self.current_idx)
        dates, values = self.get_ts(self.current_idx)
        # Prepare figure
        x_sc = bqplot.LinearScale()
        y_sc = bqplot.LinearScale()
        # Create axes
        x_ax = bqplot.Axis(label='Time', scale=x_sc, tick_format='%m-%Y', tick_rotate=45)
        y_ax = bqplot.Axis(label='Vegetation Index', scale=y_sc, orientation='vertical', side='left')
        # Create line mark
        vi_values = bqplot.Scatter(x=dates, y=values, scales={'x': x_sc, 'y': y_sc})
        # Create a dummy highlighted point out of view
        highlighted_point = Scatter(x=[-1000], y=[-1000], # Dummy point out of view
                            scales={'x': x_sc, 'y': y_sc},
                            preserve_domain={'x': True, 'y': True},
                            colors=['red'])
        # Create and display the figure
        ts_fig = bqplot.Figure(marks=[vi_values, highlighted_point], axes=[x_ax, y_ax], title='Sample temporal profile')
        ts_fig.layout.width = '100%'
        ts_fig.layout.height = '400px'

        def _handle_chip_event(idx, event):
            """Change the coordinates of the highlighted point to actual date and value when mouse enters chip"""
            if event['type'] == 'mouseenter':
                highlighted_point.x = [dates[idx]]
                highlighted_point.y = [values[idx]]
            if event['type'] == 'mouseleave':
                # Reset dummy location
                highlighted_point.x = [-1000]
                highlighted_point.y = [-1000]
            if event['type'] == 'click':
                clicked_date = dates[idx]
                print(clicked_date)
                if chips[idx].layout.border == '2px solid blue':
                    # If it has a border, remove it
                    chips[idx].layout.border = ''
                else:
                    chips[idx].layout.border = '2px solid blue'


                if clicked_date not in self.breakpoints:
                    # If no line for this chip, add a new line
                    new_line = bqplot.Lines(x=[clicked_date, clicked_date], y=[0, 1], 
                                            scales={'x': x_sc, 'y': y_sc}, colors=['red'])
                    self.breakpoints.append(clicked_date)
                    ts_fig.marks = ts_fig.marks + (new_line,)
                else:
                    # If line exists, remove it
                    self.breakpoints.remove(clicked_date)
                    #ts_fig.marks = tuple(mark for mark in self.ts_fig.marks if mark != line_to_remove)

                
        for idx, chip in enumerate(chips):
            event = Event(source=chip, watched_events = ['mouseenter', 'mouseleave', 'click'])
            event.on_dom_event(functools.partial(_handle_chip_event, idx))

        box_layout = ipw.Layout(
            display='flex',
            flex_flow='row wrap',
            align_items='stretch',
            width='100%',
            height='800px',  # Set a fixed height (modify as needed)
            overflow='auto'  # Add scrollability
        )
        box = ipw.Box(children=chips, layout=box_layout)
        interface = ipw.VBox([ts_fig, box])
        return interface
            




interface = Interface(cube=cube, features=fc)
interface.display_interface()

        

    

HBox(children=(Button(description='Previous', style=ButtonStyle()), Button(description='Next', style=ButtonSty…

Output()

In [71]:
interface.breakpoints

[numpy.datetime64('2018-11-17T10:02:59.024000000')]

In [48]:
import ipywidgets as widgets
from IPython.display import display, clear_output
import sqlite3

class DataLabelingInterface:
    def __init__(self, data_samples):
        self.data_samples = data_samples
        self.current_index = 0

        # SQLite Database Setup
        self.conn = sqlite3.connect(':memory:')
        self.cursor = self.conn.cursor()
        self.cursor.execute("CREATE TABLE labels (id INTEGER PRIMARY KEY, label TEXT)")
        self.cursor.executemany("INSERT INTO labels (label) VALUES (?)", [(label,) for label in [''] * len(data_samples)])


        # Widgets
        self.next_button = widgets.Button(description='Next')
        self.previous_button = widgets.Button(description='Previous')
        self.label_input = widgets.Text(description='Label')
        self.output_area = widgets.Output()

        # Bind events
        self.next_button.on_click(self.next_sample)
        self.previous_button.on_click(self.previous_sample)
        self.label_input.observe(self.label_changed, names='value')

        # Initialize display
        self.update_display()

    def update_display(self):
        # Update the display area with the current sample
        with self.output_area:
            clear_output(wait=True)
            print(f"Sample {self.current_index + 1}/{len(self.data_samples)}: {self.data_samples[self.current_index]}")
            label = self.get_label(self.current_index)
            self.label_input.value = label
            display(self.label_input)

    def display_interface(self):
        # Display the full interface
        display(widgets.HBox([self.previous_button, self.next_button]))
        display(self.output_area)

    def next_sample(self, b):
        # Go to next sample
        if self.current_index < len(self.data_samples) - 1:
            self.current_index += 1
            self.update_display()

    def previous_sample(self, b):
        # Go to previous sample
        if self.current_index > 0:
            self.current_index -= 1
            self.update_display()

    def label_changed(self, change):
        # Update label in database for current sample
        self.update_label(self.current_index, change['new'])

    def update_label(self, index, label):
        # Update the label in the database
        self.cursor.execute("UPDATE labels SET label = ? WHERE id = ?", (label, index + 1))
        self.conn.commit()

    def get_label(self, index):
        # Retrieve the label from the database
        self.cursor.execute("SELECT label FROM labels WHERE id = ?", (index + 1,))
        result = self.cursor.fetchone()
        return result[0] if result else ''

# Sample data
data_samples = [
    "The quick brown fox jumps over the lazy dog",
    "A stitch in time saves nine",
    "An apple a day keeps the doctor away",
    "Early to bed and early to rise, makes a man healthy, wealthy, and wise"
]

# Create and display the labeling interface
labeling_interface = DataLabelingInterface(data_samples)
labeling_interface.display_interface()



HBox(children=(Button(description='Previous', style=ButtonStyle()), Button(description='Next', style=ButtonSty…

Output()

In [35]:
interface.webmap.zoom

17.0

In [5]:
import ipywidgets as widgets
from IPython.display import display, clear_output
import sqlite3

class DataLabelingInterface:
    def __init__(self, data_samples):
        self.data_samples = data_samples
        self.current_index = 0

        # SQLite Database Setup
        self.conn = sqlite3.connect(':memory:')
        self.cursor = self.conn.cursor()
        self.cursor.execute("CREATE TABLE labels (sample_id INTEGER, label TEXT)")

        # Widgets
        self.next_button = widgets.Button(description='Next')
        self.previous_button = widgets.Button(description='Previous')
        self.add_label_button = widgets.Button(description='Add Label')
        self.label_inputs = [widgets.Text(description='Label 1')]
        self.output_area = widgets.Output()

        # Bind events
        self.next_button.on_click(self.next_sample)
        self.previous_button.on_click(self.previous_sample)
        self.add_label_button.on_click(self.add_label_input)

        # Initialize display
        self.update_display()

    def update_display(self):
        # Update the display area with the current sample
        with self.output_area:
            clear_output(wait=True)
            print(f"Sample {self.current_index + 1}/{len(self.data_samples)}: {self.data_samples[self.current_index]}")
            self.reset_label_inputs()  # Reset and repopulate label inputs based on stored labels
            for label_input in self.label_inputs:
                display(label_input)
            display(self.add_label_button)

    def reset_label_inputs(self):
        # Reset label inputs for new sample based on stored labels
        labels = self.get_all_labels_for_sample(self.current_index)
        if not labels:
            labels = ['']  # At least one label input if no labels are stored
        self.label_inputs = [widgets.Text(description=f'Label {i+1}', value=label) for i, label in enumerate(labels)]

    def get_all_labels_for_sample(self, sample_id):
        # Retrieve all labels for a specific sample from the database
        self.cursor.execute("SELECT label FROM labels WHERE sample_id = ?", (sample_id,))
        return [label[0] for label in self.cursor.fetchall()]

    def display_interface(self):
        # Display the full interface
        display(widgets.HBox([self.previous_button, self.next_button]))
        display(self.output_area)

    def next_sample(self, b):
        # Save current labels and go to next sample
        self.save_labels()
        if self.current_index < len(self.data_samples) - 1:
            self.current_index += 1
            self.reset_label_inputs()
            self.update_display()

    def previous_sample(self, b):
        # Save current labels and go to previous sample
        self.save_labels()
        if self.current_index > 0:
            self.current_index -= 1
            self.reset_label_inputs()
            self.update_display()

    def add_label_input(self, b):
        # Preserve existing labels
        for i, label_input in enumerate(self.label_inputs):
            self.label_inputs[i].value = label_input.value
    
        # Add a new label input
        new_label_index = len(self.label_inputs) + 1
        new_label_input = widgets.Text(description=f'Label {new_label_index}')
        self.label_inputs.append(new_label_input)
        self.update_display()

    def save_labels(self):
        # Save the labels of the current sample to the database
        self.cursor.execute("DELETE FROM labels WHERE sample_id = ?", (self.current_index,))
        for label_input in self.label_inputs:
            if label_input.value.strip():
                self.cursor.execute("INSERT INTO labels (sample_id, label) VALUES (?, ?)", 
                                    (self.current_index, label_input.value))
        self.conn.commit()

    def get_label(self, sample_id, label_index):
        # Retrieve a specific label from the database
        self.cursor.execute("SELECT label FROM labels WHERE sample_id = ? LIMIT 1 OFFSET ?", (sample_id, label_index))
        result = self.cursor.fetchone()
        return result[0] if result else ''

# Sample data
data_samples = [
    "The quick brown fox jumps over the lazy dog",
    "A stitch in time saves nine",
    "An apple a day keeps the doctor away",
    "Early to bed and early to rise, makes a man healthy, wealthy, and wise"
]

# Create and display the labeling interface
labeling_interface = DataLabelingInterface(data_samples)
labeling_interface.display_interface()


HBox(children=(Button(description='Previous', style=ButtonStyle()), Button(description='Next', style=ButtonSty…

Output()

In [25]:
from shapely.geometry import Point
Point(1,2).x

1.0