In [41]:
import os
import io
import functools

import rioxarray
import xarray as xr
import numpy as np
from pyproj import Transformer
from shapely.geometry import Point, Polygon, shape
from shapely import ops
from shapely.affinity import affine_transform
import matplotlib.pyplot as plt
from PIL import Image as PILImage
from PIL import ImageDraw
import bqplot
from bqplot import Scatter
from ipyevents import Event
from affine import Affine

import ipywidgets as ipw
from IPython.display import display, clear_output

In [2]:
cube = xr.open_dataset('/home/loic/Downloads/czechia_nrt_test.nc')

In [52]:
fc = [{'type': 'Feature',
       'geometry': {'type': 'Point', 'coordinates': [4813210, 2935950]},
       'properties': {'idx': 1}},
      {'type': 'Feature',
       'geometry': {'type': 'Point', 'coordinates': [4815170, 2936650]},
       'properties': {'idx': 2}}]

class ColorComposite(object):
    """Transform to create a stretched image in numpy format from a multivariate xarray Dataset

    Works only for a Dataset with a single temporal slice (e.g. only x and y coordinate valiables)
    """
    def __init__(self, b='B02_20', g='B03_20', r='B04_20',
                 blim=[20,2000], glim=[50,2000], rlim=[20,2000]):
        self.b = b
        self.g = g
        self.r = r
        self.blim = blim
        self.glim = glim
        self.rlim = rlim

    @staticmethod
    def stretch(arr, blim=[20,2000], glim=[50,2000], rlim=[20,2000]):
        """Apply color stretching and [0,1] clipping to a 3 bands image
    
        Args:
            arr (np.ndarray): 3D array; bands as last dimension in RGB order.
                (if read in 'wavelength order' satellite images are usually BGR)
            blim,glim,rlim (list): min and max values between which to stretch the individual bands
        """
        bottom = np.array([[[rlim[0], glim[0], blim[0]]]])
        top = np.array([[[rlim[1], glim[1], blim[1]]]])
        arr_stretched = (arr - bottom)/(top-bottom)
        return np.clip(arr_stretched, 0.0, 1.0)

    def __call__(self, ds):
        rgb = np.stack([ds[self.r].values, ds[self.g].values, ds[self.b].values], axis=-1)
        rgb = ColorComposite.stretch(rgb, blim=self.blim, glim=self.glim, rlim=self.rlim)
        return rgb


class NDVI(object):
    def __init__(self, red='B04_20', nir='B8A'):
        self.red = red
        self.nir = nir

    def __call__(self, ds):
        ds = (ds[self.nir].astype(np.float32) - ds[self.red].astype(np.float32)) / (ds[self.nir] + ds[self.red] + 0.0000001)
        return ds


def np2ipw(arr, geom=None, transform=None, res=20, scale=4, outline_color='magenta'):
    img = PILImage.fromarray((arr * 255).astype(np.uint8))
    img = img.resize(size=(arr.shape[0] * scale, arr.shape[1] * scale), resample=PILImage.NEAREST)

    # Modify transform to new scale
    scaled_transform = transform * Affine.scale(1 / scale)
    scaled_transform = ~scaled_transform

    # Get polygon coordinates in PIL format
    shape_ = shape(geom)
    if isinstance(shape_, Point):
        shape_ = shape_.buffer(res/2, cap_style=3)
    shape_ = affine_transform(shape_, scaled_transform.to_shapely())
    x,y = shape_.exterior.coords.xy
    polygon_coordinates = list(zip(x,y))

    # Draw polygon on image
    draw = ImageDraw.Draw(img)
    draw.polygon(polygon_coordinates, fill=None, outline=outline_color)


    # Save image to fileobject and reload as ipw.Image
    with io.BytesIO() as fileobj:
        img.save(fileobj, 'PNG')
        img_b = fileobj.getvalue()
    img_w = ipw.Image(value=img_b)
    return img_w
        

class Interface(object):
    def __init__(self, cube, features,
                 compositor=ColorComposite(),
                 vi_calculator=NDVI(),
                 window_size=500,
                 outline_color='magenta'):
        """
        Args:
            point (shapely.geometry.Point): A point in the crs of the cube
        """
        self.cube = cube
        self.features = features
        self.compositor = compositor
        self.vi_calculator = vi_calculator
        self.window_size = window_size
        self.current_idx = 0
        self.outline_color = outline_color

        # Widgets
        self.next_button = ipw.Button(description='Next')
        self.previous_button = ipw.Button(description='Previous')
        self.label_input = ipw.Text(description='Label')
        self.output_area = ipw.Output()

        # Bind events
        self.next_button.on_click(self.next_sample)
        self.previous_button.on_click(self.previous_sample)
        #self.label_input.observe(self.label_changed, names='value')

        # Initialize display
        self.update_display()

    def next_sample(self, b):
        # Go to next sample
        if self.current_idx < len(self.features) - 1:
            self.current_idx += 1
            self.update_display()

    def previous_sample(self, b):
        # Go to previous sample
        if self.current_idx > 0:
            self.current_idx -= 1
            self.update_display()

    def update_display(self):
        with self.output_area:
            clear_output(wait=True)
            print(f"Sample {self.current_idx + 1}/{len(self.features)}")
            interface = self.link_components(self.current_idx)
            display(interface)

    def display_interface(self):
        # Display the full interface
        display(ipw.HBox([self.previous_button, self.next_button]))
        display(self.output_area)

    def get_chips(self, idx):
        geom = self.features[idx]['geometry']
        point = shape(geom)
        bbox = point.buffer(self.window_size).bounds
        cube_sub = self.cube.rio.clip_box(*bbox)
        transform = cube_sub.rio.transform()
        res = cube_sub.rio.resolution()[0]
        imgs = []
        for date in cube_sub.time.values:
            slice = cube_sub.sel(time=date)
            bgr = self.compositor(slice)
            imgs.append(np2ipw(bgr, geom=geom, transform=transform,
                               res=res, outline_color=self.outline_color))
        return imgs
        

    def get_ts(self, idx):
        # TODO: handle that differently if it's a point or a polygon
        cube_sub = self.cube.sel(x=self.features[idx]['geometry']['coordinates'][0],
                                 y=self.features[idx]['geometry']['coordinates'][1],
                                 method='nearest')
        ds = self.vi_calculator(cube_sub)
        return ds.time.values, ds.values

    def link_components(self, idx):
        chips = self.get_chips(self.current_idx)
        dates, values = self.get_ts(self.current_idx)
        # Prepare figure
        x_sc = bqplot.LinearScale()
        y_sc = bqplot.LinearScale()
        # Create axes
        x_ax = bqplot.Axis(label='Time', scale=x_sc, tick_format='%m-%Y', tick_rotate=45)
        y_ax = bqplot.Axis(label='Vegetation Index', scale=y_sc, orientation='vertical', side='left')
        # Create line mark
        vi_values = bqplot.Scatter(x=dates, y=values, scales={'x': x_sc, 'y': y_sc})
        # Create a dummy highlighted point out of view
        highlighted_point = Scatter(x=[-1000], y=[-1000], # Dummy point out of view
                            scales={'x': x_sc, 'y': y_sc},
                            preserve_domain={'x': True, 'y': True},
                            colors=['red'])
        # Create and display the figure
        ts_fig = bqplot.Figure(marks=[vi_values, highlighted_point], axes=[x_ax, y_ax], title='Sample temporal profile')
        ts_fig.layout.width = '100%'
        ts_fig.layout.height = '400px'

        def _handle_chip_event(idx, event):
            """Change the coordinates of the highlighted point to actual date and value when mouse enters chip"""
            if event['type'] == 'mouseenter':
                highlighted_point.x = [dates[idx]]
                highlighted_point.y = [values[idx]]
            if event['type'] == 'mouseleave':
                # Reset dummy location
                highlighted_point.x = [-1000]
                highlighted_point.y = [-1000]
            if event['type'] == 'click':
                if chips[idx].layout.border == '2px solid blue':
                    # If it has a border, remove it
                    chips[idx].layout.border = ''
                else:
                    chips[idx].layout.border = '2px solid blue'
                

        chips_with_event = []
        for idx, chip in enumerate(chips):
            event = Event(source=chip, watched_events = ['mouseenter', 'mouseleave', 'click'])
            event.on_dom_event(functools.partial(_handle_chip_event, idx))
            #chips_with_event.append(chip)

        box_layout = ipw.Layout(
            display='flex',
            flex_flow='row wrap',
            align_items='stretch',
            width='100%',
            height='800px',  # Set a fixed height (modify as needed)
            overflow='auto'  # Add scrollability
        )
        box = ipw.Box(children=chips, layout=box_layout)
        interface = ipw.VBox([ts_fig, box])
        return interface
            




interface = Interface(cube=cube, features=fc)
interface.display_interface()

        

    

HBox(children=(Button(description='Previous', style=ButtonStyle()), Button(description='Next', style=ButtonSty…

Output()

In [4]:
import ipywidgets as widgets
from IPython.display import display, clear_output
import sqlite3

class DataLabelingInterface:
    def __init__(self, data_samples):
        self.data_samples = data_samples
        self.current_index = 0

        # SQLite Database Setup
        self.conn = sqlite3.connect(':memory:')
        self.cursor = self.conn.cursor()
        self.cursor.execute("CREATE TABLE labels (id INTEGER PRIMARY KEY, label TEXT)")
        self.cursor.executemany("INSERT INTO labels (label) VALUES (?)", [(label,) for label in [''] * len(data_samples)])


        # Widgets
        self.next_button = widgets.Button(description='Next')
        self.previous_button = widgets.Button(description='Previous')
        self.label_input = widgets.Text(description='Label')
        self.output_area = widgets.Output()

        # Bind events
        self.next_button.on_click(self.next_sample)
        self.previous_button.on_click(self.previous_sample)
        self.label_input.observe(self.label_changed, names='value')

        # Initialize display
        self.update_display()

    def update_display(self):
        # Update the display area with the current sample
        with self.output_area:
            clear_output(wait=True)
            print(f"Sample {self.current_index + 1}/{len(self.data_samples)}: {self.data_samples[self.current_index]}")
            label = self.get_label(self.current_index)
            self.label_input.value = label
            display(self.label_input)

    def display_interface(self):
        # Display the full interface
        display(widgets.HBox([self.previous_button, self.next_button]))
        display(self.output_area)

    def next_sample(self, b):
        # Go to next sample
        if self.current_index < len(self.data_samples) - 1:
            self.current_index += 1
            self.update_display()

    def previous_sample(self, b):
        # Go to previous sample
        if self.current_index > 0:
            self.current_index -= 1
            self.update_display()

    def label_changed(self, change):
        # Update label in database for current sample
        self.update_label(self.current_index, change['new'])

    def update_label(self, index, label):
        # Update the label in the database
        self.cursor.execute("UPDATE labels SET label = ? WHERE id = ?", (label, index + 1))
        self.conn.commit()

    def get_label(self, index):
        # Retrieve the label from the database
        self.cursor.execute("SELECT label FROM labels WHERE id = ?", (index + 1,))
        result = self.cursor.fetchone()
        return result[0] if result else ''

# Sample data
data_samples = [
    "The quick brown fox jumps over the lazy dog",
    "A stitch in time saves nine",
    "An apple a day keeps the doctor away",
    "Early to bed and early to rise, makes a man healthy, wealthy, and wise"
]

# Create and display the labeling interface
labeling_interface = DataLabelingInterface(data_samples)
labeling_interface.display_interface()



HBox(children=(Button(description='Previous', style=ButtonStyle()), Button(description='Next', style=ButtonSty…

Output()

In [45]:
cube.rio.resolution()[0]

20.0