# Testing development of annual Sentinel-2 Geomedian

* [Seasonal geomedian report](https://geoscienceau-my.sharepoint.com/:w:/g/personal/james_miller_ga_gov_au/EVFJCVoYGdRPlpG_MCh95DIB3n5Earq-n9766FWj4lFBZw?e=ihzQer)
* [Cloud masking report](https://geoscienceau.sharepoint.com/:w:/r/sites/DEA/Shared%20Documents/Projects%20%26%20Programs/DEA%20Land%20Cover/Cloud%20Masking%20Validation/Cloud%20Masking%20Technical%20Report%20-%20Public.docx?d=w1c5e6e3d0f664d35ada2eba5b7cce187&csf=1&web=1&e=ieSD8S)
* [DE Africa odc-plugin](https://github.com/opendatacube/odc-stats/blob/37e41140515ea1a5f7033b2144b22d8f43a231f6/odc/stats/plugins/gm.py#L131), used this mask filter: `[("opening", 2), ("dilation", 5)]`
* [odc-algo geomedian functions](https://github.com/opendatacube/odc-algo/blob/main/odc/algo/_geomedian.py#L337)

## Import libraries

In [None]:
import os
import warnings
import xarray as xr
import numpy as np
from odc.geo.xr import assign_crs
import matplotlib.pyplot as plt

import sys
sys.path.insert(1, '/home/jovyan/git/dea-notebooks/Tools/')
from dea_tools.plotting import rgb

warnings.filterwarnings("ignore")

## Analysis Parameters

In [None]:
# region_code = 'x12y19' # Perth City
# region_code = 'x43y14' # se aus forests Alps.
# region_code = 'x39y09' # West tassie
# region_code = 'x33y26' # Central Aus with salt lakes
# region_code = 'x31y43' # Tropical NT
# region_code = 'x19y18' # Esperance crops and sand dunes
# region_code = 'x42y38' # Qld tropical forests
# region_code = 'x39y13' # Melbourne city and bay+crops
# region_code = 'x41y12' # Complex coastal in Vic.

region_code = 'x42y38'
morph_params = '3_6'
time='2022'


## Load tile

In [None]:
s2cloudless = assign_crs(xr.open_dataset(f'/gdata1/projects/s2_gm/results/tiles/{morph_params}/s2_gm_annual_s2cloudless_{time}_{region_code}.nc'), crs='EPSG:3577')
fmask = assign_crs(xr.open_dataset(f'/gdata1/projects/s2_gm/results/tiles/{morph_params}/s2_gm_annual_fmask_{time}_{region_code}.nc'), crs='EPSG:3577')

## Check for NaNs in outputs

In [None]:
for b in ['nbart_red', 'nbart_green', 'nbart_blue']:
    num_of_nans = np.sum(np.isnan(fmask[b]))
    if num_of_nans>0:
        print(f'{num_of_nans.item()} NaNs present in fmask {b}')
    else:
        print(f'fmask {b} is clean')

for b in ['nbart_red', 'nbart_green', 'nbart_blue']:
    num_of_nans = np.sum(np.isnan(s2cloudless[b]))
    if num_of_nans>0:
        print(f'{num_of_nans.item()} NaNs present in s2Cloudless {b}')
    else:
        print(f's2Cloudless {b} is clean')

## True colour and count plots

In [None]:
mean_clear_fmask = fmask['count'].mean().item()
mean_clear_s2cloudless = s2cloudless['count'].mean().item()

min_clear_fmask = fmask['count'].min().item()
min_clear_s2cloudless = s2cloudless['count'].min().item()

max_clear_fmask = fmask['count'].max().item()
max_clear_s2cloudless = s2cloudless['count'].max().item()

print(f'FMASK (min, mean, max) = {min_clear_fmask}, {mean_clear_fmask:.0f}, {max_clear_fmask}')
print(f'S2Cloudless (min, mean, max) = {min_clear_s2cloudless}, {mean_clear_s2cloudless:.0f}, {max_clear_s2cloudless}')

### Replace NaNs with vibrant pink colour

In [None]:
# # Create a boolean mask where NaNs. Can do this on a single layer 
#  because we masked for contiguity
nan_mask = np.isnan(fmask['nbart_red'])

# Loop over each band and assign the pink value where the pixel is NaN
#  10000 blue, 10000 red, 0 for green.
for var in fmask.data_vars:
    if var=='nbart_red':
        fmask[var] = xr.where(nan_mask, 10000, fmask[var])
    if var=='nbart_blue':
        fmask[var] = xr.where(nan_mask, 10000, fmask[var])
    if var=='nbart_green':
        fmask[var] = xr.where(nan_mask, 0, fmask[var])
    
# Same again but now for S2Cloudless
nan_mask = np.isnan(s2cloudless['nbart_red'])

for var in s2cloudless.data_vars:
    if var=='nbart_red':
        s2cloudless[var] = xr.where(nan_mask, 10000, s2cloudless[var])
    if var=='nbart_blue':
        s2cloudless[var] = xr.where(nan_mask, 10000, s2cloudless[var])
    if var=='nbart_green':
        s2cloudless[var] = xr.where(nan_mask, 0, s2cloudless[var])

### Plot

In [None]:
fig,ax = plt.subplots(2,2, figsize=(15,12), layout='constrained')
vmin, vmax=10, 90

#--------fmask------------------------------
fmask[['nbart_red', 'nbart_green', 'nbart_blue']].to_array().plot.imshow(robust=True, ax=ax[0,0], add_labels=False);
fmask['count'].plot.imshow(vmin=vmin, vmax=vmax, cmap='magma', ax=ax[0,1], add_labels=False);

ax[0,0].set_title(f'fmask, morph-params={morph_params}')
ax[0,1].set_title(f'fmask, clear count. Mean={mean_clear_fmask:.1f}')

ax[0,0].set_yticklabels([])
ax[0,0].set_xticklabels([])
ax[0,1].set_yticklabels([])
ax[0,1].set_xticklabels([]);

#--------S2cloudless------------------------------
s2cloudless[['nbart_red', 'nbart_green', 'nbart_blue']].to_array().plot.imshow(robust=True, ax=ax[1,0], add_labels=False);
s2cloudless['count'].plot.imshow(vmin=vmin, vmax=vmax, cmap='magma', ax=ax[1,1], add_labels=False);

ax[1,0].set_title(f's2cloudless, default settings')
ax[1,1].set_title(f's2cloudless, clear count. Mean={mean_clear_s2cloudless:.1f}')

ax[1,0].set_yticklabels([])
ax[1,0].set_xticklabels([])
ax[1,1].set_yticklabels([])
ax[1,1].set_xticklabels([]);

plt.savefig(f'/gdata1/projects/s2_gm/results/processed_figs/s2_gm_annual_{morph_params}_{time}_{region_code}.png', bbox_inches='tight', dpi=300);

## Interactive plots

In [None]:
# vmin, vmax=10, 90

# s2cloudless['count'].odc.explore(vmin=vmin, vmax=vmax, cmap='magma',
#     tiles = 'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}',
#     attr = 'Esri',
#     name = 'Esri Satellite',
# )

vmin, vmax=10, 90
fmask['count'].odc.explore(vmin=vmin, vmax=vmax, cmap='magma',
    tiles = 'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}',
    attr = 'Esri',
    name = 'Esri Satellite',
)

In [None]:
# vmin, vmax = fmask[['nbart_red', 'nbart_green', 'nbart_blue']].to_array().quantile((0.01, 0.99)).values

# fmask.odc.explore(
#     vmin=vmin,
#     vmax=vmax,
#     tiles = 'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}',
#     attr = 'Esri',
#     name = 'Esri Satellite',
# )

vmin, vmax = s2cloudless[['nbart_red', 'nbart_green', 'nbart_blue']].to_array().quantile((0.01, 0.99)).values

s2cloudless.odc.explore(
    vmin=vmin,
    vmax=vmax,
    tiles = 'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}',
    attr = 'Esri',
    name = 'Esri Satellite',
)