In [None]:
import datacube
import xarray as xr
import rioxarray
import dask
from dask.distributed import performance_report
from dask.distributed import Client

from datacube.utils import masking
import utils.bandindices as bandindices
from utils.geometry import geojson_x_y

%matplotlib inline
import hvplot.xarray

## Connect to ODC
dc = datacube.Datacube()

## Start local Dask cluster
## https://docs.dask.org/en/latest/setup/single-distributed.html#localcluster
n_workers = 1
n_threads = 12
client = Client(processes=False, 
                n_workers=n_workers, 
                threads_per_worker=n_threads)

In [None]:
## Spatial subset based on GeoJSON bounding box
aoi = geojson_x_y("./misc/roda.geojson")
x = aoi[0]
y = aoi[1]

"""
ds_s1_asc = dc.load(product="s1_ARD_asc",
                     x=x, y=y,
                     dask_chunks={'time':-1, 'y':-1, 'x':-1})

ds_s1_desc = dc.load(product="s1_ARD_desc",
                     x=x, y=y,
                     dask_chunks={'time':-1, 'y':-1, 'x':-1})
"""
   
ds_s2 = dc.load(product="s2_ARD",
                x=x, y=y,
                dask_chunks={'time':-1, 'y':1100, 'x':1000})

ds_l8 = dc.load(product="l8_ARD",
                x=x, y=y,
                resolution=(-10, 10),  # resample to get same array size as s2
                dask_chunks={'time':-1, 'y':1100, 'x':1000})


---

In [None]:
## Temporal subset for summer months (June, July, August)
## More info on this kind of indexing: 
## https://xarray.pydata.org/en/stable/time-series.html#datetime-components
## https://xarray.pydata.org/en/stable/indexing.html#masking-with-where
s2 = ds_s2.where(ds_s2['time.season'] == 'JJA', drop=True)
l8 = ds_l8.where(ds_l8['time.season'] == 'JJA', drop=True)

## Calculate indices and drop anything else
s2 = bandindices.optical(s2, index=["kNDVI", "NDVI"], inplace=True, drop=True, normalise=False)
l8 = bandindices.optical(l8, index=["kNDVI", "NDVI"], inplace=True, drop=True, normalise=False)

In [None]:
## Remove anything outside NDVI range and load into memory via .compute()
s2_ndvi = s2.NDVI.where(((s2.NDVI > 0) & (s2.NDVI <= 1)), drop=True).compute()
l8_ndvi = l8.NDVI.where(((l8.NDVI > 0) & (l8.NDVI <= 1)), drop=True).compute()

In [None]:
## Create combined dataset
ndvi = s2_ndvi.combine_first(l8_ndvi)

## Median per year
ndvi_mpy = ndvi.groupby('time.year').median(dim='time')

In [None]:
ndvi_2017 = ndvi_mpy[0,:,:]
ndvi_2018 = ndvi_mpy[1,:,:]
ndvi_2019 = ndvi_mpy[2,:,:]

diff_2018 = ndvi_2017 - ndvi_2018
diff_2019 = ndvi_2017 - ndvi_2019

---

In [None]:
ndvi_2017.hvplot(height=500, width=900, cmap="viridis")

In [None]:
ndvi_2018.hvplot(height=500, width=900, cmap="viridis")

In [None]:
ndvi_2019.hvplot(height=500, width=900, cmap="viridis")

In [None]:
diff_2018.hvplot(height=500, width=900, cmap="RdBu")

In [None]:
diff_2019.hvplot(height=500, width=900, cmap="RdBu")

---

In [None]:
## bad
ndvi.sel(x=-586000, y=-437260, method='nearest').hvplot.scatter()

In [None]:
## good
ndvi.sel(x=-584580, y=-436800, method='nearest').hvplot.scatter()