In [None]:
import sys
from pathlib import Path
sys.path.append(str(Path().absolute().parent))

In [None]:
import ee 
import geemap

ee.Initialize(project="thurgau-irrigation")

In [None]:
from src.data_processing.downscaling import resample_image, Downscaler
from src.data_processing.sentinel_preprocessing import load_sentinel2_data

from utils.date_utils import (
    set_to_first_of_month,
    print_collection_dates,
    create_centered_date_ranges,
)
from utils.ee_utils import harmonized_ts, export_image_to_asset, back_to_int, back_to_float
from utils.harmonic_regressor import HarmonicRegressor

from typing import List, Callable, Tuple

---

## Constants

In [None]:
PATH_TO_AOI = "projects/thurgau-irrigation/assets/Thurgau/thrugau_borders_2024"
PATH_TO_ET_PRODUCT = "projects/thurgau-irrigation/assets/ET_products/Thurgau/WaPOR_300m"
YEAR = 2023
USE_SR = False if YEAR < 2019 else True
BUFFER_DAYS = 5
BAND_TO_RESAMPLE = "ET"
BANDS_TO_HARMONIZE = ["B3", "B4", "B8", "B11", "B12"]
AGGREGATION_OPTIONS = {
    "agg_type": "mosaic",
    "mosaic_type": "least_cloudy",
    "band_name": "NDVI",
}
INDEXES_FOR_HARMONIZATION = ["NDVI", "NDWI", "NDBI"]
INDEPENDENT_BANDS = ["gap_filled_NDVI", "gap_filled_NDBI", "gap_filled_NDWI"]
DEPENDENT_BAND = ["ET"]
NUMBER_OF_IMAGES = 36
TEMPORAL_RESOLUTION = "dekadal"
DOWNSCALED_ASSET_PATH = f"projects/thurgau-irrigation/assets/ET_products/Thurgau/WaPOR_10m_{YEAR}"

## 1. Load ET product

In [None]:
aoi_feature_collection = ee.FeatureCollection(PATH_TO_AOI)
aoi_geometry = aoi_feature_collection.geometry().simplify(500)

aoi = aoi_geometry.buffer(100)

In [None]:
et_coarse_collection = (
    ee.ImageCollection(PATH_TO_ET_PRODUCT)
    .filterBounds(aoi)
    .filterDate(f"{YEAR}-01-01", f"{YEAR}-12-31")
    .sort("system:time_start")
)

et_coarse_list = et_coarse_collection.toList(et_coarse_collection.size())

# print_collection_dates(et_coarse_collection)

## 2. Load Sentinel-2 data

In [None]:
s2_collection = load_sentinel2_data((YEAR, YEAR), aoi=aoi, use_SR=USE_SR)

time_intervals = create_centered_date_ranges(et_coarse_list, buffer_days=BUFFER_DAYS)

s2_harmonized = harmonized_ts(
    masked_collection=s2_collection,
    band_list=BANDS_TO_HARMONIZE,
    time_intervals=time_intervals,
    options=AGGREGATION_OPTIONS,
)

# s2_harmonized.first().bandNames().getInfo()

### 2.1 Prepare the Sentinel-2 data for harmonization

In [None]:
def add_temporal_bands(collection: ee.ImageCollection) -> ee.ImageCollection:
    """Add temporal bands to each image in the collection."""
    def _add_bands(image: ee.Image) -> ee.Image:
        date = ee.Date(image.get('system:time_start'))
        years = date.difference(ee.Date('1970-01-01'), 'year')
        
        projection = image.select([0]).projection()
        time_band = ee.Image(years).float().rename('t')
        constant_band = ee.Image.constant(1).rename('constant')
        
        return image.addBands([
            time_band.setDefaultProjection(projection),
            constant_band.setDefaultProjection(projection)
        ])
    
    return collection.map(_add_bands)

s2_harmonized = add_temporal_bands(s2_harmonized)
# s2_harmonized.first().bandNames().getInfo()

In [None]:
def compute_vegetation_indexes(image: ee.Image) -> ee.Image:
    """
    Compute vegetation indexes for a given image

    Args:
        image (ee.Image): The image to compute the vegetation indexes for

    Returns:
        ee.Image: The input image with the vegetation indexes

    """
    ndvi = image.normalizedDifference(["B8", "B4"]).rename("NDVI")
    ndwi = image.normalizedDifference(["B3", "B8"]).rename("NDWI")
    ndbi = image.normalizedDifference(["B11", "B8"]).rename("NDBI")
    return image.addBands(ndvi).addBands(ndwi).addBands(ndbi)

s2_harmonized_w_vegetation_indexes = s2_harmonized.map(compute_vegetation_indexes)

In [None]:
# s2_harmonized_w_vegetation_indexes.first().bandNames().getInfo()    

### 2.2 Fit an Harmonic Regressor to the Sentinel-2 data

In [None]:
s2_harmonized_gaps_filled = s2_harmonized_w_vegetation_indexes

for index in INDEXES_FOR_HARMONIZATION:
    regressor = HarmonicRegressor(
        omega=1, max_harmonic_order=2, band_to_harmonize=index
    )

    regressor.fit(s2_harmonized_w_vegetation_indexes)
    fitted_collection = regressor.predict(s2_harmonized_w_vegetation_indexes)

    fitted_collection = fitted_collection.map(
        lambda img: img.select(["fitted"]).rename(f"fitted_{index}")
    )

    s2_harmonized_gaps_filled = s2_harmonized_gaps_filled.map(
        lambda img: img.addBands(
            fitted_collection.filterDate(img.date()).first().select([f"fitted_{index}"])
        )
    )

In [None]:
# s2_harmonized_gaps_filled.first().bandNames().getInfo()

### 2.3 Use the fitted harmonic bands to fill the gaps in the original bands

In [None]:
def fill_gaps(
    img: ee.Image, source_band: str, fill_band: str, output_name: str
) -> ee.Image:
    """Fill gaps in a band with values from another band.

    Args:
        img (ee.Image): Input image containing both bands
        source_band (str): Name of band containing gaps to fill
        fill_band (str): Name of band to use for filling gaps
        output_name (str): Name for the output gap-filled band

    Returns:
        ee.Image: Image with gap-filled band
    """
    scale = img.projection().nominalScale()
    projection = img.projection()
    # Create mask where the source band is invalid (gaps)
    gap_mask = img.select(source_band).mask().Not()

    # Get the source band and fill band
    source = img.select(source_band)
    fill = img.select(fill_band)

    # Fill gaps: use source band where available, fill band where there are gaps
    filled = source.unmask().where(gap_mask, fill).rename(output_name)
    filled = filled.setDefaultProjection(projection).set("scale", scale)

    return filled


def apply_gap_filling(img: ee.Image, indexes: List[str]) -> ee.Image:
    """Apply gap filling to multiple bands.

    Args:
        img (ee.Image): Input image
        indexes (list[str]): List of index names to process (e.g., ['NDVI', 'NDWI', 'NDBI'])

    Returns:
        ee.Image: Original image with added gap-filled bands
    """
    # Start with the original image
    result = img

    # Add each gap-filled band one at a time
    for index in indexes:
        filled_band = fill_gaps(
            img=img,
            source_band=index,
            fill_band=f"fitted_{index}",
            output_name=f"gap_filled_{index}",
        )
        result = result.addBands(filled_band)

    return result


# Apply gap filling to the collection
def process_collection(
    collection: ee.ImageCollection, indexes: List[str]
) -> ee.ImageCollection:
    """Process entire collection by applying gap filling to each image.

    Args:
        collection (ee.ImageCollection): Input collection
        indexes (List[str]): List of index names to process

    Returns:
        ee.ImageCollection: Processed collection with gap-filled bands
    """
    return collection.map(lambda img: apply_gap_filling(img, indexes))


s2_harmonized_gaps_filled = process_collection(s2_harmonized_gaps_filled, INDEXES_FOR_HARMONIZATION)

# s2_harmonized_gaps_filled.first().bandNames().getInfo()

## 3. Downscale the ET product to 10m resolution

In [None]:
def process_and_export_downscaled_ET(
    downscaler: Downscaler,
    s2_indices: ee.ImageCollection,
    independent_vars: ee.ImageCollection,
    dependent_vars: ee.ImageCollection,
    aoi: ee.Geometry,
    year: str,
    scale_coarse: float,
    asset_path: str,
    scale_fine: float = 10,
    time_steps: int = 36,
    time_step_type: str = "dekadal",
) -> List[ee.batch.Task]:
    """
    Process and export downscaled WaPOR ET images to Earth Engine assets.

    Args:
        downscaler (Downscaler): The Downscaler object used to downscale the images.
        s2_indices (ee.ImageCollection): The Sentinel-2 indices image collection.
        independent_vars (ee.ImageCollection): The resampled independent variables image collection.
        dependent_vars (ee.ImageCollection): The dependent variables image collection.
        aoi (ee.Geometry): The area of interest geometry.
        year (str): The year for which the images are processed.
        scale_coarse (float): The scale of the images before downscaling.
        scale_fine (float): The scale of the images after downscaling.
        time_steps (int): Number of time steps in the year (36 for dekadal, 12 for monthly).
        time_step_type (str): Type of time step ("dekadal" or "monthly").

    Returns:
        List[ee.batch.Task]: A list of export tasks for the downscaled images.
    """
    s2_indices_list = s2_indices.toList(s2_indices.size())
    independent_vars_list = independent_vars.toList(independent_vars.size())
    dependent_vars_list = dependent_vars.toList(dependent_vars.size())

    tasks = []
    for i in range(time_steps):
        if time_step_type == "dekadal":
            j = i % 3 + 1
            m = i // 3 + 1
            date = ee.Date.fromYMD(int(year), m, j * 10 - 9)
            time_step_name = f"{m:02d}_D{j}"
        elif time_step_type == "monthly":
            m = i + 1
            date = ee.Date.fromYMD(int(year), m, 1)
            time_step_name = f"{m:02d}"
        else:
            raise ValueError("time_step_type must be either 'dekadal' or 'monthly'")

        s2_index = ee.Image(s2_indices_list.get(i))
        ind_vars = ee.Image(independent_vars_list.get(i))
        dep_vars = ee.Image(dependent_vars_list.get(i))

        # Perform downscaling
        et_image_downscaled = downscaler.downscale(
            coarse_independent_vars=ind_vars,
            coarse_dependent_var=dep_vars,
            fine_independent_vars=s2_index,
            geometry=aoi,
            resolution=scale_coarse,
        )

        # Post-process the downscaled image
        et_image_downscaled = back_to_int(et_image_downscaled, 100)

        task_name = f"Downscaled_ET_gap_filled_{time_step_type}_{year}_{time_step_name}"
        asset_id = f"{asset_path}/{task_name}"

        task = export_image_to_asset(
            et_image_downscaled,
            asset_id,
            task_name,
            year,
            aoi,
            crs="EPSG:32632",
            scale=scale_fine,
        )
        tasks.append(task)

In [None]:
scale = et_coarse_collection.first().projection().nominalScale().getInfo()

s2_indices = s2_harmonized_gaps_filled.select(INDEPENDENT_BANDS)
independent_vars = s2_indices.map(
    lambda img: resample_image(img, scale, INDEPENDENT_BANDS)
)

dependent_vars = et_coarse_collection.select(DEPENDENT_BAND)


# Initialize the Downscaler
downscaler = Downscaler(
    independent_bands=INDEPENDENT_BANDS,  
    dependent_band=DEPENDENT_BAND[0],
)


tasks = process_and_export_downscaled_ET(
    downscaler=downscaler,
    s2_indices=s2_indices,
    independent_vars=independent_vars,  
    dependent_vars=dependent_vars,
    aoi=aoi,
    year=YEAR,
    scale_coarse=scale,
    scale_fine=10,
    time_steps=NUMBER_OF_IMAGES,
    time_step_type=TEMPORAL_RESOLUTION,
    asset_path=DOWNSCALED_ASSET_PATH,
)

In [None]:
independent_vars.first().bandNames().getInfo()

## Sanity checks
### 1. Check the Sentinel-2 harmonization

In [None]:
# Map = geemap.Map()

# sentinel_vis_params = {
#     "bands": ["NDVI"],
#     "min": 0,
#     "max": 1,
#     "palette": ["white", "green"],
# }

# fitted_vis_params = {
#     "bands": ["fitted_NDVI"],
#     "min": 0,
#     "max": 1,
#     "palette": ["white", "green"],
# }

# gap_filled_vis_params = {
#     "bands": ["gap_filled_NDVI"],
#     "min": 0,
#     "max": 1,
#     "palette": ["white", "green"],
# }


# s2_image = ee.Image(s2_harmonized_gaps_filled.filterBounds(aoi).toList(36).get(-3))
# s2_image_resampled = ee.Image(independent_vars.toList(36).get(0))

# Map.addLayer(s2_image, sentinel_vis_params, "Sentinel 2")
# # # Map.addLayer(s2_image, fitted_vis_params, "Sentinel 2 Fitted")
# Map.addLayer(s2_image, gap_filled_vis_params, "Sentinel 2 Gap Filled")
# # Map.addLayer(s2_image_resampled, gap_filled_vis_params, "Sentinel 2 Resampled")


# Map.centerObject(aoi, 12)
# Map

In [None]:
# print_collection_dates(s2_harmonized_gaps_filled)

### 2. Compute the ET downscaled images

In [None]:
# collection = ee.ImageCollection("projects/thurgau-irrigation/assets/ET_products/Thurgau/WaPOR_10m_2018").map(lambda img: back_to_float(img, 100))

# print_collection_dates(collection)

In [None]:
# collection_list = collection.toList(collection.size())


In [None]:
# Map = geemap.Map()

# et_vis_params = {
#     "bands": ["downscaled"],
#     "min": 0,
#     "max": 5,
#     "palette": ["white", "lightblue", "blue", "green", "yellow", "orange", "red", "darkred"],
# }

# et_image = ee.Image(collection_list.get(14))

# Map.addLayer(et_image, et_vis_params, "ET")

# Map.centerObject(aoi, 12)

# Map

In [None]:
# sentinel_collection = ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED").filterBounds(aoi).filterDate("2019-5-01", "2019-05-31").sort("system:time_start")

# # Get list of unique dates and merge same-day images
# def merge_daily_images(collection):
#     # Get list of unique dates
#     dates = collection.aggregate_array('system:time_start').distinct()
    
#     def merge_images_for_date(date):
#         # Get images for this date
#         daily_images = collection.filter(ee.Filter.eq('system:time_start', date))
#         # Merge images (mosaic takes the first non-masked pixel)
#         merged = daily_images.mosaic()
#         # Set the date property
#         return merged.set('system:time_start', date)
    
#     # Map merge function over dates
#     merged_collection = ee.ImageCollection(dates.map(lambda date: merge_images_for_date(date)))
#     return merged_collection

# # Apply the merging
# merged_sentinel = merge_daily_images(sentinel_collection)

# # Visualize first merged image
# merged_image = ee.Image(merged_sentinel.first())

# # Print number of images before and after
# print(f"Original collection size: {sentinel_collection.size().getInfo()}")
# print(f"Merged collection size: {merged_sentinel.size().getInfo()}")

In [None]:
# Map = geemap.Map()

# sentinel_vis_params = {
#     "bands": ["B4", "B3", "B2"],
#     "min": 0,
#     "max": 2000,
# }

# sentinel_image = ee.Image(sentinel_collection.first())

# Map.addLayer(sentinel_image, sentinel_vis_params, "Sentinel 2")
# Map.addLayer(merged_image, sentinel_vis_params, "Merged Sentinel 2")

# Map.centerObject(aoi, 12)

# Map