In [1]:
from dask_gateway import GatewayCluster
import dask.distributed
import dask
import planetary_computer
from pystac_client import Client
import odc.stac
import numpy
import xarray
import geopandas
import rasterio
import rasterio.shutil
import os
import json
from scipy import ndimage
from azure.storage.blob import BlobClient
import shutil

In [2]:
def apply_sen2_vld_msk(scns_xa, bands, qa_pxl_msk="SCL", out_no_data_val=0):
    scns_lcl_xa = scns_xa.copy()
    for band in bands:
        scns_lcl_xa[band].values[
            scns_lcl_xa[qa_pxl_msk].values == 0
        ] = out_no_data_val  # No Data
        scns_lcl_xa[band].values[
            scns_lcl_xa[qa_pxl_msk].values == 1
        ] = out_no_data_val  # Saturation
        scns_lcl_xa[band].values[
            scns_lcl_xa[qa_pxl_msk].values == 2
        ] = out_no_data_val  # Cast Shadow
        scns_lcl_xa[band].values[
            scns_lcl_xa[qa_pxl_msk].values == 3
        ] = out_no_data_val  # Cloud Shadows
        scns_lcl_xa[band].values[
            scns_lcl_xa[qa_pxl_msk].values == 8
        ] = out_no_data_val  # Cloud Medium Probability
        scns_lcl_xa[band].values[
            scns_lcl_xa[qa_pxl_msk].values == 9
        ] = out_no_data_val  # Cloud High Probability
        scns_lcl_xa[band].values[
            scns_lcl_xa[qa_pxl_msk].values == 10
        ] = out_no_data_val  # Thin Cirrus
    return scns_lcl_xa


def apply_sen2_offset(sen2_scns_xa, offset=-1000):

    # Define the date splitting whether the offset should be applied.
    off_date = numpy.datetime64("2022-01-25")
    # Get Minimum date in timeseries
    time_min = sen2_scns_xa.time.min().values
    # Get Maximum date in timeseries
    time_max = sen2_scns_xa.time.max().values

    # Get the list of variables
    bands = list(sen2_scns_xa.data_vars)
    # List of all bands for which offset should be applied if present.
    s2_img_bands = [
        "B01",
        "B02",
        "B03",
        "B04",
        "B05",
        "B06",
        "B07",
        "B08",
        "B8A",
        "B09",
        "B10",
        "B11",
        "B12",
    ]

    if (time_min < off_date) and (time_max > off_date):
        # Crosses the offset data and therefore part of the dataset needs offset applying
        sen2_scns_xa_pre_off = sen2_scns_xa.sel(time=slice(time_min, off_date))
        sen2_scns_xa_post_off = sen2_scns_xa.sel(time=slice(off_date, time_max))
        for band in bands:
            if band in s2_img_bands:
                sen2_scns_xa_post_off[band] = sen2_scns_xa_post_off[band] + offset
                sen2_scns_xa_post_off[band].where(sen2_scns_xa_post_off[band] < 0, 0)
                sen2_scns_xa_post_off[band].where(
                    sen2_scns_xa_post_off[band] > 10000, 0
                )
        sen2_scns_xa = xarray.concat(
            [sen2_scns_xa_pre_off, sen2_scns_xa_post_off], dim="time"
        )
    elif time_min > off_date:
        # All scenes after offset date apply to all
        for band in bands:
            if band in s2_img_bands:
                sen2_scns_xa[band] = sen2_scns_xa[band] + offset
                sen2_scns_xa[band].where(sen2_scns_xa[band] < 0, 0)
                sen2_scns_xa[band].where(sen2_scns_xa[band] > 10000, 0)
    # else: time_max < off_date:
    # Do nothing - no offset required
    return sen2_scns_xa


def get_img_metadata(img_file):
    img_data_obj = rasterio.open(img_file)
    img_bounds = img_data_obj.bounds
    img_bbox = [img_bounds.left, img_bounds.bottom, img_bounds.right, img_bounds.top]
    img_x_res, img_y_res  = img_data_obj.res
    if img_y_res > 0:
        img_y_res = img_y_res * (-1)
    img_data_obj = None
    return img_bbox, img_x_res, img_y_res


def get_img_band_array(img_file, band=1):
    img_data_obj = rasterio.open(img_file)
    img_arr = img_data_obj.read(band)
    img_data_obj = None
    return img_arr


def test_asset_urls(signed_items):
    chkd_items = list()
    for scn_item in signed_items:
        assets_present = True
        for asset_name in scn_item.assets:
            try:
                if (
                    urllib.request.urlopen(scn_item.assets[asset_name].href).getcode()
                    != 200
                ):
                    assets_present = False
                    break
            except urllib.error.HTTPError:
                assets_present = False
                break
            time.sleep(0.1)
        if assets_present:
            chkd_items.append(scn_item)
    print(f"Before: {len(signed_items)}")
    print(f"After: {len(chkd_items)}")
    return chkd_items


In [3]:
cluster = GatewayCluster()  # Creates the Dask Scheduler. Might take a minute.
cluster.adapt(minimum=4, maximum=24)
print(cluster.dashboard_link)

client = dask.distributed.Client(cluster, timeout=10)
odc.stac.configure_rio(cloud_defaults=True, client=client)

https://pccompute.westeurope.cloudapp.azure.com/compute/services/dask-gateway/clusters/prod.c0ec3f9522d34d32a6aecd5758236ce6/status


In [4]:
catalog = Client.open("https://planetarycomputer.microsoft.com/api/stac/v1")

In [5]:
# Date range of the ROI
time_range = "2018-01-01/2018-12-31"
date_str = "201812"
# Bands to be read
bands = ["B03", "B04", "B08", "B11", "SCL"]

In [6]:
# Create a 5x5 circular operator
morph_op3 = ndimage.generate_binary_structure(2, 1)
morph_op5 = numpy.zeros((5,5))
morph_op5[2,2] = 1
morph_op5 = ndimage.binary_dilation(morph_op5, structure=morph_op3, iterations=2).astype(morph_op5.dtype)

In [7]:
# Define the tiles to be processed.
tiles_gdf = geopandas.read_file("../00_base_data/alert_region_tiles.geojson")
tiles = tiles_gdf["tile"].values
#tiles = tiles[1000:1100]

tiles = tiles.tolist()
#tiles.remove("N27W077")

n_tiles = len(tiles)

In [8]:
check_s1_data = False
check_s2_data = False

In [9]:
in_img_dir = "gmw_2018_baseline"
out_img_dir = "gmw_2018_revised"

In [10]:
tmp_path = "gmw_2018_revised_tmp"
if not os.path.exists(tmp_path):
    os.mkdir(tmp_path)

In [11]:
sas_info_file = "/home/jovyan/azure_info.json"
with open(sas_info_file) as f:
    sas_token_info = json.load(f)

In [None]:
n_tile = 0
# Iterate through the tiles.
for tile in tiles:
    print(f"{tile}: ({n_tile+1} of {n_tiles})")
    out_img_file = f"gmw_{tile}_2018_alerts_ext.tif"
    out_img_url = os.path.join(sas_token_info["url"], out_img_dir, out_img_file)
    out_img_url_signed = f"{out_img_url}?{sas_token_info['sas_token']}"
    # Do not process tile if output file already exists.
    if not BlobClient.from_blob_url(out_img_url_signed).exists():
        # Define the GMW extent image file
        gmw_tile_img = f"gmw_{tile}_2018_mng_ext.tif"
        gmw_tile_url = os.path.join(sas_token_info["url"], in_img_dir, gmw_tile_img)
        gmw_tile_url_signed = f"{gmw_tile_url}?{sas_token_info['sas_token']}"
        
        # Get the bbox and image resolution of the input image.
        bbox, img_x_res, img_y_res = get_img_metadata(gmw_tile_url_signed)
        
        # Search for scenes.
        search = catalog.search(collections=["sentinel-2-l2a"], bbox=bbox, datetime=time_range, query={"eo:cloud_cover": {"lt": 50}})
        items = search.get_all_items()
        n_items = len(items)
        print(f"\tN Sen2 Items: {n_items}")
        
        ####################################################################
        # Find the scenes for the period of interest.
        s1_search = catalog.search(collections=["sentinel-1-rtc"], bbox=bbox, datetime=time_range, query={"sar:polarizations": {"eq": ["VV", "VH"]}})
        s1_items = s1_search.get_all_items()
        n_s1_items = len(s1_items)
        print(f"\tN Sen1 Items = {n_s1_items}")
        s1_avail = True
        if n_s1_items < 1:
            s1_avail = False
        ####################################################################
        
        # Only continue analysis if there are scenes available.
        if n_items > 0:
            # Read the GMW extent into a numpy array
            gmw_ext_msk = get_img_band_array(gmw_tile_url_signed)
            
            # Sign all the items
            signed_items = [planetary_computer.sign(item) for item in items]
            if check_s2_data:
                signed_items = test_asset_urls(signed_items)

            # Read the data into dask xarray structure
            sen2_scn_xa = odc.stac.stac_load(
                signed_items,
                bands=bands,
                groupby="solar_day",
                chunks={"time":24, "latitude": 1024, "longitude": 1024},
                bbox=bbox,
                crs="EPSG:4326",
                resolution=img_x_res
            )
            
            if s1_avail:
                # Read the scenes into dask array structure and make persistant in memory
                signed_s1_items = [planetary_computer.sign(item) for item in s1_items]
                if check_s1_data:
                    signed_s1_items = test_asset_urls(signed_s1_items)

                sen1_scns_xa = odc.stac.stac_load(
                    signed_s1_items,
                    bands=["vh"],
                    groupby="solar_day",
                    chunks={"time":12, "latitude": 1024, "longitude": 1024},
                    bbox=bbox,
                    crs="EPSG:4326",
                    resolution=img_x_res
                )
                sen1_scns_dB_xa = 10 * numpy.log10(sen1_scns_xa)
                sen1_scns_dB_xa["vh"] = sen1_scns_dB_xa.vh.where(sen1_scns_dB_xa.vh>-30)
                sen1_scns_dB_min_xa = sen1_scns_dB_xa.min(dim='time', skipna=True).compute()

            # Store the dataset input dask cluster memory
            # Comment out for larger datasets which don't fit into memory.
            sen2_scn_xa = sen2_scn_xa.persist()
            
            # Apply Offset
            sen2_scn_xa = apply_sen2_offset(sen2_scn_xa)

            # Apply cloud masks etc.
            sen2_scn_xa = sen2_scn_xa.map_blocks(apply_sen2_vld_msk, kwargs={"bands":bands})

            # 'Clean' up the bands to remove any values less than zero - shouldn't be needed but just incase...
            sen2_scn_xa['B03'] = sen2_scn_xa.B03.where(sen2_scn_xa.B03>0)
            sen2_scn_xa['B04'] = sen2_scn_xa.B04.where(sen2_scn_xa.B04>0)
            sen2_scn_xa['B08'] = sen2_scn_xa.B08.where(sen2_scn_xa.B08>0)
            sen2_scn_xa['B11'] = sen2_scn_xa.B11.where(sen2_scn_xa.B11>0)

            # Calculate the NDWI (two versions)
            ndwi1_scn_xa = ((sen2_scn_xa.B03-sen2_scn_xa.B08)/(sen2_scn_xa.B03+sen2_scn_xa.B08))
            ndwi2_scn_xa = ((sen2_scn_xa.B03-sen2_scn_xa.B11)/(sen2_scn_xa.B03+sen2_scn_xa.B11))
            
            # Calculate the NDVI
            ndvi_scn_xa = ((sen2_scn_xa.B08-sen2_scn_xa.B04)/(sen2_scn_xa.B08+sen2_scn_xa.B04))
            
            # Calculate the MVI
            mvi_scn_xa = (sen2_scn_xa.B08-sen2_scn_xa.B03)/(sen2_scn_xa.B11+sen2_scn_xa.B03)

            # Apply threshold
            water1_pxls_xa = ndwi1_scn_xa > -0.1
            water2_pxls_xa = ndwi2_scn_xa > 0.15
            veg_pxls_xa = ndvi_scn_xa < 0.1
            mng_pxls_xa = mvi_scn_xa < 0.1

            # Summarise the changes by summing through time.
            water1_pxls_count_xa = water1_pxls_xa.sum(dim="time", skipna=True)
            water2_pxls_count_xa = water2_pxls_xa.sum(dim="time", skipna=True)
            veg_pxls_count_xa = veg_pxls_xa.sum(dim="time", skipna=True)
            mng_pxls_count_xa = mng_pxls_xa.sum(dim="time", skipna=True)
            
            dask.compute(water1_pxls_count_xa, water2_pxls_count_xa, veg_pxls_count_xa, mng_pxls_count_xa)
            
            # Update the GMW extent.
            gmw_ext_msk[water1_pxls_count_xa.values > 4] = 0
            gmw_ext_msk[water2_pxls_count_xa.values > 4] = 0
            gmw_ext_msk[veg_pxls_count_xa.values > 4] = 0
            gmw_ext_msk[mng_pxls_count_xa.values > 4] = 0
            if s1_avail:
                gmw_ext_msk[sen1_scns_dB_min_xa["vh"].values < -19] = 0
            
            # Apply erosion to resulting mask.
            gmw_ext_msk_erode = ndimage.binary_erosion(gmw_ext_msk, structure=morph_op5)

            # Get the image shape (i.e., number of pixels)
            img_shp = gmw_ext_msk.shape
            
            # Define the output image spatial transformation.
            out_img_transform = rasterio.transform.Affine(img_x_res, 0.0, bbox[0], 0.0, img_y_res, bbox[3])
            
            out_img_file_path = os.path.join(tmp_path, out_img_file)
            with rasterio.open(out_img_file_path,
                                'w',
                                driver='COG',
                                height=img_shp[0],
                                width=img_shp[1],
                                count=2,
                                dtype=numpy.uint8,
                                crs='epsg:4326',
                                transform=out_img_transform,
                            ) as out_img_dataset:
            
                # Write output array to the image file
                out_img_dataset.write(gmw_ext_msk, 1)
                out_img_dataset.set_band_description(1, "MngExt2018")
                
                out_img_dataset.write(gmw_ext_msk_erode, 2)
                out_img_dataset.set_band_description(2, "MngExt2018Erode")
            
            if os.path.exists(out_img_file_path):
                client = BlobClient.from_blob_url(out_img_url_signed)
                with open(out_img_file_path, 'rb') as data:
                    client.upload_blob(data)
                client = None
                rasterio.shutil.delete(out_img_file_path, driver="COG")


            # Delete the data arrays to as not used any more.
            del sen2_scn_xa
            del ndwi1_scn_xa
            del ndwi2_scn_xa
            del ndvi_scn_xa
            del mvi_scn_xa
            del water1_pxls_count_xa
            del water2_pxls_count_xa
            del veg_pxls_count_xa
            del mng_pxls_count_xa
            del gmw_ext_msk
            del gmw_ext_msk_erode
            if s1_avail:
                del sen1_scns_xa
                del sen1_scns_dB_xa
                del sen1_scns_dB_min_xa
            # Restart the dask workers to ensure all the memory etc. is cleared.
            #client.restart(wait_for_workers=False)
    # Increment the tile number for user feedback.
    n_tile += 1
    

S05W081: (1 of 484)
N01W081: (2 of 484)
N02W080: (3 of 484)
S03W081: (4 of 484)
N02W079: (5 of 484)
S03W080: (6 of 484)
N01W080: (7 of 484)
S02W080: (8 of 484)
S04W082: (9 of 484)
N00W081: (10 of 484)
S02W081: (11 of 484)
N22W091: (12 of 484)
N19W104: (13 of 484)
N19W089: (14 of 484)
N17W095: (15 of 484)
N19W094: (16 of 484)
N21W098: (17 of 484)
N21W097: (18 of 484)
N19W103: (19 of 484)
N22W106: (20 of 484)
N19W093: (21 of 484)
N27W110: (22 of 484)
N19W092: (23 of 484)
N21W106: (24 of 484)
N22W098: (25 of 484)
N19W095: (26 of 484)
N17W094: (27 of 484)
N19W088: (28 of 484)
N21W091: (29 of 484)
N19W105: (30 of 484)
N22W090: (31 of 484)
N17W096: (32 of 484)
N27W114: (33 of 484)
N24W098: (34 of 484)
N19W097: (35 of 484)
N22W087: (36 of 484)
N24W108: (37 of 484)
N22W088: (38 of 484)
N29W112: (39 of 484)
N27W113: (40 of 484)
N24W107: (41 of 484)
N17W101: (42 of 484)
N17W099: (43 of 484)
N17W098: (44 of 484)
N17W100: (45 of 484)
N27W112: (46 of 484)
N21W088: (47 of 484)
N29W113: (48 of 484)
N

In [None]:
# Close the dask cluster
client.close()
cluster.close()

In [None]:
#if os.path.exists(tmp_path):
#    shutil.rmtree(tmp_path)