In [None]:
import datacube
import xarray as xr
import rioxarray
import dask
from dask.distributed import performance_report
from dask.distributed import Client

from datacube.utils import masking
import utils.bandindices as bandindices
from utils.geometry import geojson_x_y

%matplotlib inline
import hvplot.xarray

## Connect to ODC
dc = datacube.Datacube()

## Start local Dask cluster
## https://docs.dask.org/en/latest/setup/single-distributed.html#localcluster
n_workers = 1
n_threads = 24
client = Client(processes=False, 
                n_workers=n_workers, 
                threads_per_worker=n_threads)

In [None]:
## Spatial subset based on GeoJSON bounding box
aoi = geojson_x_y("./misc/roda.geojson")
x = aoi[0]
y = aoi[1]

## (Lazy) loading of all datasets
## No chunking as AOI is fairly small and enough RAM is available 
ds_s1_asc = dc.load(product="s1_ARD_asc",
                     x=x, y=y,
                     dask_chunks={'time':-1, 'y':-1, 'x':-1})

ds_s1_desc = dc.load(product="s1_ARD_desc",
                     x=x, y=y,
                     dask_chunks={'time':-1, 'y':-1, 'x':-1})

ds_s2 = dc.load(product="s2_ARD",
                measurements=["red", "nir", "pixel_qa"],
                x=x, y=y,
                dask_chunks={'time':-1, 'y':-1, 'x':-1})

ds_l8 = dc.load(product="l8_ARD",
                measurements=["red", "nir", "pixel_qa"],
                x=x, y=y,
                resolution=(-10, 10),  # resample to get same array size as s2
                dask_chunks={'time':-1, 'y':-1, 'x':-1})


## Landsat 8 & Sentinel-2

In [None]:
## Create cloud/cloud-shadow masks from QAI band
## 'False' in mask = cloud/cloud-shadow/nodata = nan in masked datasets
flags = {'valid_data': 'valid',
        'cloud_state': 'clear',
        'cloud_shadow_flag': False}
mask_s2 = masking.make_mask(ds_s2.pixel_qa, **flags)
mask_l8 = masking.make_mask(ds_l8.pixel_qa, **flags)

## Apply masks to datasets
s2 = ds_s2.where(mask_s2) 
l8 = ds_l8.where(mask_l8) 

## Add indices to datasets 
bandindices.optical(s2, index=["NDVI", "kNDVI"], inplace=True, drop=False, normalise=False)
bandindices.optical(l8, index=["NDVI", "kNDVI"], inplace=True, drop=False, normalise=False)

# This semicolon is just suppressing the cell output :)
;

In [None]:
## Combine index variable of both datasets (S2 & resampled L8) and calculate median aggregates for each summer period and year  
## mpspy = median per summer/season per year

index = "ndvi" # Only used later on to automatically be included in output filenames

with dask.config.set(**{'array.slicing.split_large_chunks': False}):  # silence warning message about large chunks
    
    ind = s2.NDVI.combine_first(l8.NDVI)
    ind_spy = ind.where( ((ind['time.season'] == 'JJA') & (ind >= 0) & (ind <= 1)) )
    
    ## .chunk() to rechunk to more reasonable chunks than dask is applying automatically & compute into memory for faster plotting
    mpspy = ind_spy.groupby('time.year').median(dim='time', skipna=True).chunk({'year': 1, 'x': -1, 'y': -1}).compute()


In [None]:
## Calculate changes in relation to 2017
diff_18 = mpspy[1,:,:] - mpspy[0,:,:]
diff_19 = mpspy[2,:,:] - mpspy[0,:,:]
diff_1819 = ( ( mpspy[2,:,:] + mpspy[1,:,:] ) / 2 ) - mpspy[0,:,:]

---

In [None]:
## Median of 2017 summer period
mpspy[0,:,:].hvplot(height=500, width=900, cmap='viridis', xformatter="%.0f", yformatter="%.0f")

In [None]:
## Median of 2018 summer period
mpspy[1,:,:].hvplot(height=500, width=900, cmap="viridis", xformatter="%.0f", yformatter="%.0f")

In [None]:
## Median of 2019 summer period
mpspy[2,:,:].hvplot(height=500, width=900, cmap="viridis", xformatter="%.0f", yformatter="%.0f")

In [None]:
## Difference 2018 to 2017
diff_18.hvplot(height=500, width=900, cmap="RdBu", xformatter="%.0f", yformatter="%.0f")

In [None]:
## Difference 2019 to 2017
diff_19.hvplot(height=500, width=900, cmap="RdBu", xformatter="%.0f", yformatter="%.0f")

In [None]:
## Difference 2018/2019 to 2017
diff_1819.hvplot(height=500, width=900, cmap="RdBu", xformatter="%.0f", yformatter="%.0f")

In [None]:
## Forest patch with stable/slight decrease of median NDVI (ca. -0.015)
px_good = ind.sel(x=-586725, y=-437065, method='nearest').compute()

## Forest patch with decrease of median NDVI (ca. -0.435)
px_bad = ind.sel(x=-586565, y=-436605, method='nearest').compute()

In [None]:
px_good.hvplot.scatter()

In [None]:
px_bad.hvplot.scatter()

In [None]:
## Save arrays as GeoTIFF
mpspy[0,:,:].rio.to_raster(f"./output/roda_optical/{index}_median_17.tif", dtype="float32")
mpspy[1,:,:].rio.to_raster(f"./output/roda_optical/{index}_median_18.tif", dtype="float32")
mpspy[2,:,:].rio.to_raster(f"./output/roda_optical/{index}_median_19.tif", dtype="float32")
diff_18.rio.to_raster(f"./output/roda_optical/{index}_diff_18.tif", dtype="float32")
diff_19.rio.to_raster(f"./output/roda_optical/{index}_diff_19.tif", dtype="float32")
diff_1819.rio.to_raster(f"./output/roda_optical/{index}_diff_1819.tif", dtype="float32")

---

## Sentinel-1

In [None]:
bandindices.sar(ds_s1_asc, index=["VVVH", "VHVV"], inplace=True, drop=False)
bandindices.sar(ds_s1_desc, index=["VVVH", "VHVV"], inplace=True, drop=False)

# This semicolon is just suppressing the cell output :)
;

In [None]:
## Filter by nodata as some dates might contain rasters with only no data values & Remove outliers by limiting the value range to 95th percentile
nodata = ds_s1_asc.VH.attrs['nodata']
vh_asc    = ds_s1_asc.VH.where(    ((ds_s1_asc.VH != nodata)    & (ds_s1_asc.VH >= ds_s1_asc.VH.quantile(0.05))       & (ds_s1_asc.VH <= ds_s1_asc.VH.quantile(0.95)))       )
vh_desc   = ds_s1_desc.VH.where(   ((ds_s1_desc.VH != nodata)   & (ds_s1_desc.VH >= ds_s1_desc.VH.quantile(0.05))     & (ds_s1_desc.VH <= ds_s1_desc.VH.quantile(0.95)))     )
vvvh_asc  = ds_s1_asc.VVVH.where(  ((ds_s1_asc.VVVH != nodata)  & (ds_s1_asc.VVVH >= ds_s1_asc.VVVH.quantile(0.05))   & (ds_s1_asc.VVVH <= ds_s1_asc.VVVH.quantile(0.95)))   )
vvvh_desc = ds_s1_desc.VVVH.where( ((ds_s1_desc.VVVH != nodata) & (ds_s1_desc.VVVH >= ds_s1_desc.VVVH.quantile(0.05)) & (ds_s1_desc.VVVH <= ds_s1_desc.VVVH.quantile(0.95))) )

In [None]:
## Calculate seasonal median aggregates (mpspy = median per season per year) and compute into memory for faster plotting
mpspy_vh_asc  = vh_asc.where(vh_asc['time.season'] == 'JJA').groupby('time.year').median(dim='time', skipna=True).chunk({'year': 1, 'x': -1, 'y': -1}).compute()
mpspy_vh_desc  = vh_desc.where(vh_desc['time.season'] == 'JJA').groupby('time.year').median(dim='time', skipna=True).chunk({'year': 1, 'x': -1, 'y': -1}).compute()
mpspy_vvvh_asc  = vvvh_asc.where(vh_asc['time.season'] == 'JJA').groupby('time.year').median(dim='time', skipna=True).chunk({'year': 1, 'x': -1, 'y': -1}).compute()
mpspy_vvvh_desc  = vvvh_desc.where(vvvh_desc['time.season'] == 'JJA').groupby('time.year').median(dim='time', skipna=True).chunk({'year': 1, 'x': -1, 'y': -1}).compute()

In [None]:
diff_1819_vh_asc = ( ( mpspy_vh_asc[2,:,:] + mpspy_vh_asc[1,:,:] ) / 2 ) - mpspy_vh_asc[0,:,:]
diff_1819_vh_desc = ( ( mpspy_vh_desc[2,:,:] + mpspy_vh_desc[1,:,:] ) / 2 ) - mpspy_vh_desc[0,:,:]

diff_1819_vvvh_asc = ( ( mpspy_vvvh_asc[2,:,:] + mpspy_vvvh_asc[1,:,:] ) / 2 ) - mpspy_vvvh_asc[0,:,:]
diff_1819_vvvh_desc = ( ( mpspy_vvvh_desc[2,:,:] + mpspy_vvvh_desc[1,:,:] ) / 2 ) - mpspy_vvvh_desc[0,:,:]

---

In [None]:
mpspy_vh_asc[0,:,:].hvplot(height=500, width=900, cmap='viridis', xformatter="%.0f", yformatter="%.0f")

In [None]:
mpspy_vh_asc[1,:,:].hvplot(height=500, width=900, cmap='viridis', xformatter="%.0f", yformatter="%.0f")

In [None]:
mpspy_vh_asc[2,:,:].hvplot(height=500, width=900, cmap='viridis', xformatter="%.0f", yformatter="%.0f")

In [None]:
diff_1819_vh_asc.hvplot(height=500, width=900, cmap="RdBu", xformatter="%.0f", yformatter="%.0f")

In [None]:
diff_1819_vh_desc.hvplot(height=500, width=900, cmap="RdBu", xformatter="%.0f", yformatter="%.0f")

In [None]:
px_good_vh_asc = vh_asc.sel(x=-586725, y=-437065, method='nearest').compute()
px_good_vh_desc = vh_desc.sel(x=-586725, y=-437065, method='nearest').compute()
px_bad_vh_asc = vh_asc.sel(x=-586565, y=-436605, method='nearest').compute()
px_bad_vh_desc = vh_desc.sel(x=-586565, y=-436605, method='nearest').compute()

In [None]:
px_good_vh_asc.hvplot.scatter()

In [None]:
px_good_vh_desc.hvplot.scatter()

In [None]:
px_bad_vh_asc.hvplot.scatter()

In [None]:
px_bad_vh_desc.hvplot.scatter()

In [None]:
## Save arrays as GeoTIFF
mpspy_vh_asc[0,:,:].rio.to_raster(f"./output/roda_SAR/vh_asc_median_17.tif", dtype="float32")
mpspy_vh_asc[1,:,:].rio.to_raster(f"./output/roda_SAR/vh_asc_median_18.tif", dtype="float32")
mpspy_vh_asc[2,:,:].rio.to_raster(f"./output/roda_SAR/vh_asc_median_19.tif", dtype="float32")
diff_1819_vh_asc.rio.to_raster(f"./output/roda_SAR/vh_asc_diff_1819.tif", dtype="float32")
diff_1819_vh_desc.rio.to_raster(f"./output/roda_SAR/vh_desc_diff_1819.tif", dtype="float32")

## Combined time-series plot

In [None]:
import matplotlib.pyplot as plt

In [None]:
ndvi = px_good.where( ((px_good['time.season'] == 'JJA') & (px_good >= 0) & (px_good <= 1)) )

In [None]:
vh_asc_avg = px_good_vh_asc.resample(time="1W", skipna=True).mean().rename("VH_asc")
vh_desc_avg = px_good_vh_desc.resample(time="1W", skipna=True).mean().rename("VH_desc")

In [None]:
merged = xr.merge([ndvi, vh_asc_avg, vh_desc_avg], compat='override') 

In [None]:
x = merged.time
y1 = merged.NDVI
y2 = merged.VH_asc.interpolate_na(dim="time", method="linear")
y3 = merged.VH_desc.interpolate_na(dim="time", method="linear")

fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
ax3 = ax2.twinx()

ax1.plot(x, y1, 'o', c='green')
ax2.plot(x, y2, '-', c='black')
ax3.plot(x, y3, '-', c='blue')

ax1.set_xlabel('time')
ax1.set_ylabel('NDVI')
ax2.set_ylabel('VH_asc')
ax3.set_ylabel('VH_desc')

ax1.set_ylim(0,1)
ax2.set_ylim(-20,-10)
ax3.set_ylim(-20,-10)


In [None]:
px_good_spy.hvplot.scatter(ylim=(0,1))

In [None]:
px_good_vh_asc.hvplot.scatter(ylim=(-20,-10))

---

In [None]:
px_bad_spy.hvplot.scatter(ylim=(0,1))