In [6]:
from datetime import datetime, timedelta

# Guarda una referencia a la función print original
original_print = __builtins__.print

# Redefine la función print
def print(*args, **kwargs):
    current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    original_print(current_time, "-", *args, **kwargs)

In [5]:
from Tesina_Indexs_Utils import calculate_water_index_from_xr_dataset, calculate_fire_index_from_xr_dataset, calculate_vegetation_index_from_xr_dataset, VegetationIndex, FireIndex, WaterIndex, IndexCategory
from Tesina_Indexs_Utils import calculate_cloud_index_from_xr_dataset, CloudIndex
from Tesina_Maps_Utils import get_kml_polygon_masks
from Tesina_Sentinel_Utils import S2Band

In [4]:
from Tesina_Sentinel_Utils import classify_sunny_cloudy_dates_by_scene_classification
from Tesina_Images_Utils import normalize_image_percentile

In [7]:
import numpy as np
import os
from fastkml import  kml
import pandas as pd
import json

In [8]:
# Define the RGB color mapping for each land cover class
land_cover_palette = {
    0: (255, 255, 255),  # No data
    1: (255, 0, 0),  # Saturated / Defective
    2: (0, 0, 0),  # Dark Area Pixels
    3: (255, 255, 0),  # Cloud Shadows
    4: (0, 255, 0),  # Vegetation
    5: (255, 0, 255),  # Bare Soils
    6: (0, 0, 255),  # Water
    7: (255, 72, 0),  # Clouds low probability / Unclassified
    8: (255, 100, 0),  # Clouds medium probability
    9: (255, 128, 0),  # Clouds high probability
    10: (255, 255, 255),  # Cirrus
    11: (255, 255, 255)  # Snow / Ice
}

# Define descriptive labels for each land cover class
land_cover_labels = {
    0: 'No data',
    1: 'Saturated / Defective',
    2: 'Dark Area Pixels',
    3: 'Cloud Shadows',
    4: 'Vegetation',
    5: 'Bare Soils',
    6: 'Water',
    7: 'Clouds low probability / Unclassified',
    8: 'Clouds medium probability',
    9: 'Clouds high probability',
    10: 'Cirrus',
    11: 'Snow / Ice'
}

land_cover_palette_list = [land_cover_palette[i] for i in range(len(land_cover_palette))]
land_cover_palette_array = np.array(land_cover_palette_list)


In [10]:
from lxml import etree

#Read kml with forests polygons and convert it to GeoDataFrame
#Data is in zones folder, read every file and convert it to GeoDataFrame

#Get the list of files from the zones folder
files = os.listdir("./zones")
print(files)

zone_dict = {}
#resolution of the image
resolution = 10
#Get the data from the sentinel-2
km2deg = 1/110.574 # 1km in degrees

#Loop through the files
for file in files:
    print(f'Processing file: {file}')
    doc = open("./zones/" + file, "r", encoding='utf-8', errors='ignore').read()
    k = kml.KML()
    
    try:
        k.from_string(doc.encode('utf-8'))
    except etree.ParseError:
        print(f"Unable to parse file: {file}")
        continue

    features = list(list(k.features())[0].features())   
    for f in features:
        print(f.name)
        zone_dict[f.name] = {}

    polygons = list(features[0].features())
    for p in polygons:
        print(p.name)
        #read polygon and get square of it
        polygon = p.geometry
        #add polygon to the dictionary
        zone_dict[f.name][p.name] = {}
        zone_dict[f.name][p.name]["polygon"] = polygon
        print("Bounding box: {}".format(polygon.bounds))

2023-12-22 16:11:18 - ['Bosques Bio Bio.kml', 'Incendios.kml', 'Bosques Arauco.kml', 'Provoque.kml']
2023-12-22 16:11:18 - Processing file: Bosques Bio Bio.kml
2023-12-22 16:11:18 - Bosques Bio Bio
2023-12-22 16:11:18 - Bosque 1
2023-12-22 16:11:18 - Bounding box: (-72.45142890120229, -37.1935384751085, -72.43650978092387, -37.18581491700413)
2023-12-22 16:11:18 - Bosque 2
2023-12-22 16:11:18 - Bounding box: (-72.42482075716856, -37.16361946111896, -72.41455274852883, -37.15297897451463)
2023-12-22 16:11:18 - Bosque 3
2023-12-22 16:11:18 - Bounding box: (-72.42423154064022, -37.1722707957546, -72.39075268176586, -37.15535508695594)
2023-12-22 16:11:18 - Processing file: Incendios.kml
2023-12-22 16:11:18 - Incendios
2023-12-22 16:11:18 - Chiguayante 1
2023-12-22 16:11:18 - Bounding box: (-73.14301140515278, -36.97713577738772, -73.07640424786813, -36.93275197623442)
2023-12-22 16:11:18 - Chiguayante 2
2023-12-22 16:11:18 - Bounding box: (-73.14176444168599, -36.9381030333295, -73.075936

In [12]:
from datetime import datetime
import hashlib
from Tesina_General_Utils import Logger, LogLevel
from odc.stac import stac_load
import xarray as xr

def convert_band_uint16_dtype(variable_data):
    if variable_data.dtype == 'uint16':
        max_value = variable_data.max().values.item()
        if max_value <= np.iinfo(np.int16).max:
            new_dtype = 'int16'
        else:
            new_dtype = 'int32'
        return variable_data.astype(new_dtype)
    else:
        return variable_data

def fetch_or_cache_stac_data_by_band_and_month(catalog, bands, collections, start_date:str, end_date:str, limit, bbox, cloud_cover, resolution, stac_config=None, crs="EPSG:4326", cache_dir='data_cache', force_download=False, log_level: LogLevel = LogLevel.Info, save_rgb_image=False):

    logger = Logger(min_log_level=log_level)
    start_date_dt = datetime.strptime(start_date, "%Y-%m-%d")
    end_date_dt = datetime.strptime(end_date, "%Y-%m-%d")
    current_date_dt = datetime.now()

    combined_data = []  # List to hold the data for each month
    properties = []  # List to hold the properties for each month

    current_month_start = start_date_dt
    while current_month_start <= end_date_dt:    
        range_properties = []  # List to hold the properties for a specific month
        logger.debug(f"Processing month {current_month_start.date()} - {current_month_start + pd.DateOffset(months=1) - pd.DateOffset(minutes=1)}")
        # Determine the date range for the cache file
        cache_start_date = current_month_start
        cache_end_date = min(current_month_start + pd.DateOffset(months=1), end_date_dt)
        if cache_end_date > current_date_dt:
            cache_end_date = current_date_dt

        cache_start_date_str = cache_start_date.strftime("%Y-%m-%d")
        cache_end_date_str = cache_end_date.strftime("%Y-%m-%d")

        group_data_list = []
        bands_to_download = []
        bands_to_skip = []
        dims = None
        items = None

        if force_download:
            logger.info(f"Force download data for all bands")
            bands_to_download = bands
        else:   
            # Check for each band
            for band in bands:
                metadata_str = f"{collections}_{cache_start_date_str}_{cache_end_date_str}_{limit}_{bbox}_{cloud_cover}_{stac_config}_{resolution}_{crs}_{band}"
                metadata_hash = hashlib.md5(metadata_str.encode()).hexdigest()
                cache_filepath = os.path.join(cache_dir, f"{metadata_hash}.nc")

                if os.path.exists(cache_filepath):
                    bands_to_skip.append(band)
                    band_data = xr.open_dataset(cache_filepath)
                    variable_data =  band_data[band]
                    invalid_values_count = variable_data.isnull().sum() + (variable_data == 0).sum()
                    invalid_values_percent = (invalid_values_count / variable_data.size).values

                    # If more of 5% of the data is NaN or 0 then re-download the data
                    if invalid_values_percent > 0.05 and band != 'wvp':
                        logger.warning(f"Data found in cache for band {band} but {invalid_values_percent*100}% of the data is NaN or 0. Re-downloading data")
                        # Remove the file
                        os.remove(cache_filepath)
                        bands_to_download.append(band)
                        bands_to_skip.remove(band)
                    elif 'time' in band_data.coords:
                        group_data_list.append(band_data)
                    else:
                        logger.error(f"Data found in cache for band {band} but it does not have a 'time' dimension. Skipping")
                        bands_to_download.append(band)
                        bands_to_skip.remove(band)
                else:
                    bands_to_download.append(band)
                
            if len(bands_to_skip) == len(bands):
                logger.debug(f"Data found in cache for all bands. Skipping download")
            elif len(bands_to_skip) > 0:
                logger.info(f"Data found in cache for bands {bands_to_skip}. Downloading data for bands {bands_to_download}")
        
        prop_metadata_str = f"{collections}_{cache_start_date_str}_{cache_end_date_str}_{limit}_{bbox}_{cloud_cover}_{stac_config}_{resolution}_{crs}"
        prop_metadata_hash = hashlib.md5(prop_metadata_str.encode()).hexdigest()
        prop_cache_filepath = os.path.join(cache_dir, f"{prop_metadata_hash}.json")

        if os.path.exists(prop_cache_filepath):
            logger.debug(f"Data found in cache for month {current_month_start.date()} - {current_month_start + pd.DateOffset(months=1) - pd.DateOffset(minutes=1)}")
            with open(prop_cache_filepath, 'r') as f:
                range_properties = json.load(f)

        if bands_to_download or len(range_properties) == 0:
            if collections ==  ["sentinel-2-l2a"]:
                query = catalog.search(
                    collections=collections,
                    datetime=f"{cache_start_date_str}/{cache_end_date_str}",
                    limit=limit,
                    bbox=bbox,
                    #query=[f'eo:cloud_cover<={cloud_cover}']
                )
            else:
                query = catalog.search(
                    collections=collections,
                    datetime=f"{cache_start_date_str}/{cache_end_date_str}",
                    limit=limit,
                    bbox=bbox,
                )
            
            items = list(query.get_items())
            if len(items) != 0:
                #list with item properties to json
                range_properties = [item.properties for item in items]
                logger.info(f"Saving properties for month {current_month_start.date()} - {current_month_start + pd.DateOffset(months=1) - pd.DateOffset(minutes=1)}")
                with open(prop_cache_filepath, 'w') as f:
                    json.dump(range_properties, f)
                
                if collections ==  ["sentinel-2-l1c"]:
                    for i in items:
                        for a in i.assets:
                            i.assets[a].href = i.assets[a].href.replace('sentinel-s2-l2a', 'sentinel-s2-l1c')

                data = stac_load(items, resolution=10, bands=bands_to_download, bbox=bbox, stac_cfg=stac_config, groupby='solar_day', crs=crs)


                #if there is data to download
                if len(data) > 0:
                    # Filter for the month
                    month_data = data.sel(time=slice(cache_start_date_str, cache_end_date_str))
                
                    #print shape (w,h)
                    if dims is None:
                        dims = month_data.dims
                    
                    # Check if the data is the expected shape is smaller than expected then raise an error
                    if month_data.dims['x'] < dims['x'] or month_data.dims['y'] < dims['y']:
                        raise Exception(f"Data for bands {bands_to_download} is smaller than expected. Expected shape: {dims}. Actual shape: {month_data.dims}")

                    # Compute and save each band                    
                    logger.info(f"Saving data for bands {bands_to_download}")   
                    #print only date part             
                    logger.info(f"Dates for bands: {', '.join([date.strftime('%Y-%m-%d') for date in month_data.time.dt.date.values])}")


                    for band in bands_to_download:
                        metadata_str = f"{collections}_{cache_start_date_str}_{cache_end_date_str}_{limit}_{bbox}_{cloud_cover}_{stac_config}_{resolution}_{crs}_{band}"
                        metadata_hash = hashlib.md5(metadata_str.encode()).hexdigest()
                        cache_filepath = os.path.join(cache_dir, f"{metadata_hash}.nc")

                        band_data = month_data[band].compute()

                        if 'time' not in band_data.coords:
                            logger.error(f"Data for band {band} does not have a 'time' dimension. Here are the coordinates: {band_data.coords}")
                            continue
                        
                        # If more of 50% of the data is NaN or 0 print a warning (check by date and band)
                        # Create a list to store the dates that you want to keep
                        valid_dates = []

                        for date in band_data['time']:
                            # Select the data for the current date
                            date_data = band_data.sel(time=date)

                            # Calculate the number of NaN values
                            nan_count = date_data.isnull().sum().values

                            # Calculate the number of 0 values
                            zero_count = (date_data == 0).sum().values

                            # Calculate the total percentage of NaN or zero values
                            invalid_values_percent = (nan_count + zero_count) / date_data.size

                            # Make a decision based on that percentage
                            if invalid_values_percent > 0.05 and band != 'wvp':
                                logger.warning(f"{invalid_values_percent*100}% of the data for band {band} at date {date.values} is NaN or 0. Skipping.")
                            else:
                                valid_dates.append(date.values)

                        # Filter band_data to only the valid dates
                        band_data = band_data.sel(time=valid_dates)

                        if len(band_data.time) == 0:
                            logger.warning(f"No valid data found for band {band}. Skipping")
                            continue
                        
                        logger.info(f"Saving data for band {band} to {cache_filepath}. Shape: {band_data.shape}")
                        
                        band_data.to_netcdf(cache_filepath, engine='netcdf4')
                        group_data_list.append(band_data)
            else:
                logger.warning(f"No data found for bands {bands_to_download}")

        properties += range_properties
        
        if len(group_data_list) == len(bands):
            group_data_list = [ds.drop_vars('spatial_ref') for ds in group_data_list]
            # Combine the bands for this month
            group_data = xr.merge(group_data_list)
            combined_data.append(group_data)
            
        if save_rgb_image and len(bands_to_download) > 0:
            # Check if the RGB image is in group_data_list
            rgb_bands = ['red', 'green', 'blue']
            if all(band in group_data_list for band in rgb_bands):
                # Take the first image in the month
                rgb_data = group_data_list[0][rgb_bands]
        # Move to the next month
        current_month_start += pd.DateOffset(months=1)

    for i, ds in enumerate(combined_data):
        if 'time' not in ds.coords:
            logger.error(f"The dataset for month {i} does not have a 'time' dimension. Here are the coordinates: {ds.coords}")
            print(ds.coords)
        
    # Concatenate the monthly data along the time dimension
    final_data = xr.concat(combined_data, dim='time', coords='minimal')
    # Remove redundant time values
    final_data = final_data.sel(time=~final_data.indexes['time'].duplicated())
    logger.debug("Concatenated data")

    # Initialize an empty list to collect times with NaN values
    nan_times_list = []

    for band in bands:
        data_array = final_data[band]
        numpy_array = data_array.data  # This is a NumPy array
        
        # Check for NaN in the NumPy array
        where_nan = np.isnan(numpy_array)
        
        # Sum along all dimensions to find any slice along 'time' that contains at least one NaN
        nan_along_time = np.any(where_nan, axis=(1, 2))  # Here I assumed the time is the first dimension
        
        # Retrieve the corresponding time values
        times_with_nan = data_array['time'].values[nan_along_time]
        if len(times_with_nan) > 0:            
            #nan_times_list.extend(np.unique(times_with_nan.values))
            nan_times_list.extend(times_with_nan)
            logger.debug(f"Band {band} has NaN values on the following dates: {times_with_nan}")

    # Remove duplicates from the list
    unique_nan_times = list(set(nan_times_list))

    if len(unique_nan_times) > 0:
        logger.warning(f"Unique dates with NaN values across all bands: {unique_nan_times}")    

    # Filter out the dates with NaN values from final_data
    final_data = final_data.sel(time=~final_data['time'].isin(unique_nan_times))

    # Remove elements from properties that are duplicated in datetime
    xr_times = final_data['time'].values.astype(str).tolist()

    seen = set()
    filtered_properties = [seen.add(d['datetime'][:10]) or d for d in properties if d['datetime'][:10] not in seen]
    final_properties = [d for d in filtered_properties if d['datetime'][:19] in {x[:19] for x in xr_times}]
    sorted_final_properties = sorted(final_properties, key=lambda x: x['datetime'])

    logger.debug("End of stac_load")
    return final_data, sorted_final_properties


In [13]:
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches





In [14]:
import ipywidgets as widgets
from IPython.display import display
import matplotlib.pyplot as plt
import numpy as np
from datetime import datetime, timedelta
from matplotlib import cm
import matplotlib

def interactive_image_plotter_index(rgb_images, index_images, dates, mask = None, name= None):
    
    vmin_all = np.min(index_images)
    vmax_all = np.max(index_images)
    print(f"Valor mínimo: {vmin_all}")
    print(f"Valor máximo: {vmax_all}")
    
    if mask is not None:
        mask3 = np.repeat(mask[:, :, np.newaxis], 3, axis=2)
   
    def plot_images(date_index, threshold_a, threshold_b):
        """
        Dibuja un par de imágenes RGB e índice basado en el índice de fecha seleccionado.
        """        
        nonlocal selected_date
        selected_date.value = f"Fecha seleccionada: {dates[date_index].strftime('%Y-%m-%d')}"
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
        
        transposed_rgb = rgb_images[date_index].transpose((1, 2, 0))
        transposed_rgb = transposed_rgb * mask3 if mask is not None else transposed_rgb
        transposed_rgb = normalize_image_percentile(transposed_rgb)
        
        # Mostrar imagen RGB
        ax1.imshow(transposed_rgb)
        ax1.set_title(f'{dates[date_index].strftime("%Y-%m-%d")} RGB')
        ax1.axis('off')
        
        index_img = index_images[date_index]
        normed_img = (index_img - threshold_a) / (threshold_b - threshold_a)  # Normalize the data
        my_cmap = matplotlib.colormaps['viridis']
        colored_img = my_cmap(normed_img)
        colored_img = colored_img[:, :, :3]

        threshold_mask = (index_img > threshold_a) & (index_img < threshold_b)
        threshold_mask3 = np.repeat(threshold_mask[:, :, np.newaxis], 3, axis=2)
        
        masked_img = colored_img * threshold_mask3  * mask3 if mask is not None else colored_img * threshold_mask3

        # Mostrar imagen de índice
        cax = ax2.imshow(masked_img, cmap='viridis', vmin=threshold_a, vmax=threshold_b)
        if name is not None:
            ax2.set_title(f'{dates[date_index].strftime("%Y-%m-%d")} {name}')
        else:
            ax2.set_title(f'{dates[date_index].strftime("%Y-%m-%d")} Índice')
        ax2.axis('off')
        cbar = plt.colorbar(cax, ax=ax2)
        cbar.set_label('Valor del Índice', rotation=270, labelpad=20)

        plt.show()
    
    threshold_a = widgets.BoundedFloatText(
        value=vmin_all,  # Inicializar al mínimo global
        min=vmin_all,    # Establecer el mínimo permisible
        max=vmax_all,    # Establecer el máximo permisible
        step=0.01,
        description='Umbral A:',
    )

    threshold_b = widgets.BoundedFloatText(
        value=vmax_all,  # Inicializar al máximo global
        min=vmin_all,    # Establecer el mínimo permisible
        max=vmax_all,    # Establecer el máximo permisible
        step=0.01,
        description='Umbral B:',
    )
    
    # Create slider widget for date selection
    date_slider = widgets.IntSlider(
        value=0,
        min=0,
        max=len(dates) - 1,
        step=1,
        description='Index:',
        continuous_update=False
    )
    date_slider.layout.width = '500px'
    date_slider.index = dates
    
    # Create buttons for easier navigation
    prev_button = widgets.Button(description="Previous")
    next_button = widgets.Button(description="Next")
    
    def on_prev_button_clicked(b):
        date_slider.value = max(0, date_slider.value - 1)

    def on_next_button_clicked(b):
        date_slider.value = min(len(dates) - 1, date_slider.value + 1)
    
   
        
    prev_button.on_click(on_prev_button_clicked)
    next_button.on_click(on_next_button_clicked)
    
    # Create a label to display the selected date
    selected_date = widgets.Label(value="")

    # Create the interactive plot
    interactive_plot = widgets.interactive(plot_images, date_index=date_slider, threshold_a=threshold_a, threshold_b=threshold_b)
    
    # Show the widgets
    display(widgets.HBox([prev_button, next_button]), selected_date, interactive_plot)



In [15]:
from pystac_client import Client
from odc.stac import configure_rio, stac_load
import dask.distributed

catalog = Client.open("https://earth-search.aws.element84.com/v1")
client = dask.distributed.Client()
configure_rio(cloud_defaults=True, aws={"aws_unsigned": True}, client=client)
display(client)

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 5
Total threads: 20,Total memory: 15.50 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:43975,Workers: 5
Dashboard: http://127.0.0.1:8787/status,Total threads: 20
Started: Just now,Total memory: 15.50 GiB

0,1
Comm: tcp://127.0.0.1:41443,Total threads: 4
Dashboard: http://127.0.0.1:33631/status,Memory: 3.10 GiB
Nanny: tcp://127.0.0.1:44395,
Local directory: /tmp/dask-scratch-space/worker-q1xrn6oj,Local directory: /tmp/dask-scratch-space/worker-q1xrn6oj

0,1
Comm: tcp://127.0.0.1:45593,Total threads: 4
Dashboard: http://127.0.0.1:44319/status,Memory: 3.10 GiB
Nanny: tcp://127.0.0.1:35215,
Local directory: /tmp/dask-scratch-space/worker-2g89_eqy,Local directory: /tmp/dask-scratch-space/worker-2g89_eqy

0,1
Comm: tcp://127.0.0.1:45701,Total threads: 4
Dashboard: http://127.0.0.1:38631/status,Memory: 3.10 GiB
Nanny: tcp://127.0.0.1:39451,
Local directory: /tmp/dask-scratch-space/worker-9gbz8lbo,Local directory: /tmp/dask-scratch-space/worker-9gbz8lbo

0,1
Comm: tcp://127.0.0.1:40549,Total threads: 4
Dashboard: http://127.0.0.1:39423/status,Memory: 3.10 GiB
Nanny: tcp://127.0.0.1:43983,
Local directory: /tmp/dask-scratch-space/worker-5o4pv8ny,Local directory: /tmp/dask-scratch-space/worker-5o4pv8ny

0,1
Comm: tcp://127.0.0.1:43335,Total threads: 4
Dashboard: http://127.0.0.1:41425/status,Memory: 3.10 GiB
Nanny: tcp://127.0.0.1:41517,
Local directory: /tmp/dask-scratch-space/worker-4ch5vhf8,Local directory: /tmp/dask-scratch-space/worker-4ch5vhf8


In [16]:
#Get the start date and end date from the user (year and month)
history_start  = datetime(2022, 1, 1)
history_end = datetime(2023, 5, 30)
ignore_ranges = []
ignore_ranges.append([datetime(2022, 1, 30), datetime(2022, 2, 28)]) #Problems with the data

In [17]:
from shapely.geometry import Polygon, MultiPolygon, Point
import importlib

stac_cfg = {
    'sentinel-2-l2a': {
        'assets': {'*': {'data_type': 'uint16', 'nodata': 0}},
    }
}

ndvi_thresholds = [None, -0.1, -0.01]
rvi_thresholds = [None, -0.1, 0, 0.1, 0.25, 1]
savi_thresholds = [None, -0.1, 0, 0.1, 1]
arvi_thresholds = [None, -0.1, 0, 0.1, 1]
bai_thresholds = [0.01, 0.1, 1]
nbr_thresholds = [None, -0.1, 0]
afi_thresholds = [None,-0.15]
msrif_thresholds = [None, 0, None]
#bands = ['coastal', 'blue', 'green', 'red', 'nir', 'nir08', 'nir09', 'rededge1', 'rededge2', 'rededge3', 'scl', 'swir16', 'swir22']
bands = ['coastal','blue', 'green', 'red', 'nir', 'nir08', 'nir09', 'rededge1', 'rededge2', 'rededge3', 'scl', 'swir16', 'swir22']



thresholds = {
    'eo:cloud_cover': 20,
    's2:medium_proba_clouds_percentage': 10,
    's2:high_proba_clouds_percentage': 5,
    's2:low_proba_clouds_percentage': 20,
    's2:thin_cirrus_percentage': 10,
    's2:cloud_shadow_percentage': 5
}

In [18]:
def get_images_by_zone(zone, region, only_sunny_dates = True):
    input_path = make_dir(["./data/input", zone, region])
    output_path = make_dir(["./data/output", zone, region])
    polygon = zone_dict[zone][region]["polygon"]
        
    data, properties = fetch_or_cache_stac_data_by_band_and_month(catalog, 
                                                    collections=["sentinel-2-l2a"],
                                                    bands=bands,
                                                    start_date=history_start.strftime("%Y-%m-%d"),
                                                    end_date= history_end.strftime("%Y-%m-%d"), 
                                                    limit=1000,
                                                    bbox=polygon.bounds,
                                                    resolution=10,
                                                    cloud_cover=100,
                                                    stac_config=stac_cfg,
                                                    #crs="EPSG:4326",
                                                    crs = "EPSG:3857",
                                                    cache_dir=input_path,
                                                    force_download=False,
                                                    log_level= LogLevel.Debug)    
    
    #mask, mask3 = get_kml_polygon_masks(polygon, data['red'].shape[2], data['red'].shape[1])
        
    data = data.sel(time=[t for t in data.time.values if all(t < np.datetime64(start) or t > np.datetime64(end) for start, end in ignore_ranges)])
    xr_times = data['time'].values.astype(str).tolist()
    properties = [d for d in properties if d['datetime'][:19] in {x[:19] for x in xr_times}]
    mask, mask3 = get_kml_polygon_masks(polygon, data['red'].shape[2], data['red'].shape[1])
    if(only_sunny_dates):
        sunny_dates, cloudy_dates = classify_sunny_cloudy_dates_by_scene_classification(data, thresholds, mask)        
        data = data.sel(time=[t for t in data.time.values if t in sunny_dates])
    
    dates = [pd.Timestamp(date).to_pydatetime().replace(hour=0, minute=0, second=0, microsecond=0)  for date in data['time'].values]
    return data, dates, polygon, properties

In [19]:
from enum import Enum

class ComparisonType(Enum):
    GREATER_THAN = "greater_than"
    LESS_THAN = "less_than"

class IndexSource(Enum):
    FINAL_IMAGE = "Imagen final"
    REFERENCE_IMAGE = "Imagen de referencia"
    DIFFERENCE = "Diferencia"
    ALL_TOGETHER = "Todas los filtros juntos"
    
    @staticmethod
    def get_position(value):
        return [e.value for e in IndexSource].index(value)
  

class ImageMaskConfiguration:
    def __init__(self, enable=False, source=0, index_category=0, index_type=0, comparison=ComparisonType.GREATER_THAN, threshold=0.0):
        self._enable = enable
        self._source = source
        self._index_category = index_category
        self._index_type = index_type
        self._comparison = comparison
        self._threshold = threshold
        
    @property
    def enable(self):
        """Getter for enable attribute"""
        return self._enable

    @enable.setter
    def enable(self, value):
        """Setter for enable attribute with basic validation"""
        if isinstance(value, bool):
            self._enable = value
        else:
            print("Enable should be a boolean value.")
        
    @property
    def source(self):
        """Getter for source attribute"""
        return self._source
    
    @source.setter
    def source(self, value):
        """Setter for source attribute with basic validation"""
        if isinstance(value, int):
            self._source = value
        else:
            print("Source should be an integer.")
        
    @property
    def index_category(self):
        """Getter for index_category attribute"""
        return self._index_category

    @index_category.setter
    def index_category(self, value):
        """Setter for index_category attribute with basic validation"""
        if isinstance(value, int):
            self._index_category = value
        else:
            print("IndexType should be an integer.")
            
    @property
    def index_type(self):
        """Getter for index_type attribute"""
        return self._index_type

    @index_type.setter
    def index_type(self, value):
        """Setter for index_type attribute with basic validation"""
        if isinstance(value, int):
            self._index_type = value
        else:
            print("IndexValue should be an integer.")
            
    @property
    def comparison(self):
        """Getter for comparison attribute"""
        return self._comparison

    @comparison.setter
    def comparison(self, value):
        """Setter for comparison attribute with validation against Enum"""
        if isinstance(value, ComparisonType):
            self._comparison = value
        else:
            print("Invalid comparison. Please use a value from ComparisonType Enum.")
            
    @property
    def threshold(self):
        """Getter for threshold attribute"""
        return self._threshold

    @threshold.setter
    def threshold(self, value):
        """Setter for threshold attribute with basic validation"""
        if isinstance(value, float):
            self._threshold = value
        else:
            print("Threshold should be a float.")
            
    def print(self):
        print(f"Enable: {self.enable}")
        print(f"IndexSource: {self.source}")
        print(f"IndexCategory: {self.index_category}")
        print(f"IndexType: {self.index_type}")
        print(f"Comparison: {self.comparison}")
        print(f"Threshold: {self.threshold}")
        

class LandscapeLabels(Enum):
    DEFORESTED = "Bosque Talado"
    BURNED = "Bosque Quemado"
    WATER = "Agua"
    FOREST = "Bosque"
    OTHER_VEGETATION = "Otra vegetacion"
    PEST_IN_FOREST = "Plaga en Bosque"
    CLOUDS = "Nubes"

    @classmethod
    def get_position(cls, label_name):
        labels = [label for label in cls]
        for position, label in enumerate(labels):
            if label_name == label.value:
                return position
        return None  # Retorna None si el nombre de la etiqueta no se encuentra
    
    @classmethod
    def image_filters_configure(cls, label, configurations):
        if isinstance(label, str):
            label = cls(label)
        
        for config in configurations:
            config.enable = False
            config.source = 0
            config.index_category = 0
            config.index_type = 0
            config.comparison = ComparisonType.GREATER_THAN
            config.threshold = 0.0
            
        if label == cls.DEFORESTED:
            config = configurations[0]
            config.enable = True
            config.source = list(IndexSource).index(IndexSource.REFERENCE_IMAGE)
            config.index_category = IndexCategory.get_index_category_position(IndexCategory.VEGETATION)
            config.index_type = VegetationIndex.get_position(VegetationIndex.NDVI)
            config.comparison = ComparisonType.GREATER_THAN
            config.threshold = 0.65
            config = configurations[1]
            config.enable = True
            config.source = list(IndexSource).index(IndexSource.DIFFERENCE)
            config.index_category = IndexCategory.get_index_category_position(IndexCategory.VEGETATION)
            config.index_type = VegetationIndex.get_position(VegetationIndex.NDVI)
            config.comparison = ComparisonType.LESS_THAN
            config.threshold = -0.2
        elif label == cls.BURNED:
            config = configurations[0]
            config.enable = True
            config.source = list(IndexSource).index(IndexSource.REFERENCE_IMAGE)
            config.index_category = IndexCategory.get_index_category_position(IndexCategory.VEGETATION)
            config.index_type = VegetationIndex.get_position(VegetationIndex.NDVI)
            config.comparison = ComparisonType.GREATER_THAN
            config.threshold = 0.65
            config = configurations[1]
            config.enable = True
            config.source = list(IndexSource).index(IndexSource.DIFFERENCE)
            config.index_category = IndexCategory.get_index_category_position(IndexCategory.VEGETATION)
            config.index_type = VegetationIndex.get_position(VegetationIndex.NDVI)
            config.comparison = ComparisonType.LESS_THAN
            config.threshold = -0.1  
            config = configurations[2]
            config.enable = True
            config.source = list(IndexSource).index(IndexSource.DIFFERENCE)
            config.index_category = IndexCategory.get_index_category_position(IndexCategory.FIRE)
            config.index_type = FireIndex.get_position(FireIndex.NBR)
            config.comparison = ComparisonType.LESS_THAN
            config.threshold = -0.4
            config = configurations[3]
            config.enable = True
            config.source = list(IndexSource).index(IndexSource.DIFFERENCE)
            config.index_category = IndexCategory.get_index_category_position(IndexCategory.FIRE)
            config.index_type = FireIndex.get_position(FireIndex.MSRIF)
            config.comparison = ComparisonType.LESS_THAN
            config.threshold = 0.3
        elif label == cls.WATER:
            config = configurations[0]
            config.enable = True
            config.index_category = IndexCategory.get_index_category_position(IndexCategory.VEGETATION)
            config.index_type = VegetationIndex.get_position(VegetationIndex.NDVI)
            config.comparison = ComparisonType.LESS_THAN
            config.threshold = 0.2
            config = configurations[1]
            config.enable = True
            config.index_category = IndexCategory.get_index_category_position(IndexCategory.WATER)
            config.index_type = WaterIndex.get_position(WaterIndex.MNDWI)
            config.comparison = ComparisonType.GREATER_THAN
            config.threshold = 0.1
        elif label == cls.FOREST:
            config = configurations[0]
            config.enable = True
            config.source = list(IndexSource).index(IndexSource.REFERENCE_IMAGE)
            config.index_category = IndexCategory.get_index_category_position(IndexCategory.VEGETATION)
            config.index_type = VegetationIndex.get_position(VegetationIndex.NDVI)
            config.comparison = ComparisonType.GREATER_THAN
            config.threshold = 0.6
            config = configurations[1]
            config.enable = True
            config.source = list(IndexSource).index(IndexSource.FINAL_IMAGE)
            config.index_category = IndexCategory.get_index_category_position(IndexCategory.VEGETATION)
            config.index_type = VegetationIndex.get_position(VegetationIndex.NDVI)
            config.comparison = ComparisonType.GREATER_THAN
            config.threshold = 0.7
        elif label == cls.CLOUDS:
            config = configurations[0]
            config.enable = True
            config.source = list(IndexSource).index(IndexSource.FINAL_IMAGE)
            config.index_category = IndexCategory.get_index_category_position(IndexCategory.CLOUDS)
            config.index_type = CloudIndex.get_position(CloudIndex.CCI)
            config.comparison = ComparisonType.LESS_THAN
            config.threshold = -0.1
            config = configurations[1]
            config.enable = True
            config.source = list(IndexSource).index(IndexSource.DIFFERENCE)
            config.index_category = IndexCategory.get_index_category_position(IndexCategory.VEGETATION)
            config.index_type = VegetationIndex.get_position(VegetationIndex.NDVI)
            config.comparison = ComparisonType.LESS_THAN
            config.threshold = -0.1
            


In [20]:
import xarray as xr

def calculate_mask_image(data: xr.Dataset, date, image_mask_config: ImageMaskConfiguration):
    
    date_data = data.sel(time=date, method='nearest')
    
    if image_mask_config.enable:
        index = None
        
        print(f'Calculando {image_mask_config.index_category} {image_mask_config.index_type} para {date}')
        if image_mask_config.index_category == IndexCategory.get_index_category_position(IndexCategory.VEGETATION):
            index = calculate_vegetation_index_from_xr_dataset(date_data, VegetationIndex.get_index(image_mask_config.index_type))
        elif image_mask_config.index_category == IndexCategory.get_index_category_position(IndexCategory.FIRE):
            index = calculate_fire_index_from_xr_dataset(date_data, FireIndex.get_index(image_mask_config.index_type))
        elif image_mask_config.index_category == IndexCategory.get_index_category_position(IndexCategory.WATER):
            index = calculate_water_index_from_xr_dataset(date_data, WaterIndex.get_index(image_mask_config.index_type))
        elif image_mask_config.index_category == IndexCategory.get_index_category_position(IndexCategory.CLOUDS):
            index = calculate_cloud_index_from_xr_dataset(date_data, CloudIndex.get_index(image_mask_config.index_type))
        elif image_mask_config.index_category == IndexCategory.get_index_category_position(IndexCategory.S2BANDS):
            index = date_data[S2Band.from_position(image_mask_config.index_type).value]
            
    
        if index is None:
            raise ValueError("Index calculation failed or index type not recognized.")

        if image_mask_config.comparison == ComparisonType.GREATER_THAN:
            mask_by_index = index > image_mask_config.threshold
        else:
            mask_by_index = index < image_mask_config.threshold
        return mask_by_index, index
    else:
        return None, None

def calculate_mask_difference_of_index_images(data: xr.Dataset, date, date_prev, image_mask_config: ImageMaskConfiguration):
    date_data = data.sel(time=date, method='nearest')
    date_prev_data = data.sel(time=date_prev, method='nearest')
    
    if image_mask_config.enable:
        index = None
        index_prev = None
        
        if image_mask_config.index_category == IndexCategory.get_index_category_position(IndexCategory.VEGETATION):
            index = calculate_vegetation_index_from_xr_dataset(date_data, VegetationIndex.get_index(image_mask_config.index_type))
            index_prev = calculate_vegetation_index_from_xr_dataset(date_prev_data, VegetationIndex.get_index(image_mask_config.index_type))
        elif image_mask_config.index_category == IndexCategory.get_index_category_position(IndexCategory.FIRE):
            index = calculate_fire_index_from_xr_dataset(date_data, FireIndex.get_index(image_mask_config.index_type))
            index_prev = calculate_fire_index_from_xr_dataset(date_prev_data, FireIndex.get_index(image_mask_config.index_type))
        elif image_mask_config.index_category == IndexCategory.get_index_category_position(IndexCategory.WATER):
            index = calculate_water_index_from_xr_dataset(date_data, WaterIndex.get_index(image_mask_config.index_type))
            index_prev = calculate_water_index_from_xr_dataset(date_prev_data, WaterIndex.get_index(image_mask_config.index_type))
        elif image_mask_config.index_category == IndexCategory.get_index_category_position(IndexCategory.CLOUDS):
            index = calculate_cloud_index_from_xr_dataset(date_data, CloudIndex.get_index(image_mask_config.index_type))
            index_prev = calculate_cloud_index_from_xr_dataset(date_prev_data, CloudIndex.get_index(image_mask_config.index_type))
        elif image_mask_config.index_category == IndexCategory.get_index_category_position(IndexCategory.S2BANDS):
            index = date_data[S2Band.from_position(image_mask_config.index_type).value]
            index_prev = date_prev_data[S2Band.from_position(image_mask_config.index_type).value]
        
        if index is None or index_prev is None:
            raise ValueError("Index calculation failed or index type not recognized.")
        
        index_difference = index - index_prev
        
        if image_mask_config.comparison == ComparisonType.GREATER_THAN:
            mask_by_index = index_difference > image_mask_config.threshold
        else:
            mask_by_index = index_difference < image_mask_config.threshold
            
        return mask_by_index, index_difference
    else:
        return None, None
        


In [21]:
from shapely.affinity import affine_transform
import numpy as np
from skimage.measure import find_contours
from shapely.geometry import Polygon, mapping
import geopandas as gpd

def get_transformation_matrix(polygon_bound, mask_shape):
    """
    Calculate the transformation matrix based on the polygon_bound and mask_shape.
    """
    minx, miny, maxx, maxy = polygon_bound.bounds  # Obtaining the bounds of the polygon
    print(f"Bounds: {minx}, {miny}, {maxx}, {maxy}")
    y_pixels, x_pixels = mask_shape
    print(f"Shape: {x_pixels}, {y_pixels}")

    x_scale = (maxx - minx) / x_pixels
    y_scale = (maxy - miny) / y_pixels
    print(f"Scales: {x_scale}, {y_scale}")

    # Crear una matriz de transformación para convertir las coordenadas de píxeles a coordenadas geográficas
    # [x_scale, 0, 0, y_scale, minx, miny]
    transformation_matrix = [x_scale, 0, 0, -y_scale, minx, maxy]  

    return transformation_matrix


def get_polygons_from_mask(mask, label, polygon_bound, zone, region, type_index, date, ref_img_date = None, ref_filename = None, ref_prev_img_filename = None):
    """
    Convert a mask to a list of polygons with metadata.

    Parameters:
    mask (numpy array or xarray): A binary mask to converrt to polygons.
    polygon_bound (Polygon): The bounding polygon.

    Returns:
    list: A list of transformed polygons.
    """
    polygons = []

    # Check if mask is an xarray, and if so, extract the numpy array
    if hasattr(mask, 'values'):
        mask_values = mask.values
    else:
        mask_values = mask

    # Making the first and last row and column zero
    mask_values[0, :] = 0
    mask_values[-1, :] = 0
    mask_values[:, 0] = 0
    mask_values[:, -1] = 0

    mask_values = np.transpose(mask_values)

    contours = find_contours(mask_values, level=0.5)
    transformation_matrix = get_transformation_matrix(polygon_bound, mask.shape)  # Define this function

    for contour in contours:
        # Convert contour coordinates to polygon
        poly = Polygon(contour)

        # Apply the transformation
        poly_transformed = affine_transform(poly, transformation_matrix)
        polygons.append({
            'geometry': poly_transformed,
        })

    polygons = associate_holes(polygons)  # Function to classify polygons as 'outer' or 'hole'

    for poly_data in polygons:
        poly_data.update({
            'zone': zone,
            'region': region,
            'type_index': type_index,
            'date': date,  # Adding date here
            'label': label,
            'tif': ref_filename
        })
        if ref_img_date is not None:
            poly_data.update({
                'prev_date': ref_img_date
            })
            
        if ref_prev_img_filename is not None:
            poly_data.update({
                'prev_tif': ref_prev_img_filename
            })

    return polygons

def associate_holes(polygons):
    """
    Associate holes with their corresponding outer polygons.
    """
    associated_polygons_data = []
    
    for poly_data in polygons:
        poly = poly_data['geometry']  # Accessing the polygon

        # Identifying outer polygons
        if not any(poly.within(p['geometry']) for p in polygons if p != poly_data):
            # If it's an outer polygon, check for holes
            holes = [p['geometry'] for p in polygons if p['geometry'].within(poly) and p != poly_data]
            
            # Create a new polygon with the identified holes
            if holes:
                holes_coords = [hole.exterior.coords[:] for hole in holes]
                new_poly = Polygon(poly.exterior.coords[:], holes=holes_coords)
            else:
                new_poly = poly
            
            # Keep the metadata and update the geometry
            new_poly_data = poly_data.copy()  # Copy the metadata
            new_poly_data['geometry'] = new_poly  # Update the geometry
            associated_polygons_data.append(new_poly_data)
                
    return associated_polygons_data


def filter_polygons_by_size(polygons, min_size):
    """
    Filter polygons by size.
    """
    return [p for p in polygons if p.area >= min_size]


def save_mask_image(label, zone, region, date, polygon, mask, type_index, output_path='./data/output', ref_img_date = None, tiff_filename = None, tiff_prev_img_filename = None):
    """
    Save mask image as a shapefile containing various polygons.
    """

    print(f"Parameters: {zone}, {region}, {date}, {type_index}")
    
    # Formatting the date to exclude time
    formatted_date = date.strftime('%Y-%m-%d')
    if ref_img_date is not None:
        formatted_ref_img_date = ref_img_date.strftime('%Y-%m-%d')
    else:
        formatted_ref_img_date = None
    
    # Convert mask to polygons and get metadata
    polygons_with_metadata = get_polygons_from_mask(mask, label, polygon, zone, region, type_index, formatted_date, formatted_ref_img_date, tiff_filename, tiff_prev_img_filename)

    # Filter polygons by size
    # valid_polygons = filter_polygons_by_size(polygons_with_metadata, min_size)
    # Creating a GeoDataFrame
    gdf = gpd.GeoDataFrame(polygons_with_metadata, crs='EPSG:4326')

    # Creating the filename
    filename = f"{type_index}_{formatted_date}.shp"

    # Ensure the directory exists
    dir_path = f"{output_path}/labels/{zone}/{region}/{label}"
    os.makedirs(dir_path, exist_ok=True)
    
    # Saving the GeoDataFrame as a Shapefile
    filepath = f"{output_path}/labels/{zone}/{region}/{label}/{filename}"

    gdf.to_file(filepath)

In [22]:
import numpy as np
from scipy.ndimage import label, sum

def remove_small_objects(mask, min_size):
    """
    Remove connected components that are smaller than min_size from the binary mask.
    
    Parameters:
    mask (numpy.ndarray): binary mask where objects are represented by ones and zeros.
    min_size (int): minimum size (number of pixels) that a connected component must have to be kept.
    
    Returns:
    numpy.ndarray: a new binary mask with the small connected components removed.
    """
    
    # Label each connected component in the mask
    labeled_mask, num_labels = label(mask)
    
    # For each labeled object, calculate its size
    sizes = sum(mask, labeled_mask, range(1, num_labels+1))
    
    # Create a new mask where small objects are removed
    new_mask = np.zeros_like(mask)
    removed = 0
    kept = 0
    for label_num, size in enumerate(sizes):
        if size >= min_size:
            new_mask[labeled_mask == label_num+1] = 1
            kept += 1
        else:
            removed += 1
            
    return new_mask, kept, removed

In [23]:
import os
import rasterio
from rasterio.warp import calculate_default_transform, reproject, Resampling
from rasterio.transform import from_bounds
import xarray as xr

def save_tiff(dataset, date, output_path, base_filename, user_metadata):    

    ds_selected = dataset.sel(time=date, method='nearest')
    bands = [ds_selected[banda].values for banda in ds_selected.data_vars]
    band_names = [banda for banda in ds_selected.data_vars]
    print(f'Bands: {band_names}')
    print(dataset.coords)
    x_min = ds_selected.x.min().values
    x_max = ds_selected.x.max().values
    y_min = ds_selected.y.min().values
    y_max = ds_selected.y.max().values
    height, width = bands[0].shape
    print(f"X min: {x_min}, X max: {x_max}, Y min: {y_min}, Y max: {y_max}")
    print(f"Height: {height}, Width: {width}")
    transform = from_bounds(west=x_min, south=y_min, east=x_max, north=y_max, width=width, height=height)
    
    crs = 'EPSG:3857'
    #add crs folder 
    output_path = os.path.join(output_path, crs.replace(':', '_'))
    output_filename = f"{base_filename}_{date.strftime('%Y-%m-%d')}.tif"
    output_filepath = os.path.join(output_path, output_filename)
    metadata = {
    'driver': 'GTiff',
    'count': len(bands),  # número de bandas
    'dtype': bands[0].dtype,  # asegúrate de que todas las bandas tengan el mismo dtype
    'width': bands[0].shape[1],
    'height': bands[0].shape[0],
    'crs': crs,  # o el CRS que estés utilizando,
    'transform': transform,
    'band_names': band_names
    }

    if user_metadata is None:
        user_metadata = {}
        
    user_metadata['band_names'] = band_names
    user_metadata['date'] = date.strftime('%Y-%m-%d')
    
    # Ensure the directory exists
    os.makedirs(output_path, exist_ok=True)
    
    with rasterio.open(
            output_filepath,
            'w',
            **metadata
    ) as dst:
        for i, var_name in enumerate(ds_selected.data_vars):
            dst.write(ds_selected[var_name].data, i+1)
            dst.set_band_description(i+1, var_name)
            print(f'Band name: {var_name}')
        dst.update_tags(**user_metadata)
    print(f"File saved to: {output_filepath}")
    return output_filepath


In [24]:

from IPython.display import display
import time
def process_images_mask(zone, region, data, dates, polygon, properties, visible_images, num_index_to_process = 4, image_index = 0, prev_image_index = -1, labels = None):

    enable_widgets = []
    source_widgets = []
    type_widgets = []
    index_widgets = []
    comparison_widgets = []
    number_inputs = []
    
    mask, mask3 = get_kml_polygon_masks(polygon, data['red'].shape[2], data['red'].shape[1])

    index_categories = {
        IndexCategory.VEGETATION.value: [e.name for e in VegetationIndex],
        IndexCategory.FIRE.value: [e.name for e in FireIndex],
        IndexCategory.WATER.value: [e.name for e in WaterIndex],
        IndexCategory.CLOUDS.value: [e.name for e in CloudIndex],
        IndexCategory.S2BANDS.value: S2Band.to_list()
    }
         
    image_type_category = widgets.Dropdown(
            options= [item.value for item in IndexSource],
            description='Imagen:',
            disabled=False,
        )
    
    # Dropdown para seleccionar el color
    color_dropdown = widgets.Dropdown(
        options=['Red', 'Green', 'Blue', 'Orange', 'Yellow', 'Purple'],
        value='Orange',
        description='Color mascara:'
    )

    # Slider para seleccionar la transparencia (alpha)
    alpha_slider = widgets.FloatSlider(
        value=0.5,
        min=0,
        max=1.0,
        step=0.01,
        description='Transparencia:',
        continuous_update=False
    )

    configurations = [ImageMaskConfiguration() for _ in range(num_index_to_process)]
    trigger_widget = widgets.Checkbox(value=False, layout=widgets.Layout(display='none'))
    config_widgets_programatic_update = False

    def on_value_change(change):
        nonlocal config_widgets_programatic_update
        new_value = change['new']

        if change['owner'].description == 'Habilitar':
            config = change['owner'].config_ref
            config.enable = new_value
        elif change['owner'].description == 'Índice':
            config = change['owner'].config_ref
            new_index = change['owner'].index 
            config.index_type = new_index
        elif change['owner'].description == 'Comparación':
            config = change['owner'].config_ref
            comparison_mapping = {
                "Menor que": ComparisonType.LESS_THAN,
                "Mayor que": ComparisonType.GREATER_THAN
            }
            config.comparison = comparison_mapping.get(new_value, None)
            if config.comparison is None:
                print(f"Invalid comparison value: {new_value}")
        elif change['owner'].description == 'Umbral':
            config = change['owner'].config_ref
            config.threshold = new_value
            if change['owner'].manual_update == True:
                return
        elif change['owner'].description == 'Fuente':
            config = change['owner'].config_ref
            config.source = IndexSource.get_position(new_value)
         
                
        if config_widgets_programatic_update == False:
            trigger_widget.value = not trigger_widget.value

    
    def create_widgets_for_configuration(config, on_value_change):
        enable_widget = widgets.Checkbox(value=config.enable, description='Habilitar')
        source_selector = widgets.Dropdown(options=[item.value for item in IndexSource][:-1], description='Fuente')
        type_selector = widgets.Dropdown(options= IndexCategory.to_list(), index=IndexCategory.get_index_category_position(config.index_category), description='Tipo índice')
        index_selector = widgets.Dropdown(options=index_categories[IndexCategory.VEGETATION.value], description='Índice')
        comparison_selector = widgets.Dropdown(options=["Mayor que", "Menor que"], description='Comparación')
        number_input = widgets.BoundedFloatText(value=0, min=0, max=1000, step=0.1, description='Umbral')
        
        enable_widget.config_ref = config
        source_selector.config_ref = config
        type_selector.config_ref = config
        index_selector.config_ref = config
        comparison_selector.config_ref = config
        number_input.config_ref = config
        number_input.manual_update = False
            
        def update_indices(change):
            index_selector.options = index_categories[change['new']]
            config = change['owner'].config_ref
            new_index = change['owner'].index  
            config.index_category = new_index
            
        type_selector.observe(update_indices, names='value')
        
        def toggle_widgets(change):
            state = change['new']
            source_selector.disabled = not state
            type_selector.disabled = not state
            index_selector.disabled = not state
            comparison_selector.disabled = not state
            number_input.disabled = not state
        
        enable_widget.observe(toggle_widgets, names='value')

        enable_widget.observe(on_value_change, names='value')
        source_selector.observe(on_value_change, names='value')
        type_selector.observe(on_value_change, names='value')
        index_selector.observe(on_value_change, names='value')
        comparison_selector.observe(on_value_change, names='value')
        number_input.observe(on_value_change, names='value')
        enable_widgets.append(enable_widget)
        source_widgets.append(source_selector)
        type_widgets.append(type_selector)
        index_widgets.append(index_selector)
        comparison_widgets.append(comparison_selector)
        number_inputs.append(number_input)
        return [enable_widget, source_selector, type_selector, index_selector, comparison_selector, number_input]
    
    def update_config_widgets(config_widgets_list):
        nonlocal config_widgets_programatic_update
        config_widgets_programatic_update = True
        for i, config in enumerate(configurations):
            enable_widgets[i].value = config.enable
            source_widgets[i].index = config.source
            type_widgets[i].index = config.index_category
            index_widgets[i].index = config.index_type
            comparison_widgets[i].index = 0 if config.comparison == ComparisonType.GREATER_THAN else 1
            try:
                if number_inputs[i].max < config.threshold:
                    number_inputs[i].max = config.threshold
                if number_inputs[i].min > config.threshold:
                    number_inputs[i].min = config.threshold
                number_inputs[i].value = config.threshold
            except:
                print(f"Error updating number input for index {i}")
        config_widgets_programatic_update = False
        
    color_dropdown.observe(on_value_change, names='value')
    alpha_slider.observe(on_value_change, names='value')

    PIXEL_AREA = 100  # 10m * 10m = 100m²
    min_pixels = 500 / PIXEL_AREA  # Convert area to number of pixels
    
    
    #Slider to select the minimum area (in m²) of the objects to keep
    min_area_slider = widgets.IntSlider(
        value=10,  # Initial value
        min=0,  # Min value
        max=1000,  # Max value
        step=10,  # Step size
        description='Min px:',  # Description or label for the slider
        continuous_update=False  # Update only when the slider stops moving
    )
    
    status_label = widgets.Label(value="")

    def update_threshold_range(position, min_value, max_value):
        nonlocal config_widgets_programatic_update
        config_widgets_programatic_update = True
        if(position <= len(number_inputs) - 1):
            number_input = number_inputs[position]  # Accede al widget específico basado en la posición
           
    
            if number_input.min != min_value or number_input.max != max_value:
                number_input.manual_update = True  # Indica que el usuario cambió el valor
                current_min = number_input.min
                current_max = number_input.max

                if min_value <= current_max:
                    number_input.min = min_value  # Configura el valor mínimo
                    number_input.max = max_value  # Configura el valor máximo
                else:
                    number_input.max = max_value  # Configura el valor máximo
                    number_input.min = min_value  # Configura el valor mínimo
                range_diff = max_value - min_value
                if range_diff > 100000:
                    precision_value = 5
                elif range_diff > 10000:
                    precision_value = 4
                else:
                    precision_value = 3
                
                #number_input.value = round((min_value + max_value) / 2, precision_value)  # Ajusta el valor
                number_input.step = (range_diff) / 100  # Ajusta el valor del paso
                if number_input.step > 0.05:
                    number_input.step = 0.05
                    
                time.sleep(0.1)  # Espera 0.1 segundos
                number_input.manual_update = False  # Indica que el usuario no cambió el valor
        config_widgets_programatic_update = False
            
    color_mapping = {
            'Red': [1, 0, 0],
            'Green': [0, 1, 0],
            'Blue': [0, 0, 1],
            'Orange': [1, 0.5, 0],
            'Yellow': [1, 1, 0],
            'Purple': [0.5, 0, 0.5]
        }
    
    previous_trigger = None
    mask_image = None
    mask_prev_image = None
    mask_diff = None
    mask_to_show = None
    mask_to_save = None 
    mask_calculated = False
    image_type_index = None
    img_index = None
    objects_kept = 0
    objects_removed = 0
    
    def plot_image_mask(trigger, image_type, min_area):
        combined_mask = None
        nonlocal previous_trigger, mask_calculated, mask_image, mask_prev_image, mask_diff, image_type_index, mask_to_show, mask_to_save, objects_kept, objects_removed, img_index
        
        if previous_trigger is None or trigger != previous_trigger:
            mask_calculated = False
            mask_image = None
            mask_prev_image = None
            mask_diff = None
            previous_trigger = trigger
                         
        if mask_calculated == False:
            mask_calculated = True
            image_date = dates[image_index]
            prev_image_date = dates[prev_image_index]
        
            for i, config in enumerate(configurations):
                if config.enable == True:
                    if config.source == list(IndexSource).index(IndexSource.DIFFERENCE):
                        mask_index, image_for_index = calculate_mask_difference_of_index_images(data, image_date, prev_image_date, config)
                        mask_index = np.logical_and(mask == 1, mask_index == 1)
                        max_index_img = image_for_index.max().item()
                        min_index_img = image_for_index.min().item()
                        update_threshold_range(i, min_index_img, max_index_img)  
                        mask_diff = mask_index if mask_diff is None else mask_diff * mask_index
                    elif config.source == list(IndexSource).index(IndexSource.FINAL_IMAGE):
                        mask_index, image_for_index = calculate_mask_image(data, image_date, config)
                        mask_index = np.logical_and(mask == 1, mask_index == 1)
                        max_index_img = image_for_index.max().item()
                        min_index_img = image_for_index.min().item()
                        update_threshold_range(i, min_index_img, max_index_img)  
                        mask_image = mask_index if mask_image is None else mask_image * mask_index
                    elif config.source == list(IndexSource).index(IndexSource.REFERENCE_IMAGE):
                        mask_index, image_for_index = calculate_mask_image(data, prev_image_date, config)
                        mask_index = np.logical_and(mask == 1, mask_index == 1)
                        max_index_img = image_for_index.max().item()
                        min_index_img = image_for_index.min().item()
                        update_threshold_range(i, min_index_img, max_index_img)  
                        mask_prev_image = mask_index if mask_prev_image is None else mask_prev_image * mask_index
                    else:
                        print(f"Invalid index source: {config.source}")
        
            if mask_image is not None and mask_prev_image is not None and mask_diff is not None:
                mask_to_save = mask_image * mask_prev_image * mask_diff
            elif mask_image is not None and mask_prev_image is not None:
                mask_to_save = mask_image * mask_prev_image
            elif mask_image is not None and mask_diff is not None:
                mask_to_save = mask_image * mask_diff
            elif mask_prev_image is not None and mask_diff is not None:
                mask_to_save = mask_prev_image * mask_diff
            elif mask_image is not None:
                mask_to_save = mask_image
            elif mask_prev_image is not None:
                mask_to_save = mask_prev_image
            elif mask_diff is not None:
                mask_to_save = mask_diff        
            else:
                mask_to_save = None

            if mask_to_save is not None:
                mask_to_save = np.logical_and(mask == 1, mask_to_save == 1)
            

        image_type_index = image_type_category.options.index(image_type)

        if image_type_index == list(IndexSource).index(IndexSource.FINAL_IMAGE):
            img_index = image_index            
            img_mask = mask_image
        elif image_type_index == list(IndexSource).index(IndexSource.REFERENCE_IMAGE):
            img_index = prev_image_index
            img_mask = mask_prev_image
        elif image_type_index == list(IndexSource).index(IndexSource.DIFFERENCE):
            img_index = image_index
            img_mask = mask_diff
        else:
            img_index = image_index
            img_mask = mask_to_save
            
        if img_mask is None:
                img_mask = np.zeros((mask.shape[0], mask.shape[1]))
                
        mask_to_show = np.logical_and(mask == 1, img_mask == 1)
  
        if img_index is not None:
            transposed_rgb = visible_images[img_index].transpose((1, 2, 0))
            transposed_rgb = transposed_rgb * mask3 if mask is not None else transposed_rgb
            transposed_rgb = normalize_image_percentile(transposed_rgb)
                
            #show image
            fig, ax = plt.subplots(1, 1, figsize=(10, 6))
            ax.imshow(transposed_rgb)
            
            if(mask_to_show is not None):
                mask_to_show, objects_kept, objects_removed = remove_small_objects(mask_to_show, min_area)
            else:
                mask_to_show = None
                objects_kept = 0
                objects_removed = 0
                
            status_label.value = f"Removidos: {objects_removed} | Mantenidos: {objects_kept}"
            
            if mask_to_show is not None:
                selected_color = color_mapping[color_dropdown.value] + [alpha_slider.value]
                color_mask = np.zeros((mask_to_show.shape[0], mask_to_show.shape[1], 4))
                color_mask[mask_to_show == 1] = selected_color  # Aplicar color solo donde mask es 1
                ax.imshow(color_mask)
                    
            ax.set_title(f'{dates[img_index].strftime("%Y-%m-%d")} RGB')
            ax.axis('off')
            plt.show()
    
    def labels_dropdown_on_change(change):
    
        LandscapeLabels.image_filters_configure(change['new'], configurations)
        update_config_widgets(configurations)
        trigger_widget.value = not trigger_widget.value

        
    labels_dropdown = widgets.Dropdown(
        options=labels,
        description='Etiqueta:',
        disabled=False,
    )
    
    labels_dropdown.observe(labels_dropdown_on_change, names='value')

    interactive_plot_image = widgets.interactive_output(plot_image_mask, {'trigger': trigger_widget, 'image_type': image_type_category, 'min_area': min_area_slider})
    
    def on_button_process_clicked(b):
        if mask_to_save is not None:
            tiff_output_dir = './data/output/labels/tiff'
            tiff_filename = f"{zone}_{region}"
            user_properties = {
                'zone': zone,
                'region': region,
            }
            tiff_output_filepath = save_tiff(data, dates[image_index], tiff_output_dir, tiff_filename, user_properties)                
            tiff_prev_output_filepath = save_tiff(data, dates[prev_image_index], tiff_output_dir, tiff_filename, user_properties)
            save_mask_image(labels_dropdown.value, zone, region, dates[image_index], polygon, mask_to_save, labels_dropdown.index, './data/output', dates[prev_image_index], tiff_output_filepath, tiff_prev_output_filepath)
            
        else:
            print("No hay máscara para guardar.")
    
    button_process = widgets.Button(description="Guardar Shapefile")
    button_process.on_click(on_button_process_clicked)
    
    display(widgets.HBox([image_type_category, color_dropdown, alpha_slider]), widgets.HBox([min_area_slider, status_label, labels_dropdown, button_process]))

    # # Mostrar widgets
    # for config in configurations:
    #     widgets_list = create_widgets_for_configuration(config, on_value_change)
    #     display(widgets.HBox(widgets_list))
    
    vboxes = []
    for config in configurations:
        widgets_list = create_widgets_for_configuration(config, on_value_change)
        vboxes.append(widgets.VBox(widgets_list))
    display(widgets.HBox(vboxes))

    display(interactive_plot_image)  
    labels_dropdown.index = 1
    time.sleep(0.1)
    labels_dropdown.index = 0
    image_type_category.index = 3
        


In [25]:
def select_images_widget(zone, region, only_sunny_dates = True, labels = None):
    data, dates, polygon, properties = get_images_by_zone(zone, region, only_sunny_dates)
    visible_images = [(img / img.max()).astype(np.float32) for img in 
                  [data.sel(time=time)[["red", "green", "blue"]].to_array().values for time in data.time.values]]
    mask, mask3 = get_kml_polygon_masks(polygon, data['red'].shape[2], data['red'].shape[1])
    
    date_slider_image = widgets.IntSlider(
        value=1,
        min=1,
        max=len(dates) - 1,
        step=1,
        description='Imagen:',
        continuous_update=False
    )
    
    zone_label = widgets.Label(value= f'   Zona: {zone}, Sector: {region}. Imagenes disponibles: {len(dates)}')

    date_slider_image.layout.width = '450px'
    date_slider_image.index = dates
    
    def on_date_slider_image_change(change):
        date_slider_ref_image.value = min(date_slider_ref_image.value, date_slider_image.value - 1)
        
    date_slider_image.observe(on_date_slider_image_change, names='value')
    
    # Create buttons for easier navigation
    prev_image_button = widgets.Button(description="Previa")
    next_image_button = widgets.Button(description="Siguiente")

    date_slider_ref_image = widgets.IntSlider(
        value=0,
        min=0,
        max=len(dates) - 2,
        step=1,
        description='Referencia:',
        continuous_update=False
    )
    date_slider_ref_image.layout.width = '450px'
    date_slider_ref_image.index = dates
    
    def on_date_slider_ref_image_change(change):
        date_slider_image.value = max(date_slider_image.value, date_slider_ref_image.value + 1)
    
    date_slider_ref_image.observe(on_date_slider_ref_image_change, names='value')
    
    # Create buttons for easier navigation
    prev_ref_image_button = widgets.Button(description="Previa")
    next_ref_image_button = widgets.Button(description="Siguiente")
        
    process_button = widgets.Button(description="Seleccionar")   
    
    def on_prev_image_button_clicked(b):
        date_slider_image.value = max(date_slider_image.min, date_slider_image.value - 1)
    
    def on_next_image_button_clicked(b):
        date_slider_image.value = min(date_slider_image.max, date_slider_image.value + 1)
        
    def on_prev_ref_image_button_clicked(b):
        date_slider_ref_image.value = max(date_slider_ref_image.min, date_slider_ref_image.value - 1)
    
    def on_next_ref_image_button_clicked(b):
        date_slider_ref_image.value = min(date_slider_ref_image.max, date_slider_ref_image.value + 1)
    
    plot_image_in_progress = False

    def plot_image(date_index):
        nonlocal plot_image_in_progress
        
        if plot_image_in_progress:
            return
        
        plot_image_in_progress = True
        
        try:
            transposed_rgb = visible_images[date_index].transpose((1, 2, 0))
            transposed_rgb = transposed_rgb * mask3 if mask is not None else transposed_rgb
            transposed_rgb = normalize_image_percentile(transposed_rgb)
            
            #show image
            fig, ax = plt.subplots(1, 1, figsize=(10, 5))
            ax.imshow(transposed_rgb)
            ax.set_title(f'{dates[date_index].strftime("%Y-%m-%d")} RGB')
            ax.axis('off')
            plt.show()
        finally:
            plot_image_in_progress = False


    
    interactive_plot_image = widgets.interactive_output(plot_image, {'date_index': date_slider_image})
    interactive_plot_ref_image = widgets.interactive_output(plot_image, {'date_index': date_slider_ref_image})
    
    prev_image_button.on_click(on_prev_image_button_clicked)
    next_image_button.on_click(on_next_image_button_clicked)
    prev_ref_image_button.on_click(on_prev_ref_image_button_clicked)
    next_ref_image_button.on_click(on_next_ref_image_button_clicked)

    def on_process_button_clicked(b):
        process_button.disabled = True
        clear_output(wait=True)
        display(widgets.HBox([ 
            widgets.VBox([
                widgets.HBox([date_slider_ref_image, prev_ref_image_button, next_ref_image_button]),
                interactive_plot_ref_image
                ])
            ,
            widgets.VBox([
                widgets.HBox([date_slider_image, prev_image_button, next_image_button]),
                interactive_plot_image
            ])
        ]),
        widgets.HBox([process_button, zone_label]))
        process_images_mask(zone, region, data, dates, polygon, properties, visible_images, 4, date_slider_image.value, date_slider_ref_image.value, labels)
        process_button.disabled = False
      
    process_button.on_click(on_process_button_clicked)

    display(widgets.HBox([ 
        widgets.VBox([
            widgets.HBox([date_slider_ref_image, prev_ref_image_button, next_ref_image_button]),
            interactive_plot_ref_image
            ])
        ,
        widgets.VBox([
            widgets.HBox([date_slider_image, prev_image_button, next_image_button]),
            interactive_plot_image
        ])
    ]),
    widgets.HBox([process_button, zone_label]))

In [26]:
import ipywidgets as widgets
from IPython.display import clear_output, display

def select_region_widget(select_callback, labels):
    zone_widget = widgets.Dropdown(
        options=zone_dict.keys(),
        description='Zone:',
        disabled=False
    )

    region_widget = widgets.Dropdown(
        options=zone_dict[zone_widget.value].keys(),
        description='Region:',
        disabled=False
    )

    def update_region_options(change):
        region_widget.options = zone_dict[change['new']].keys()

    zone_widget.observe(update_region_options, names='value')

    only_sunny_dates = widgets.Checkbox(value=True, description='Descartar imágenes con nubes')

    # Botón para procesar
    process_button = widgets.Button(description="Process")
    clear_button = widgets.Button(description="Clear")
    status_label = widgets.Label(value="")

    # Función para manejar evento del botón
    def on_process_button_clicked(b):  
        process_button.disabled = True
        zone_widget.disabled = True
        region_widget.disabled = True
        only_sunny_dates.disabled = True
        status_label.value = "Procesando..."
        select_callback(zone_widget.value, region_widget.value, only_sunny_dates.value, labels)
        time.sleep(1)
        status_label.value = "Procesamiento completado"
        clear_button.disabled = False
        
    process_button.on_click(on_process_button_clicked)
    
    def display_region_widget(clear_output_request = False):
        if clear_output_request:
            clear_output(wait=True)
        process_button.disabled = False
        zone_widget.disabled = False
        region_widget.disabled = False
        only_sunny_dates.disabled = False
        clear_button.disabled = True
        status_label.value = ""
        display(widgets.HBox([zone_widget, region_widget, only_sunny_dates]), widgets.HBox([process_button, clear_button]), status_label)
    
    
    def on_clear_button_clicked(b):
        display_region_widget(True)
        
    clear_button.on_click(on_clear_button_clicked)
    
    # Show widgets
    display_region_widget()

In [27]:
labels = [label.value for label in LandscapeLabels]
select_region_widget(select_images_widget, labels)

HBox(children=(Dropdown(description='Zone:', options=('Bosques Bio Bio', 'Incendios', 'Bosques Arauco', 'Provo…

HBox(children=(Button(description='Process', style=ButtonStyle()), Button(description='Clear', disabled=True, …

Label(value='')