In [None]:
from dask_gateway import GatewayCluster
import dask.distributed
import dask.utils
import dask.array
import dask
import planetary_computer
from pystac_client import Client
import odc.stac
import numpy
import xarray
import rasterio
import rasterio.enums
import gc
import math
import os
from scipy import ndimage

In [None]:
def apply_sen2_vld_msk(scns_xa, bands, qa_pxl_msk="SCL", out_no_data_val=0):
    scns_lcl_xa = scns_xa.copy()
    for band in bands:
        scns_lcl_xa[band].values[scns_lcl_xa["SCL"].values == 0] = out_no_data_val  # No Data
        scns_lcl_xa[band].values[scns_lcl_xa["SCL"].values == 1] = out_no_data_val  # Saturation
        scns_lcl_xa[band].values[scns_lcl_xa["SCL"].values == 2] = out_no_data_val  # Cast Shadow
        scns_lcl_xa[band].values[scns_lcl_xa["SCL"].values == 3] = out_no_data_val  # Cloud Shadows
        scns_lcl_xa[band].values[scns_lcl_xa["SCL"].values == 8] = out_no_data_val  # Cloud Medium Probability
        scns_lcl_xa[band].values[scns_lcl_xa["SCL"].values == 9] = out_no_data_val  # Cloud High Probability
        scns_lcl_xa[band].values[scns_lcl_xa["SCL"].values == 10] = out_no_data_val # Thin Cirrus
    return scns_lcl_xa

def get_img_metadata(img_file):
    img_data_obj = rasterio.open(img_file)
    img_bounds = img_data_obj.bounds
    img_bbox = [img_bounds.left, img_bounds.bottom, img_bounds.right, img_bounds.top]
    img_x_res, img_y_res  = img_data_obj.res
    if img_y_res > 0:
        img_y_res = img_y_res * (-1)
    img_data_obj = None
    return img_bbox, img_x_res, img_y_res

def get_img_band_array(img_file, band=1):
    img_data_obj = rasterio.open(img_file)
    img_arr = img_data_obj.read(band)
    img_data_obj = None
    return img_arr

def test_asset_urls(signed_items):
    chkd_items = list()
    for scn_item in signed_items:
        assets_present = True
        for asset_name in scn_item.assets:
            try:
                if (
                    urllib.request.urlopen(scn_item.assets[asset_name].href).getcode()
                    != 200
                ):
                    assets_present = False
                    break
            except urllib.error.HTTPError:
                assets_present = False
                break
            time.sleep(0.1)
        if assets_present:
            chkd_items.append(scn_item)
    print(f"Before: {len(signed_items)}")
    print(f"After: {len(chkd_items)}")
    return chkd_items

In [None]:
cluster = GatewayCluster()  # Creates the Dask Scheduler. Might take a minute.
cluster.adapt(minimum=4, maximum=24)
print(cluster.dashboard_link)

client = dask.distributed.Client(cluster, timeout=10)
odc.stac.configure_rio(cloud_defaults=True, client=client)

In [None]:
catalog = Client.open("https://planetarycomputer.microsoft.com/api/stac/v1")

In [None]:
# Date range of the ROI
time_range = "2018-01-01/2018-12-31"
date_str = "201812"
# Bands to be read
bands = ["B03", "B04", "B08", "B11", "SCL"]

In [None]:
# Create a 5x5 circular operator
morph_op3 = ndimage.generate_binary_structure(2, 1)
morph_op5 = numpy.zeros((5,5))
morph_op5[2,2] = 1
morph_op5 = ndimage.binary_dilation(morph_op5, structure=morph_op3, iterations=2).astype(morph_op5.dtype)

In [None]:
# Define the tiles to be processed.
tiles_gdf = geopandas.read_file("../00_base_data/alert_region_tiles.geojson")
tiles = tiles_gdf["tile"].values
#tiles = tiles[1000:1100]

tiles = tiles.tolist()
#tiles.remove("N27W077")

n_tiles = len(tiles)

In [None]:
check_s1_data = False
check_s2_data = False

In [None]:
in_img_dir = "../gmw_2018_ext_tiles"

In [None]:
out_img_dir = "../gmw_2018_revised_ext_opt_s1"
if not os.path.exists(out_dir):
    os.mkdir(out_dir)

In [None]:
n_tile = 0
# Iterate through the tiles.
for tile in tiles:
    print(f"{tile}: ({n_tile+1} of {n_tiles})")
    out_img_file = os.path.join(out_img_dir, f'gmw_{tile}_2018_alerts_ext.tif')
    # Do not process tile if output file already exists.
    if not os.path.exists(out_img_file):
        # Define the GMW extent image file
        gmw_tile_img = os.path.join(in_img_dir, f"gmw_{tile}_2018_mng_ext.tif")
        # Get the bbox and image resolution of the input image.
        bbox, img_x_res, img_y_res = get_img_metadata(gmw_tile_img)
        
        # Search for scenes.
        search = catalog.search(collections=["sentinel-2-l2a"], bbox=bbox, datetime=time_range, query={"eo:cloud_cover": {"lt": 50}})
        items = search.get_all_items()
        n_items = len(items)
        print(f"\tN Sen2 Scenes: {n_items}")
        
        ####################################################################
        # Find the scenes for the period of interest.
        s1_search = catalog.search(collections=["sentinel-1-rtc"], bbox=bbox, datetime=time_range, query={"sar:polarizations": {"eq": ["VV", "VH"]}})
        s1_items = s1_search.get_all_items()
        n_s1_items = len(s1_items)
        print(f"\tN Sen1 Items = {n_s1_items}")
        s1_avail = True
        if n_s1_items < 1:
            s1_avail = False
        ####################################################################
        
        # Only continue analysis if there are scenes available.
        if n_items > 0:
            # Read the GMW extent into a numpy array
            gmw_ext_msk = get_img_band_array(gmw_tile_img)
            
            # Sign all the items
            signed_items = [planetary_computer.sign(item) for item in items]
            if check_s2_data:
                signed_items = test_asset_urls(signed_items)

            # Read the data into dask xarray structure
            sen2_scn_xa = odc.stac.stac_load(
                signed_items,
                bands=bands,
                groupby="solar_day",
                #dtype=numpy.uint16,
                chunks={"time":24, "latitude": 1024, "longitude": 1024},
                bbox=bbox,
                crs="EPSG:4326",
                resolution=img_x_res
            )
            
            if s1_avail:
                # Read the scenes into dask array structure and make persistant in memory
                signed_s1_items = [planetary_computer.sign(item) for item in s1_items]
                if check_s1_data:
                    signed_s1_items = test_asset_urls(signed_s1_items)

                sen1_scns_xa = odc.stac.stac_load(
                    signed_s1_items,
                    bands=["vh"],
                    groupby="solar_day",
                    chunks={"time":12, "latitude": 1024, "longitude": 1024},
                    bbox=bbox,
                    crs="EPSG:4326",
                    resolution=img_x_res
                )
                sen1_scns_dB_xa = 10 * numpy.log10(sen1_scns_xa)
                sen1_scns_dB_xa["vh"] = sen1_scns_dB_xa.vh.where(sen1_scns_dB_xa.vh>-30)
                sen1_scns_dB_min_xa = sen1_scns_dB_xa.min(dim='time', skipna=True).compute()

            # Store the dataset input dask cluster memory
            # Comment out for larger datasets which don't fit into memory.
            sen2_scn_xa = sen2_scn_xa.persist()

            # Apply cloud masks etc.
            sen2_scn_xa = sen2_scn_xa.map_blocks(apply_sen2_vld_msk, kwargs={"bands":bands})

            # 'Clean' up the bands to remove any values less than zero - shouldn't be needed but just incase...
            sen2_scn_xa['B03'] = sen2_scn_xa.B03.where(sen2_scn_xa.B03>0)
            sen2_scn_xa['B04'] = sen2_scn_xa.B04.where(sen2_scn_xa.B04>0)
            sen2_scn_xa['B08'] = sen2_scn_xa.B08.where(sen2_scn_xa.B08>0)
            sen2_scn_xa['B11'] = sen2_scn_xa.B11.where(sen2_scn_xa.B11>0)

            # Calculate the NDWI (two versions)
            ndwi1_scn_xa = ((sen2_scn_xa.B03-sen2_scn_xa.B08)/(sen2_scn_xa.B03+sen2_scn_xa.B08))
            ndwi2_scn_xa = ((sen2_scn_xa.B03-sen2_scn_xa.B11)/(sen2_scn_xa.B03+sen2_scn_xa.B11))
            
            # Calculate the NDVI
            ndvi_scn_xa = ((sen2_scn_xa.B08-sen2_scn_xa.B04)/(sen2_scn_xa.B08+sen2_scn_xa.B04))
            
            # Calculate the MVI
            mvi_scn_xa = (sen2_scn_xa.B08-sen2_scn_xa.B03)/(sen2_scn_xa.B11+sen2_scn_xa.B03)

            # Apply threshold
            water1_pxls_xa = ndwi1_scn_xa > -0.1
            water2_pxls_xa = ndwi2_scn_xa > 0.15
            veg_pxls_xa = ndvi_scn_xa < 0.1
            mng_pxls_xa = mvi_scn_xa < 0.1

            # Summarise the changes by summing through time.
            water1_pxls_count_xa = water1_pxls_xa.sum(dim="time", skipna=True)
            water2_pxls_count_xa = water2_pxls_xa.sum(dim="time", skipna=True)
            veg_pxls_count_xa = veg_pxls_xa.sum(dim="time", skipna=True)
            mng_pxls_count_xa = mng_pxls_xa.sum(dim="time", skipna=True)
            
            dask.compute(water1_pxls_count_xa, water2_pxls_count_xa, veg_pxls_count_xa, mng_pxls_count_xa)
            
            # Update the GMW extent.
            gmw_ext_msk[water1_pxls_count_xa.values > 4] = 0
            gmw_ext_msk[water2_pxls_count_xa.values > 4] = 0
            gmw_ext_msk[veg_pxls_count_xa.values > 4] = 0
            gmw_ext_msk[mng_pxls_count_xa.values > 4] = 0
            if s1_avail:
                gmw_ext_msk[sen1_scns_dB_min_xa["vh"].values < -19] = 0
            
            # Apply erosion to resulting mask.
            gmw_ext_msk_erode = ndimage.binary_erosion(gmw_ext_msk, structure=morph_op5)

            # Get the image shape (i.e., number of pixels)
            img_shp = gmw_ext_msk.shape
            
            # Define the output image spatial transformation.
            out_img_transform = rasterio.transform.Affine(img_x_res, 0.0, bbox[0], 0.0, img_y_res, bbox[3])
           
            with rasterio.open(out_img_file,
                                'w',
                                driver='COG',
                                height=img_shp[0],
                                width=img_shp[1],
                                count=2,
                                dtype=numpy.uint8,
                                crs='epsg:4326',
                                transform=out_img_transform,
                            ) as out_img_dataset:
            
                # Write output array to the image file
                out_img_dataset.write(gmw_ext_msk, 1)
                out_img_dataset.set_band_description(1, "MngExt2018")
                
                out_img_dataset.write(gmw_ext_msk_erode, 2)
                out_img_dataset.set_band_description(2, "MngExt2018Erode")


            # Delete the data arrays to as not used any more.
            del sen2_scn_xa
            del ndwi1_scn_xa
            del ndwi2_scn_xa
            del ndvi_scn_xa
            del mvi_scn_xa
            del water1_pxls_count_xa
            del water2_pxls_count_xa
            del veg_pxls_count_xa
            del mng_pxls_count_xa
            del gmw_ext_msk
            del gmw_ext_msk_erode
            if s1_avail:
                del sen1_scns_xa
                del sen1_scns_dB_xa
                del sen1_scns_dB_min_xa
            # Restart the dask workers to ensure all the memory etc. is cleared.
            #client.restart(wait_for_workers=False)
    # Increment the tile number for user feedback.
    n_tile += 1
    

In [None]:
# Close the dask cluster
client.close()
cluster.close()