In [48]:
from datetime import datetime, timedelta

# Guarda una referencia a la función print original
original_print = __builtins__.print

# Redefine la función print
def print(*args, **kwargs):
    current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    original_print(current_time, "-", *args, **kwargs)

In [49]:
from tesina_utils import calculate_water_index_from_xr_dataset, calculate_fire_index_from_xr_dataset, calculate_vegetation_index_from_xr_dataset, VegetationIndex, FireIndex, WaterIndex, IndexCategory


In [50]:
import os
import sys
import json
import time
import requests
import numpy as np
import pandas as pd
from fastkml import  kml


In [51]:
from typing import Any, Optional, Tuple
def polygon_coords_to_px_coords(polygon, width: int, height: int) -> list[Tuple[int, int]]:
    """Transforms polygon coordinates from lat/lon to pixel coordinates

    :param polygon: list of coordinates in lat/lon
    :param bbox: bounding box
    :param width: width of bounding box in pixels
    :param height: height of bounding box in pixels
    :return: list of coordinates in pixels
    """
    east1, north1 = polygon.bounds[0], polygon.bounds[1]
    east2, north2 = polygon.bounds[2], polygon.bounds[3]
    div_x = (east2 - east1) / width
    div_y = (north2 - north1) / height

    pixel_coords = []

    for coord in polygon.exterior.coords:
        # Scale and translate the coordinate
        px_coord = ((coord[0] - east1) / div_x, (coord[1] - north1) / div_y)
        # Append to the list of pixel coordinates
        pixel_coords.append(px_coord)
    return pixel_coords
    

In [52]:
import cv2
def get_kml_polygon_masks(polygon, width: int, height: int) -> list[Tuple[np.ndarray, np.ndarray]]:
    #get the polygon mask
    polygon_coords = polygon_coords_to_px_coords(polygon, width, height)
    polygon_mask = np.zeros((height, width))
    polygon_mask = cv2.fillPoly(polygon_mask, np.array([polygon_coords], dtype=np.int32), 1)
    polygon_mask = np.flipud(polygon_mask).astype(np.uint8)
    return polygon_mask, np.dstack([polygon_mask] * 3)

In [53]:
import cv2

def equalize_image(image):
    """
    Applies Contrast Limited Adaptive Histogram Equalization (CLAHE) to an image.
    
    Parameters:
    image: ndarray, an RGB image to be equalized.

    Returns:
    image_clahe: ndarray, the equalized image.
    """
    # Convert the image from RGB to Lab
    image_lab = cv2.cvtColor(image.astype(np.uint8), cv2.COLOR_BGR2Lab)

    # Split the Lab channels
    l, a, b = cv2.split(image_lab)

    # Create the CLAHE object
    clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))

    # Apply CLAHE to the L channel and merge back
    l_clahe = clahe.apply(l)
    image_lab_clahe = cv2.merge((l_clahe,a,b))

    # Convert the image from Lab to BGR
    image_clahe = cv2.cvtColor(image_lab_clahe, cv2.COLOR_Lab2BGR)

    return image_clahe


In [54]:
from enum import Enum
class S2Band(Enum):
    B01 = 0     # Aerosols
    B02 = 1     # Blue, 492.4 nm (S2A), 492.1 nm (S2B)
    B03 = 2     # Green, 559.8 nm (S2A), 559 nm (S2B)
    B04 = 3     # Red, 664.6 nm (S2A), 665 nm (S2B)
    B05 = 4     # Vegetation red edge, 704.1 nm (S2A), 703.8 nm (S2B)
    B06 = 5     # Vegetation red edge, 740.5 nm (S2A), 739.1 nm (S2B)
    B07 = 6     # Vegetation red edge, 782.8 nm (S2A), 779.7 nm (S2B)
    B08 = 7     # NIR, 832.8 nm (S2A), 833 nm (S2B)
    B8A = 8     # Narrow NIR, 864.7 nm (S2A), 864 nm (S2B)
    B09 = 9     # Water vapour, 945 nm (S2A), 943.2 nm (S2B)

class S2BandNames(Enum):
    AEROSOLS = 0,
    BLUE = 1,
    GREEN = 2,
    RED = 3,
    VEGETATION_RED_EDGE_1 = 4,
    VEGETATION_RED_EDGE_2 = 5,
    VEGETATION_RED_EDGE_3 = 6,
    NIR = 7,
    NARROW_NIR = 8,
    WATER_VAPOUR = 9


In [55]:
from sklearn.preprocessing import RobustScaler
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

def robust_scaler_normalize_images(images):
    robust_scaler = RobustScaler()

    # Stack all the images into a single numpy array
    all_images = np.stack(images)

    # Retain the dimensions of the individual images for later
    image_shape = all_images.shape[1:]

    # Reshape the array so that each row is an image
    all_images = all_images.reshape(-1, np.prod(image_shape))

    # Normalize all the images at once
    all_images_normalized = robust_scaler.fit_transform(all_images)

    # Reshape the images back to their original shape
    images_normalized = [image.reshape(image_shape) for image in all_images_normalized]

    return images_normalized

def plot_distribution_before_after_robust_scaler(images, images_normalized):
    all_images = np.stack(images).reshape(-1, np.prod(np.stack(images).shape[1:]))
    all_images_normalized = np.stack(images_normalized).reshape(-1, np.prod(np.stack(images_normalized).shape[1:]))

    fig, axs = plt.subplots(1, 2, figsize=(12, 6))

    # Distribution before normalization
    sns.histplot(all_images.flatten(), ax=axs[0], color='blue', kde=True)
    axs[0].set_title('Distribution before normalization')

    # Distribution after normalization
    sns.histplot(all_images_normalized.flatten(), ax=axs[1], color='green', kde=True)
    axs[1].set_title('Distribution after normalization')

    plt.show()



In [56]:
#make directory if it does not exist, concatenate the path from the list of directories
def make_dir(dir_list: list):
    dir_path = os.path.join(*dir_list)
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)
    return dir_path

In [57]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import os

def save_image_with_palette_and_labels(image, directory, filename, palette, labels):
    # Create an empty RGB image with the same shape as the input image
    rgb_image = np.zeros((image.shape[0], image.shape[1], 3), dtype=np.uint8)

    # Replace each pixel with the corresponding RGB color in the palette
    for i in range(image.shape[0]):
        for j in range(image.shape[1]):
            rgb_image[i, j] = palette[image[i, j]]

    # Create a legend with percentage of each class
    unique_labels, counts = np.unique(image, return_counts=True)
    total_pixels = image.size
    patches = [mpatches.Patch(color=np.array(palette[label])/255., 
                label=f"{labels[label]} [{100 * count / total_pixels:.2f}%]") 
                for label, count in zip(unique_labels, counts)]

    # Plot the image and the legend
    plt.figure(figsize=(10,10))
    plt.imshow(rgb_image)
    plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0. )
    plt.axis('off')

    # Combine the directory and filename to form the output path
    output_path = os.path.join(directory, filename)
    plt.savefig(output_path, bbox_inches='tight')
    plt.close()




In [58]:
# Define the RGB color mapping for each land cover class
land_cover_palette = {
    0: (255, 255, 255),  # No data
    1: (255, 0, 0),  # Saturated / Defective
    2: (0, 0, 0),  # Dark Area Pixels
    3: (255, 255, 0),  # Cloud Shadows
    4: (0, 255, 0),  # Vegetation
    5: (255, 0, 255),  # Bare Soils
    6: (0, 0, 255),  # Water
    7: (255, 72, 0),  # Clouds low probability / Unclassified
    8: (255, 100, 0),  # Clouds medium probability
    9: (255, 128, 0),  # Clouds high probability
    10: (255, 255, 255),  # Cirrus
    11: (255, 255, 255)  # Snow / Ice
}

# Define descriptive labels for each land cover class
land_cover_labels = {
    0: 'No data',
    1: 'Saturated / Defective',
    2: 'Dark Area Pixels',
    3: 'Cloud Shadows',
    4: 'Vegetation',
    5: 'Bare Soils',
    6: 'Water',
    7: 'Clouds low probability / Unclassified',
    8: 'Clouds medium probability',
    9: 'Clouds high probability',
    10: 'Cirrus',
    11: 'Snow / Ice'
}

land_cover_palette_list = [land_cover_palette[i] for i in range(len(land_cover_palette))]
land_cover_palette_array = np.array(land_cover_palette_list)


In [59]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

def plot_breaks_histogram(break_image, monitoring_start, dates, name=None):
    """
    Plots a histogram of breaks.

    Parameters:
    break_image: ndarray, an image where each pixel value indicates the break number.
    monitoring_start: datetime, the start date of monitoring.
    dates: list, list of dates corresponding to break values.
    name: str, optional name for the title.
    """
    # Set Seaborn style
    sns.set(style="whitegrid")

    # Flattened breaks
    breaks_flattened = break_image.flatten()

    # Remove zeros if you don't want to count them
    breaks_flattened = breaks_flattened[breaks_flattened > 0]

    # Determine bins based on integer values
    bins = np.arange(1, breaks_flattened.max() + 1) - 0.5

    plt.figure(figsize=(15, 4))

    # Create a histogram with Seaborn
    sns.histplot(breaks_flattened, bins=bins, kde=False, color='skyblue', edgecolor='black')

    # Set integer x-axis ticks
    plt.xticks(np.arange(breaks_flattened.min(), breaks_flattened.max() + 1))

    # Title
    if name is not None:
        plt.title("Histogram of breaks " + name + " " + str(monitoring_start.date()) + " - " + str(dates[-1]))
    else:
        plt.title("Histogram of breaks " + str(monitoring_start.date()) + " - " + str(dates[-1]))

   


In [60]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

def plot_histogram(data, title, x_label='Value', y_label='Frequency', bins=80, discard_outliers=0.02, discard_value: int = 0):
    """
    Plots a histogram of the given data.

    Parameters:
    data: ndarray, a 2D array where each pixel value represents a data point.
    title: str, the title of the plot.
    x_label: str, label for the x-axis.
    y_label: str, label for the y-axis.
    bins: int, the number of bins in the histogram.
    discard_outliers: float, fraction of data points to be discarded from both tails.
    """

    # Set Seaborn style
    sns.set(style="whitegrid")

    # Flattened data
    data_flattened = data.flatten()

    # Discard the given fraction of the highest and lowest values
    data_flattened = np.sort(data_flattened)[int(len(data_flattened) * discard_outliers):int(len(data_flattened) * (1 - discard_outliers))]

    if discard_value is not None:
        data_flattened = data_flattened[data_flattened != discard_value]

    plt.figure(figsize=(15, 4))

    # Create a histogram with Seaborn
    sns.histplot(data_flattened, bins=bins, kde=False, color='skyblue', edgecolor='black')

    # Labels and title
    plt.title(title)
    plt.xlabel(x_label)
    plt.ylabel(y_label)

    plt.show()


In [61]:
from skimage import io
import numpy as np
import os
from typing import List

def save_images_to_folder(images: List[np.ndarray], 
                          image_names: List[str],                           
                          mask: np.ndarray = None,                           
                          folder_path: str = './', 
                          subfolder_name: str = 'images', 
                          normalize: bool=False) -> None:
    """
    Save images to folder
    :param images: images to save
    :param mask: mask to apply to images
    :param folder_path: path to folder
    :param subfolder_name: name of subfolder
    :param image_names: names of images
    :param normalize: Boolean flag to normalize image or not
    :return: None
    """
        
    path_to_folder = os.path.join(folder_path, subfolder_name)
    os.makedirs(path_to_folder, exist_ok=True)
        
    if mask is not None:
        if mask.ndim == 2:  # If the mask is a single channel
            mask = np.repeat(mask[np.newaxis, :, :], 3, axis=0)  # Duplicate the mask channel to match the image's shape


        # Save images
        for i, image in enumerate(images):
            if image.ndim == 2:  # If image is grayscale
                masked_image = image * mask[:, :, 0]  # Apply single channel mask
            elif image.ndim == 3 and image.shape[2] == 4:  # If image is RGBA
                # Apply mask to RGB channels and leave alpha channel intact
                masked_image = np.zeros_like(image)
                masked_image[:, :, :3] = image[:, :, :3] * mask
                masked_image[:, :, 3] = image[:, :, 3]
            else:  # If image is RGB
                masked_image = image * mask  # Apply multi-channel mask
                
            if normalize:
                masked_image = ((masked_image) / 
                                (masked_image.max())) * 255
            masked_image = masked_image.astype(np.uint8)
            masked_image = masked_image.transpose(1, 2, 0)
            imageio.imwrite(os.path.join(path_to_folder, image_names[i]), masked_image)
    else:
        for i, image in enumerate(images):
            if normalize:
                image = ((image) / 
                         (image.max())) * 255
            image = image.astype(np.uint8)
            image = image.transpose(1, 2, 0)
            imageio.imwrite(os.path.join(path_to_folder, image_names[i]), image)
            
    return None


In [62]:
from lxml import etree

#Read kml with forests polygons and convert it to GeoDataFrame
#Data is in zones folder, read every file and convert it to GeoDataFrame

#Get the list of files from the zones folder
files = os.listdir("./zones")
print(files)

#Create the empty list to store the GeoDataFrame
gdf_list = []
zone_dict = {}
#resolution of the image
resolution = 10
#Get the data from the sentinel-2
km2deg = 1/110.574 # 1km in degrees

#Loop through the files
for file in files:
    print(f'Processing file: {file}')
    doc = open("./zones/" + file, "r", encoding='utf-8', errors='ignore').read()
    k = kml.KML()
    
    try:
        k.from_string(doc.encode('utf-8'))
    except etree.ParseError:
        print(f"Unable to parse file: {file}")
        continue

    features = list(list(k.features())[0].features())   
    for f in features:
        print(f.name)
        zone_dict[f.name] = {}

    polygons = list(features[0].features())
    for p in polygons:
        print(p.name)
        #read polygon and get square of it
        polygon = p.geometry
        #add polygon to the dictionary
        zone_dict[f.name][p.name] = {}
        zone_dict[f.name][p.name]["polygon"] = polygon
        print("Bounding box: {}".format(polygon.bounds))

2023-10-19 09:23:19 - ['Bosques Bio Bio.kml', 'Incendios.kml', 'Bosques Arauco.kml']
2023-10-19 09:23:19 - Processing file: Bosques Bio Bio.kml
2023-10-19 09:23:19 - Bosques Bio Bio
2023-10-19 09:23:19 - Bosque 1
2023-10-19 09:23:19 - Bounding box: (-72.45142890120229, -37.1935384751085, -72.43650978092387, -37.18581491700413)
2023-10-19 09:23:19 - Bosque 2
2023-10-19 09:23:19 - Bounding box: (-72.42482075716856, -37.16361946111896, -72.41455274852883, -37.15297897451463)
2023-10-19 09:23:19 - Bosque 3
2023-10-19 09:23:19 - Bounding box: (-72.42423154064022, -37.1722707957546, -72.39075268176586, -37.15535508695594)
2023-10-19 09:23:19 - Processing file: Incendios.kml
2023-10-19 09:23:19 - Incendios
2023-10-19 09:23:19 - Chiguayante 1
2023-10-19 09:23:19 - Bounding box: (-73.14301140515278, -36.97713577738772, -73.07640424786813, -36.93275197623442)
2023-10-19 09:23:19 - Chiguayante 2
2023-10-19 09:23:19 - Bounding box: (-73.14176444168599, -36.9381030333295, -73.0759365158813, -36.895

In [63]:
def convert_bounds(bbox, invert_y=False):
    """
    Helper method for changing bounding box representation to leaflet notation

    ``(lon1, lat1, lon2, lat2) -> ((lat1, lon1), (lat2, lon2))``
    """
    x1, y1, x2, y2 = bbox
    if invert_y:
        y1, y2 = y2, y1
    return ((y1, x1), (y2, x2))

In [64]:
import dask.distributed
#import folium
#import folium.plugins
#import geopandas as gpd
#import shapely.geometry
from IPython.display import HTML, display
from pystac_client import Client


In [65]:
from termcolor import colored
from enum import Enum

class LogLevel(Enum):
    Trace = 0
    Debug = 1
    Info = 2
    Warning = 3
    Error = 4
    Disabled = 5

class Logger:

    def __init__(self, min_log_level=LogLevel.Info):
        self.min_log_level = min_log_level

    def log_message(self, level, message):
        if level.value < self.min_log_level.value:
            return

        if level == LogLevel.Error:
            print(colored(message, 'red'))
        elif level == LogLevel.Warning:
            print(colored(message, 'yellow'))
        elif level == LogLevel.Info:
            print(colored(message, 'blue'))
        elif level == LogLevel.Debug:
            print(colored(message, 'cyan'))
        elif level == LogLevel.Trace:
            print(colored(message, 'magenta'))
        else:
            print(message)

    def error(self, message):
        self.log_message(LogLevel.Error, message)

    def warning(self, message):
        self.log_message(LogLevel.Warning, message)

    def info(self, message):
        self.log_message(LogLevel.Info, message)

    def debug(self, message):
        self.log_message(LogLevel.Debug, message)

    def trace(self, message):
        self.log_message(LogLevel.Trace, message)


In [66]:
from datetime import datetime
import hashlib

def convert_band_uint16_dtype(variable_data):
    if variable_data.dtype == 'uint16':
        max_value = variable_data.max().values.item()
        if max_value <= np.iinfo(np.int16).max:
            new_dtype = 'int16'
        else:
            new_dtype = 'int32'
        return variable_data.astype(new_dtype)
    else:
        return variable_data

def fetch_or_cache_stac_data_by_band_and_month(catalog, bands, collections, start_date:str, end_date:str, limit, bbox, cloud_cover, resolution, stac_config=None, crs="EPSG:4326", cache_dir='data_cache', force_download=False, log_level: LogLevel = LogLevel.Info, save_rgb_image=False):

    logger = Logger(min_log_level=log_level)
    start_date_dt = datetime.strptime(start_date, "%Y-%m-%d")
    end_date_dt = datetime.strptime(end_date, "%Y-%m-%d")
    current_date_dt = datetime.now()

    combined_data = []  # List to hold the data for each month
    properties = []  # List to hold the properties for each month

    current_month_start = start_date_dt
    while current_month_start <= end_date_dt:    
        range_properties = []  # List to hold the properties for a specific month
        logger.debug(f"Processing month {current_month_start.date()} - {current_month_start + pd.DateOffset(months=1) - pd.DateOffset(minutes=1)}")
        # Determine the date range for the cache file
        cache_start_date = current_month_start
        cache_end_date = min(current_month_start + pd.DateOffset(months=1), end_date_dt)
        if cache_end_date > current_date_dt:
            cache_end_date = current_date_dt

        cache_start_date_str = cache_start_date.strftime("%Y-%m-%d")
        cache_end_date_str = cache_end_date.strftime("%Y-%m-%d")

        group_data_list = []
        bands_to_download = []
        bands_to_skip = []
        dims = None
        items = None

        if force_download:
            logger.info(f"Force download data for all bands")
            bands_to_download = bands
        else:   
            # Check for each band
            for band in bands:
                metadata_str = f"{collections}_{cache_start_date_str}_{cache_end_date_str}_{limit}_{bbox}_{cloud_cover}_{stac_config}_{resolution}_{crs}_{band}"
                metadata_hash = hashlib.md5(metadata_str.encode()).hexdigest()
                cache_filepath = os.path.join(cache_dir, f"{metadata_hash}.nc")

                if os.path.exists(cache_filepath):
                    bands_to_skip.append(band)
                    band_data = xr.open_dataset(cache_filepath)
                    variable_data =  band_data[band]
                    invalid_values_count = variable_data.isnull().sum() + (variable_data == 0).sum()
                    invalid_values_percent = (invalid_values_count / variable_data.size).values

                    # If more of 5% of the data is NaN or 0 then re-download the data
                    if invalid_values_percent > 0.05 and band != 'wvp':
                        logger.warning(f"Data found in cache for band {band} but {invalid_values_percent*100}% of the data is NaN or 0. Re-downloading data")
                        # Remove the file
                        os.remove(cache_filepath)
                        bands_to_download.append(band)
                        bands_to_skip.remove(band)
                    elif 'time' in band_data.coords:
                        group_data_list.append(band_data)
                    else:
                        logger.error(f"Data found in cache for band {band} but it does not have a 'time' dimension. Skipping")
                        bands_to_download.append(band)
                        bands_to_skip.remove(band)
                else:
                    bands_to_download.append(band)
                
            if len(bands_to_skip) == len(bands):
                logger.debug(f"Data found in cache for all bands. Skipping download")
            elif len(bands_to_skip) > 0:
                logger.info(f"Data found in cache for bands {bands_to_skip}. Downloading data for bands {bands_to_download}")
        
        prop_metadata_str = f"{collections}_{cache_start_date_str}_{cache_end_date_str}_{limit}_{bbox}_{cloud_cover}_{stac_config}_{resolution}_{crs}"
        prop_metadata_hash = hashlib.md5(prop_metadata_str.encode()).hexdigest()
        prop_cache_filepath = os.path.join(cache_dir, f"{prop_metadata_hash}.json")

        if os.path.exists(prop_cache_filepath):
            logger.debug(f"Data found in cache for month {current_month_start.date()} - {current_month_start + pd.DateOffset(months=1) - pd.DateOffset(minutes=1)}")
            with open(prop_cache_filepath, 'r') as f:
                range_properties = json.load(f)

        if bands_to_download or len(range_properties) == 0:
            if collections ==  ["sentinel-2-l2a"]:
                query = catalog.search(
                    collections=collections,
                    datetime=f"{cache_start_date_str}/{cache_end_date_str}",
                    limit=limit,
                    bbox=bbox,
                    #query=[f'eo:cloud_cover<={cloud_cover}']
                )
            else:
                query = catalog.search(
                    collections=collections,
                    datetime=f"{cache_start_date_str}/{cache_end_date_str}",
                    limit=limit,
                    bbox=bbox,
                )
            
            items = list(query.get_items())
            if len(items) != 0:
                #list with item properties to json
                range_properties = [item.properties for item in items]
                logger.info(f"Saving properties for month {current_month_start.date()} - {current_month_start + pd.DateOffset(months=1) - pd.DateOffset(minutes=1)}")
                with open(prop_cache_filepath, 'w') as f:
                    json.dump(range_properties, f)
                
                if collections ==  ["sentinel-2-l1c"]:
                    for i in items:
                        for a in i.assets:
                            i.assets[a].href = i.assets[a].href.replace('sentinel-s2-l2a', 'sentinel-s2-l1c')

                data = stac_load(items, resolution=10, bands=bands_to_download, bbox=bbox, stac_cfg=stac_config, groupby='solar_day', crs=crs)


                #if there is data to download
                if len(data) > 0:
                    # Filter for the month
                    month_data = data.sel(time=slice(cache_start_date_str, cache_end_date_str))
                
                    #print shape (w,h)
                    if dims is None:
                        dims = month_data.dims
                    
                    # Check if the data is the expected shape is smaller than expected then raise an error
                    if month_data.dims['x'] < dims['x'] or month_data.dims['y'] < dims['y']:
                        raise Exception(f"Data for bands {bands_to_download} is smaller than expected. Expected shape: {dims}. Actual shape: {month_data.dims}")

                    # Compute and save each band                    
                    logger.info(f"Saving data for bands {bands_to_download}")   
                    #print only date part             
                    logger.info(f"Dates for bands: {', '.join([date.strftime('%Y-%m-%d') for date in month_data.time.dt.date.values])}")


                    for band in bands_to_download:
                        metadata_str = f"{collections}_{cache_start_date_str}_{cache_end_date_str}_{limit}_{bbox}_{cloud_cover}_{stac_config}_{resolution}_{crs}_{band}"
                        metadata_hash = hashlib.md5(metadata_str.encode()).hexdigest()
                        cache_filepath = os.path.join(cache_dir, f"{metadata_hash}.nc")

                        band_data = month_data[band].compute()

                        if 'time' not in band_data.coords:
                            logger.error(f"Data for band {band} does not have a 'time' dimension. Here are the coordinates: {band_data.coords}")
                            continue
                        
                        # If more of 50% of the data is NaN or 0 print a warning (check by date and band)
                        # Create a list to store the dates that you want to keep
                        valid_dates = []

                        for date in band_data['time']:
                            # Select the data for the current date
                            date_data = band_data.sel(time=date)

                            # Calculate the number of NaN values
                            nan_count = date_data.isnull().sum().values

                            # Calculate the number of 0 values
                            zero_count = (date_data == 0).sum().values

                            # Calculate the total percentage of NaN or zero values
                            invalid_values_percent = (nan_count + zero_count) / date_data.size

                            # Make a decision based on that percentage
                            if invalid_values_percent > 0.05 and band != 'wvp':
                                logger.warning(f"{invalid_values_percent*100}% of the data for band {band} at date {date.values} is NaN or 0. Skipping.")
                            else:
                                valid_dates.append(date.values)

                        # Filter band_data to only the valid dates
                        band_data = band_data.sel(time=valid_dates)

                        if len(band_data.time) == 0:
                            logger.warning(f"No valid data found for band {band}. Skipping")
                            continue
                        
                        logger.info(f"Saving data for band {band} to {cache_filepath}. Shape: {band_data.shape}")
                        
                        band_data.to_netcdf(cache_filepath, engine='netcdf4')
                        group_data_list.append(band_data)
            else:
                logger.warning(f"No data found for bands {bands_to_download}")

        properties += range_properties
        
        if len(group_data_list) == len(bands):
            group_data_list = [ds.drop_vars('spatial_ref') for ds in group_data_list]
            # Combine the bands for this month
            group_data = xr.merge(group_data_list)
            combined_data.append(group_data)
            
        if save_rgb_image and len(bands_to_download) > 0:
            # Check if the RGB image is in group_data_list
            rgb_bands = ['red', 'green', 'blue']
            if all(band in group_data_list for band in rgb_bands):
                # Take the first image in the month
                rgb_data = group_data_list[0][rgb_bands]
        # Move to the next month
        current_month_start += pd.DateOffset(months=1)

    for i, ds in enumerate(combined_data):
        if 'time' not in ds.coords:
            logger.error(f"The dataset for month {i} does not have a 'time' dimension. Here are the coordinates: {ds.coords}")
            print(ds.coords)
        
    # Concatenate the monthly data along the time dimension
    final_data = xr.concat(combined_data, dim='time', coords='minimal')
    # Remove redundant time values
    final_data = final_data.sel(time=~final_data.indexes['time'].duplicated())
    logger.debug("Concatenated data")

    # Initialize an empty list to collect times with NaN values
    nan_times_list = []

    for band in bands:
        data_array = final_data[band]
        numpy_array = data_array.data  # This is a NumPy array
        
        # Check for NaN in the NumPy array
        where_nan = np.isnan(numpy_array)
        
        # Sum along all dimensions to find any slice along 'time' that contains at least one NaN
        nan_along_time = np.any(where_nan, axis=(1, 2))  # Here I assumed the time is the first dimension
        
        # Retrieve the corresponding time values
        times_with_nan = data_array['time'].values[nan_along_time]
        if len(times_with_nan) > 0:            
            #nan_times_list.extend(np.unique(times_with_nan.values))
            nan_times_list.extend(times_with_nan)
            logger.debug(f"Band {band} has NaN values on the following dates: {times_with_nan}")

    # Remove duplicates from the list
    unique_nan_times = list(set(nan_times_list))

    logger.warning(f"Unique dates with NaN values across all bands: {unique_nan_times}")

    # Filter out the dates with NaN values from final_data
    final_data = final_data.sel(time=~final_data['time'].isin(unique_nan_times))

    # Remove elements from properties that are duplicated in datetime
    xr_times = final_data['time'].values.astype(str).tolist()

    seen = set()
    filtered_properties = [seen.add(d['datetime'][:10]) or d for d in properties if d['datetime'][:10] not in seen]
    final_properties = [d for d in filtered_properties if d['datetime'][:19] in {x[:19] for x in xr_times}]
    sorted_final_properties = sorted(final_properties, key=lambda x: x['datetime'])

    logger.debug("End of stac_load")
    return final_data, sorted_final_properties


In [67]:
import xarray as xr
import numpy as np
from typing import List, Tuple

def classify_sunny_cloudy_dates_by_scene_classification(ds: xr.Dataset, thresholds: dict, mask: Optional[np.ndarray] = None) -> Tuple[List[str], List[str]]:
    """
    Classify dates in xarray dataset as sunny or cloudy based on the SCL band and thresholds.

    Args:
    - ds (xr.Dataset): xarray dataset containing Sentinel-2 bands.
    - thresholds (dict): Dictionary containing cloud coverage criteria.
    - mask (np.ndarray, optional): Binary mask to filter the region of interest. Default is None.

    Returns:
    - Tuple[List[str], List[str]]: Two lists of dates categorized as sunny and cloudy.
    """
    
    logger = Logger(min_log_level=LogLevel.Info)
    sunny_dates = []
    cloudy_dates = []

    for t in ds.time.values:
        # Calculate percentages for the specific date
        current_scl = ds['scl'].sel(time=t).data
        
        # Apply the mask to the current_scl if provided
        if mask is not None:
            masked_scl = current_scl * mask
        else:
            masked_scl = current_scl
        
        if mask is not None:
            total_pixels = np.sum(mask)  # Total pixels are the sum of the mask
        else:
            total_pixels = current_scl.size  # Original total pixels

        cloud_cover_percent = 100 * (np.sum((masked_scl == 7) | (masked_scl == 8) | (masked_scl == 9)) / total_pixels)
        low_proba_clouds_percent = 100 * (np.sum(masked_scl == 7) / total_pixels)
        medium_proba_clouds_percent = 100 * (np.sum(masked_scl == 8) / total_pixels)
        high_proba_clouds_percent = 100 * (np.sum(masked_scl == 9) / total_pixels)
        thin_cirrus_percent = 100 * (np.sum(masked_scl == 10) / total_pixels)
        cloud_shadow_percent = 100 * (np.sum(masked_scl == 3) / total_pixels)

        if any([cloud_cover_percent, medium_proba_clouds_percent, high_proba_clouds_percent, thin_cirrus_percent, cloud_shadow_percent]):
            logger.debug(f"Date: {t}. Cloud cover: {cloud_cover_percent}. Medium proba clouds: {medium_proba_clouds_percent}. High proba clouds: {high_proba_clouds_percent}. Thin cirrus: {thin_cirrus_percent}. Cloud shadow: {cloud_shadow_percent}")
        
        properties = {
            'eo:cloud_cover': cloud_cover_percent,
            's2:low_proba_clouds_percentage': low_proba_clouds_percent, 
            's2:medium_proba_clouds_percentage': medium_proba_clouds_percent,
            's2:high_proba_clouds_percentage': high_proba_clouds_percent,
            's2:thin_cirrus_percentage': thin_cirrus_percent,
            's2:cloud_shadow_percentage': cloud_shadow_percent
        }

        # Check thresholds
        if all(properties[key] < value for key, value in thresholds.items()):
            sunny_dates.append(str(t))
        else:
            cloudy_dates.append(str(t))
            logger.trace(f"Cloudy date: {t}")
            for key, value in thresholds.items():
                if properties[key] >= value:
                    logger.debug(f"Cloudy date: {t}. Coverage {key} exceeded threshold {value}, value was {properties[key]}")
                    
    # Convert to datetime64
    sunny_dates = np.array(sunny_dates, dtype='datetime64')
    cloudy_dates = np.array(cloudy_dates, dtype='datetime64')

    return sunny_dates, cloudy_dates


In [68]:
def save_images_with_palette_and_labels(images, directory, filenames, palette, labels, overwrite=False):
    for image, filename in zip(images, filenames):        
        output_path = os.path.join(directory, filename)
        
        # Check if the file exists at the beginning
        if os.path.exists(output_path) and not overwrite:
            #print(f"File {output_path} already exists! Skipping...")
            continue
        
        rgb_image = palette[image]
        unique_labels, counts = np.unique(image, return_counts=True)
        total_pixels = image.size
        patches = [mpatches.Patch(color=np.array(palette[label]) / 255.,
                                  label=f"{labels[label]} [{100 * count / total_pixels:.2f}%]")
                   for label, count in zip(unique_labels, counts)]

        plt.figure(figsize=(10, 10))
        plt.imshow(rgb_image)
        plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
        plt.axis('off')

        plt.savefig(output_path, bbox_inches='tight')
        plt.close()


In [69]:
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

def normalize_image_percentile(image, lower_percentile=1, upper_percentile=99):
    """Normaliza la imagen utilizando percentiles."""
    lower = np.percentile(image, lower_percentile)
    upper = np.percentile(image, upper_percentile)
    image = np.clip(image, lower, upper)
    return (image - lower) / (upper - lower)


def save_images_with_palette_and_labels(images, directory, filenames, palette, labels, 
                                        overwrite=False, visible_images=None, mask=None):
    
    
    # Si se proporcionan imágenes visibles, deben tener la misma longitud que las imágenes originales
    if visible_images is not None:
        assert len(images) == len(visible_images), "Images and visible_images must have the same length"
    
    if mask is not None:
            mask = np.repeat(mask[:, :, np.newaxis], 3, axis=2)

    for idx, (image, filename) in enumerate(zip(images, filenames)):
        output_path = os.path.join(directory, filename)
        
        # Check if the file exists at the beginning
        if os.path.exists(output_path) and not overwrite:
            continue
        
        rgb_image = palette[image]
         
        # Calculate aspect ratio
        height, width = image.shape[:2]
        aspect_ratio = height / width

        # Adjust figure height based on aspect ratio
        if visible_images is not None:
            fig_width = 22
            fig_height = (fig_width * aspect_ratio) / 1.8  # We divide by 2 because two images will be shown side by side
        else:
            fig_width = 10
            fig_height = fig_width * aspect_ratio

        if mask is not None:
            assert rgb_image.shape == mask.shape, "Image and mask must have the same shape"
            rgb_image = rgb_image * mask

        plt.figure(figsize=(fig_width, fig_height))

        unique_labels, counts = np.unique(image, return_counts=True)
        total_pixels = image.size
        patches = [mpatches.Patch(color=np.array(palette[label]) / 255.,
                                  label=f"{labels[label]} [{100 * count / total_pixels:.2f}%]")
                   for label, count in zip(unique_labels, counts)]

        
        # If there's a visible_image in the list, we show it alongside the original image
        if visible_images is not None:
            visible_image = visible_images[idx].transpose(1, 2, 0)
            if mask is not None:
                visible_image = visible_image * mask
            visible_image = normalize_image_percentile(visible_image)  # Normalización usando percentiles
            plt.subplot(1, 2, 1)
            plt.imshow(visible_image)
            plt.axis('off')
            
            plt.subplot(1, 2, 2)

        plt.imshow(rgb_image)
        plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
        plt.axis('off')

        plt.savefig(output_path, bbox_inches='tight')
        plt.close()


In [70]:
import os
import json
import numpy as np

def process_and_save_images(bands, sunny_dates, cloudy_dates, properties, output_folder, mask, overwrite=False):
    """
    Process and save images based on given dates and properties.

    Args:
    bands: xarray dataset containing image bands
    sunny_dates: list of datetime64 objects representing sunny dates
    cloudy_dates: list of datetime64 objects representing cloudy dates
    properties: list of properties used to determine image classification
    output_folder: str, path to output folder
    mask: array-like object used to mask images

    Returns:
    None
    """
    logger = Logger(min_log_level=LogLevel.Debug)

    # Create the visible images
    visible_images = [(img / img.max()).astype(np.float32) for img in 
                      [bands.sel(time=time)[["red", "green", "blue"]].to_array().values for time in bands.time.values]]

    # List of names for the images with dates
    images_names = [f"{str(date)[:10]}.png" for date in bands.time.values]

    # Check if the properties have changed
    json_path = os.path.join(output_folder, 'visible', 'properties.json')
    if os.path.exists(json_path):
        with open(json_path, 'r') as f:
            existing_properties = json.load(f)
            if existing_properties != properties:
                # If properties change, clear content of the 'visible' folder
                for root, dirs, files in os.walk(os.path.join(output_folder, 'visible'), topdown=False):
                    for name in files:
                        os.remove(os.path.join(root, name))
                    for name in dirs:
                        os.rmdir(os.path.join(root, name))
                with open(json_path, 'w') as f:
                    json.dump(properties, f)
    else:
        if not os.path.exists(os.path.join(output_folder, 'visible')):
            os.makedirs(os.path.join(output_folder, 'visible'))
        with open(json_path, 'w') as f:
            json.dump(properties, f)

    sunny_images = []
    sunny_image_names = []
    
    cloudy_images = []
    cloudy_image_names = []
    
    unclassified_images = []
    unclassified_image_names = []

    # Classify images into their corresponding lists
    for img, img_name in zip(visible_images, images_names):
        date_str = img_name.split('.')[0]
        current_date = np.datetime64(date_str, 'D')  # Convert to just date (ignoring time)
        if current_date in [np.datetime64(date, 'D') for date in sunny_dates]:
            sunny_images.append(img)
            sunny_image_names.append(img_name)
        elif current_date in [np.datetime64(date, 'D') for date in cloudy_dates]:
            cloudy_images.append(img)
            cloudy_image_names.append(img_name)
        else:
            unclassified_images.append(img)
            unclassified_image_names.append(img_name)

    # Save images for each category
    if sunny_images:
        save_path = os.path.join(output_folder, 'visible')
        save_images_to_folder(sunny_images, sunny_image_names, mask, save_path, 'sunny', True)
    
    if cloudy_images:
        save_path = os.path.join(output_folder, 'visible')
        save_images_to_folder(cloudy_images, cloudy_image_names, mask, save_path, 'cloudy', True)
    
    if unclassified_images:
        save_path = os.path.join(output_folder, 'visible')
        save_images_to_folder(unclassified_images, unclassified_image_names, mask, save_path, 'unclassified', True)
        
    logger.info('Visible images saved')
        
    if 'scl' in bands.data_vars:
        # Obtener imágenes 'scl' para todas las fechas
        scl_images = [bands.sel(time=time)['scl'].values.astype(np.uint8) for time in bands.time.values]
        
        # Crear nombres scl
        scl_names = [img_name.replace('.png', '_scl.png') for img_name in images_names]

        # Clasificar imágenes 'scl' en sus listas correspondientes
        sunny_scl_images = []
        cloudy_scl_images = []
        unclassified_scl_images = []

        sunny_scl_names = []
        cloudy_scl_names = []
        unclassified_scl_names = []

        for scl_img, img_name, scl_name in zip(scl_images, images_names, scl_names):
            date_str = img_name.split('.')[0]
            current_date = np.datetime64(date_str, 'D')
            if current_date in [np.datetime64(date, 'D') for date in sunny_dates]:
                sunny_scl_images.append(scl_img)
                sunny_scl_names.append(scl_name)
            elif current_date in [np.datetime64(date, 'D') for date in cloudy_dates]:
                cloudy_scl_images.append(scl_img)
                cloudy_scl_names.append(scl_name)
            else:
                unclassified_scl_images.append(scl_img)
                unclassified_scl_names.append(scl_name)

        # Guardar las imágenes 'scl' con sufijo '_scl' en el nombre del archivo
        if sunny_scl_images:
            save_path = os.path.join(output_folder, 'visible', 'sunny')
            save_images_with_palette_and_labels(sunny_scl_images, save_path, sunny_scl_names, land_cover_palette_array, land_cover_labels, visible_images=sunny_images, overwrite=overwrite, mask=mask)
        if cloudy_scl_images:
            save_path = os.path.join(output_folder, 'visible', 'cloudy')
            save_images_with_palette_and_labels(cloudy_scl_images, save_path, cloudy_scl_names, land_cover_palette_array, land_cover_labels, visible_images=cloudy_images, overwrite=overwrite, mask=mask)
        if unclassified_scl_images:
            save_path = os.path.join(output_folder, 'visible', 'unclassified')
            save_images_with_palette_and_labels(unclassified_scl_images, save_path, unclassified_scl_names, land_cover_palette_array, land_cover_labels, visible_images=unclassified_images, overwrite=overwrite, mask=mask)

        logger.info('SCL images saved')


In [71]:
import ipywidgets as widgets
from IPython.display import display
import matplotlib.pyplot as plt
import numpy as np
from datetime import datetime, timedelta
from matplotlib import cm
import matplotlib

def interactive_image_plotter_index(rgb_images, index_images, dates, mask = None, name= None):
    
    vmin_all = np.min(index_images)
    vmax_all = np.max(index_images)
    print(f"Valor mínimo: {vmin_all}")
    print(f"Valor máximo: {vmax_all}")
    
    if mask is not None:
        mask3 = np.repeat(mask[:, :, np.newaxis], 3, axis=2)
   
    def plot_images(date_index, threshold_a, threshold_b):
        """
        Dibuja un par de imágenes RGB e índice basado en el índice de fecha seleccionado.
        """        
        nonlocal selected_date
        selected_date.value = f"Fecha seleccionada: {dates[date_index].strftime('%Y-%m-%d')}"
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
        
        transposed_rgb = rgb_images[date_index].transpose((1, 2, 0))
        transposed_rgb = transposed_rgb * mask3 if mask is not None else transposed_rgb
        transposed_rgb = normalize_image_percentile(transposed_rgb)
        
        # Mostrar imagen RGB
        ax1.imshow(transposed_rgb)
        ax1.set_title(f'{dates[date_index].strftime("%Y-%m-%d")} RGB')
        ax1.axis('off')
        
        index_img = index_images[date_index]
        normed_img = (index_img - threshold_a) / (threshold_b - threshold_a)  # Normalize the data
        my_cmap = matplotlib.colormaps['viridis']
        colored_img = my_cmap(normed_img)
        colored_img = colored_img[:, :, :3]

        threshold_mask = (index_img > threshold_a) & (index_img < threshold_b)
        threshold_mask3 = np.repeat(threshold_mask[:, :, np.newaxis], 3, axis=2)
        
        masked_img = colored_img * threshold_mask3  * mask3 if mask is not None else colored_img * threshold_mask3

        # Mostrar imagen de índice
        cax = ax2.imshow(masked_img, cmap='viridis', vmin=threshold_a, vmax=threshold_b)
        if name is not None:
            ax2.set_title(f'{dates[date_index].strftime("%Y-%m-%d")} {name}')
        else:
            ax2.set_title(f'{dates[date_index].strftime("%Y-%m-%d")} Índice')
        ax2.axis('off')
        cbar = plt.colorbar(cax, ax=ax2)
        cbar.set_label('Valor del Índice', rotation=270, labelpad=20)

        plt.show()
    
    threshold_a = widgets.BoundedFloatText(
        value=vmin_all,  # Inicializar al mínimo global
        min=vmin_all,    # Establecer el mínimo permisible
        max=vmax_all,    # Establecer el máximo permisible
        step=0.01,
        description='Umbral A:',
    )

    threshold_b = widgets.BoundedFloatText(
        value=vmax_all,  # Inicializar al máximo global
        min=vmin_all,    # Establecer el mínimo permisible
        max=vmax_all,    # Establecer el máximo permisible
        step=0.01,
        description='Umbral B:',
    )
    
    # Create slider widget for date selection
    date_slider = widgets.IntSlider(
        value=0,
        min=0,
        max=len(dates) - 1,
        step=1,
        description='Index:',
        continuous_update=False
    )
    date_slider.layout.width = '500px'
    date_slider.index = dates
    
    # Create buttons for easier navigation
    prev_button = widgets.Button(description="Previous")
    next_button = widgets.Button(description="Next")
    
    def on_prev_button_clicked(b):
        date_slider.value = max(0, date_slider.value - 1)

    def on_next_button_clicked(b):
        date_slider.value = min(len(dates) - 1, date_slider.value + 1)
    
   
        
    prev_button.on_click(on_prev_button_clicked)
    next_button.on_click(on_next_button_clicked)
    
    # Create a label to display the selected date
    selected_date = widgets.Label(value="")

    # Create the interactive plot
    interactive_plot = widgets.interactive(plot_images, date_index=date_slider, threshold_a=threshold_a, threshold_b=threshold_b)
    
    # Show the widgets
    display(widgets.HBox([prev_button, next_button]), selected_date, interactive_plot)



In [72]:
import ipywidgets as widgets
from IPython.display import display
import matplotlib.pyplot as plt
import numpy as np
from datetime import datetime, timedelta
from matplotlib import cm
import matplotlib

def interactive_image_plotter(data:xr.Dataset, dates, mask = None):
    
    rgb_images = [(img / img.max()).astype(np.float32) for img in 
                  [data.sel(time=time)[["red", "green", "blue"]].to_array().values for time in data.time.values]]
    
    index_images = None
    
    index_category = None
    index_type = None
    
    # Diccionarios para los índices
    index_categories = {
        "Vegetativo": [e.name for e in VegetationIndex],
        "Fuego": [e.name for e in FireIndex],
        "Agua": [e.name for e in WaterIndex]
    }

    # Dropdown para Tipo de Índice
    index_category_selector = widgets.Dropdown(
        options=["Vegetativo", "Fuego", "Agua"],
        description='Tipo de Índice:',
        disabled=False,
        value=None
    )

    # Dropdown para Índice Específico
    index_type_selector = widgets.Dropdown(
        options=index_categories["Vegetativo"],
        description='Índice Específico:',
        disabled=False,
        value=None
    )

    # Actualizar opciones de índice cuando cambia el tipo de índice
    def update_index_category(change):
        global index_category
        index_category = change['new']
        index_type_selector.options = index_categories[index_category]
        index_type_selector.value = None
        threshold_a.disabled = True
        threshold_b.disabled = True
           
    def update_index_type(change):
        global index_images, index_type, index_category, vmin_all, vmax_all
        index_type = change['new']

        if index_category == "Vegetativo":
            enum_index_type = VegetationIndex[index_type]
            index_images = calculate_vegetation_index_from_xr_dataset(data, enum_index_type)
        elif index_category == "Fuego":
            enum_index_type = FireIndex[index_type]
            index_images = calculate_fire_index_from_xr_dataset(data, enum_index_type)
        elif index_category == "Agua":
            enum_index_type = WaterIndex[index_type]
            index_images = calculate_water_index_from_xr_dataset(data, enum_index_type)
        else:
            index_images = None
            
        index_images = np.nan_to_num(index_images) # Replace NaN with 0

            
        if index_images is not None:
            vmin_all = np.min(index_images)
            vmax_all = np.max(index_images)
            # Actualizar y habilitar los widgets de umbral
            threshold_a.min = vmin_all
            threshold_a.max = vmax_all
            threshold_a.value = vmin_all
            threshold_a.disabled = False
            
            threshold_b.min = vmin_all
            threshold_b.max = vmax_all
            threshold_b.value = vmax_all
            threshold_b.disabled = False
        else:
            # Deshabilitar los widgets de umbral si no hay imágenes de índice
            threshold_a.disabled = True
            threshold_b.disabled = True
            
    index_category_selector.observe(update_index_category, names='value')
    index_type_selector.observe(update_index_type, names='value')
    
    if mask is not None:
        mask3 = np.repeat(mask[:, :, np.newaxis], 3, axis=2)
   
    def plot_images(date_index, threshold_a, threshold_b, name = None):
        """
        Dibuja un par de imágenes RGB e índice basado en el índice de fecha seleccionado.
        """        
        nonlocal selected_date
        selected_date.value = f"Fecha seleccionada: {dates[date_index].strftime('%Y-%m-%d')}"
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
        
        transposed_rgb = rgb_images[date_index].transpose((1, 2, 0))
        transposed_rgb = transposed_rgb * mask3 if mask is not None else transposed_rgb
        transposed_rgb = normalize_image_percentile(transposed_rgb)
        
        # Mostrar imagen RGB
        ax1.imshow(transposed_rgb)
        ax1.set_title(f'{dates[date_index].strftime("%Y-%m-%d")} RGB')
        ax1.axis('off')
        
        if index_images is not None:
            index_img = index_images[date_index]
            normed_img = (index_img - threshold_a) / (threshold_b - threshold_a)  # Normalize the data
            my_cmap = matplotlib.colormaps['viridis']
            colored_img = my_cmap(normed_img)
            colored_img = colored_img[:, :, :3]

            threshold_mask = (index_img > threshold_a) & (index_img < threshold_b)
            threshold_mask3 = np.repeat(threshold_mask[:, :, np.newaxis], 3, axis=2)
            
            masked_img = colored_img * threshold_mask3  * mask3 if mask is not None else colored_img * threshold_mask3

            # Mostrar imagen de índice
            cax = ax2.imshow(masked_img, cmap='viridis', vmin=threshold_a, vmax=threshold_b)
            if name is not None:
                ax2.set_title(f'{dates[date_index].strftime("%Y-%m-%d")} {name}')
            else:
                ax2.set_title(f'{dates[date_index].strftime("%Y-%m-%d")} Índice')
            ax2.axis('off')
            cbar = plt.colorbar(cax, ax=ax2)
            cbar.set_label('Valor del Índice', rotation=270, labelpad=20)
        else:
            ax2.imshow(np.zeros_like(transposed_rgb))
            ax2.set_title("Índice no disponible")
            ax2.axis('off')

        plt.show()
    
    threshold_a = widgets.BoundedFloatText(
        step=0.01,
        description='Umbral A:',
        disabled=True
    )

    threshold_b = widgets.BoundedFloatText(
        step=0.01,
        description='Umbral B:',
        disabled=True
    )
    
    # Create slider widget for date selection
    date_slider = widgets.IntSlider(
        value=0,
        min=0,
        max=len(dates) - 1,
        step=1,
        description='Index:',
        continuous_update=False
    )
    date_slider.layout.width = '500px'
    date_slider.index = dates
    
    # Create buttons for easier navigation
    prev_button = widgets.Button(description="Previous")
    next_button = widgets.Button(description="Next")
    
    def on_prev_button_clicked(b):
        date_slider.value = max(0, date_slider.value - 1)

    def on_next_button_clicked(b):
        date_slider.value = min(len(dates) - 1, date_slider.value + 1)
    
   
        
    prev_button.on_click(on_prev_button_clicked)
    next_button.on_click(on_next_button_clicked)
    
    # Create a label to display the selected date
    selected_date = widgets.Label(value="")

    navigation_box = widgets.HBox([prev_button, next_button])
    index_selector_box = widgets.HBox([index_category_selector, index_type_selector])
    threshold_box = widgets.HBox([threshold_a, threshold_b])


    # Create the interactive plot
    interactive_plot = widgets.interactive_output(plot_images, {'date_index': date_slider, 'threshold_a': threshold_a, 'threshold_b': threshold_b})
    
    # Show the widgets
    display(widgets.VBox([navigation_box, selected_date, interactive_plot, date_slider, index_selector_box, threshold_box]))



In [73]:

def proccess_vegetation_index_from_xr_dataset(bands: xr.Dataset,
                                    dates: list[datetime],
                                    monitoring_start: datetime, 
                                    mask: np.ndarray, 
                                    input_path_folder: str,
                                    output_folder: str, 
                                    normalize: bool, 
                                    vegetation_index: VegetationIndex = VegetationIndex.NDVI, 
                                    index_range_a: float = -1,
                                    index_range_b: float = 1):
    
    visible_images = [(img / img.max()).astype(np.float32) for img in 
                  [bands.sel(time=time)[["red", "green", "blue"]].to_array().values for time in bands.time.values]]
    
    vegetation_indexes = calculate_vegetation_index_from_xr_dataset(bands, vegetation_index) # Calculate vegetation index
    vegetation_indexes = np.nan_to_num(vegetation_indexes) # Replace NaN with 0
   
    if index_range_a is not None and index_range_b is not None:
        vegetation_indexes = np.clip(vegetation_indexes, index_range_a, index_range_b)
    elif index_range_a is not None:
        vegetation_indexes = np.clip(vegetation_indexes, index_range_a, np.max(vegetation_indexes))
    elif index_range_b is not None:
        vegetation_indexes = np.clip(vegetation_indexes, np.min(vegetation_indexes), index_range_b)

    #aply mask to the fire indexes (xr DataArray) if mask not None if not in mask then 0
    if mask is not None:
        vegetation_indexes = [np.where(mask, index, 0) for index in vegetation_indexes]
    
    interactive_image_plotter_index(visible_images, vegetation_indexes, dates, mask, vegetation_index) # Plot images


def proccess_fire_index_from_xr_dataset(bands: xr.Dataset,
                                    dates: list[datetime],
                                    monitoring_start: datetime, 
                                    mask: np.ndarray, 
                                    input_path_folder: str,
                                    output_folder: str, 
                                    normalize: bool, 
                                    fire_index: FireIndex = FireIndex.NBR,
                                    index_range_a: float = -1,
                                    index_range_b: float = 1):
    
    visible_images = [(img / img.max()).astype(np.float32) for img in 
                  [bands.sel(time=time)[["red", "green", "blue"]].to_array().values for time in bands.time.values]]
    fire_indexes = calculate_fire_index_from_xr_dataset(bands, fire_index) # Calculate the fire index
    fire_indexes = np.nan_to_num(fire_indexes) # Replace NaN with 0
    print('Ranges for fire index are:')
    print(f"Valor mínimo: {index_range_a}")
    print(f"Valor máximo: {index_range_b}")
    if index_range_a is not None and index_range_b is not None:
        fire_indexes = np.clip(fire_indexes, index_range_a, index_range_b)
    elif index_range_a is not None:
        fire_indexes = np.clip(fire_indexes, index_range_a, np.max(fire_indexes))
    elif index_range_b is not None:
        fire_indexes = np.clip(fire_indexes, np.min(fire_indexes), index_range_b)

    #aply mask to the fire indexes (xr DataArray) if mask not None if not in mask then 0
    if mask is not None:
        fire_indexes = [np.where(mask, index, 0) for index in fire_indexes]
        
    interactive_image_plotter_index(visible_images, fire_indexes, dates, mask, fire_index) # Plot images
    
def proccess_water_index_from_xr_dataset(bands: xr.Dataset,
                                    dates: list[datetime],
                                    monitoring_start: datetime, 
                                    mask: np.ndarray, 
                                    input_path_folder: str,
                                    output_folder: str, 
                                    normalize: bool, 
                                    water_index: WaterIndex = WaterIndex.MNDWI,
                                    index_range_a: float = -1,
                                    index_range_b: float = 1):
    
    visible_images = [(img / img.max()).astype(np.float32) for img in 
                  [bands.sel(time=time)[["red", "green", "blue"]].to_array().values for time in bands.time.values]]
    water_indexes = calculate_water_index_from_xr_dataset(bands, water_index) # Calculate the fire index
    water_indexes = np.nan_to_num(water_indexes) # Replace NaN with 0
    
    # range selection
    if index_range_a is not None and index_range_b is not None:
        water_indexes = np.clip(water_indexes, index_range_a, index_range_b)
    elif index_range_a is not None:
        water_indexes = np.clip(water_indexes, index_range_a, np.max(water_indexes))
    elif index_range_b is not None:
        water_indexes = np.clip(water_indexes, np.min(water_indexes), index_range_b)

    #aply mask to the fire indexes (xr DataArray) if mask not None if not in mask then 0
    if mask is not None:
        water_indexes = [np.where(mask, index, 0) for index in water_indexes]
        
    
        
    interactive_image_plotter_index(visible_images, water_indexes, dates, mask, water_index) # Plot images
    

def proccess_index_from_xr_dataset(bands: xr.Dataset,
                                    dates: list[datetime],
                                    monitoring_start: datetime, 
                                    mask: np.ndarray, 
                                    input_path_folder: str,
                                    output_folder: str, 
                                    normalize: bool):

    interactive_image_plotter(bands, dates, mask) # Plot images


In [74]:
from pystac_client import Client
from odc.stac import configure_rio, stac_load

catalog = Client.open("https://earth-search.aws.element84.com/v1")
client = dask.distributed.Client()
configure_rio(cloud_defaults=True, aws={"aws_unsigned": True}, client=client)
display(client)

Perhaps you already have a cluster running?
Hosting the HTTP server on port 35401 instead


0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:35401/status,

0,1
Dashboard: http://127.0.0.1:35401/status,Workers: 4
Total threads: 12,Total memory: 15.55 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:42741,Workers: 4
Dashboard: http://127.0.0.1:35401/status,Total threads: 12
Started: Just now,Total memory: 15.55 GiB

0,1
Comm: tcp://127.0.0.1:32887,Total threads: 3
Dashboard: http://127.0.0.1:45773/status,Memory: 3.89 GiB
Nanny: tcp://127.0.0.1:37591,
Local directory: /tmp/dask-scratch-space/worker-8p2k968g,Local directory: /tmp/dask-scratch-space/worker-8p2k968g

0,1
Comm: tcp://127.0.0.1:42777,Total threads: 3
Dashboard: http://127.0.0.1:40259/status,Memory: 3.89 GiB
Nanny: tcp://127.0.0.1:35837,
Local directory: /tmp/dask-scratch-space/worker-ofs1gtt8,Local directory: /tmp/dask-scratch-space/worker-ofs1gtt8

0,1
Comm: tcp://127.0.0.1:36665,Total threads: 3
Dashboard: http://127.0.0.1:38801/status,Memory: 3.89 GiB
Nanny: tcp://127.0.0.1:46045,
Local directory: /tmp/dask-scratch-space/worker-in8fjftq,Local directory: /tmp/dask-scratch-space/worker-in8fjftq

0,1
Comm: tcp://127.0.0.1:43123,Total threads: 3
Dashboard: http://127.0.0.1:39583/status,Memory: 3.89 GiB
Nanny: tcp://127.0.0.1:37811,
Local directory: /tmp/dask-scratch-space/worker-ih_ke6b5,Local directory: /tmp/dask-scratch-space/worker-ih_ke6b5


In [75]:
#Get the start date and end date from the user (year and month)
#history_start  = datetime(2020, 1, 1)
history_start  = datetime(2022, 1, 1)
monitoring_start = datetime(2022, 12, 30)
#history_end = datetime(2022, 10, 31)
history_end = datetime(2023, 4, 30)
ignore_ranges = []
ignore_ranges.append([datetime(2022, 1, 30), datetime(2022, 2, 28)])

In [76]:
from shapely.geometry import Polygon, MultiPolygon, Point
import importlib

stac_cfg = {
    'sentinel-2-l2a': {
        'assets': {'*': {'data_type': 'uint16', 'nodata': 0}},
    }
}

ndvi_thresholds = [None, -0.1, -0.01]
rvi_thresholds = [None, -0.1, 0, 0.1, 0.25, 1]
savi_thresholds = [None, -0.1, 0, 0.1, 1]
arvi_thresholds = [None, -0.1, 0, 0.1, 1]
bai_thresholds = [0.01, 0.1, 1]
nbr_thresholds = [None, -0.1, 0]
afi_thresholds = [None,-0.15]
msrif_thresholds = [None, 0, None]
#bands = ['coastal', 'blue', 'green', 'red', 'nir', 'nir08', 'nir09', 'rededge1', 'rededge2', 'rededge3', 'scl', 'swir16', 'swir22']
bands = ['blue', 'green', 'red', 'nir', 'nir08', 'nir09', 'rededge1', 'rededge2', 'rededge3', 'scl', 'swir16', 'swir22']



thresholds = {
    'eo:cloud_cover': 20,
    's2:medium_proba_clouds_percentage': 10,
    's2:high_proba_clouds_percentage': 5,
    's2:low_proba_clouds_percentage': 20,
    's2:thin_cirrus_percentage': 10,
    's2:cloud_shadow_percentage': 5
}

In [77]:
import ipywidgets as widgets
from IPython.display import display

def process_region(zone, region, index_type=None, index_range_a = None, index_range_b = None):
    input_path = make_dir(["./data/input", zone, region])
    output_path = make_dir(["./data/output", zone, region])
    polygon = zone_dict[zone][region]["polygon"]
        
    data, properties = fetch_or_cache_stac_data_by_band_and_month(catalog, 
                                                    collections=["sentinel-2-l2a"],
                                                    bands=bands,
                                                    start_date=history_start.strftime("%Y-%m-%d"),
                                                    end_date= history_end.strftime("%Y-%m-%d"), 
                                                    limit=1000,
                                                    bbox=polygon.bounds,
                                                    resolution=10,
                                                    cloud_cover=100,
                                                    stac_config=stac_cfg,
                                                    #crs="EPSG:4326",
                                                    crs = "EPSG:3857",
                                                    cache_dir=input_path,
                                                    force_download=False,
                                                    log_level= LogLevel.Info,)    
    
    mask, mask3 = get_kml_polygon_masks(polygon, data['red'].shape[2], data['red'].shape[1])
        
    data = data.sel(time=[t for t in data.time.values if all(t < np.datetime64(start) or t > np.datetime64(end) for start, end in ignore_ranges)])
    xr_times = data['time'].values.astype(str).tolist()
    properties = [d for d in properties if d['datetime'][:19] in {x[:19] for x in xr_times}]
    
    sunny_dates, cloudy_dates = classify_sunny_cloudy_dates_by_scene_classification(data, thresholds, mask)        
    
    # remove cloudy images
    data = data.sel(time=[t for t in data.time.values if t in sunny_dates])
    dates = [pd.Timestamp(date).to_pydatetime().replace(hour=0, minute=0, second=0, microsecond=0)  for date in data['time'].values]
        
    if index_type is None:
        proccess_index_from_xr_dataset(data, dates, monitoring_start, mask, input_path, output_path, True)
    else:
        if index_type in [e.name for e in VegetationIndex]:
            enum_index_type = VegetationIndex[index_type]
            proccess_vegetation_index_from_xr_dataset(data, dates, monitoring_start, mask, input_path, output_path, True, enum_index_type, index_range_a, index_range_b)
        elif index_type in [e.name for e in FireIndex]:
            enum_index_type = FireIndex[index_type]
            proccess_fire_index_from_xr_dataset(data, dates, monitoring_start, mask, input_path, output_path, True, enum_index_type, index_range_a, index_range_b)
        elif index_type in [e.name for e in WaterIndex]:
            enum_index_type = WaterIndex[index_type]
            proccess_water_index_from_xr_dataset(data, dates, monitoring_start, mask, input_path, output_path, True, enum_index_type, index_range_a, index_range_b)    


In [78]:
from IPython.display import clear_output, display

# Diccionarios para los índices
index_categories = {
    "Vegetativo": [e.name for e in VegetationIndex],
    "Fuego": [e.name for e in FireIndex],
    "Agua": [e.name for e in WaterIndex]
}

# Dropdown para Tipo de Índice
type_selector = widgets.Dropdown(
    options=["Vegetativo", "Fuego", "Agua"],
    description='Tipo de Índice:',
    disabled=False,
)

# Dropdown para Índice Específico
index_selector = widgets.Dropdown(
    options=index_categories["Vegetativo"],
    description='Índice Específico:',
    disabled=False,
)

# Actualizar opciones de índice cuando cambia el tipo de índice
def update_indices(change):
    index_selector.options = index_categories[change['new']]

type_selector.observe(update_indices, names='value')


zone_widget = widgets.Dropdown(
    options=zone_dict.keys(),
    description='Zone:',
    disabled=False
)

region_widget = widgets.Dropdown(
    options=zone_dict[zone_widget.value].keys(),
    description='Region:',
    disabled=False
)

def update_region_options(change):
    region_widget.options = zone_dict[change['new']].keys()

zone_widget.observe(update_region_options, names='value')

# Botón para procesar
process_button = widgets.Button(description="Process")

# Función para manejar evento del botón
def on_process_button_clicked(b):  
    clear_output(wait=True)  # Limpiar la salida anterior
    display(type_selector, index_selector, zone_widget, region_widget, process_button)  # Volver a mostrar los widgets
    process_region(zone_widget.value, region_widget.value, index_selector.value)

process_button.on_click(on_process_button_clicked)

# Mostrar widgets
display(type_selector, index_selector, zone_widget, region_widget, process_button)

Dropdown(description='Tipo de Índice:', options=('Vegetativo', 'Fuego', 'Agua'), value='Vegetativo')

Dropdown(description='Índice Específico:', options=('DVI', 'RVI', 'PVI', 'IPVI', 'WDVI', 'TNDVI', 'GNDVI', 'GE…

Dropdown(description='Zone:', options=('Bosques Bio Bio', 'Incendios', 'Bosques Arauco'), value='Bosques Bio B…

Dropdown(description='Region:', options=('Bosque 1', 'Bosque 2', 'Bosque 3'), value='Bosque 1')

Button(description='Process', style=ButtonStyle())

In [79]:
zone_widget = widgets.Dropdown(
    options=zone_dict.keys(),
    description='Zone:',
    disabled=False
)

region_widget = widgets.Dropdown(
    options=zone_dict[zone_widget.value].keys(),
    description='Region:',
    disabled=False
)

def update_region_options(change):
    region_widget.options = zone_dict[change['new']].keys()

zone_widget.observe(update_region_options, names='value')

# Botón para procesar
process_button = widgets.Button(description="Process")
status_label = widgets.Label(value="")

# Función para manejar evento del botón
def on_process_button_clicked(b):  
    process_button.disabled = True
    status_label.value = "Procesando..."
    process_region(zone_widget.value, region_widget.value)
    process_button.disabled = False
    status_label.value = "Procesamiento completado"
    
process_button.on_click(on_process_button_clicked)

# Mostrar widgets
display(zone_widget, region_widget, process_button, status_label)

Dropdown(description='Zone:', options=('Bosques Bio Bio', 'Incendios', 'Bosques Arauco'), value='Bosques Bio B…

Dropdown(description='Region:', options=('Bosque 1', 'Bosque 2', 'Bosque 3'), value='Bosque 1')

Button(description='Process', style=ButtonStyle())

Label(value='')

In [80]:
def get_images_by_zone(zone, region, only_sunny_dates = True):
    input_path = make_dir(["./data/input", zone, region])
    output_path = make_dir(["./data/output", zone, region])
    polygon = zone_dict[zone][region]["polygon"]
        
    data, properties = fetch_or_cache_stac_data_by_band_and_month(catalog, 
                                                    collections=["sentinel-2-l2a"],
                                                    bands=bands,
                                                    start_date=history_start.strftime("%Y-%m-%d"),
                                                    end_date= history_end.strftime("%Y-%m-%d"), 
                                                    limit=1000,
                                                    bbox=polygon.bounds,
                                                    resolution=10,
                                                    cloud_cover=100,
                                                    stac_config=stac_cfg,
                                                    #crs="EPSG:4326",
                                                    crs = "EPSG:3857",
                                                    cache_dir=input_path,
                                                    force_download=False,
                                                    log_level= LogLevel.Info,)    
    
    #mask, mask3 = get_kml_polygon_masks(polygon, data['red'].shape[2], data['red'].shape[1])
        
    data = data.sel(time=[t for t in data.time.values if all(t < np.datetime64(start) or t > np.datetime64(end) for start, end in ignore_ranges)])
    xr_times = data['time'].values.astype(str).tolist()
    properties = [d for d in properties if d['datetime'][:19] in {x[:19] for x in xr_times}]
    mask, mask3 = get_kml_polygon_masks(polygon, data['red'].shape[2], data['red'].shape[1])
    if(only_sunny_dates):
        sunny_dates, cloudy_dates = classify_sunny_cloudy_dates_by_scene_classification(data, thresholds, mask)        
        data = data.sel(time=[t for t in data.time.values if t in sunny_dates])
    
    dates = [pd.Timestamp(date).to_pydatetime().replace(hour=0, minute=0, second=0, microsecond=0)  for date in data['time'].values]
    return data, dates, polygon, properties

In [81]:
from enum import Enum

class ComparisonType(Enum):
    GREATER_THAN = "greater_than"
    LESS_THAN = "less_than"

class ImageMaskConfiguration:
    def __init__(self, enable=False, index_category=0, index_type=0, comparison=ComparisonType.GREATER_THAN, threshold=0.0):
        self._enable = enable
        self._index_category = index_category
        self._index_type = index_type
        self._comparison = comparison
        self._threshold = threshold
        
    @property
    def enable(self):
        """Getter for enable attribute"""
        return self._enable

    @enable.setter
    def enable(self, value):
        """Setter for enable attribute with basic validation"""
        if isinstance(value, bool):
            self._enable = value
        else:
            print("Enable should be a boolean value.")
        
    @property
    def index_category(self):
        """Getter for index_category attribute"""
        return self._index_category

    @index_category.setter
    def index_category(self, value):
        """Setter for index_category attribute with basic validation"""
        if isinstance(value, int):
            self._index_category = value
        else:
            print("IndexType should be an integer.")
            
    @property
    def index_type(self):
        """Getter for index_type attribute"""
        return self._index_type

    @index_type.setter
    def index_type(self, value):
        """Setter for index_type attribute with basic validation"""
        if isinstance(value, int):
            self._index_type = value
        else:
            print("IndexValue should be an integer.")
            
    @property
    def comparison(self):
        """Getter for comparison attribute"""
        return self._comparison

    @comparison.setter
    def comparison(self, value):
        """Setter for comparison attribute with validation against Enum"""
        if isinstance(value, ComparisonType):
            self._comparison = value
        else:
            print("Invalid comparison. Please use a value from ComparisonType Enum.")
            
    @property
    def threshold(self):
        """Getter for threshold attribute"""
        return self._threshold

    @threshold.setter
    def threshold(self, value):
        """Setter for threshold attribute with basic validation"""
        if isinstance(value, float):
            self._threshold = value
        else:
            print("Threshold should be a float.")
            
    def print(self):
        print(f"Enable: {self.enable}")
        print(f"IndexType: {self.index_category}")
        print(f"IndexValue: {self.index_type}")
        print(f"Comparison: {self.comparison}")
        print(f"Threshold: {self.threshold}")


In [82]:
def calculate_mask_image(data: xr.Dataset, date, image_mask_config: ImageMaskConfiguration):
    
    date_data = data.sel(time=date, method='nearest')
    
    if image_mask_config.enable:
        index = None
        if image_mask_config.index_category == IndexCategory.VEGETATION.value:
            enum_index_type = VegetationIndex(list(VegetationIndex)[image_mask_config.index_type])
            index = calculate_vegetation_index_from_xr_dataset(date_data, enum_index_type)
        elif image_mask_config.index_category == IndexCategory.FIRE.value:
            enum_index_type = FireIndex(list(FireIndex)[image_mask_config.index_type])
            index = calculate_fire_index_from_xr_dataset(date_data, enum_index_type)
        elif image_mask_config.index_category == IndexCategory.WATER.value:
            enum_index_type = WaterIndex(list(WaterIndex)[image_mask_config.index_type])
            index = calculate_water_index_from_xr_dataset(date_data, enum_index_type)
        
        if index is None:
            raise ValueError("Index calculation failed or index type not recognized.")

        if image_mask_config.comparison == ComparisonType.GREATER_THAN:
            mask_by_index = index > image_mask_config.threshold
        else:
            mask_by_index = index < image_mask_config.threshold
        return mask_by_index, index
    else:
        return None, None

def calculate_mask_difference_of_index_images(data: xr.Dataset, date, date_prev, image_mask_config: ImageMaskConfiguration):
    date_data = data.sel(time=date, method='nearest')
    date_prev_data = data.sel(time=date_prev, method='nearest')
    
    if image_mask_config.enable:
        index = None
        index_prev = None
        
        if image_mask_config.index_category == IndexCategory.VEGETATION.value:
            enum_index_type = VegetationIndex(list(VegetationIndex)[image_mask_config.index_type])
            index = calculate_vegetation_index_from_xr_dataset(date_data, enum_index_type)
            index_prev = calculate_vegetation_index_from_xr_dataset(date_prev_data, enum_index_type)
        elif image_mask_config.index_category == IndexCategory.FIRE.value:
            enum_index_type = FireIndex(list(FireIndex)[image_mask_config.index_type])
            index = calculate_fire_index_from_xr_dataset(date_data, enum_index_type)
            index_prev = calculate_fire_index_from_xr_dataset(date_prev_data, enum_index_type)
        elif image_mask_config.index_category == IndexCategory.WATER.value:
            enum_index_type = WaterIndex(list(WaterIndex)[image_mask_config.index_type])
            index = calculate_water_index_from_xr_dataset(date_data, enum_index_type)
            index_prev = calculate_water_index_from_xr_dataset(date_prev_data, enum_index_type)
        
        if index is None or index_prev is None:
            raise ValueError("Index calculation failed or index type not recognized.")
        
        index_difference = index - index_prev
        
        if image_mask_config.comparison == ComparisonType.GREATER_THAN:
            mask_by_index = index_difference > image_mask_config.threshold
        else:
            mask_by_index = index_difference < image_mask_config.threshold
            
        return mask_by_index, index_difference
    else:
        return None, None
        


In [83]:
from shapely.affinity import affine_transform
import numpy as np
from skimage.measure import find_contours
from shapely.geometry import Polygon, mapping
import geopandas as gpd

def get_transformation_matrix(polygon_bound, mask_shape):
    """
    Calculate the transformation matrix based on the polygon_bound and mask_shape.
    """
    minx, miny, maxx, maxy = polygon_bound.bounds  # Obtaining the bounds of the polygon
    print(f"Bounds: {minx}, {miny}, {maxx}, {maxy}")
    y_pixels, x_pixels = mask_shape
    print(f"Shape: {x_pixels}, {y_pixels}")

    x_scale = (maxx - minx) / x_pixels
    y_scale = (maxy - miny) / y_pixels
    print(f"Scales: {x_scale}, {y_scale}")

    # Crear una matriz de transformación para convertir las coordenadas de píxeles a coordenadas geográficas
    # [x_scale, 0, 0, y_scale, minx, miny]
    transformation_matrix = [x_scale, 0, 0, -y_scale, minx, maxy]  

    return transformation_matrix


def get_polygons_from_mask(mask, label, polygon_bound, zone, region, type_index, date, ref_img_date = None, ref_filename = None):
    """
    Convert a mask to a list of polygons with metadata.

    Parameters:
    mask (numpy array or xarray): A binary mask to convert to polygons.
    polygon_bound (Polygon): The bounding polygon.

    Returns:
    list: A list of transformed polygons.
    """
    polygons = []

    # Check if mask is an xarray, and if so, extract the numpy array
    if hasattr(mask, 'values'):
        mask_values = mask.values
    else:
        mask_values = mask

    mask_values = np.transpose(mask_values)

    contours = find_contours(mask_values, level=0.5)
    transformation_matrix = get_transformation_matrix(polygon_bound, mask.shape)  # Define this function

    for contour in contours:
        # Convert contour coordinates to polygon
        poly = Polygon(contour)

        # Apply the transformation
        poly_transformed = affine_transform(poly, transformation_matrix)
        polygons.append({
            'geometry': poly_transformed,
        })

    polygons = associate_holes(polygons)  # Function to classify polygons as 'outer' or 'hole'

    for poly_data in polygons:
        poly_data.update({
            'zone': zone,
            'region': region,
            'type_index': type_index,
            'date': date,  # Adding date here
            'label': label,
            'ref_filename': ref_filename
        })
        if ref_img_date is not None:
            poly_data.update({
                'ref_date': ref_img_date
            })

    return polygons

def associate_holes(polygons):
    """
    Associate holes with their corresponding outer polygons.
    """
    associated_polygons_data = []
    
    for poly_data in polygons:
        poly = poly_data['geometry']  # Accessing the polygon

        # Identifying outer polygons
        if not any(poly.within(p['geometry']) for p in polygons if p != poly_data):
            # If it's an outer polygon, check for holes
            holes = [p['geometry'] for p in polygons if p['geometry'].within(poly) and p != poly_data]
            
            # Create a new polygon with the identified holes
            if holes:
                holes_coords = [hole.exterior.coords[:] for hole in holes]
                new_poly = Polygon(poly.exterior.coords[:], holes=holes_coords)
            else:
                new_poly = poly
            
            # Keep the metadata and update the geometry
            new_poly_data = poly_data.copy()  # Copy the metadata
            new_poly_data['geometry'] = new_poly  # Update the geometry
            associated_polygons_data.append(new_poly_data)
                
    return associated_polygons_data


def filter_polygons_by_size(polygons, min_size):
    """
    Filter polygons by size.
    """
    return [p for p in polygons if p.area >= min_size]


def save_mask_image(label, zone, region, date, polygon, mask, type_index, output_path='./data/output', ref_img_date = None, tiff_filename = None):
    """
    Save mask image as a shapefile containing various polygons.
    """

    print(f"Parameters: {zone}, {region}, {date}, {type_index}")
    
    # Formatting the date to exclude time
    formatted_date = date.strftime('%Y-%m-%d')
    if ref_img_date is not None:
        formatted_ref_img_date = ref_img_date.strftime('%Y-%m-%d')
    else:
        formatted_ref_img_date = None
    
    # Convert mask to polygons and get metadata
    polygons_with_metadata = get_polygons_from_mask(mask, label, polygon, zone, region, type_index, formatted_date, formatted_ref_img_date, tiff_filename)

    # Filter polygons by size
    # valid_polygons = filter_polygons_by_size(polygons_with_metadata, min_size)
    # Creating a GeoDataFrame
    gdf = gpd.GeoDataFrame(polygons_with_metadata, crs='EPSG:4326')

    # Creating the filename
    filename = f"{type_index}_{formatted_date}.shp"

    # Ensure the directory exists
    dir_path = f"{output_path}/labels/{zone}/{region}/{label}"
    os.makedirs(dir_path, exist_ok=True)
    
    # Saving the GeoDataFrame as a Shapefile
    filepath = f"{output_path}/labels/{zone}/{region}/{label}/{filename}"

    gdf.to_file(filepath)

In [84]:
import numpy as np
from scipy.ndimage import label, sum

def remove_small_objects(mask, min_size):
    """
    Remove connected components that are smaller than min_size from the binary mask.
    
    Parameters:
    mask (numpy.ndarray): binary mask where objects are represented by ones and zeros.
    min_size (int): minimum size (number of pixels) that a connected component must have to be kept.
    
    Returns:
    numpy.ndarray: a new binary mask with the small connected components removed.
    """
    
    # Label each connected component in the mask
    labeled_mask, num_labels = label(mask)
    
    # For each labeled object, calculate its size
    sizes = sum(mask, labeled_mask, range(1, num_labels+1))
    
    # Create a new mask where small objects are removed
    new_mask = np.zeros_like(mask)
    removed = 0
    kept = 0
    for label_num, size in enumerate(sizes):
        if size >= min_size:
            new_mask[labeled_mask == label_num+1] = 1
            kept += 1
        else:
            removed += 1
            
    return new_mask, kept, removed

In [91]:
import os
import rasterio
from rasterio.transform import from_bounds
import xarray as xr

def save_tiff(dataset, date, output_path, base_filename, user_metadata):    

    ds_selected = dataset.sel(time=date, method='nearest')
    bands = [ds_selected[banda].values for banda in ds_selected.data_vars]

    # Asumiendo que las coordenadas son latitud y longitud
    lon_min, lat_min, lon_max, lat_max = ds_selected.longitude.min(), ds_selected.latitude.min(), ds_selected.longitude.max(), ds_selected.latitude.max()
    
    height, width = len(ds_selected.latitude), len(ds_selected.longitude)
    
    transform = from_bounds(west=lon_min, south=lat_min, east=lon_max, north=lat_max, width=width, height=height)
    
    output_filename = f"{base_filename}_{date.strftime('%Y-%m-%d')}.tif"
    output_filepath = os.path.join(output_path, output_filename)
    
    metadata = {
    'driver': 'GTiff',
    'count': len(bands),  # número de bandas
    'dtype': bands[0].dtype,  # asegúrate de que todas las bandas tengan el mismo dtype
    'width': bands[0].shape[1],
    'height': bands[0].shape[0],
    'crs': 'EPSG:4326',  # o el CRS que estés utilizando,
    'transform': transform
    }
    metadata['tags'] = user_metadata
    metadata['date'] = date.strftime('%Y-%m-%d')
    
    
    with rasterio.open(
            output_filepath,
            'w',
            driver='GTiff',
            **metadata
    ) as dst:
        for i, var_name in enumerate(ds_selected.data_vars):
            dst.write(ds_selected[var_name].data, i+1)
            dst.set_band_description(i+1, var_name)
            
    print(f"File saved to: {output_filepath}")
    return output_filepath


In [86]:

from IPython.display import display
import time
def process_images_mask(zone, region, data, dates, polygon, properties, visible_images, num_index_to_process = 3, image_index = 0, prev_image_index = -1, labels = None):

    number_inputs = []
    mask, mask3 = get_kml_polygon_masks(polygon, data['red'].shape[2], data['red'].shape[1])
    index_categories = {
        "Vegetativo": [e.name for e in VegetationIndex],
        "Fuego": [e.name for e in FireIndex],
        "Agua": [e.name for e in WaterIndex]
    }

    radiobuttons = widgets.RadioButtons(
        options=['Imagen individual', 'Calcular máscaras primero', 'Calcular máscaras después'],
        description='Método:',
        disabled=False
    )
    
    
    def on_change_radio_buttons(change):
        if change['new'] == 'Imagen individual':
            image_type_category.options = ["Imagen final"]
            image_type_category.index = 0
        if change['new'] == 'Calcular máscaras primero':
            image_type_category.options = ["Imagen final", "Imagen de referencia", "Diferencia"]
            image_type_category.index = 2
        else:
            image_type_category.options = ["Imagen final", "Imagen de referencia"]
            image_type_category.index = 0
                
    radiobuttons.observe(on_change_radio_buttons, names='value')
                
    image_type_category = widgets.Dropdown(
            options=["Imagen final", "Imagen de referencia", "Diferencia"],
            description='Imagen:',
            disabled=False,
        )
    # Dropdown para seleccionar el color
    color_dropdown = widgets.Dropdown(
        options=['Red', 'Green', 'Blue', 'Orange', 'Yellow', 'Purple'],
        value='Orange',
        description='Color mascara:'
    )

    # Slider para seleccionar la transparencia (alpha)
    alpha_slider = widgets.FloatSlider(
        value=0.5,
        min=0,
        max=1.0,
        step=0.01,
        description='Transparencia:',
        continuous_update=False
    )

    configurations = [ImageMaskConfiguration() for _ in range(num_index_to_process)]
    trigger_widget = widgets.Checkbox(value=False, layout=widgets.Layout(display='none'))

    def on_value_change(change):
        new_value = change['new']

        if change['owner'].description == 'Habilitar':
            config = change['owner'].config_ref
            config.enable = new_value
        elif change['owner'].description == 'Índice Específico':
            config = change['owner'].config_ref
            new_index = change['owner'].index 
            config.index_type = new_index
        elif change['owner'].description == 'Comparación':
            config = change['owner'].config_ref
            comparison_mapping = {
                "Menor que": ComparisonType.LESS_THAN,
                "Mayor que": ComparisonType.GREATER_THAN
            }
            config.comparison = comparison_mapping.get(new_value, None)
            if config.comparison is None:
                print(f"Invalid comparison value: {new_value}")
        elif change['owner'].description == 'Umbral':
            config = change['owner'].config_ref
            config.threshold = new_value
            if change['owner'].manual_update == True:
                return
            
        trigger_widget.value = not trigger_widget.value

            
    def create_widgets_for_configuration(config, on_value_change):
        enable_widget = widgets.Checkbox(value=config.enable, description='Habilitar')
        type_selector = widgets.Dropdown(options=["Vegetativo", "Fuego", "Agua"], description='Tipo de Índice')
        index_selector = widgets.Dropdown(options=index_categories["Vegetativo"], description='Índice Específico')
        comparison_selector = widgets.Dropdown(options=["Mayor que", "Menor que"], description='Comparación')
        number_input = widgets.BoundedFloatText(value=0, min=0, max=1000, step=0.1, description='Umbral')
        
        enable_widget.config_ref = config
        type_selector.config_ref = config
        index_selector.config_ref = config
        comparison_selector.config_ref = config
        number_input.config_ref = config
        number_input.manual_update = False
            
        def update_indices(change):
            index_selector.options = index_categories[change['new']]
            config = change['owner'].config_ref
            new_index = change['owner'].index  
            config.index_category = new_index
            
        type_selector.observe(update_indices, names='value')
        
        def toggle_widgets(change):
            state = change['new']
            type_selector.disabled = not state
            index_selector.disabled = not state
            comparison_selector.disabled = not state
            number_input.disabled = not state
        
        enable_widget.observe(toggle_widgets, names='value')

        enable_widget.observe(on_value_change, names='value')
        type_selector.observe(on_value_change, names='value')
        index_selector.observe(on_value_change, names='value')
        comparison_selector.observe(on_value_change, names='value')
        number_input.observe(on_value_change, names='value')
        number_inputs.append(number_input)
        return [enable_widget, type_selector, index_selector, comparison_selector, number_input]
    
    color_dropdown.observe(on_value_change, names='value')
    alpha_slider.observe(on_value_change, names='value')

    PIXEL_AREA = 100  # 10m * 10m = 100m²
    min_pixels = 500 / PIXEL_AREA  # Convert area to number of pixels
    
    
    #Slider to select the minimum area (in m²) of the objects to keep
    min_area_slider = widgets.IntSlider(
        value=10,  # Initial value
        min=0,  # Min value
        max=1000,  # Max value
        step=10,  # Step size
        description='Min px:',  # Description or label for the slider
        continuous_update=False  # Update only when the slider stops moving
    )
    
    status_label = widgets.Label(value="")

    def update_threshold_range(position, min_value, max_value):
        if(position <= len(number_inputs) - 1):
            number_input = number_inputs[position]  # Accede al widget específico basado en la posición
            number_input.manual_update = True  # Indica que el usuario cambió el valor
            number_input.min = min_value  # Ajusta el valor min
            number_input.max = max_value  # Ajusta el valor max
            number_input.value = (min_value + max_value) / 2  # Ajusta el valor a la mitad del rango
            time.sleep(0.1)  # Espera 0.1 segundos
            number_input.manual_update = False  # Indica que el usuario no cambió el valor
            
    color_mapping = {
            'Red': [1, 0, 0],
            'Green': [0, 1, 0],
            'Blue': [0, 0, 1],
            'Orange': [1, 0.5, 0],
            'Yellow': [1, 1, 0],
            'Purple': [0.5, 0, 0.5]
        }
    
    previous_trigger = None
    mask_image = None
    mask_prev_image = None
    mask_diff = None
    mask_calculated = False
    mask_final = None
    image_date = None 
    image_type_index = 0
    objects_kept = 0
    objects_removed = 0
    def plot_image_mask(trigger, image_type, min_area, radio_button_value):
        combined_mask = None
        nonlocal previous_trigger, mask_calculated, mask_image, mask_prev_image, mask_diff, image_date, image_type_index, mask_final, objects_kept, objects_removed
        print(radio_button_value)
        if previous_trigger is None or trigger != previous_trigger:
            mask_calculated = False
            mask_image = None
            mask_prev_image = None
            mask_diff = None
            previous_trigger = trigger
                 
        image_type_index = image_type_category.options.index(image_type)

        if(radio_button_value == 'Imagen individual'):
            if mask_calculated == False:
                image_date = dates[image_index]
                for i, config in enumerate(configurations):
                    if config.enable == True:
                        mask_index, image_for_index = calculate_mask_image(data, image_date, config)
                        mask_index = np.logical_and(mask == 1, mask_index == 1)
                        max_index_img = image_for_index.max().item()
                        min_index_img = image_for_index.min().item()
                        update_threshold_range(i, min_index_img, max_index_img)  
                        mask_image = mask_index if mask_image is None else mask_image * mask_index
            img_index = image_index
            img_mask = mask_image
        elif(radio_button_value == 'Calcular máscaras primero'):   
            if mask_calculated == False:
                image_date = dates[image_index]
                prev_image_date = dates[prev_image_index]

                for i, config in enumerate(configurations):
                    if(config.enable):
                        mask_index, image_for_index = calculate_mask_image(data, image_date, config)                
                        mask_index = np.logical_and(mask == 1, mask_index == 1)
                        max_index_img = image_for_index.max().item()
                        min_index_img = image_for_index.min().item()
                        update_threshold_range(i, min_index_img, max_index_img)  
                        mask_image = mask_index if mask_image is None else mask_image * mask_index
      
                for i, config in enumerate(configurations):
                    if(config.enable):
                        mask_index, image_for_index = calculate_mask_image(data, prev_image_date, config)
                        mask_index = np.logical_and(mask == 1, mask_index == 1)
                        mask_prev_image = mask_index if mask_prev_image is None else mask_prev_image * mask_index
                        mask_diff = mask_image ^ mask_prev_image
            if(image_type_index == 0):
                img_index = image_index
                img_mask = mask_image
            elif(image_type_index == 1):
                img_index = prev_image_index
                img_mask = mask_prev_image
            else:
                img_index = image_index
                img_mask = mask_diff
        else:
            if mask_calculated == False:
                image_date = dates[image_index]
                prev_image_date = dates[prev_image_index]
                
                for i, config in enumerate(configurations):
                    if(config.enable):
                        mask_index, image_for_index = calculate_mask_difference_of_index_images(data, image_date, prev_image_date, config)
                        mask_index = np.logical_and(mask == 1, mask_index == 1)
                        max_index_img = image_for_index.max().item()
                        min_index_img = image_for_index.min().item()
                        update_threshold_range(i, min_index_img, max_index_img)  
                        mask_diff = mask_index if mask_diff is None else mask_diff * mask_index
        
            if(image_type_index == 0):
                img_index = image_index
                img_mask = mask_diff
            elif(image_type_index == 1):
                img_index = prev_image_index
                img_mask = mask_diff
            
        if img_mask is not None:
            combined_mask = np.logical_and(mask == 1, img_mask == 1)
        
        transposed_rgb = visible_images[img_index].transpose((1, 2, 0))
        transposed_rgb = transposed_rgb * mask3 if mask is not None else transposed_rgb
        transposed_rgb = normalize_image_percentile(transposed_rgb)
              
        #show image
        fig, ax = plt.subplots(1, 1, figsize=(10, 6))
        ax.imshow(transposed_rgb)
        
        if(combined_mask is not None):
            mask_final, objects_kept, objects_removed = remove_small_objects(combined_mask, min_area)
        else:
            mask_final = None
            objects_kept = 0
            objects_removed = 0
            
        status_label.value = f"Objetos removidos: {objects_removed} | Objetos mantenidos: {objects_kept}"
        
        if mask_final is not None:
            selected_color = color_mapping[color_dropdown.value] + [alpha_slider.value]
            color_mask = np.zeros((mask_final.shape[0], mask_final.shape[1], 4))
            color_mask[mask_final == 1] = selected_color  # Aplicar color solo donde mask es 1
            ax.imshow(color_mask)
        
        ax.set_title(f'{dates[img_index].strftime("%Y-%m-%d")} RGB')
        ax.axis('off')
        plt.show()
    
    labels_dropdown = widgets.Dropdown(
        options=labels,
        description='Etiqueta:',
        disabled=False,
    )
     
    interactive_plot_image = widgets.interactive_output(plot_image_mask, {'trigger': trigger_widget, 'image_type': image_type_category, 'min_area': min_area_slider, 'radio_button_value': radiobuttons})
    
    def on_button_process_clicked(b):
        if mask_final is not None:
            tiff_output_dir = './data/output/labels/tiff'
            tiff_filename = f"{zone}_{region}"
            user_properties = {
                'zone': zone,
                'region': region,
            }
            tiff_output_filepath = save_tiff(data, dates[image_index], tiff_output_dir, tiff_filename, user_properties)
            save_mask_image(labels_dropdown.value, zone, region, dates[prev_image_index] if image_type_index == 1 else image_date, polygon, mask_final, labels_dropdown.index, './data/output',  dates[prev_image_index] if image_type_index == 2 else None)
        else:
            print("No hay máscara para guardar.")
    
    button_process = widgets.Button(description="Guardar Shapefile")
    button_process.on_click(on_button_process_clicked)
    
    display(widgets.HBox([image_type_category, color_dropdown, alpha_slider]), widgets.HBox([radiobuttons, min_area_slider, status_label]))

    # Mostrar widgets
    for config in configurations:
        widgets_list = create_widgets_for_configuration(config, on_value_change)
        display(widgets.HBox(widgets_list))
    
    display(interactive_plot_image)
    display(widgets.HBox([labels_dropdown, button_process]))  
        


In [87]:
def select_images_widget(zone, region, only_sunny_dates = True, labels = None):
    data, dates, polygon, properties = get_images_by_zone(zone, region, only_sunny_dates)
    visible_images = [(img / img.max()).astype(np.float32) for img in 
                  [data.sel(time=time)[["red", "green", "blue"]].to_array().values for time in data.time.values]]
    mask, mask3 = get_kml_polygon_masks(polygon, data['red'].shape[2], data['red'].shape[1])
    
    date_slider_image = widgets.IntSlider(
        value=1,
        min=1,
        max=len(dates) - 1,
        step=1,
        description='Imagen:',
        continuous_update=False
    )
    date_slider_image.layout.width = '450px'
    date_slider_image.index = dates
    
    def on_date_slider_image_change(change):
        date_slider_ref_image.value = min(date_slider_ref_image.value, date_slider_image.value - 1)
        
    date_slider_image.observe(on_date_slider_image_change, names='value')
    
    # Create buttons for easier navigation
    prev_image_button = widgets.Button(description="Previa")
    next_image_button = widgets.Button(description="Siguiente")

    date_slider_ref_image = widgets.IntSlider(
        value=0,
        min=0,
        max=len(dates) - 2,
        step=1,
        description='Referencia:',
        continuous_update=False
    )
    date_slider_ref_image.layout.width = '450px'
    date_slider_ref_image.index = dates
    
    def on_date_slider_ref_image_change(change):
        date_slider_image.value = max(date_slider_image.value, date_slider_ref_image.value + 1)
    
    date_slider_ref_image.observe(on_date_slider_ref_image_change, names='value')
    
    # Create buttons for easier navigation
    prev_ref_image_button = widgets.Button(description="Previa")
    next_ref_image_button = widgets.Button(description="Siguiente")
        
    process_button = widgets.Button(description="Seleccionar")   
    
    def on_prev_image_button_clicked(b):
        date_slider_image.value = max(date_slider_image.min, date_slider_image.value - 1)
    
    def on_next_image_button_clicked(b):
        date_slider_image.value = min(date_slider_image.max, date_slider_image.value + 1)
        
    def on_prev_ref_image_button_clicked(b):
        date_slider_ref_image.value = max(date_slider_ref_image.min, date_slider_ref_image.value - 1)
    
    def on_next_ref_image_button_clicked(b):
        date_slider_ref_image.value = min(date_slider_ref_image.max, date_slider_ref_image.value + 1)
    
    def plot_image(date_index):
        transposed_rgb = visible_images[date_index].transpose((1, 2, 0))
        transposed_rgb = transposed_rgb * mask3 if mask is not None else transposed_rgb
        transposed_rgb = normalize_image_percentile(transposed_rgb)
        
        #show image
        fig, ax = plt.subplots(1, 1, figsize=(10, 5))
        ax.imshow(transposed_rgb)
        ax.set_title(f'{dates[date_index].strftime("%Y-%m-%d")} RGB')
        ax.axis('off')
        plt.show()
    
    interactive_plot_image = widgets.interactive_output(plot_image, {'date_index': date_slider_image})
    interactive_plot_ref_image = widgets.interactive_output(plot_image, {'date_index': date_slider_ref_image})
    
    prev_image_button.on_click(on_prev_image_button_clicked)
    next_image_button.on_click(on_next_image_button_clicked)
    prev_ref_image_button.on_click(on_prev_ref_image_button_clicked)
    next_ref_image_button.on_click(on_next_ref_image_button_clicked)

    def on_process_button_clicked(b):
        process_button.disabled = True
        clear_output(wait=True)
        display(widgets.HBox([ 
            widgets.VBox([
                widgets.HBox([date_slider_ref_image, prev_ref_image_button, next_ref_image_button]),
                interactive_plot_ref_image
                ])
            ,
            widgets.VBox([
                widgets.HBox([date_slider_image, prev_image_button, next_image_button]),
                interactive_plot_image
            ])
        ]),
        process_button)
        process_images_mask(zone, region, data, dates, polygon, properties, visible_images, 3, date_slider_image.value, date_slider_ref_image.value, labels)
        process_button.disabled = False
      
    process_button.on_click(on_process_button_clicked)

    display(widgets.HBox([ 
        widgets.VBox([
            widgets.HBox([date_slider_ref_image, prev_ref_image_button, next_ref_image_button]),
            interactive_plot_ref_image
            ])
        ,
        widgets.VBox([
            widgets.HBox([date_slider_image, prev_image_button, next_image_button]),
            interactive_plot_image
        ])
    ]),
    process_button)

In [88]:
import ipywidgets as widgets
from IPython.display import clear_output, display

def select_region_widget(select_callback, labels):
    zone_widget = widgets.Dropdown(
        options=zone_dict.keys(),
        description='Zone:',
        disabled=False
    )

    region_widget = widgets.Dropdown(
        options=zone_dict[zone_widget.value].keys(),
        description='Region:',
        disabled=False
    )

    def update_region_options(change):
        region_widget.options = zone_dict[change['new']].keys()

    zone_widget.observe(update_region_options, names='value')

    only_sunny_dates = widgets.Checkbox(value=True, description='Descartar imágenes con nubes')

    # Botón para procesar
    process_button = widgets.Button(description="Process")
    clear_button = widgets.Button(description="Clear")
    status_label = widgets.Label(value="")

    # Función para manejar evento del botón
    def on_process_button_clicked(b):  
        process_button.disabled = True
        zone_widget.disabled = True
        region_widget.disabled = True
        only_sunny_dates.disabled = True
        status_label.value = "Procesando..."
        select_callback(zone_widget.value, region_widget.value, only_sunny_dates.value, labels)
        status_label.value = "Procesamiento completado"
        clear_button.disabled = False
        
    process_button.on_click(on_process_button_clicked)
    
    def display_region_widget(clear_output_request = False):
        if clear_output_request:
            clear_output(wait=True)
        process_button.disabled = False
        zone_widget.disabled = False
        region_widget.disabled = False
        only_sunny_dates.disabled = False
        clear_button.disabled = True
        status_label.value = ""
        display(widgets.HBox([zone_widget, region_widget, only_sunny_dates]), widgets.HBox([process_button, clear_button]), status_label)
    
    
    def on_clear_button_clicked(b):
        display_region_widget(True)
        
    clear_button.on_click(on_clear_button_clicked)
    
    # Show widgets
    display_region_widget()

In [89]:
labels = ['Talado', 'Incendio', 'Inundacion', 'Bosque', 'Otra vegetacion', 'Plaga']
select_region_widget(select_images_widget, labels)

HBox(children=(VBox(children=(HBox(children=(IntSlider(value=0, continuous_update=False, description='Referenc…

Button(description='Seleccionar', disabled=True, style=ButtonStyle())

HBox(children=(Dropdown(description='Imagen:', options=('Imagen final', 'Imagen de referencia', 'Diferencia'),…

HBox(children=(RadioButtons(description='Método:', options=('Imagen individual', 'Calcular máscaras primero', …

HBox(children=(Checkbox(value=False, description='Habilitar'), Dropdown(description='Tipo de Índice', options=…

HBox(children=(Checkbox(value=False, description='Habilitar'), Dropdown(description='Tipo de Índice', options=…

HBox(children=(Checkbox(value=False, description='Habilitar'), Dropdown(description='Tipo de Índice', options=…

Output()

HBox(children=(Dropdown(description='Etiqueta:', options=('Talado', 'Incendio', 'Inundacion', 'Bosque', 'Otra …

2023-10-19 09:23:40 - No hay máscara para guardar.


KeyError: "not all values found in index 'time'. Try setting the `method` keyword argument (example: method='nearest')."

In [45]:
import ipywidgets as widgets
from IPython.display import display

def process_region_mask(zone, region, image, num_index_to_process = 3, prev_image = None, comparison = False):
    lines = []  # Lista para almacenar las referencias a los widgets de cada línea

    index_categories = {
        "Vegetativo": [e.name for e in VegetationIndex],
        "Fuego": [e.name for e in FireIndex],
        "Agua": [e.name for e in WaterIndex]
    }

    def create_line():
        enable_widget = widgets.Checkbox(value=True, description='Habilitar')
        
        type_selector = widgets.Dropdown(options=["Vegetativo", "Fuego", "Agua"], description='Tipo de Índice:')
        index_selector = widgets.Dropdown(options=index_categories["Vegetativo"], description='Índice Específico:')
        comparison_selector = widgets.Dropdown(options=["Mayor que", "Menor que"], description='Comparación:')
        number_input = widgets.BoundedFloatText(value=0, min=0, max=1000, step=0.1, description='Número:')
        
        def update_indices(change):
            index_selector.options = index_categories[change['new']]
            
        type_selector.observe(update_indices, names='value')
        
        def toggle_widgets(change):
            state = change['new']
            type_selector.disabled = not state
            index_selector.disabled = not state
            comparison_selector.disabled = not state
            number_input.disabled = not state
        
        enable_widget.observe(toggle_widgets, names='value')
        
        def on_value_change(change):
            line_index = lines.index(line_widgets)
            print(f"Valor cambiado en la línea {line_index+1}: {change['owner'].description} ahora es {change['new']}")

        enable_widget.observe(on_value_change, names='value')
        type_selector.observe(on_value_change, names='value')
        index_selector.observe(on_value_change, names='value')
        comparison_selector.observe(on_value_change, names='value')
        number_input.observe(on_value_change, names='value')
        
        line_widgets = {
            "enable": enable_widget,
            "type_selector": type_selector,
            "index_selector": index_selector,
            "comparison_selector": comparison_selector,
            "number_input": number_input
        }
        
        lines.append(line_widgets)
        
        return widgets.HBox([enable_widget, type_selector, index_selector, comparison_selector, number_input])

    # Botón para procesar
    process_button = widgets.Button(description="Procesar")
    process_button.on_click(on_process_button_clicked)

    # Botón para guardar
    save_button = widgets.Button(description="Guardar Máscara")

    # Función para manejar evento del botón de procesar
    def on_process_button_clicked(b):  
        for line in lines:
            if line["enable"].value:  
                # Puedes procesar cada línea de widgets aquí como lo necesites
                
                print(f"Procesando línea con tipo de índice: {line['type_selector'].value}")

    # Función para manejar evento del botón de guardar
    def on_save_button_clicked(b):  
        print("Guardando máscara...")

    save_button.on_click(on_save_button_clicked)

    # Mostrar widgets
    display(process_button, save_button)
    for _ in range(3):
        display(create_line())

In [46]:
zone_widget = widgets.Dropdown(
    options=zone_dict.keys(),
    description='Zone:',
    disabled=False
)

region_widget = widgets.Dropdown(
    options=zone_dict[zone_widget.value].keys(),
    description='Region:',
    disabled=False
)

def update_region_options(change):
    region_widget.options = zone_dict[change['new']].keys()

zone_widget.observe(update_region_options, names='value')

# Botón para procesar
process_button = widgets.Button(description="Process")
status_label = widgets.Label(value="")

# Función para manejar evento del botón
def on_process_button_clicked(b):  
    process_button.disabled = True
    status_label.value = "Procesando..."
    process_region(zone_widget.value, region_widget.value)
    process_button.disabled = False
    status_label.value = "Procesamiento completado"
    
process_button.on_click(on_process_button_clicked)

# Mostrar widgets
display(widgets.HBox([zone_widget, region_widget, process_button]), status_label)

HBox(children=(Dropdown(description='Zone:', options=('Bosques Bio Bio', 'Incendios', 'Bosques Arauco'), value…

Label(value='')

2023-10-19 09:25:53 - [33mUnique dates with NaN values across all bands: [][0m


VBox(children=(HBox(children=(Button(description='Previous', style=ButtonStyle()), Button(description='Next', …