In [None]:
import logging
import os
from pathlib import Path
from typing import Callable
from datacube.testutils.io import rio_slurp_xarray

import click
import datacube
import fsspec
import geopandas as gpd
import pandas as pd
import numpy as np

from skimage import measure, morphology
from skimage.segmentation import watershed
import scipy.ndimage as ndi

from deafrica_tools.spatial import xr_vectorize

from deafrica_waterbodies.cli.logs import logging_setup
from deafrica_waterbodies.io import (
    check_dir_exists,
    check_file_exists,
    check_if_s3_uri,
    find_parquet_files,
)
from deafrica_waterbodies.make_polygons import (
    check_wetness_thresholds,
    get_polygons_from_tile_with_land_sea_mask_filtering,
    merge_polygons_at_tile_boundaries
)
from deafrica_waterbodies.tiling import (
    filter_tiles,
    get_tiles_ids,
    tile_wofs_ls_summary_alltime,
)

In [None]:
import os

# These are the default AWS configurations for the Analysis Sandbox.
# that are set in the environmnet variables.
aws_default_config = {
    # "AWS_NO_SIGN_REQUEST": "YES",
    "AWS_SECRET_ACCESS_KEY": "fake",
    "AWS_ACCESS_KEY_ID": "fake",
}

# To access public bucket, need to remove the AWS credentials in
# the environment variables or the following error will occur.
# PermissionError: The AWS Access Key Id you provided does not exist in our records.

for key in aws_default_config.keys():
    if key in os.environ:
        del os.environ[key]

In [None]:
verbose = 1

aoi_vector_file = None
tile_size_factor = 4
num_workers = 16

primary_threshold: float = 0.1
secondary_threshold: float = 0.05
minimum_valid_observations: int = 128
output_directory = "s3://deafrica-waterbodies-dev/test_out_dir/raster_processing/continental"
overwrite = False
land_sea_mask_fp = "data/af_msk_3s.tif"

In [None]:
import xarray as xr
def filter_hydrosheds_land_mask(hydrosheds_land_mask: xr.DataArray) -> xr.DataArray:
    """
    Function to filter the HydroSHEDs Land Mask into a boolean mask.
    """
    # Indicator values: 1 = land, 2 = ocean sink, 3 = inland sink, 255 is no data.
    boolean_mask = (hydrosheds_land_mask != 255) & (hydrosheds_land_mask != 2)
    return boolean_mask

In [None]:
# Set up logger.
logging_setup(verbose=verbose)
_log = logging.getLogger(__name__)

In [None]:
# Support pathlib Paths.
if aoi_vector_file is not None:
    aoi_vector_file = str(aoi_vector_file)
output_directory = str(output_directory)

In [None]:
# Parameters to use when loading datasets.
dask_chunks = {"x": 3200, "y": 3200, "time": 1}

In [None]:
# Load the area of interest as a GeoDataFrame.
if aoi_vector_file is not None:
    try:
        aoi_gdf = gpd.read_file(aoi_vector_file)
    except Exception as error:
        _log.exception(f"Could not read the file {aoi_vector_file}")
        raise error
else:
    aoi_gdf = None

In [None]:
# Tile the wofs_ls_summary_alltime product.
tiles, grid_workflow = tile_wofs_ls_summary_alltime(tile_size_factor)

In [None]:
# Filter the tiles to the area of interest.
filtered_tile_ids = filter_tiles(tiles, aoi_gdf, num_workers)
filtered_tiles = {k: v for k, v in tiles.items() if k in filtered_tile_ids}

In [None]:
# Directory to write generated waterbody polygons to.
polygons_from_thresholds_dir = os.path.join(output_directory, "polygons_from_thresholds")

In [None]:
# Set the filesystem to use.
if check_if_s3_uri(polygons_from_thresholds_dir):
    fs = fsspec.filesystem("s3")
else:
    fs = fsspec.filesystem("file")

In [None]:
# Check if the directory exists. If it does not, create it.
if not check_dir_exists(polygons_from_thresholds_dir):
    fs.mkdirs(polygons_from_thresholds_dir, exist_ok=True)
    _log.info(f"Created directory {polygons_from_thresholds_dir}")

In [None]:
# Check if the wetness thresholds have been set correctly.
minimum_wet_thresholds = [secondary_threshold, primary_threshold]
_log.info(check_wetness_thresholds(minimum_wet_thresholds))

In [None]:
def load_wofs_frequency(
    tile: tuple[tuple[int, int], datacube.api.grid_workflow.Tile],
    grid_workflow: datacube.api.GridWorkflow,
    dask_chunks: dict[str, int] = {"x": 3200, "y": 3200, "time": 1},
    resolution: tuple[int, int] = (-30, 30),
    output_crs: str = "EPSG:6933",
    min_valid_observations: int = 128,
    primary_threshold: float = 0.1,
    secondary_threshold: float = 0.05,
    land_sea_mask_fp: str | Path = "",
    resampling_method: str = "bilinear",
    filter_land_sea_mask: Callable = filter_hydrosheds_land_mask,
):
    # Set up the primary and secondary thresholds.
    minimum_wet_thresholds = [secondary_threshold, primary_threshold]
    
    # Get the tile id and tile object.
    tile_id = tile[0]
    tile_object = tile[1]
    
    # Generate the waterbody polygons using the primary and secondary thresholds,
    # from the tile.
    try:
        _log.info(f"Generating water body polygons for tile {tile_id}")

        # Load the data for the tile.
        wofs_alltime_summary = grid_workflow.load(tile_object, dask_chunks=dask_chunks).squeeze()

        # Load the land sea mask.
        if land_sea_mask_fp:
            land_sea_mask = rio_slurp_xarray(
                fname=land_sea_mask_fp,
                gbox=wofs_alltime_summary.geobox,
                resampling=resampling_method,
            )

            # Filter the land sea mask.
            boolean_land_sea_mask = filter_land_sea_mask(land_sea_mask)

            # Mask the WOfS All-Time Summary dataset using the boolean land sea mask.
            wofs_alltime_summary = wofs_alltime_summary.where(boolean_land_sea_mask)

        # Set the no-data values to nan.
        # Masking here is done using the frequency measurement because for multiple
        # areas NaN values are present in the frequency measurement but the
        # no data value -999 is not present in the count_clear and
        # count_wet measurements.
        # Note: it seems some pixels with NaN values in the frequency measurement
        # have a value of zero in the count_clear and/or the count_wet measurements.
        wofs_alltime_summary = wofs_alltime_summary.where(~np.isnan(wofs_alltime_summary.frequency))

        # Mask pixels not observed at least min_valid_observations times.
        wofs_alltime_summary_valid_clear_count = (
            wofs_alltime_summary.count_clear >= min_valid_observations
        )
        
        # Get detection and extent thresholds
        detection = wofs_alltime_summary.frequency > primary_threshold
        valid_detection = (detection > 0) & wofs_alltime_summary_valid_clear_count
        
        extent = wofs_alltime_summary.frequency > secondary_threshold
        valid_extent = (extent > 0) & wofs_alltime_summary_valid_clear_count
            
    except Exception as error:
        _log.exception(
            f"\nDataset {str(tile_id)} did not run. \n"
            "This is probably because there are no waterbodies present in this scene."
        )
        _log.exception(error)

    return valid_detection, valid_extent

def remove_small_waterbodies(waterbody_raster, min_size=6):
    
    waterbodies_labelled = morphology.label(waterbody_raster, background=0)
    waterbodies_small_removed = morphology.remove_small_objects(waterbodies_labelled, min_size=min_size, connectivity=1)
    
    return waterbodies_small_removed

# Need a step to only segment the largest objects
# only segment bigger than minsize
def select_waterbodies_for_segmentation(waterbodies_labelled, min_size=1000):
    props = measure.regionprops(waterbodies_labelled)
    
    labels_to_keep = []
    for region_prop in props:
        count = region_prop.num_pixels
        label = region_prop.label
        
        if (count > min_size):
            labels_to_keep.append(label)
            
    segment_image = np.where(np.isin(waterbodies_labelled, labels_to_keep), 1, 0)
    
    return segment_image

def generate_segmentation_markers(marker_source, erosion_radius=1, min_size=100):
    
    markers = morphology.erosion(marker_source, footprint=morphology.disk(radius=erosion_radius))
    markers_relabelled = morphology.label(markers, background=0)
    
    markers_acceptable_size = morphology.remove_small_objects(markers_relabelled, min_size=min_size, connectivity=1)
    
    return markers_acceptable_size

def run_watershed(waterbodies_for_segementation, segmentation_markers):
    
    distance = ndi.distance_transform_edt(waterbodies_for_segementation)
    segmented = watershed(-distance, segmentation_markers, mask=waterbodies_for_segementation)
    
    return segmented

def confirm_extent_contains_detection(extent, detection):
    
    def sum_intensity(regionmask, intensity_image):
        return np.sum(intensity_image[regionmask])
    
    props = measure.regionprops(extent, intensity_image=detection, extra_properties=(sum_intensity,))
    
    labels_to_keep = []
    for region_prop in props:
        detection_count = region_prop.sum_intensity
        label = region_prop.label
        
        if (detection_count > 0):
            labels_to_keep.append(label)
            
    extent_keep = np.where(np.isin(extent, labels_to_keep), extent, 0)
    
    return extent_keep

def process_raster_polygons(
    tile: tuple[tuple[int, int], datacube.api.grid_workflow.Tile],
    grid_workflow: datacube.api.GridWorkflow,
    dask_chunks: dict[str, int] = {"x": 3200, "y": 3200, "time": 1},
    resolution: tuple[int, int] = (-30, 30),
    output_crs: str = "EPSG:6933",
    min_valid_observations: int = 128,
    primary_threshold: float = 0.1,
    secondary_threshold: float = 0.05,
    dc: datacube.Datacube | None = None,
    land_sea_mask_fp: str | Path = "",
    filter_land_sea_mask: Callable = filter_hydrosheds_land_mask,
): 
    
    xr_detection, xr_extent = load_wofs_frequency(tile, grid_workflow, land_sea_mask_fp=land_sea_mask_fp, filter_land_sea_mask=filter_hydrosheds_land_mask)

    # Convert to numpy arrays for image processing
    np_detection = xr_detection.to_numpy().astype(int)
    np_extent = xr_extent.to_numpy().astype(int)

    # Remove any objects of size 5 or less, as measured by connectivity=1
    np_extent_small_removed = remove_small_waterbodies(np_extent, min_size=6)

    # Identify waterbodies to apply segmentation to 
    np_extent_segment = select_waterbodies_for_segmentation(np_extent_small_removed, min_size=1000)
    np_extent_nosegment = np.where(np_extent_segment>0, 0, np_extent_small_removed)

    # Create watershed segementation markers by taking the detection threshold pixels and eroding them by 1
    # Includes removal of any markers smaller than 100 pixels
    segmentation_markers = generate_segmentation_markers(np_detection, erosion_radius=1, min_size=100)

    # Run segmentation 
    np_segmented_extent = run_watershed(np_extent_segment, segmentation_markers)

    # Combine segmented and non segmented back together
    np_combined_extent = np.where(np_segmented_extent > 0, np_segmented_extent, np_extent_nosegment)

    # Only keep extent areas that contain a detection pixel
    np_combined_extent_contains_detection = confirm_extent_contains_detection(np_combined_extent, np_detection)

    # Relabel and remove small objects
    np_combined_clean_label = remove_small_waterbodies(np_combined_extent_contains_detection, min_size=6)

    # Convert back to xarray
    xr_combined_extent = xr.DataArray(np_combined_clean_label, coords=xr_extent.coords, dims=xr_extent.dims, attrs=xr_extent.attrs)

    # Vectorize
    vector_combined_extent = xr_vectorize(xr_combined_extent, crs=output_crs, mask=xr_combined_extent.values>0)
    
    return vector_combined_extent

In [None]:
# Generate the first set of primary and secondary threhsold polygons for each of the tiles.
for tile in filtered_tiles.items():
    tile_id = tile[0]
    
    raster_polygons_fp = os.path.join(polygons_from_thresholds_dir, f"{tile_id[0]}_{tile_id[1]}_raster_polygons.parquet")

    if not overwrite:
        _log.info(f"Checking existence of {raster_polygons_fp}")
        exists = check_file_exists(raster_polygons_fp)

    if overwrite or not exists:
        
        try: 
            raster_polgyons = process_raster_polygons(tile, grid_workflow, land_sea_mask_fp=land_sea_mask_fp, filter_land_sea_mask=filter_hydrosheds_land_mask)
            
             # Write the polygons to parquet files.
            raster_polgyons.to_parquet(raster_polygons_fp)
            
        except Exception as error:
            _log.exception(
                f"\nDataset {str(tile_id)} did not run. \n"
            )
            _log.exception(error)


In [None]:
# Get the extents for each tile.
crs = grid_workflow.grid_spec.crs
filtered_tiles_extents_geoms = [tile[1].geobox.extent.geom for tile in filtered_tiles.items()]
filtered_tiles_extents_gdf = gpd.GeoDataFrame(geometry=filtered_tiles_extents_geoms, crs=crs)

In [None]:
# Find all parquet files for the primary threshold.
raster_polygon_paths = find_parquet_files(path=polygons_from_thresholds_dir, pattern=".*raster_polygons.*")
_log.info(f"Found {len(raster_polygon_paths)} parquet files for the raster polygons.")

# Load all the primary threshold polygons into a single GeoDataFrame.
_log.info("Loading the raster polygons parquet files..")
raster_polygon_polygons_list = []
for path in raster_polygon_paths:
    gdf = gpd.read_parquet(path)
    raster_polygon_polygons_list.append(gdf)

raster_polygons = pd.concat(raster_polygon_polygons_list, ignore_index=True)
_log.info(f"Found {len(raster_polygons)} raster polygons.")

_log.info("Merging raster waterbody polygons located at tile boundaries...")
raster_polygons_merged = merge_polygons_at_tile_boundaries(
    raster_polygons, filtered_tiles_extents_gdf
)
_log.info(f"Raster polygons count {len(raster_polygons_merged)}.")

_log.info("Writing raster polygons merged at tile boundaries to disk..")
raster_polygons_output_fp = os.path.join(
    output_directory, "raster_polygons_merged_at_tile_boundaries.parquet"
)

raster_polygons_merged.to_parquet(raster_polygons_output_fp)
_log.info(f"Polygons written to {raster_polygons_output_fp}")

In [None]:
# # Find all parquet files for the primary threshold.
# primary_threshold_polygons_paths = find_parquet_files(path=polygons_from_thresholds_dir, pattern=".*primary.*")
# _log.info(f"Found {len(primary_threshold_polygons_paths)} parquet files for the primary threshold polygons.")

In [None]:
# # Load all the primary threshold polygons into a single GeoDataFrame.
# _log.info("Loading the primary threshold polygons parquet files..")
# primary_threshold_polygons_list = []
# for path in primary_threshold_polygons_paths:
#     gdf = gpd.read_parquet(path)
#     primary_threshold_polygons_list.append(gdf)

# primary_threshold_polygons = pd.concat(primary_threshold_polygons_list, ignore_index=True)
# _log.info(f"Found {len(primary_threshold_polygons)} primary threshold polygons.")

In [None]:
# _log.info("Merging primary threshold waterbody polygons located at tile boundaries...")
# primary_threshold_polygons_merged = merge_polygons_at_tile_boundaries(
#     primary_threshold_polygons, filtered_tiles_extents_gdf
# )
# _log.info(f"Primary threshold polygons count {len(primary_threshold_polygons_merged)}.")

In [None]:
# _log.info("Writing primary threshold polygons merged at tile boundaries to disk..")
# primary_threshold_polygons_output_fp = os.path.join(
#     output_directory, "primary_threshold_polygons_merged_at_tile_boundaries.parquet"
# )

# primary_threshold_polygons_merged.to_parquet(primary_threshold_polygons_output_fp)
# _log.info(f"Polygons written to {primary_threshold_polygons_output_fp}")

In [None]:
# # Find all parquet files for the secondary threshold.
# secondary_threshold_polygons_paths = find_parquet_files(path=polygons_from_thresholds_dir, pattern=".*secondary.*")
# _log.info(f"Found {len(secondary_threshold_polygons_paths)} parquet files for the secondary threshold polygons.")

In [None]:
# # Load all the secondary threshold polygons into a single GeoDataFrame.
# _log.info("Loading the secondary threshold polygons parquet files...")
# secondary_threshold_polygons_list = []
# for path in secondary_threshold_polygons_paths:
#     gdf = gpd.read_parquet(path)
#     secondary_threshold_polygons_list.append(gdf)

# secondary_threshold_polygons = pd.concat(secondary_threshold_polygons_list, ignore_index=True)
# _log.info(f"Found {len(secondary_threshold_polygons)} secondary threshold polygons.")

In [None]:
# _log.info("Merging secondary threshold waterbody polygons located at dataset/scene boundaries...")
# secondary_threshold_polygons_merged = merge_polygons_at_tile_boundaries(
#     secondary_threshold_polygons, filtered_tiles_extents_gdf
# )
# _log.info(f"Secondary threshold polygons count {len(secondary_threshold_polygons_merged)}.")

In [None]:
# _log.info("Writing secondary threshold polygons merged at tile boundaries to disk..")
# secondary_threshold_polygons_output_fp = os.path.join(
#     output_directory, "secondary_threshold_polygons_merged_at_ds_boundaries.parquet"
# )

# secondary_threshold_polygons_merged.to_parquet(secondary_threshold_polygons_output_fp)

# _log.info(f"Polygons written to {secondary_threshold_polygons_output_fp}")