# Creating an odc-stats plugin; or how I learned to stop worrying and love odc-stats

> Important: for this notebook to work in the Sandbox, both odc-stats and odc-algo require upgrading

Useful links:
* [odc-stats](https://github.com/opendatacube/odc-stats)
* [odc-stats plugins](https://github.com/opendatacube/odc-stats/tree/develop/odc/stats/plugins)

In [1]:
# !pip install git+https://github.com/opendatacube/odc-algo.git
# !pip install -U odc-stats

## Background

**ODC-Statistician** ([Open Data Cube Statistician](https://github.com/opendatacube/odc-stats)) is a framework of tools for generating statistical summaries (usually across time) of large collections of Earth Observation imagery managed in an Open Datacube Instance. 

`odc-stats` is a powerful and flexible tool for running batch processing of tiled Earth Observation summary products across many EC2 instances on a cloud compute environment. However, for an Earth Observation scientist who is more familiar with developing projects on a single machine (think the Sandbox), it can be confusing to transition code to odc-stats.  This is partly because of a lack of documentation of odc-stats functions, which requires reading source code to understand how it works, and partly because the structure and nomenclature of odc-stats differs from the usual parlance of EO scientists working within the Sandbox ecosystem.  

**Importantly**, this notebook is not intended for developers looking for instructions on how to coordinate a large scale batch run of odc-stats on kubernetes. Instead, its intention is to <ins>provide guidance to EO scientists on how to translate their code from the Sandbox to odc-stats</ins>.  

## Description

**The aims of this notebook are two-fold**:
1. Demystify the use of odc-stats by demonstrating development of a minimal odc-stats "plugin".
2. Provide example code for running odc-stats functions within a local machine, thus demonstrating how to develop and test the development of an odc-stats "plugin".

**Two key things are required to use odc-stats.** A `plugin`, which is essentially a python class that contains functions for summarising satellite images within an ODC environment. And a `config` file (a .yaml file) which provides arguments to the plugin.  In this notebook we will systematically build these files.

The notebook is broken up into <ins>two main sections</ins>.
1. Firstly, we will develop a simple odc-stats plugin, and run it _without using odc-stats_. This can be useful for testing purposes when developing a function, and it will also help us understand the transition from 'sandbox-esque' code to odc-stats code.
2. In the second section, we will demonstrate running the plugin using odc-stats.


## Import libraries 

In [None]:
import os
import yaml
import json
import warnings
import xarray as xr
import rioxarray as rxr
import geopandas as gpd
from pprint import pprint
import matplotlib.pyplot as plt
from odc.geo.xr import assign_crs
from odc.stats.tasks import TaskReader
from odc.stats.model import OutputProduct

warnings.filterwarnings("ignore")

## Section 1: Create a plugin, run it _without_ using odc-stats

### Creating an odc-stats plugin

To begin, let's create a python `StatsPluginInterface` class function that summarises a Landsat time series of NDVI.

The base `StatsPluginInterface` python class is described [here](https://github.com/opendatacube/odc-stats/blob/develop/odc/stats/plugins/_base.py).  To define a custom plugin, we generally need to define a few key functions:

1. <ins>`measurements`</ins>: a function that describes the output 'measurements' of the final product (AKA 'bands'). In this example, the final product will have a measurement called 'ndvi_median'
   
2. <ins>`native_transform`</ins>: this function is passed to an upstream function called [odc.algo.io.load_with_native_transform](https://github.com/opendatacube/odc-algo/blob/bd2fb6828beafed60b5f58f465df8da78cb071e2/odc/algo/io.py#L157). The role of this function is to define pre-processing steps that are applied to individually to every satellite image. This is usually used for things like masking cloud, nodata, and contiguity masking.  The 'load_with_native_transform' function sits within a higher order function called [input_data](https://github.com/opendatacube/odc-stats/blob/7f34c86bdbd481340c41b5be7e0d0873ce3b3e1c/odc/stats/plugins/_base.py#L65) (which itself is within the `StatsPluginInterface` class). For relatively standard odc-stats operations where the main tasks are loading, masking, and summarising a time series of satellite images, defining a "native_transform" function for masking is all that's required (and this is passed to "input_data"). However, for more flexibility, we can define our own custom "input_data" function and odc-stats will run this instead. For example, see [this plugin](https://github.com/opendatacube/odc-stats/blob/develop/odc/stats/plugins/lc_ml_treelite.py) that runs a machine learning prediction. 

3. <ins>`reduce`</ins>: This function describes how we summarise a time series to a single image. For example, by taking a temporal median. However, this function can be highly flexible. For example, we could load a machine learning model in this step to classify data. We could even load other ancillary datasets if needed.


Now, lets create a simple `StatsPlugin` class that, when provided with a series of DEA Landsat images, will mask for clouds and bad data, calculate NDVI, and then 'reduce' the time series using a temporal median.

In [None]:
#masking functions
from typing import Optional, Sequence, Tuple
from datacube.utils.masking import mask_invalid_data
from odc.algo._masking import (
    enum_to_bool,
    mask_cleanup,
    erase_bad
)
#odc-stats functions for registering a plugin
from odc.stats.plugins._registry import register, StatsPluginInterface

class StatsNDVI(StatsPluginInterface):
    """
    Define a class for summarising time 
    series of NDVI using the median.
    """
    
    NAME = "ndvi_median"
    SHORT_NAME = NAME
    VERSION = "1.0"
    PRODUCT_FAMILY = "ndvi"

    def __init__(
        self,
        input_bands: Sequence[str] = None,
        output_bands: Sequence[str] = None,
        mask_band: Sequence[str] = None,
        contiguity_band: Sequence[str] = None,
        group_by: str = "solar_day",
        **kwargs,
    ):
        
        self.input_bands = input_bands
        self.output_bands = output_bands
        self.mask_band = mask_band
        self.contiguity_band = contiguity_band
        self.group_by = group_by

        ## These params get passed to the upstream 
        #  base StatsPluginInterface class
        super().__init__(
            input_bands=tuple(input_bands)+(mask_band,)+(contiguity_band,),
            **kwargs
        )

        
    @property
    def measurements(self) -> Tuple[str, ...]:
        """
        Here we define the output bands, in this example we
        will pass the names of the output bands into the config file,
        but equally we could define the outputs names within this function.
        For example, by adding a suffix to the input bands.
        """
        
        return self.output_bands

    def native_transform(self, xx):
        """
        This function is passed to an upstream function
        called "odc.algo.io.load_with_native_transform".
        The function decribed here is applied on every time
        step of data and is usually used for things like
        masking clouds, nodata, and contiguity masking.
        """
        #grab the QA band from the Landsat data
        mask = xx[self.mask_band]

        # create boolean arrays from the mask for cloud
        # and cloud shadows, and nodata
        bad = enum_to_bool(mask, ("nodata",))
        non_contiguent = xx.get(self.contiguity_band, 1) == 0
        bad = bad | non_contiguent
        
        cloud_mask = enum_to_bool(mask, ("cloud", "shadow"))
        bad =  cloud_mask | bad

        # drop masking bands
        xx = xx.drop_vars([self.mask_band] + [self.contiguity_band])
        
        ## Mask the bad data (clouds etc)
        xx = erase_bad(xx, bad)

        return xx

    def reduce(self, xx: xr.Dataset) -> xr.Dataset:
        """
        Calculate NDVI and summarise time series with a median.
        """
        # convert to float and convert nodata to NaN so NDVI
        # isn't calculated on the nodata integer values
        xx = mask_invalid_data(xx)
        
        # Calculate NDVI
        ndvi = (xx['nbart_nir'] - xx['nbart_red']) / (xx['nbart_nir'] + xx['nbart_red'])

        # calculate temporal median NDVI. 
        # !!!!!!!!!!!!!!!!!!!!!!!!!!!! 
        # Note that we use 'spec' here and not 'time', this is an odc-stats thing
        # where the dimensions are labelled as spec, x, and y.
        # !!!!!!!!!!!!!!!!!!!!!!!!!!!!
        ndvi = ndvi.median('spec').rename(self.output_bands)
        
        return ndvi.to_dataset()

# now lets 'register' the function with odc-stats
register("NDVI-median", StatsNDVI)

### Create a config

We need to create a config to describe the input parameters for the plugin, for now, we will create a dictionary. However, when it comes time to run the function with odc-stats, this is usually stored in a .yaml file.

In [None]:
config = dict(
    input_bands=["nbart_red",  "nbart_nir"],
    output_bands = 'ndvi_median',
    mask_band="oa_fmask",
    contiguity_band='oa_nbart_contiguity'
)

### Run the plugin code

To test our plugin, we will load datasets in a way that's more familiar to those used to working on the Sandbox. These steps mimic the inputs that odc-stats expects, but we load them in a more 'conventional' way than odc-stats (more on this in the next section).

Below we define an analysis area, and set up a ODC query. Then we load a list of datasets and the geobox that describes the geographical extent of the datasets.

In [None]:
import datacube
dc = datacube.Datacube(app="odc-stats example")

In [None]:
#analysis params
lat, lon = -34.134, 140.747
buffer = 0.05
time_range = ('2024')
resolution = (-30, 30)

lat_range = (lat-buffer, lat+buffer)
lon_range = (lon-buffer, lon+buffer)

#set up query object
query = {
    'x': lon_range,
    'y': lat_range,
    'time': time_range,
    'resolution': resolution,
    'output_crs':'epsg:3577',
    'measurements':['nbart_red','nbart_nir','oa_fmask']
}

# load some data, but we'll just extract the geobox parameter
# because 'input_data' expects a geobox
gbox = dc.load(product=['ga_ls8c_ard_3'], dask_chunks={}, **query).geobox

# load a list of datatsets to mimic odc-stats cached ".db" files
dss = dc.find_datasets(product=['ga_ls8c_ard_3'], **query)

print('Number of datasets:', len(dss))

### Run the plugin functions

These are evaluated 'lazily' with dask, so will evaluate quickly. Once we run `.load()` below, the functions will be executed and the result will be brought into memory.

In [None]:
#call the function
func=StatsNDVI(**config)

# run the separate functions
ndvi = func.input_data(datasets=dss, geobox=gbox)

result = func.reduce(ndvi)
result

### Bring into memory and plot

This will take about 20 seconds to load with the default example.

In [None]:
result.load()

result['ndvi_median'].plot(vmin=0, vmax=0.7, size=5, add_labels=False)
plt.title('Annual Median NDVI');

## Section 2: Run the plugin using odc-stats



### Saving tasks
Before we can run a plugin with odc-stats, we need to extract datasets. In the sandbox (or a local machine), we would ordinarily do this by running [datacube.load](https://opendatacube.readthedocs.io/en/latest/api/indexed-data/generate/datacube.Datacube.load.html) (or sometimes `dea_tools.load_ard`). However, odc-stats works instead by caching a copy of the database to disk, thus providing a list of 'tasks' for odc-stats to run.  This is achieved with the function [odc-stats save-tasks]((https://github.com/opendatacube/odc-stats/blob/develop/odc/stats/_cli_save_tasks.py)). When run, this function will output three files:
1. A .csv file listing all tasks for all the years in the database e.g., `ga_ls8c_ard_3_2017--P1Y.csv`
2. A database cache file used by statistician when running jobs, e.g. `ga_ls8c_ard_3_2017--P1Y.db`
3. A GeoJSON file per year, for visualising the prospective run e.g. `ga_ls8c_ard_3_2017--P1Y.geojson`

The [save-tasks](https://github.com/opendatacube/odc-stats/blob/develop/odc/stats/_cli_save_tasks.py) function has a number of parameters that can be passed. Below we outline an example, and then list the main parameters.

For example:

    odc-stats save-tasks --frequency annual --grid au-extended-30 --year 2017 --input-products ga_ls8c_ard_3

This would save tasks for all the Landsat 8 satellite imagery across Australia for the year 2017, on a 30m grid. `save-tasks` is quite flexible, so we can adjust these parameters to suit the kinds of product we're building: 
* **--input-products**: If, for example, we wanted to cache datasets from both Landsat 8 and Landsat 9, we can update the input-products parameter to read `ga_ls8c_ard_3-ga_ls9c_ard_3`, or in the case of sentinel-2, this could be `ga_s2am_ard_3-ga_s2bm_ard_3-ga_s2cm_ard_3`. In other cases, we may want to 'fuse' datasets where products require data from bands stored in multiple products. Products can be fused to use bands from both products in the derivative products, this creates a virtual product that contains the bands from both products. Note that for datasets to be fused they must have the same `center_time` and `region_code`. This process finds the matching dataset from each product that are in the same time and place and fuses them into one product.  An example of this is fc-percentiles, which uses the fractional cover bands in `ga_ls_fc_3` to calculate the percentiles, and uses the `ga_ls_wo_3` band to mask out bad data. The input-products parameter in this case looks like this: `ga_ls_fc_3+ga_ls_wo_3`.

* **--frequency**: This determines the temporal binning of datasets. For example, for 3-month rolling we would use: `rolling-3months`. A list of supported values is [here](https://github.com/opendatacube/odc-stats/blob/7f34c86bdbd481340c41b5be7e0d0873ce3b3e1c/odc/stats/_cli_save_tasks.py#L24)
* **--temporal-range**: Only extract datasets for a given time range. e.g. `2020-05--P1M` for the month of May in 2020, or `2017--P1Y`, will extract one years worth of data for 2017.
* **--grid**: The spatial resolution and grid to use. For Australia this is `au_extended` plus a resolution, one of: `{10|20|30|60}`. e.g. for Sentinel-2 we would use `au_extended_10`.
* **--gqa**: Only save datasets that pass `gqa_iterative_mean_xy <= gqa` test.
* **--dataset-filter**: We can use this to filter based on metadata, for example: `{"dataset_maturity": "final"}`
* **--year**: Use this flag as a shortcut for `--temporal-range=<int>--P1Y`, it will extract tasks for a single calendar year.

***
### Running save-tasks

Let's run save-tasks in a way that mimics the datasets loading we did in Section 1.  The `!` will instruct the notebook to run this on the command line;  odc-stats is built to run through a command line interface.  You could also trigger this by wrapping the command in `os.system('odc-stats save-tasks ...')`.

Note that this command will output datasets for all tiles in Australia. In the next step we will index the list of tiles so we only run one tile for testing.  Equally, we could have passed to save-tasks `--tiles <grid-index>` and it would only export datasets for a single tile or a list of tiles. We would need to know the index of the tile though for this to work.

Remember, this will output three files: a `.db` file, a `.geosjson`, and a `.csv`

This will take about a minute to run.

In [None]:
!odc-stats save-tasks --frequency annual --grid au-extended-30 --temporal-range 2024--P1Y --input-products ga_ls8c_ard_3

### Find a tile to run

Use the interactive map below to find a "region_code" to run (hover over a tile). Add the region_code numbers (e.g. `t = 36,17` if the region_code is 'x36y17') to the cell below the interactive map.

In [None]:
gdf  = gpd.read_file('~/gdata1/projects/s2_gm/ga_ls8c_ard_3_2024--P1Y-2024--P1Y.geojson')

gdf.explore()

### Running ODC-Statistician

First, set up a few parameters

In [None]:
t = 36,17  # tile id to run i.e. x36y17
resolution = 60 # can coarsen resolution to run to speed up testing
results = '/gdata1/projects/s2_gm/results/' # where are we outputting resulting geotiffs? This could equally be an s3 path.
name, version = 'ndvi_ls_median', '0-0-1' # product name and version (appended to results path)

# Dask client parameters
ncpus=7 #how many cpus to run on?
mem='60Gi' # How much memory?

#### Find the tile index to run

We selected a region code, but the odc-stats "run" command expects a zero-based index to determine which tile to run.  Below we open the cached task database and find the index of the tile we want run.  We'll pass this index to odc-stats next.

In [None]:
## Open the task database to find our tile, we need to create the OutputProduct class
#  to open the taskdb but it doesn't do anything.
op = OutputProduct(
            name=name,
            version=version,
            short_name=name,
            location=f"s3://dummy-bucket/{name}/{version}", #this is a fake path
            properties={"odc:file_format": "GeoTIFF"},
            measurements=['nbart_red'], #any measurements, doesn't matter.
        )

taskdb = TaskReader(f'ga_ls8c_ard_3_2024--P1Y.db', product=op)

#select our individual task i.e. our tile
task = taskdb.load_task((f'2024--P1Y', t[0], t[1]))

# Now find index of the tile we want to run
# We'll pass this index to odc-stats next
tile_index_to_run = []
all_tiles = list(taskdb.all_tiles)
for i, index in zip(all_tiles, range(0, len(all_tiles))):
    if (i[1]==t[0]) & (i[2]==t[1]):
        tile_index_to_run.append(index)
        print('Tile index =', index)

#### Optionally view tile to check location

The next cell will plot the tile extent on an interactive map so you can ensure its the tile you want to run.

In [None]:
gdf = gpd.GeoDataFrame(index=[0], crs='epsg:4326', geometry=[task.geobox.extent.to_crs('epsg:4326').geom])
gdf.explore()

#### Running the plugin using odc-stats

This is where it get's a little complicated. In order for `odc-stats` to 'see' our plugin, we need to put our plugin within an **installable python module** (note that [plugins](https://github.com/opendatacube/odc-stats/tree/develop/odc/stats/plugins) that are already within the odc-stats repository are available to use by default).  Similarly, we need to put our **configuration parameters into an external .yaml file**.  

This has been done (its called `config_ndvi_ls_mean.yaml`), and below we open the config yaml to view its contents:

The important parts are:
* `plugin`: an import path to the installed python plugin
* `plugin_config`: The parameters names and values that are passed to the plugin
* `product`: These are key metadata fields for the product you're building

The other parameters relate to batch runs on kubernetes, and the exported COG attributes such as compression levels etc.  These fields can usually be copied over from other similar product configs, such as this example for [Landsat geomedian](https://github.com/GeoscienceAustralia/dea-config/blob/3953ea18eee702a41867458c720e7480bd785c10/prod/services/odc-stats/geomedian/ga_ls8cls9c_gm_cyear_3.yaml). 

In [None]:
yaml_path= 'config_ndvi_ls_median.yaml'

with open(yaml_path, 'r') as file:
    data = yaml.safe_load(file)

pprint(data)

#### Install our python plugin

Describing how to install a python package is beyond the scope of this notebook. A simple example of setting up an installable module on a local machine is [here](https://github.com/digitalearthafrica/crop-mask/tree/main/production/cm_tools).

In [None]:
# !pip install s2_gm_tools/

#### Run odc-stats

We will use `os.system('odc-stats run ...')` to call the command 'odc-stats run' so we can pass in variables defined earlier as python objects.

"odc-stats run" has a number of parameters, some of which are described here. The key information to pass in is the name of the cached database file, the location of the config file, the output location, and the tile index to run.

* **--filedb**: The name of the .db output by save-task e.g. "ga_ls8c_ard_3_2024--P1Y.db"
* **--config**: Path to the config for plugin in yaml format
* **--location**: Output location prefix as a uri or local file path: `s3://bucket/path/`
* **--resolution**: Override output resolution, use this to speed up testing, e.g. '60'
* **--threads**: Number of worker threads for the dask cluster, as an integer
* **--memory-limit**: Limit memory used by Dask cluster, e.g. '100Gi'

To view the progress of odc-stats, view the **dask dashboard**. Alter the email address to yours and use this link: 

https://app.sandbox.dea.ga.gov.au/user/chad.burton@ga.gov.au/proxy/8787/status

In [None]:
%%time
os.system("odc-stats run "\
          f"--filedb=ga_ls8c_ard_3_2024--P1Y.db "\
          f"--config={yaml_path} "\
          f"--resolution={resolution} "\
          f"--threads={ncpus} "\
          f"--memory-limit={mem} "\
          f"--location=file:///home/jovyan/{results}{name}/{version} " +str(tile_index_to_run[0])
         )

### Plot the results

In [None]:
x= f'x{t[0]}'
y= f'y{t[1]}'

ndvi_path = f'{results}{name}/{version}/{x}/{y}/2024--P1Y/{name}_{x}{y}_2024--P1Y_final_ndvi_median.tif'
ndvi=rxr.open_rasterio(ndvi_path).squeeze().drop_vars('band')
ndvi=assign_crs(ndvi, crs='EPSG:3577')

ndvi.plot(vmin=0, vmax=0.7, size=8, add_labels=False)
plt.title('Annual Median NDVI');

## Remove all files

In [None]:
# !rm -r -f results/ndvi_ls_median/