In [None]:
import numpy as np
import xarray as xr

import rasterio.mask
import rasterio.features
import stackstac
import pystac_client
import json
import pyproj
import dask
import rasterio
import os
import glob

from dask_gateway import GatewayCluster
from shapely.geometry import shape, mapping
from shapely.ops import transform
from tqdm import tqdm
from rasterio.crs import CRS
from rasterio.profiles import DefaultGTiffProfile
from rasterio.merge import merge
from random import choices
from dask.distributed import wait

In [None]:
periods = {
    'before': '2022-05-01/2022-05-31',
    'after': '2022-08-01/2022-09-30',
}

In [None]:
cluster = GatewayCluster()  # Creates the Dask Scheduler. Might take a minute.

client = cluster.get_client()

cluster.adapt(minimum=4, maximum=100)
print(cluster.dashboard_link)

In [None]:
@dask.delayed
def get_ndwi(task_id, feature, epsg_code, dt_range):
    src_crs = pyproj.CRS(f"EPSG:{epsg_code}")
    dst_crs = pyproj.CRS("EPSG:4326")
    
    project = pyproj.Transformer.from_crs(src_crs, dst_crs, always_xy=True).transform
    
    utm_shape = shape(feature)
    
    area_of_interest = mapping(transform(project, utm_shape))
    
    bbox = rasterio.features.bounds(area_of_interest)
    
    stac = pystac_client.Client.open(
        "https://earth-search.aws.element84.com/v0"
    )

    search = stac.search(
        bbox=bbox,
        datetime=dt_range,#"2022-09-01/2022-09-30",
        collections=["sentinel-s2-l2a-cogs"],
        query={"eo:cloud_cover": {"lt": 50}},
    )

    items = search.item_collection()
    
    data = (
        stackstac.stack(
            items,
            assets=["B08", "B03"],
            chunksize=4096,
            resolution=10,
            epsg=epsg_code,
        )
        .where(lambda x: x > 0, other=np.nan)  # sentinel-2 uses 0 as nodata
        .assign_coords(band=lambda x: x.common_name.rename("band"))  # use common names
    )

    minx, miny, maxx, maxy = utm_shape.bounds
    data = data.loc[..., maxy:miny, minx:maxx]
    
    data = data.persist()
    
    median = data.median('time')
    
    ndwi = (median[1] - median[0]) / (median[1] + median[0])
    
    return task_id, ndwi.compute()

In [None]:
with open('sindh_grid.geojson', 'r') as f:
    fc = json.load(f)
    epsg_code = int(fc['crs']['properties']['name'].split(':')[-1])


tasks = {}

for period, dt_range in periods.items():
    if not os.path.exists(f'ndwi_tiles/{period}'):
        os.makedirs(f'ndwi_tiles/{period}')
        
    for i, feature in enumerate(fc['features']):
        task_id = f'{period}_{i:03d}'
        tasks[task_id] = {
            'period': period,
            'dt_range': dt_range,
            'tile_id': f'{i:03d}',
            'feature': feature,
            'epsg_code': epsg_code
        }

In [None]:
def sample(x, k):
    if len(x) <= k:
        return x
    
    return choices(x, k=k)

def save_ndwi(task_id, ndwi):
    task = tasks[task_id]
    
    utm_shape = shape(task['feature']['geometry'])
    ndwi = ((ndwi * 127) + 128)

    trans=rasterio.transform.from_bounds(*utm_shape.bounds, width=ndwi.shape[1], height=ndwi.shape[0])
    with rasterio.open(
        f'ndwi_tiles/{task["period"]}/{task["tile_id"]}.tif',
        'w',
        **DefaultGTiffProfile(
            count=1,
            height=ndwi.shape[0],
            width=ndwi.shape[1],
            dtype=rasterio.uint8,
            transform=trans,
            crs=CRS.from_string(f"EPSG:{task['epsg_code']}")
        )) as dst_dataset:
        dst_dataset.write(ndwi, 1)

def get_missing_tasks():
    missing = {}
    for task_id, task in tasks.items():
        fname = f'ndwi_tiles/{tasks[task_id]["period"]}/{tasks[task_id]["tile_id"]}.tif'
        if os.path.exists(fname):
            continue
        
        missing[task_id] = task
            
    return list(missing.keys())

while len(get_missing_tasks()) > 0:
    missing_task_ids = get_missing_tasks()
    print(f'{len(missing_task_ids):03d} Tasks Remaining')
    
    working_tasks = []
    for task_id in sample(missing_task_ids, 50):
        working_tasks.append(get_ndwi(task_id, tasks[task_id]['feature']['geometry'], tasks[task_id]['epsg_code'], tasks[task_id]['dt_range']))
    
    wait(dask.persist(*working_tasks))
    
    for task_result in working_tasks:
        try:
            task_id, ndwi = task_result.compute()
            ndwi = ndwi.compute()
            save_ndwi(task_id, ndwi)
        except Exception as e:
            print(e)
            pass
        
print('Done')

In [None]:
with open('sindh_utm.geojson', 'r') as f:
    fc = json.load(f)
    sindh = shape(fc['features'][0]['geometry'])

folder_names = list(glob.glob('ndwi_tiles/*'))
for folder_name in tqdm(folder_names):
    period_name = folder_name.split('/')[-1]

    merge_files = [ rasterio.open(f, 'r') for f in glob.glob(f'{folder_name}/*.tif') ]

    mosaic, out_trans = merge(merge_files)
    out_meta = merge_files[0].meta.copy()
    out_meta.update({"driver": "GTiff",
            "height": mosaic.shape[1],
            "width": mosaic.shape[2],
            "transform": out_trans
        }
    )

    with rasterio.open(f'output/{period_name}_ndwi.tif', "w", **out_meta) as dest:
        dest.write(mosaic)
        
    with rasterio.open(f'output/{period_name}_ndwi.tif', 'r') as src:
        out_image, out_transform = rasterio.mask.mask(src, [sindh], crop=True)
        out_meta = src.meta
                       
        out_meta.update({"driver": "GTiff",
                 "height": out_image.shape[1],
                 "width": out_image.shape[2],
                 "transform": out_transform})

        with rasterio.open(f'output/{period_name}_ndwi_clipped.tif', "w", **out_meta) as dest:
            dest.write(out_image)