# Testing development of annual Sentinel-2 Geomedian

* [Seasonal geomedian report](https://geoscienceau-my.sharepoint.com/:w:/g/personal/james_miller_ga_gov_au/EVFJCVoYGdRPlpG_MCh95DIB3n5Earq-n9766FWj4lFBZw?e=ihzQer)
* [Cloud masking report](https://geoscienceau.sharepoint.com/:w:/r/sites/DEA/Shared%20Documents/Projects%20%26%20Programs/DEA%20Land%20Cover/Cloud%20Masking%20Validation/Cloud%20Masking%20Technical%20Report%20-%20Public.docx?d=w1c5e6e3d0f664d35ada2eba5b7cce187&csf=1&web=1&e=ieSD8S)
* [DE Africa odc-plugin](https://github.com/opendatacube/odc-stats/blob/37e41140515ea1a5f7033b2144b22d8f43a231f6/odc/stats/plugins/gm.py#L131), used this mask filter: `[("opening", 2), ("dilation", 5)]`
* [odc-algo geomedian functions](https://github.com/opendatacube/odc-algo/blob/main/odc/algo/_geomedian.py#L337)

Notes on gms from coastal:
* They run everything at 10m resolution, with cubic resampling of 20m bands.
* Use S2Cloudless with no additional filters

## Import libraries

In [None]:
import os
import warnings
import datacube
import xarray as xr
import geopandas as gpd
from odc.geo.xr import assign_crs
import matplotlib.pyplot as plt
from odc.geo.geom import Geometry
from odc.algo import geomedian_with_mads

import sys
sys.path.insert(1, '/home/jovyan/git/dea-notebooks/Tools/')
from dea_tools.datahandling import load_ard
from dea_tools.dask import create_local_dask_cluster
from dea_tools.plotting import rgb

warnings.filterwarnings("ignore")

In [None]:
client = create_local_dask_cluster(return_client=True)

## Analysis Parameters

In [None]:
# region_code = 'x43y14' # se aus forests Alps.
# region_code = 'x39y09' # West tassie
# region_code = 'x33y26' # Central Aus with salt lakes
# region_code = 'x31y43' # Tropical NT
# region_code = 'x19y18' # Esperance crops and sand dunes
# region_code = 'x42y38' # Qld tropical forests
# region_code = 'x39y13' # Melbourne city and bay+crops
# region_code = 'x12y19' # Perth City
# region_code = 'x41y12' # Complex coastal in Vic.

region_codes = ['x43y14','x39y09','x33y26','x31y43','x19y18','x42y38','x39y13','x12y19','x41y12']

time='2022'
resolution=(-10,10)
mask_filters = [("opening", 3), ("dilation", 6)]
filters_id = '3_6'
measurements=['nbart_green', 'nbart_red', 'nbart_blue']
dask_chunks = dict(x=1000, y=1000)

## Set up dc query

In [None]:
#connect to dc
dc = datacube.Datacube(app='s2_gm_test')

# Create a reusable query
query = {
    'time': time,
    "measurements": measurements,
    'resolution': resolution,
    'dask_chunks' : dask_chunks,
    'group_by': 'solar_day',
    'output_crs': 'EPSG:3577',
}

## Open tiles and select

In [None]:
gdf = gpd.read_file('~/gdata1/data/albers_grids/ga_summary_grid_c3.geojson')

gdf = gdf[gdf['region_code'].isin(region_codes)]

In [None]:
# gdf.explore(
#         tiles = 'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}',
#         attr = 'Esri',
#         name = 'Esri Satellite',
# )

## Run function for geomedians

In [None]:
def gm_run_function(dc, geom, query, time, region_code, filters_id):
    """
    
    """
        
    # Update the query to include our geopolygon
    query.update({'geopolygon': geom}) 

    # Load available data from Sentinel 2 satellites
    s2_fmask = load_ard(dc=dc,
                     products=['ga_s2am_ard_3', 'ga_s2bm_ard_3', 'ga_s2cm_ard_3'],
                     resampling={"oa_fmask": "nearest", "*": "cubic"},
                     mask_filters=mask_filters,
                     verbose=False,
                     mask_contiguity=True,
                     skip_broken_datasets=True,
                     **query
                    )
    
    s2_cloudless = load_ard(dc=dc,
                 products=['ga_s2am_ard_3', 'ga_s2bm_ard_3', 'ga_s2cm_ard_3'],
                 cloud_mask='s2cloudless',
                 resampling={"oa_s2cloudless_mask": "nearest", "*": "cubic"},
                 verbose=False,
                 mask_contiguity=True,
                 skip_broken_datasets=True,
                 **query
                )

    # s2_cloudless = s2_cloudless.isel(time=range(0,10))
    # s2_fmask = s2_fmask.isel(time=range(0,10))
    
    # ---geomedians----------------------
    # ---------------Fmask-
    s2_fmask_gm = geomedian_with_mads(
        s2_fmask,
        reshape_strategy='mem',
        compute_mads=False
    )
    s2_fmask_gm = assign_crs(s2_fmask_gm.load(), crs='EPSG:4326')
    
    #export
    tiles_folder = f'/gdata1/projects/s2_gm/results/tiles/{filters_id}/'
    if not os.path.exists(tiles_folder):
        os.makedirs(tiles_folder)
        
    for var in s2_fmask_gm.data_vars:
        try:
            del s2_fmask_gm[var].attrs['grid_mapping']
        except:
            pass
    s2_fmask_gm[['nbart_red', 'nbart_green', 'nbart_blue','count']].to_netcdf(f'/gdata1/projects/s2_gm/results/tiles/{filters_id}/s2_gm_annual_fmask_{time}_{region_code}.nc')

     # ---------------S2 Cloudless-----
    s2_cloudless_gm = geomedian_with_mads(
        s2_cloudless,
        reshape_strategy='mem',
        compute_mads=False
    )
    s2_cloudless_gm = assign_crs(s2_cloudless_gm.load(), crs='EPSG:4326')
    
    for var in s2_cloudless_gm.data_vars:
        try:
            del s2_cloudless_gm[var].attrs['grid_mapping']
        except:
            pass
    s2_cloudless_gm[['nbart_red', 'nbart_green', 'nbart_blue', 'count']].to_netcdf(f'/gdata1/projects/s2_gm/results/tiles/{filters_id}/s2_gm_annual_s2cloudless_{time}_{region_code}.nc')
    

    #------------plot-----------------------
    fig, ax = plt.subplots(1,3, figsize=(20,6), layout='constrained')
    
    s2_cloudless_gm[['nbart_red', 'nbart_green', 'nbart_blue']].to_array().plot.imshow(robust=True, ax=ax[0], add_labels=False)
    s2_fmask_gm[['nbart_red', 'nbart_green', 'nbart_blue']].to_array().plot.imshow(robust=True, ax=ax[1], add_labels=False)
    s2_fmask_gm['count'].plot.imshow(robust=True, ax=ax[2], add_labels=False)

    ax[0].set_title(f'S2 annual GM, s2cloudless, {time}, {region_code}')
    ax[1].set_title(f'S2 annual GM, fmask, {time}, {region_code}')
    ax[2].set_title(f'fmask count, {time}, {region_code}')
    
    ax[0].set_yticklabels([])
    ax[0].set_xticklabels([])
    ax[1].set_yticklabels([])
    ax[1].set_xticklabels([])
    ax[2].set_yticklabels([])
    ax[2].set_xticklabels([]);

    figs_folder = f'/gdata1/projects/s2_gm/results/figs/{filters_id}/'
    if not os.path.exists(figs_folder):
        os.makedirs(figs_folder)
    
    plt.savefig(f'/gdata1/projects/s2_gm/results/figs/{filters_id}/s2_gm_annual_{time}_{region_code}.png', bbox_inches='tight', dpi=300)

## Loop through tiles and export images


In [None]:
# Loop through polygons in geodataframe and extract satellite data
i=0
for index, row in gdf.iterrows():
    print(f'Feature: {i + 1}/{len(gdf)}')

    geom = Geometry(geom=row.geometry, crs=gdf.crs)
    
    gm_run_function(dc, geom, query, time, row['region_code'], filters_id)
    i+=1