### testing sentinel-1 to detect flooded forests
#### note: works on Microsoft Plantary Computer

- [ ] bring in geojson AOI
- [ ] get tide for area using fes2104 from dea_coastal tools (need to ask Robbi how to go about this for MPC)
- [ ] load in sentinel-1 for high and low tide examples, filter etc and plot up
- could do notebook that does for AOI 2020 and pulls out into geotiff with dates then look at range of values for STF? - yes
- see how long coastal tile takes
- if too long, do geojson input
- make output file name same as geojson with years etc

In [1]:
import sys

import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
import geopandas as gpd
from osgeo import gdal
import pystac
import pystac_client
import planetary_computer
import requests

import odc.stac
from datacube.utils.cog import write_cog

from scipy.ndimage import uniform_filter
from scipy.ndimage import variance

from IPython.display import Image

# Tide modelling tools
sys.path.insert(1, "/home/jovyan/code/dea-notebooks/Tools") # needed pip3 install OWSLib
from dea_tools.coastal import model_tides,tidal_tag, pixel_tides, tidal_stats


# GRD perhaps?
# VV seems a winner

In [2]:
# MPC catalog connect
catalog = pystac_client.Client.open(
    "https://planetarycomputer.microsoft.com/api/stac/v1",
    modifier=planetary_computer.sign_inplace,
)

In [3]:
# # Defining a file path
# vector_file = '../data/geojson/hawksnest.geojson'
# # Define time period of interest
# time = "2020-01-01/2020-12-31"


# gdf = gpd.read_file(vector_file)

# # Visualizing 
# gdf.explore()

In [15]:
# GA coastal tiles
vector_file = '../data/geojson/ga_summary_grid_c3_coastal.geojson'
attribute_col = 'geometry'

gdf = gpd.read_file(vector_file)
mainland_grid = gdf[gdf['type'] == 'mainland']

id_list = []
for index, row in mainland_grid.iterrows():
    id_list.append(row['id'])
mainland_grid.explore()

In [17]:
gdf = mainland_grid.loc[mainland_grid['id'] == 255]
gdf.explore()

In [19]:
bbox = [gdf.geometry.bounds.minx.min(), gdf.geometry.bounds.miny.min(), gdf.geometry.bounds.maxx.max(), gdf.geometry.bounds.maxy.max()]

# search MPC collections
search = catalog.search(
    collections=["sentinel-1-rtc"], bbox=bbox, datetime=time)
items = search.item_collection()
print(f"Found {len(items)} items")
item = items[0]

Found 142 items


In [None]:
%%time
# load as odc stac dataset
# loading whole coastal tile at 10m res takes too long to load, need to do smaller AOI
ds_s1 = odc.stac.load(items,
                        bbox=bbox,
                        crs="EPSG:3577",
                        resolution=10,
                        groupby='solar_day')

In [6]:
# Scale to plot data in decibels
ds_s1["vv_dB"] = 10 * np.log10(ds_s1.vv)
ds_s1["vh_dB"] = 10 * np.log10(ds_s1.vh)

  result_data = func(*input_data)


In [7]:
# Adapted from https://stackoverflow.com/questions/39785970/speckle-lee-filter-in-python
def lee_filter(da, size):
    img = da.values
    img_mean = uniform_filter(img, (size, size))
    img_sqr_mean = uniform_filter(img**2, (size, size))
    img_variance = img_sqr_mean - img_mean**2

    overall_variance = variance(img)

    img_weights = img_variance / (img_variance + overall_variance)
    img_output = img_mean + img_weights * (img - img_mean)
    return img_output

In [8]:
# Set any null values to 0 before applying the filter to prevent issues
ds_s1_filled = ds_s1.where(~ds_s1.isnull(), 0)

# Create a new entry in dataset corresponding to filtered VV and VH data
ds_s1["filtered_vv"] = ds_s1_filled.vv.groupby("time").apply(lee_filter, size=7)
ds_s1["filtered_vh"] = ds_s1_filled.vh.groupby("time").apply(lee_filter, size=7)

CPU times: user 2.61 s, sys: 3.08 s, total: 5.69 s
Wall time: 6.34 s


In [9]:
# Scale to plot data in decibels
ds_s1["filtered_vv_dB"] = 10 * np.log10(ds_s1.filtered_vv)
ds_s1["filtered_vh_dB"] = 10 * np.log10(ds_s1.filtered_vh)

# # Plot all filtered VH observations for the year
# ds_s1.filtered_vh_dB.plot(cmap="Greys_r", robust=True, col="time", col_wrap=5)
# ds_s1.filtered_vv_dB.plot(cmap="Greys_r", robust=True, col="time", col_wrap=5)

# plt.show()

In [None]:
# output filtered_vv_dB time to geotiff and filtered_vh_dB time to geotiff

In [11]:
# make all into datasets and merge
variables_xarray_list = []

for time_step in ds_s1.filtered_vv_dB.time:
    # Access the data for the current time step
    current_data = ds_s1.filtered_vv_dB.sel(time=time_step)
    current_data_ds = current_data.to_dataset(name=str(current_data.time.values)).squeeze().drop('time')
    variables_xarray_list.append(current_data_ds)

stacked_xarray = xr.merge(variables_xarray_list)

In [12]:
stacked_xarray

In [13]:
def set_band_names(input_img: str, band_names: list, feedback: bool = False):
    """
    NOTE: Function is from RSGISlib that is not avaliable on DEA sandbox
    A utility function to set band names.

    :param input_img: input image file.
    :param band_names: is a list of band names
    :param feedback: is a boolean specifying whether feedback will be printed to the
                     console (True= Printed / False (default) Not Printed)

    .. code:: python

        from rsgislib import imageutils

        input_img = 'injune_p142_casi_sub_utm.kea'
        band_names = ['446nm','530nm','549nm','569nm','598nm','633nm','680nm','696nm',
                      '714nm','732nm','741nm','752nm','800nm','838nm']

        imageutils.set_band_names(input_img, band_names)

    """
    dataset = gdal.Open(input_img, gdal.GA_Update)

    for i in range(len(band_names)):
        band = i + 1
        bandName = band_names[i]

        imgBand = dataset.GetRasterBand(band)
        # Check the image band is available
        if not imgBand is None:
            if feedback:
                print('Setting Band {0} to "{1}"'.format(band, bandName))
            imgBand.SetDescription(bandName)
        else:
            pass # raise rsgislib.RSGISPyException("Could not open the image band: ", band)

In [14]:
write_cog(geo_im=stacked_xarray.to_array(),
          # fname='initial_STF_model_coastal_tile_gridID_' + gdf['id'].astype(str).item() +'_year_' + time_range[0] + '.tif',
          fname='test.tif',
          overwrite=True,
          nodata=0.0)

input_img = 'test.tif'
band_names = list(stacked_xarray.data_vars.keys())
set_band_names(input_img, band_names)

