# Annual Sentinel-2 Geomedian run with odc-stats

Useful links:
* [odc-stats](https://github.com/opendatacube/odc-stats)
* [crop-mask plugin](https://github.com/digitalearthafrica/crop-mask/blob/main/production/cm_tools/cm_tools/gm_ml_pred.py)
* [odc-algo geomedians](https://github.com/opendatacube/odc-algo/blob/main/odc/algo/_geomedian.py#L337)
* [example geomedian config files](https://github.com/GeoscienceAustralia/dea-config/tree/09fa937a9c79e3505e85d2364a30bc002ca0c5f3/dev/services/odc-stats/geomedian)
* DEA-config for other [geomedians runs](https://github.com/GeoscienceAustralia/dea-config/tree/09fa937a9c79e3505e85d2364a30bc002ca0c5f3/dev/services/odc-stats/geomedian)

In [None]:
!pip uninstall s2_gm_tools -y
!pip install s2_gm_tools/

In [None]:
import os
import json
import warnings
import xarray as xr
import rioxarray as rxr
import geopandas as gpd
import matplotlib.pyplot as plt
from odc.geo.xr import assign_crs
from odc.stats.tasks import TaskReader
from odc.stats.model import OutputProduct

warnings.filterwarnings("ignore")

## Analysis Parameters

Some tile ids to run
* 'x43y14' # se aus forests Alps.
* 'x39y09' # West tassie
* 'x33y26' # Central Aus with salt lakes
* 'x31y43' # Tropical NT
* 'x19y18' # Esperance crops and sand dunes
* 'x42y38' # Qld tropical forests
* 'x39y13' # Melbourne city and bay+crops
* 'x12y19' # Perth City
* 'x41y12' # Complex coastal in Vic.

In [None]:
# tiles = ['x30y34','x36y52','x61y30','x58y22','x57y28', 'x61y29', 'x64y32', 'x65y40', 'x60y53' ,'x55y51', 'x46y58', 'x46y46', 'x36y34']
# gdf = gpd.read_file('~/gdata1/projects/s2_gm/testing_tile_suite.geojson')

# gdf = gdf[gdf['region_code'].isin(tiles)]
# gdf.reset_index(drop=True).to_file('~/gdata1/projects/s2_gm/testing_tile_suite_13tiles.geojson')

In [None]:
year='2022' 
t = 19,18  # tile id to run i.e. x19y18
resolution = 30 # can coarsen resolution to run to speed up testing
products='ga_s2am_ard_3-ga_s2bm_ard_3-ga_s2cm_ard_3' # use all S2 observations
name, version = 'ga_s2_gm_cyear_3', '0-0-1' #product name and version
results = '/gdata1/projects/s2_gm/results/' #where are we outputting results?
ncpus=30
mem='220Gi'

## Save tasks database etc.

In [None]:
os.system("odc-stats save-tasks "\
          "--grid au-10 "\
          f"--year {year} "\
          f"--input-products {products}"
         )

## Find the tile ID to run

We'll pass this index to odc-stats next to tell it to run this tile

In [None]:
## Open the task database to find out tiles
op = OutputProduct(
            name=name,
            version=version,
            short_name=name,
            location=f"s3://dummy-bucket/{name}/{version}",
            properties={"odc:file_format": "GeoTIFF"},
            measurements=['nbart_red'],
        )

taskdb = TaskReader(f'{products}_{year}--P1Y.db', product=op)
task = taskdb.load_task((f'{year}--P1Y', t[0], t[1]))

# Now find index of the tile we want to run
# We'll pass this index to odc-stats next to tell it to run this tile
tile_index_to_run = []
all_tiles = list(taskdb.all_tiles)
for i, index in zip(all_tiles, range(0, len(all_tiles))):
    if (i[1]==t[0]) & (i[2]==t[1]):
        tile_index_to_run.append(index)
        print(index)

### Optionally view tile to check location

The next cell will plot the tile extent on an interactive map so you can ensure its the tile you want to run.

In [None]:
# with open('task_tile_check.geojson', 'w') as fh:
#     json.dump(task.geobox.extent.to_crs('epsg:4326').json, fh, indent=2)

gdf = gpd.GeoDataFrame(index=[0], crs='epsg:4326', geometry=[task.geobox.extent.to_crs('epsg:4326').geom])
gdf.explore()

## Run the geomedian algo using odc-stats

Put this link into the dask dashboard to view the progress, altering the email address to yours: https://app.sandbox.dea.ga.gov.au/user/chad.burton@ga.gov.au/proxy/8787/status

In [None]:
!pip uninstall s2_gm_tools -y
!pip install s2_gm_tools/

In [None]:
%%time
os.system("odc-stats run "\
          f"{products}_{year}--P1Y.db "\
          "--config=s2_gm_tools/s2_gm_tools/config/config_gm_s2_annual_s2Cloudless_enhanced.yaml "\
          f"--resolution={resolution} "\
          f"--threads={ncpus} "\
          f"--memory-limit={mem} "\
          f"--location=file:///home/jovyan/{results}{name}/{version} " +str(tile_index_to_run[0])
         )

## Plot the RGBA output

In [None]:
# t = 3,19  # tile id
name, version = 'ga_s2_gm_cyear_3', '0-0-1'
results = '/gdata1/projects/s2_gm/results/'

In [None]:
x= f'x{t[0]}'
y= f'y{t[1]}'

path = f'{results}{name}/{version}/{x}/{y}/{year}--P1Y/{name}_{x}{y}_{year}--P1Y_final_rgba.tif'
rgba=rxr.open_rasterio(path)
rgba=assign_crs(rgba, crs='EPSG:3577')

rgba.plot.imshow(size=10);
plt.title(x+y);

## Interactively explore results

In [None]:
red_path = f'{results}{name}/{version}/{x}/{y}/{year}--P1Y/{name}_{x}{y}_{year}--P1Y_final_nbart_red.tif'
green_path = f'{results}{name}/{version}/{x}/{y}/{year}--P1Y/{name}_{x}{y}_{year}--P1Y_final_nbart_green.tif'
blue_path = f'{results}{name}/{version}/{x}/{y}/{year}--P1Y/{name}_{x}{y}_{year}--P1Y_final_nbart_blue.tif'
count_path = f'{results}{name}/{version}/{x}/{y}/{year}--P1Y/{name}_{x}{y}_{year}--P1Y_final_count.tif'

r=assign_crs(rxr.open_rasterio(red_path).squeeze().drop_vars('band'),crs='EPSG:3577')
g=assign_crs(rxr.open_rasterio(green_path).squeeze().drop_vars('band'),crs='EPSG:3577')
b=assign_crs(rxr.open_rasterio(blue_path).squeeze().drop_vars('band'),crs='EPSG:3577')

r = r.rename('nbart_red')
g = g.rename('nbart_green')
b = b.rename('nbart_blue')

ds = assign_crs(xr.merge([r,g,b]), crs='EPSG:3577')

In [None]:
count=assign_crs(rxr.open_rasterio(count_path).squeeze().drop_vars('band'),crs='EPSG:3577')

mean_clear_updated = count.mean().item()
min_clear_updated = count.min().item()
max_clear_updated = count.max().item()

print(f'Updated masking clear counts (min, mean, max) = {min_clear_updated}, {mean_clear_updated:.0f}, {max_clear_updated}')

count.plot(cmap='magma', vmin=10, vmax=90, size=8);

In [None]:
vmin, vmax = ds[['nbart_red', 'nbart_green', 'nbart_blue']].to_array().quantile((0.01, 0.99)).values

ds.odc.explore(vmin=vmin,
    vmax=vmax,
    # tiles = 'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}',
    # attr = 'Esri',
    # name = 'Esri Satellite'
              )

## Remove all files

In [None]:
!rm -r -f results/ga_s2_gm_cyear_3/