# Can we improve cloud masking?

For example:
* Thresholding blue bands?
* Utilising S2Cloudless probability layer?

## Import libraries

In [None]:
import os
import warnings
import datacube
import numpy as np
import xarray as xr
import geopandas as gpd
from odc.geo.xr import assign_crs
import matplotlib.pyplot as plt
from odc.algo import mask_cleanup
from odc.geo.geom import Geometry
from odc.algo import xr_quantile
from odc.algo._masking import mask_cleanup
from odc.algo import geomedian_with_mads

import sys
sys.path.insert(1, '/home/jovyan/git/dea-notebooks/Tools/')
from dea_tools.datahandling import load_ard
from dea_tools.dask import create_local_dask_cluster
from dea_tools.plotting import rgb

warnings.filterwarnings("ignore")

In [None]:
client = create_local_dask_cluster(return_client=True)

## Analysis Parameters

Locations for testing:

* 'x33y26' # Central Aus with salt lakes
* 'x19y18' # Esperance crops and sand dunes
* 'x42y38' # Qld tropical forests
* 'x41y12' # Complex coastal in Vic.
* 'x39y13' # Melbourne city and bay+crops
* 'x39y09' # West tassie
* 'x40y07' # southwest tassie
* 'x33y26' # Central Aus with salt lakes
* 'x31y43' # Tropical NT

In [None]:
region_code = ['x42y38']

time='2022'
resolution=(-30,30)
dask_chunks = dict(x=1024, y=1024)

s2cloudless_threshold = 0.4
cp_threshold = 0.1
mask_filters = [("opening", 2), ("dilation", 3)]

## Set up dc query

In [None]:
#connect to dc
dc = datacube.Datacube(app='s2_gm_test')

# Create a reusable query
query = {
    'time': time,
    'resolution': resolution,
    'dask_chunks' : dask_chunks,
    'group_by': 'solar_day',
    'output_crs': 'EPSG:3577',
}

## Open tiles and select

In [None]:
gdf = gpd.read_file('~/gdata1/data/albers_grids/ga_summary_grid_c3.geojson')

gdf = gdf[gdf['region_code'].isin(region_code)]

geom = Geometry(geom=gdf.iloc[0].geometry, crs=gdf.crs)

query.update({'geopolygon': geom})

In [None]:
# gdf.explore(
#         tiles = 'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}',
#         attr = 'Esri',
#         name = 'Esri Satellite',
# )

## Load long-term cloud-probability

In [None]:
cp_probs = load_ard(dc=dc,
             products=['ga_s2am_ard_3', 'ga_s2bm_ard_3', 'ga_s2cm_ard_3'],
             measurements=['oa_s2cloudless_prob'],
             time=('2020','2025'),
             resolution=resolution,
             geopolygon=geom,
             dask_chunks=dask_chunks,
             group_by='solar_day',
             output_crs='EPSG:3577',
             cloud_mask='s2cloudless',
             resampling="cubic",
             verbose=False,
             mask_pixel_quality=False,
             mask_contiguity=True,
             skip_broken_datasets=True,
            )


## Compute quantiles

In [None]:
%%time
prob_quantiles = xr_quantile(cp_probs[['oa_s2cloudless_prob']].chunk(dict(time=-1)), quantiles=[0.1], nodata=np.nan).compute()

In [None]:
fig,ax = plt.subplots(1,2, figsize=(14,6), layout='constrained', sharey=True)

# for q,ax in zip(prob_quantiles['quantile'].values, axes.ravel()):
prob_quantiles['oa_s2cloudless_prob'].sel(quantile=0.1).plot.imshow(ax=ax[0], vmin=0, vmax=0.4, add_labels=False)
ax[0].set_title(f'2020-2025 cloud probability quantile=0.1');

(prob_quantiles['oa_s2cloudless_prob'].sel(quantile=0.1) > 0.1).plot.imshow(ax=ax[1], add_labels=False)
ax[1].set_title(f'quantile=0.1 > 0.1');


In [None]:
# prob_quantiles['oa_s2cloudless_prob'].sel(quantile=0.1).odc.explore(vmin=0, vmax=0.1,
#     tiles = 'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}',
#     attr = 'Esri',
#     name = 'Esri Satellite'
# )

## Load SR data

load masked and unmasked so we can compare the enhanced cloud mask with the standard mask

In [None]:
s2_masked = load_ard(dc=dc,
             products=['ga_s2am_ard_3', 'ga_s2bm_ard_3', 'ga_s2cm_ard_3'],
             measurements=['nbart_green', 'nbart_red', 'nbart_blue'],
             cloud_mask='s2cloudless',
             resampling={"oa_s2cloudless_mask": "nearest", "*": "cubic"},
             verbose=False,
             mask_pixel_quality=True,
             mask_contiguity=True,
             skip_broken_datasets=True,
             **query
            )

s2_unmasked = load_ard(dc=dc,
             products=['ga_s2am_ard_3', 'ga_s2bm_ard_3', 'ga_s2cm_ard_3'],
             measurements=['nbart_green', 'nbart_red', 'nbart_blue', 'oa_s2cloudless_prob'],
             cloud_mask='s2cloudless',
             resampling={"*": "cubic"},
             verbose=False,
             mask_pixel_quality=False,
             mask_contiguity=True,
             skip_broken_datasets=True,
             **query
            )

## Percentiles of cloud probability

* Calculate an all-of-time 10% quantile of cloud probabilities to identify pixels that consistently overclassify cloud - e.g. bright beaches, urban), and only consider a pixel as cloud if it is X% cloudier than that threshold)
* Default threshold is 0.4, aim to increase this where long-term lower quantile is near this threshold

In [None]:
# prob_quantiles = xr_quantile(s2_unmasked[['oa_s2cloudless_prob']].chunk(dict(time=-1)), quantiles=[0.05, 0.1], nodata=np.nan).compute()

# fig,axes = plt.subplots(1,2, figsize=(14,6), layout='constrained', sharey=True)

# for q,ax in zip(prob_quantiles['quantile'].values, axes.ravel()):
#     prob_quantiles['oa_s2cloudless_prob'].sel(quantile=q).plot.imshow(ax=ax, vmin=0)
#     ax.set_title(q)

In [None]:
# fig,ax = plt.subplots(1,4, figsize=(20,6), layout='constrained', sharey=True)
# prob_quantiles['oa_s2cloudless_prob'].sel(quantile=0.1).plot.imshow(vmin=0, vmax=0.4, ax=ax[0])
# ax[1].set_title('10th percentile probabilities')

# xr.where(prob_quantiles['oa_s2cloudless_prob'].sel(quantile=0.1)>=0.1, 1, 0).plot.imshow(ax=ax[1])
# ax[1].set_title('10th percentile, threshold=0.1')

# xr.where(prob_quantiles['oa_s2cloudless_prob'].sel(quantile=0.1)>=0.2, 1, 0).plot.imshow(ax=ax[2])
# ax[2].set_title('10th percentile, threshold=0.2')

# xr.where(prob_quantiles['oa_s2cloudless_prob'].sel(quantile=0.1)>=0.4, 1, 0).plot.imshow(ax=ax[3])
# ax[3].set_title('10th percentile, threshold=0.4');

In [None]:
# xr.where(prob_quantiles['oa_s2cloudless_prob'].sel(quantile=0.1)>0.4, 1, 0).odc.explore()

## Enhanced cloud masking with S2Cloudless probability

Logics...`Consider combining the options?`

`Option 1 (highly conservative)`
1. Load long-term cloud probabilities (say five years), take the 10th percentile.
2. Where 10th percentile CP is greater than or equal to default probability threshold (0.4), double threshold before its counted as 'cloud'
3. Generate two separate cloud masks.
   * Where 10th percentile is >= `cp_10th_percentile_threshold`, double probability threshold before declaring as cloud.
   * Where 10th percentile is < `cp_10th_percentile_threshold`, use default S2Cloudless probability threshold (0.4)
5. Calculate geomedians and compare.


`Option 2 (less conservative, easier to implement)`
1. Load long-term cloud probabilities (say five years), take the 10th percentile.
2. Add 0.4 (the default threshold) to the long-term percentiles and this is the new cloud-probability threshold.
3. In regions where bright targets don't confuse S2cloudless, the threshold will still be 0.4 (or close to 0.4), in regions where the targets are commonly confused, the threshold will be substantially higher.
4. Clip the maximum threshold to 0.95

`Option 3: Synthesis of option 1 and 2`
1. Load long-term cloud probabilities (say five years), take the 10th percentile.
2. Where 10th percentile CP is greater than or equal to 0.1, Add 0.4 (the default threshold) to the long-term percentiles and this is the new cloud-probability threshold
3. Clip the maximum threshold to 0.95

Morphological filtering:
* Using odc `mask_filters` approach for now rather than s2cloudless's `opencv` approach for simplicity. s2cloudless approach applies smoothing to the probabilities layer first, and since we're seperating out probabilities for two different conditions, it gets awkward.

https://github.com/sentinel-hub/sentinel2-cloud-detector/blob/711d86176416b5afc2960963777407062adf852f/s2cloudless/cloud_detector.py#L127

<!-- # import cv2

# average_over = 4
# dilation_size = 2

# def cv2_disk(radius: int) -> np.ndarray:
#     """Recreates the disk structural element from skimage.morphology using OpenCV."""
#     return cv2.circle(  # type: ignore[call-overload]
#         np.zeros((radius * 2 + 1, radius * 2 + 1), dtype=np.uint8), (radius, radius), radius, color=1, thickness=-1
#     )

# disk = cv2_disk(average_over)
# conv_filter = disk / np.sum(disk)
# dilation_filter = cv2_disk(dilation_size)

# s2cloudless does this:
# cloud_masks = np.asarray(
#     [cv2.filter2D(cloud_prob, -1, conv_filter, borderType=cv2.BORDER_REFLECT) for cloud_prob in updated_cloud_mask],
#         dtype=np.uint8,
#     )

# cloud_masks = np.asarray(
#     [cv2.dilate(cloud_mask, dilation_filter) for cloud_mask in updated_cloud_mask], dtype=np.uint8
#     ) -->

### Option 1

In [None]:
# # Calculate cloud probability percentiles
# prob_quantile = prob_quantiles['oa_s2cloudless_prob'].sel(quantile=0.1)

# # cloud mask for regions repeatedly misclassified as cloud 
# quant_mask = xr.where(prob_quantile>=cp_10th_percentile_threshold, True, False) 
# quant_mask_probabilities = s2_unmasked['oa_s2cloudless_prob'].where(quant_mask)
# quant_mask_probabilities_mask = xr.where(quant_mask_probabilities>=s2cloudless_threshold*2, True, False)

# # cloud mask for regions NOT repeatedly misclassified as cloud
# nonquant_mask = xr.where(prob_quantile<cp_10th_percentile_threshold, True, False)
# nonquant_mask_probabilities = s2_unmasked['oa_s2cloudless_prob'].where(nonquant_mask)
# nonquant_mask_probabilities_mask = xr.where(nonquant_mask_probabilities>s2cloudless_threshold, True, False)

# ## Combine cloud masks
# updated_cloud_mask = np.logical_or(
#     quant_mask_probabilities_mask, nonquant_mask_probabilities_mask
#             )

# # apply morphological filters
# updated_cloud_mask_filtered = mask_cleanup(updated_cloud_mask, mask_filters=mask_filters).compute()

# # Apply updated cloud mask to observations
# s2_updated_masked = s2_unmasked[['nbart_green', 'nbart_red', 'nbart_blue']].where(~updated_cloud_mask_filtered)
# s2_updated_masked = s2_updated_masked.drop_vars('quantile')

In [None]:
# s2_unmasked['nbart_red'].isel(time=2).plot.imshow(size=6)
# s2_masked['nbart_red'].isel(time=2).plot.imshow(size=6)
# s2_updated_masked['nbart_red'].isel(time=2).plot.imshow(size=6)
# (updated_cloud_mask).isel(time=2).plot.imshow(size=6)
# updated_cloud_mask_filtered.isel(time=2).plot.imshow(size=6)

In [None]:
# nonquant_mask.compute().plot()
# nonquant_mask_probabilities_mask.sel(time='13-03-2022').squeeze().plot.imshow()
# quant_mask_probabilities_mask.sel(time='13-03-2022').squeeze().plot.imshow()

### Option 2

In [None]:
# Add long-term 0.1 quantile to the default 0.4 threshold. But clip range to 0.95 so threshold can't
# be larger than 95 %
# enhanced_prob_thresh = (prob_quantiles['oa_s2cloudless_prob'].sel(quantile=0.1) + s2cloudless_threshold).clip(0, 0.95)

# #create binary cloud mask 
# updated_cloud_mask = s2_unmasked['oa_s2cloudless_prob'] > enhanced_prob_thresh
# updated_cloud_mask = updated_cloud_mask.drop_vars('quantile')

In [None]:
# # apply morphological filters
# updated_cloud_mask_filtered = mask_cleanup(updated_cloud_mask, mask_filters=mask_filters).compute()

# # Apply updated cloud mask to observations
# s2_updated_masked = s2_unmasked[['nbart_green', 'nbart_red', 'nbart_blue']].where(~updated_cloud_mask_filtered)
# s2_updated_masked

In [None]:
# s2_unmasked['nbart_red'].isel(time=2).plot.imshow(size=6)
# s2_masked['nbart_red'].isel(time=2).plot.imshow(size=6)
# s2_updated_masked['nbart_red'].isel(time=20).plot.imshow(size=6)
# (updated_cloud_mask).isel(time=2).plot.imshow(size=6)
# updated_cloud_mask_filtered.isel(time=20).plot.imshow(size=6)

## Option 3

In [None]:
##---S2Cloudless means of filtering, works on numpy arrays but not dask arrays--------

# average_over = 4
# dilation_size = 2

# cloud_probs = s2_unmasked['oa_s2cloudless_prob'].data.compute()

# import cv2
# def cv2_disk(radius: int) -> np.ndarray:
#     """Recreates the disk structural element from skimage.morphology using OpenCV."""
#     return cv2.circle(  # type: ignore[call-overload]
#         np.zeros((radius * 2 + 1, radius * 2 + 1), dtype=np.uint8), (radius, radius), radius, color=1, thickness=-1
#     )

# def convolve_filter(ds, average_over=4):
#     disk = cv2_disk(average_over)
#     conv_filter = disk / np.sum(disk)
#     return cv2.filter2D(ds, 0, conv_filter, borderType=cv2.BORDER_REFLECT)

# def dilate(ds, dilation_size=2):
#     dilation_filter = cv2_disk(dilation_size)
#     return cv2.dilate(ds, dilation_filter)

# # s2cloudless does this, smoothing the probabolity array which removes
# # small cloud speckles. OpenCV does not work with dask!
# smoothed_cloud_probs = np.asarray(
#     [convolve_filter(cloud_prob, average_over=average_over) for cloud_prob in cloud_probs],
#         dtype=np.uint8,
#     )

# updated_cloud_mask_filtered = np.asarray(
#     [dilate(cloud_mask, dilation_size) for cloud_mask in updated_cloud_mask], dtype=np.uint8
#     )

In [None]:
cp_threshold = 0.1

In [None]:
# select 10th CP percentiles
cp_10th_percentile = prob_quantiles['oa_s2cloudless_prob'].sel(quantile=0.1)

#this should work but was failing weirdly...
# updated_cloud_mask = xr.where(cp_10th_percentile > cp_threshold, #where 10th % cp is above 0.1:
#             s2_unmasked['oa_s2cloudless_prob'] > (cp_10th_percentile+s2cloudless_threshold).clip(0, 0.95), # threshold probability by 0.4 + cp_10th_%
#             s2_unmasked['oa_s2cloudless_prob'] > s2cloudless_threshold, #otherwise just threshold using 0.4
#                              ).drop_vars('quantile')

# cloud mask for regions repeatedly misclassified as cloud 
quant_mask = xr.where(cp_10th_percentile>cp_threshold, True, False) 
quant_probabilities = s2_unmasked['oa_s2cloudless_prob'].where(quant_mask)
quant_probabilities_mask = xr.where(quant_probabilities>=(cp_10th_percentile+s2cloudless_threshold).clip(0, 0.90), True, False)

# cloud mask for regions NOT repeatedly misclassified as cloud
nonquant_mask = xr.where(cp_10th_percentile<=cp_threshold, True, False)
nonquant_probabilities = s2_unmasked['oa_s2cloudless_prob'].where(nonquant_mask)
nonquant_probabilities_mask = xr.where(nonquant_probabilities>s2cloudless_threshold, True, False)

## Combine cloud masks
updated_cloud_mask = np.logical_or(
    quant_probabilities_mask, nonquant_probabilities_mask
            )

# apply morphological filters
updated_cloud_mask_filtered = mask_cleanup(updated_cloud_mask, mask_filters=mask_filters)#.compute()

# Apply updated cloud mask to observations
s2_updated_masked = s2_unmasked[['nbart_green', 'nbart_red', 'nbart_blue']].where(~updated_cloud_mask_filtered)
s2_updated_masked = s2_updated_masked.drop_vars('quantile')

In [None]:
# s2_unmasked['nbart_red'].isel(time=38).plot.imshow(size=6, vmin=0, vmax=3000)
# s2_masked['nbart_red'].isel(time=38).plot.imshow(size=6, vmin=0, vmax=3000)
# s2_updated_masked['nbart_red'].isel(time=38).plot.imshow(size=6, vmin=0, vmax=3000)
# updated_cloud_mask.isel(time=38).plot.imshow(size=6)
# updated_cloud_mask_filtered.isel(time=38).plot.imshow(size=6)

## Geomedians

In [None]:
## Standard GM with no additional filtering
s2_gm_standard = geomedian_with_mads(
    s2_masked,
    reshape_strategy='mem',
    compute_mads=False
)

s2_gm_standard = assign_crs(s2_gm_standard.load(), crs='EPSG:3577')

## GM with additional filtering
s2_gm_updated = geomedian_with_mads(
    s2_updated_masked,
    reshape_strategy='mem',
    compute_mads=False
)

s2_gm_updated = assign_crs(s2_gm_updated.load(), crs='EPSG:3577')

### Difference in clear counts

In [None]:
diff_count = (s2_gm_updated['count'].astype(np.float32) - s2_gm_standard['count'].astype(np.float32))
diff_count = assign_crs(diff_count, crs='EPSG:3577')

diff_count.plot.imshow(vmin=-5, vmax=5, cmap='RdBu_r', size=5)
plt.title('Enhanced clear count minus original');

In [None]:
diff_count.odc.explore(
    vmin=-5,
    vmax=5,
    cmap='RdBu_r',
    tiles = 'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}',
    attr = 'Esri',
    name = 'Esri Satellite',
)

### Count NaNs

In [None]:
for b in ['nbart_red', 'nbart_green', 'nbart_blue']:
    num_of_nans = np.sum(np.isnan(s2_gm_standard[b]))
    if num_of_nans>0:
        print(f'{num_of_nans.item()} NaNs present in standard masking {b}')
    else:
        print(f'standard masking {b} is clean')

for b in ['nbart_red', 'nbart_green', 'nbart_blue']:
    
    num_of_nans = np.sum(np.isnan(s2_gm_updated[b]))
    if num_of_nans>0:
        print(f'{num_of_nans.item()} NaNs present in enhanced masking {b}')
    else:
        print(f'enhanced masking {b} is clean')

### Make NaNs appear pink

In [None]:
# # Create a boolean mask where NaNs. Can do this on a single layer 
#  because we masked for contiguity
nan_mask = np.isnan(s2_gm_updated['nbart_red'])

# Loop over each band and assign the pink value where the pixel is NaN
#  10000 blue, 10000 red, 0 for green.
for var in s2_gm_updated.data_vars:
    if var=='nbart_red':
        s2_gm_updated[var] = xr.where(nan_mask, 10000, s2_gm_updated[var])
    if var=='nbart_blue':
        s2_gm_updated[var] = xr.where(nan_mask, 10000, s2_gm_updated[var])
    if var=='nbart_green':
        s2_gm_updated[var] = xr.where(nan_mask, 0, s2_gm_updated[var])
    
# Same again but now for S2Cloudless
nan_mask = np.isnan(s2_gm_standard['nbart_red'])

for var in s2_gm_standard.data_vars:
    if var=='nbart_red':
        s2_gm_standard[var] = xr.where(nan_mask, 10000, s2_gm_standard[var])
    if var=='nbart_blue':
        s2_gm_standard[var] = xr.where(nan_mask, 10000, s2_gm_standard[var])
    if var=='nbart_green':
        s2_gm_standard[var] = xr.where(nan_mask, 0, s2_gm_standard[var])

### Summary Stats

In [None]:
mean_clear_updated = s2_gm_updated['count'].mean().item()
mean_clear_standard = s2_gm_standard['count'].mean().item()

min_clear_updated = s2_gm_updated['count'].min().item()
min_clear_standard = s2_gm_standard['count'].min().item()

max_clear_updated = s2_gm_updated['count'].max().item()
max_clear_standard = s2_gm_standard['count'].max().item()

print(f'Updated masking clear counts (min, mean, max) = {min_clear_updated}, {mean_clear_updated:.0f}, {max_clear_updated}')
print(f'Standard masking clear counts (min, mean, max) = {min_clear_standard}, {mean_clear_standard:.0f}, {max_clear_standard}')

## Plot

In [None]:
fig,ax = plt.subplots(2,2, figsize=(15,12), layout='constrained')
vmin, vmax=10, 90

#--------standard------------------------------
s2_gm_standard[['nbart_red', 'nbart_green', 'nbart_blue']].to_array().plot.imshow(robust=True, ax=ax[0,0], add_labels=False);
s2_gm_standard['count'].plot.imshow(vmin=vmin, vmax=vmax, cmap='magma', ax=ax[0,1], add_labels=False);

ax[0,0].set_title(f'Standard masking')
ax[0,1].set_title(f'Standard masking, clear count. Mean={mean_clear_standard:.1f}')

ax[0,0].set_yticklabels([])
ax[0,0].set_xticklabels([])
ax[0,1].set_yticklabels([])
ax[0,1].set_xticklabels([]);

#--------updated------------------------------
s2_gm_updated[['nbart_red', 'nbart_green', 'nbart_blue']].to_array().plot.imshow(robust=True, ax=ax[1,0], add_labels=False);
s2_gm_updated['count'].plot.imshow(vmin=vmin, vmax=vmax, cmap='magma', ax=ax[1,1], add_labels=False);

ax[1,0].set_title(f'Enhanced masking')
ax[1,1].set_title(f'Enhanced masking clear count. Mean={mean_clear_updated:.1f}')

ax[1,0].set_yticklabels([])
ax[1,0].set_xticklabels([])
ax[1,1].set_yticklabels([])
ax[1,1].set_xticklabels([]);

plt.savefig(f'/gdata1/projects/s2_gm/results/processed_figs/s2_gm_annual_{region_code[0]}_improvedcloudmasking.png', bbox_inches='tight', dpi=300);

## Interactive plots

In [None]:
# vmin, vmax = s2_gm_updated[['nbart_red', 'nbart_green', 'nbart_blue']].to_array().quantile((0.01, 0.99)).values

# assign_crs(s2_gm_updated, crs='EPSG:3577').odc.explore(
#     vmin=vmin,
#     vmax=vmax,
#     tiles = 'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}',
#     attr = 'Esri',
#     name = 'Esri Satellite',
# )

vmin, vmax = s2_gm_standard[['nbart_red', 'nbart_green', 'nbart_blue']].to_array().quantile((0.01, 0.99)).values

assign_crs(s2_gm_standard, crs='EPSG:3577').odc.explore(
    vmin=vmin,
    vmax=vmax,
    tiles = 'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}',
    attr = 'Esri',
    name = 'Esri Satellite',
)