### todo:
- add water masking using WOs
    - https://github.com/GeoscienceAustralia/dea-notebooks/blob/develop/DEA_products/DEA_Fractional_Cover.ipynb
- refine cloud and shadow masking for landsat sensors
- try adding landsat 7 to see what happens
- test in cloudy area
- test in area pre landsat-8

In [1]:
%xmode verbose


Exception reporting mode: Verbose


In [2]:
# %pip uninstall fc -y
%pip uninstall datacube -y


Found existing installation: datacube 1.8.20
Uninstalling datacube-1.8.20:
  Successfully uninstalled datacube-1.8.20
Note: you may need to restart the kernel to use updated packages.


In [3]:
%pip install datacube==1.8.20 -q


Collecting datacube==1.8.20
  Using cached datacube-1.8.20-py2.py3-none-any.whl.metadata (9.6 kB)
Using cached datacube-1.8.20-py2.py3-none-any.whl (377 kB)
Installing collected packages: datacube
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datacube-ows 1.9.4 requires datacube[performance,s3]>=1.9.4, but you have datacube 1.8.20 which is incompatible.
eodatasets3 1.9.3 requires datacube>=1.9.0, but you have datacube 1.8.20 which is incompatible.
odc-apps-dc-tools 1.9.3 requires datacube>=1.9.6, but you have datacube 1.8.20 which is incompatible.
odc-dscache 1.9.1 requires datacube>=1.9, but you have datacube 1.8.20 which is incompatible.
odc-stats 1.9.2 requires datacube>=1.9.6, but you have datacube 1.8.20 which is incompatible.
odc-stats 1.9.2 requires distributed>=2025.4, but you have distributed 2024.10.0 which is incompatible.[0m[31m
[0mSuccess

In [4]:
#%pip install git+https://github.com/GeoscienceAustralia/fc.git -q
%pip install fractional_cover --find-links="https://packages.dea.ga.gov.au/fc" -q


Looking in links: https://packages.dea.ga.gov.au/fc
Note: you may need to restart the kernel to use updated packages.


In [None]:
import os
import gc
import yaml
import json
import warnings
import datacube
import numpy as np
import pandas as pd
import xarray as xr
import geopandas as gpd
import calendar
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.gridspec as gridspec

from datacube.drivers.netcdf import write_dataset_to_netcdf
from datacube.utils.cog import write_cog
from fc.fractional_cover import (
    fractional_cover,
)  # import the FC package after installing it above

from odc.geo.xr import assign_crs
from odc.geo.geom import Geometry
from odc.algo import keep_good_only
from odc.algo._percentile import xr_quantile_bands
from datacube.model import Measurement
from datacube.utils import unsqueeze_dataset, masking

import sys

sys.path.insert(1, "/home/jovyan/dev/Tools/")
from dea_tools.datahandling import load_ard, wofs_fuser
from dea_tools.dask import create_local_dask_cluster
from dea_tools.plotting import rgb

warnings.filterwarnings("ignore")


In [None]:
!pip list


Package                           Version
--------------------------------- ------------------
absl-py                           2.3.1
access                            1.1.9
affine                            2.4.0
aiobotocore                       2.24.0
aiohappyeyeballs                  2.6.1
aiohttp                           3.12.15
aioitertools                      0.12.0
aiosignal                         1.4.0
alabaster                         1.0.0
alembic                           1.16.4
amply                             0.1.6
annotated-types                   0.7.0
antimeridian                      0.4.3
anyio                             4.10.0
argon2-cffi                       25.1.0
argon2-cffi-bindings              25.1.0
arrow                             1.3.0
asciitree                         0.3.3
astropy                           6.1.7
astropy-iers-data                 0.2025.8.11.0.41.9
asttokens                         3.0.0
astunparse                        1.6.3
asyn

In [None]:
client = create_local_dask_cluster(return_client=True)


0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: /user/jenna.guffogg@ga.gov.au/proxy/8787/status,

0,1
Dashboard: /user/jenna.guffogg@ga.gov.au/proxy/8787/status,Workers: 1
Total threads: 62,Total memory: 456.00 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:36583,Workers: 1
Dashboard: /user/jenna.guffogg@ga.gov.au/proxy/8787/status,Total threads: 62
Started: Just now,Total memory: 456.00 GiB

0,1
Comm: tcp://127.0.0.1:38213,Total threads: 62
Dashboard: /user/jenna.guffogg@ga.gov.au/proxy/42025/status,Memory: 456.00 GiB
Nanny: tcp://127.0.0.1:39667,
Local directory: /tmp/dask-scratch-space/worker-8a33tjmr,Local directory: /tmp/dask-scratch-space/worker-8a33tjmr


In [None]:
SAVE_COGS = True
SAVE_NETCDF = True
FIGURES = False


In [None]:
# Read YAML file
with open("/home/jovyan/git/fc-sub-annual/notebooks/variables.yaml", "r") as f:
    config = yaml.safe_load(f)

# Extract variables
MEASUREMENTS = config["measurements"]
LANDSAT_SENSOR_REGRESSION_COEFFICIENTS = config[
    "landsat_sensor_regression_coefficients"
]
SENTINEL_SENSOR_REGRESSION_COEFFICIENTS = config[
    "sentinel_sensor_regression_coefficients"
]
LANDSAT_BAND_MAPPING = config["landsat_band_mapping"]
SENTINEL_BAND_MAPPING = config["sentinel_band_mapping"]

# Convert measurements to Measurement objects
MEASUREMENTS_OBJ = [Measurement(**m) for m in MEASUREMENTS]


In [None]:
def run_fc_multi(nbart: xr.Dataset, measurements, regression_coefficients):
    results = []
    times = nbart.time.values
    for t in times:
        # Select one time slice and remove the time dimension
        input_tile = nbart.sel(time=t)
        print(input_tile.coords)
        if "time" in input_tile.dims:
            input_tile = input_tile.squeeze("time").drop("time")
        data = fractional_cover(input_tile, measurements, regression_coefficients)
        # Add time back in
        output_tile = unsqueeze_dataset(data, "time", t)
        results.append(output_tile)
    # Concatenate along time
    combined = xr.concat(results, dim="time")
    return combined


In [None]:
ls_measurements = [
    "nbart_green",
    "nbart_red",
    "nbart_blue",
    "nbart_nir",
    "nbart_swir_1",
    "nbart_swir_2",
]

s2_measurements = [
    "nbart_green",
    "nbart_red",
    "nbart_blue",
    "nbart_nir_1",
    "nbart_swir_2",  # closest match to landsat swir1
    "nbart_swir_3",  # closest match to landsat swir2
]


In [None]:
# region_code = ['x176y085'] #marysville
# region_code = ['x168y092'] #hopetoun
# region_code = ['x140y138'] #west macdonnell
# region_codes = ['x148y166'] #limmen NT - cloud heavy area
# region_codes = ['x175y066'] # Tas SW nat park


region_codes = ["x148y166"]

# including DEC previous year to get full season for DJF
start_date = "2024-03-01"
end_date = "2024-05-31"
time = (start_date, end_date)


In [None]:
# save outputs
output_dir = "/home/jovyan/gdata1/projects/fc-sub-annual/results/"
tile_dir = os.path.join(output_dir, "tiles", region_codes[0])
figs_dir = os.path.join(output_dir, "figures", region_codes[0])

for d in [tile_dir, figs_dir]:
    os.makedirs(d, exist_ok=True)


In [None]:
# open tiles and select

gdf = gpd.read_file(
    "~/gdata1/projects/fc-sub-annual/data/testing_minitile_suite.geojson"
)

gdf = gdf[gdf["region_code"].isin(region_codes)]
geom = Geometry(geom=gdf.iloc[0].geometry, crs=gdf.crs)


In [None]:
dc = datacube.Datacube(app="fc_ls_test")

query = {
    "time": time,
    "resolution": (-150, 150),
    "geopolygon": geom,
    "group_by": "solar_day",
    "output_crs": "EPSG:3577",
}


In [None]:
# gdf.explore(
#     tiles='https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}',
#     attr='Esri',
#     name='Esri satellite'
# )


In [None]:
ls_ds = load_ard(
    dc=dc,
    products=["ga_ls8c_ard_3", "ga_ls9c_ard_3"],
    measurements=ls_measurements,
    cloud_mask="fmask",
    mask_pixel_quality=True,
    fmask_categories=["valid", "snow", "water"],
    mask_filters=[("dilation", 5)],
    mask_contiguity=True,
    skip_broken_datasets=True,
    verbose=True,
    dask_chunks={"time": 1, "x": 1024, "y": 1024},
    **query
)


Finding datasets
    ga_ls8c_ard_3
    ga_ls9c_ard_3
Applying morphological filters to pixel quality mask: [('dilation', 5)]
Applying fmask pixel quality/cloud mask
Applying contiguity mask (oa_nbart_contiguity)
Returning 24 time steps as a dask array


In [None]:
s2_ds = load_ard(
    dc=dc,
    products=["ga_s2am_ard_3", "ga_s2bm_ard_3"],
    measurements=s2_measurements,
    cloud_mask="s2cloudless",
    mask_pixel_quality=True,
    mask_contiguity=True,
    skip_broken_datasets=True,
    verbose=True,
    dask_chunks={"time": 1, "x": 1024, "y": 1024},
    **query,
)


Finding datasets
    ga_s2am_ard_3
    ga_s2bm_ard_3
Applying s2cloudless pixel quality/cloud mask
Applying contiguity mask (oa_nbart_contiguity)
Returning 18 time steps as a dask array


In [None]:
for mapping in LANDSAT_BAND_MAPPING:
    if mapping["rename"]:
        ls_rename_dict = mapping["rename"]
        break

for mapping in SENTINEL_BAND_MAPPING:
    if mapping["rename"]:
        s2_rename_dict = mapping["rename"]
        break

ls_renamed = ls_ds.rename(ls_rename_dict)

s2_renamed = s2_ds.rename(s2_rename_dict)


In [None]:
results = []
for i in range(ls_renamed.sizes["time"]):
    test = ls_renamed.isel(time=i)
    test = test.drop_vars("time")
    print(test.coords)
    data = fractional_cover(
        test, MEASUREMENTS_OBJ, LANDSAT_SENSOR_REGRESSION_COEFFICIENTS
    )
    results.append(data)


Coordinates:
  * y            (y) float64 2kB -1.568e+06 -1.568e+06 ... -1.6e+06 -1.6e+06
  * x            (x) float64 2kB 3.2e+05 3.202e+05 ... 3.518e+05 3.52e+05
    spatial_ref  int32 4B 3577
Coordinates:
  * y            (y) float64 2kB -1.568e+06 -1.568e+06 ... -1.6e+06 -1.6e+06
  * x            (x) float64 2kB 3.2e+05 3.202e+05 ... 3.518e+05 3.52e+05
    spatial_ref  int32 4B 3577
Coordinates:
  * y            (y) float64 2kB -1.568e+06 -1.568e+06 ... -1.6e+06 -1.6e+06
  * x            (x) float64 2kB 3.2e+05 3.202e+05 ... 3.518e+05 3.52e+05
    spatial_ref  int32 4B 3577
Coordinates:
  * y            (y) float64 2kB -1.568e+06 -1.568e+06 ... -1.6e+06 -1.6e+06
  * x            (x) float64 2kB 3.2e+05 3.202e+05 ... 3.518e+05 3.52e+05
    spatial_ref  int32 4B 3577
Coordinates:
  * y            (y) float64 2kB -1.568e+06 -1.568e+06 ... -1.6e+06 -1.6e+06
  * x            (x) float64 2kB 3.2e+05 3.202e+05 ... 3.518e+05 3.52e+05
    spatial_ref  int32 4B 3577
Coordinates:
  * y       

In [None]:
ls_fc = run_fc_multi(
    ls_renamed, MEASUREMENTS_OBJ, LANDSAT_SENSOR_REGRESSION_COEFFICIENTS
)


Coordinates:
    time         datetime64[ns] 8B 2024-03-03T01:05:13.674885
  * y            (y) float64 2kB -1.568e+06 -1.568e+06 ... -1.6e+06 -1.6e+06
  * x            (x) float64 2kB 3.2e+05 3.202e+05 ... 3.518e+05 3.52e+05
    spatial_ref  int32 4B 3577
Coordinates:
    time         datetime64[ns] 8B 2024-03-04T00:58:47.784852
  * y            (y) float64 2kB -1.568e+06 -1.568e+06 ... -1.6e+06 -1.6e+06
  * x            (x) float64 2kB 3.2e+05 3.202e+05 ... 3.518e+05 3.52e+05
    spatial_ref  int32 4B 3577
Coordinates:
    time         datetime64[ns] 8B 2024-03-11T01:04:54.495205
  * y            (y) float64 2kB -1.568e+06 -1.568e+06 ... -1.6e+06 -1.6e+06
  * x            (x) float64 2kB 3.2e+05 3.202e+05 ... 3.518e+05 3.52e+05
    spatial_ref  int32 4B 3577
Coordinates:
    time         datetime64[ns] 8B 2024-03-12T00:59:02.670070
  * y            (y) float64 2kB -1.568e+06 -1.568e+06 ... -1.6e+06 -1.6e+06
  * x            (x) float64 2kB 3.2e+05 3.202e+05 ... 3.518e+05 3.52e+05
   

In [None]:
s2_fc = run_fc_multi(
    s2_renamed, MEASUREMENTS_OBJ, LANDSAT_SENSOR_REGRESSION_COEFFICIENTS
)


Coordinates:
    time         datetime64[ns] 8B 2024-03-03T01:21:22.665343
  * y            (y) float64 2kB -1.568e+06 -1.568e+06 ... -1.6e+06 -1.6e+06
  * x            (x) float64 2kB 3.2e+05 3.202e+05 ... 3.518e+05 3.52e+05
    spatial_ref  int32 4B 3577
Coordinates:
    time         datetime64[ns] 8B 2024-03-08T01:21:16.695865
  * y            (y) float64 2kB -1.568e+06 -1.568e+06 ... -1.6e+06 -1.6e+06
  * x            (x) float64 2kB 3.2e+05 3.202e+05 ... 3.518e+05 3.52e+05
    spatial_ref  int32 4B 3577
Coordinates:
    time         datetime64[ns] 8B 2024-03-13T01:21:22.331488
  * y            (y) float64 2kB -1.568e+06 -1.568e+06 ... -1.6e+06 -1.6e+06
  * x            (x) float64 2kB 3.2e+05 3.202e+05 ... 3.518e+05 3.52e+05
    spatial_ref  int32 4B 3577
Coordinates:
    time         datetime64[ns] 8B 2024-03-18T01:21:21.503141
  * y            (y) float64 2kB -1.568e+06 -1.568e+06 ... -1.6e+06 -1.6e+06
  * x            (x) float64 2kB 3.2e+05 3.202e+05 ... 3.518e+05 3.52e+05
   

In [None]:
for coord in ls_ds.coords:
    if coord not in ls_fc.coords:
        ls_fc = ls_fc.assign_coords({coord: ls_ds.coords[coord]})
ls_fc.attrs = ls_ds.attrs.copy()

merged_ls = xr.merge([ls_ds, ls_fc])


In [None]:
for coord in s2_ds.coords:
    if coord not in s2_fc.coords:
        s2_fc = s2_fc.assign_coords({coord: s2_ds.coords[coord]})
s2_fc.attrs = s2_ds.attrs.copy()

merged_s2 = xr.merge([s2_ds, s2_fc])


## merge sensors

In [None]:
# both Landsat and Sentinel-2 FC datasets merged here.
# NOTE: This is probably not done correctly and is for demo purposes/ Jenna learning only

merged_all = xr.concat([merged_ls, merged_s2], dim="time")
merged_all = merged_all.drop_vars(
    ["nbart_nir", "nbart_nir_1", "nbart_swir_1", "nbart_swir_2", "nbart_swir_3"]
)  # drop duplicate NIR band from S2

merged_all


Unnamed: 0,Array,Chunk
Bytes,7.34 MiB,178.89 kiB
Shape,"(42, 214, 214)","(1, 214, 214)"
Dask graph,42 chunks in 23 graph layers,42 chunks in 23 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 7.34 MiB 178.89 kiB Shape (42, 214, 214) (1, 214, 214) Dask graph 42 chunks in 23 graph layers Data type float32 numpy.ndarray",214  214  42,

Unnamed: 0,Array,Chunk
Bytes,7.34 MiB,178.89 kiB
Shape,"(42, 214, 214)","(1, 214, 214)"
Dask graph,42 chunks in 23 graph layers,42 chunks in 23 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,7.34 MiB,178.89 kiB
Shape,"(42, 214, 214)","(1, 214, 214)"
Dask graph,42 chunks in 23 graph layers,42 chunks in 23 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 7.34 MiB 178.89 kiB Shape (42, 214, 214) (1, 214, 214) Dask graph 42 chunks in 23 graph layers Data type float32 numpy.ndarray",214  214  42,

Unnamed: 0,Array,Chunk
Bytes,7.34 MiB,178.89 kiB
Shape,"(42, 214, 214)","(1, 214, 214)"
Dask graph,42 chunks in 23 graph layers,42 chunks in 23 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,7.34 MiB,178.89 kiB
Shape,"(42, 214, 214)","(1, 214, 214)"
Dask graph,42 chunks in 23 graph layers,42 chunks in 23 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 7.34 MiB 178.89 kiB Shape (42, 214, 214) (1, 214, 214) Dask graph 42 chunks in 23 graph layers Data type float32 numpy.ndarray",214  214  42,

Unnamed: 0,Array,Chunk
Bytes,7.34 MiB,178.89 kiB
Shape,"(42, 214, 214)","(1, 214, 214)"
Dask graph,42 chunks in 23 graph layers,42 chunks in 23 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,1.83 MiB,44.72 kiB
Shape,"(42, 214, 214)","(1, 214, 214)"
Dask graph,42 chunks in 1436 graph layers,42 chunks in 1436 graph layers
Data type,int8 numpy.ndarray,int8 numpy.ndarray
"Array Chunk Bytes 1.83 MiB 44.72 kiB Shape (42, 214, 214) (1, 214, 214) Dask graph 42 chunks in 1436 graph layers Data type int8 numpy.ndarray",214  214  42,

Unnamed: 0,Array,Chunk
Bytes,1.83 MiB,44.72 kiB
Shape,"(42, 214, 214)","(1, 214, 214)"
Dask graph,42 chunks in 1436 graph layers,42 chunks in 1436 graph layers
Data type,int8 numpy.ndarray,int8 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,1.83 MiB,44.72 kiB
Shape,"(42, 214, 214)","(1, 214, 214)"
Dask graph,42 chunks in 1436 graph layers,42 chunks in 1436 graph layers
Data type,int8 numpy.ndarray,int8 numpy.ndarray
"Array Chunk Bytes 1.83 MiB 44.72 kiB Shape (42, 214, 214) (1, 214, 214) Dask graph 42 chunks in 1436 graph layers Data type int8 numpy.ndarray",214  214  42,

Unnamed: 0,Array,Chunk
Bytes,1.83 MiB,44.72 kiB
Shape,"(42, 214, 214)","(1, 214, 214)"
Dask graph,42 chunks in 1436 graph layers,42 chunks in 1436 graph layers
Data type,int8 numpy.ndarray,int8 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,1.83 MiB,44.72 kiB
Shape,"(42, 214, 214)","(1, 214, 214)"
Dask graph,42 chunks in 1436 graph layers,42 chunks in 1436 graph layers
Data type,int8 numpy.ndarray,int8 numpy.ndarray
"Array Chunk Bytes 1.83 MiB 44.72 kiB Shape (42, 214, 214) (1, 214, 214) Dask graph 42 chunks in 1436 graph layers Data type int8 numpy.ndarray",214  214  42,

Unnamed: 0,Array,Chunk
Bytes,1.83 MiB,44.72 kiB
Shape,"(42, 214, 214)","(1, 214, 214)"
Dask graph,42 chunks in 1436 graph layers,42 chunks in 1436 graph layers
Data type,int8 numpy.ndarray,int8 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,1.83 MiB,44.72 kiB
Shape,"(42, 214, 214)","(1, 214, 214)"
Dask graph,42 chunks in 1436 graph layers,42 chunks in 1436 graph layers
Data type,int8 numpy.ndarray,int8 numpy.ndarray
"Array Chunk Bytes 1.83 MiB 44.72 kiB Shape (42, 214, 214) (1, 214, 214) Dask graph 42 chunks in 1436 graph layers Data type int8 numpy.ndarray",214  214  42,

Unnamed: 0,Array,Chunk
Bytes,1.83 MiB,44.72 kiB
Shape,"(42, 214, 214)","(1, 214, 214)"
Dask graph,42 chunks in 1436 graph layers,42 chunks in 1436 graph layers
Data type,int8 numpy.ndarray,int8 numpy.ndarray


## calculating percentiles

In [None]:
merged_all = masking.mask_invalid_data(merged_all)
ls_fc = masking.mask_invalid_data(ls_fc)
s2_fc = masking.mask_invalid_data(s2_fc)


In [None]:
# mask water

wo_ls = dc.load(
    product="ga_ls_wo_3",
    group_by="solar_day",
    fuse_func=wofs_fuser,
    like=ls_ds,
)

wo_s2 = dc.load(
    product="ga_s2_wo_provisional_3",
    fuse_func=wofs_fuser,
    group_by="solar_day",
    like=s2_ds,
)


In [None]:
wo_merged = xr.concat([wo_ls, wo_s2], dim="time")
wo_mask = masking.make_mask(wo_merged.water, dry=True)

fc_merged_wo_masked = merged_all.where(wo_mask)
fc_ls_wo_masked = ls_fc.where(wo_mask)
fc_s2_wo_masked = s2_fc.where(wo_mask)


In [None]:
stacked_pc = xr_quantile_bands(
    fc_merged_wo_masked, quantiles=[0.1, 0.5, 0.9], nodata=np.nan
)
ls_pc = xr_quantile_bands(fc_ls_wo_masked, quantiles=[0.1, 0.5, 0.9], nodata=np.nan)
s2_pc = xr_quantile_bands(fc_s2_wo_masked, quantiles=[0.1, 0.5, 0.9], nodata=np.nan)


## Means, Medians and rolling medians oh my...

In [None]:
def save_as_netcdf(ds, ds_name):
    """
    Save the dataset to a NetCDF file."""

    output_path = os.path.join(
        tile_dir, f"{ds_name}_{start_date}-{end_date}_{region_codes[0]}.nc"
    )

    if os.path.exists(output_path):
        os.remove(output_path)
        print(f"Removed existing file: {output_path}")

    write_dataset_to_netcdf(ds, os.path.join(output_path))


In [None]:
def save_as_cogs(ds, ds_name, region_codes, cadence: str):
    region_code = region_codes[0]

    if not os.path.exists(tile_dir):
        os.makedirs(tile_dir)

    time_dims = ["time", "year_month", "year_season"]
    if cadence not in time_dims:
        raise ValueError(
            f"Unsupported cadence: {cadence}. Supported cadences are: {time_dims}"
        )

    if cadence == "year_month":
        years = ds["year"].values
        months = ds["month"].values
        # loop over all year/ month combos
        for i in range(ds.sizes["year"]):
            for j in range(ds.sizes["month"]):
                year = years[i]
                month = months[j]
                try:
                    singletimestamp_da = ds.isel(year=i, month=j).to_array()
                except Exception:
                    continue
                output_fname = os.path.join(
                    tile_dir, f"{year:04d}_{month:02d}_{ds_name}_{region_code}.tif"
                )
                cog_file = write_cog(
                    geo_im=singletimestamp_da,
                    fname=output_fname,
                    overwrite=True,
                ).compute()

    elif cadence == "year_season":
        years = ds["year"].values
        seasons = ds["season"].values
        for i in range(ds.sizes["year"]):
            for j in range(ds.sizes["season"]):
                year = years[i]
                season = seasons[j]
                try:
                    singletimestamp_da = ds.isel(year=i, season=j).to_array()
                except Exception:
                    continue
                output_fname = os.path.join(
                    tile_dir, f"{year:04d}_{season}_{ds_name}_{region_code}.tif"
                )
                cog_file = write_cog(
                    geo_im=singletimestamp_da,
                    fname=output_fname,
                    overwrite=True,
                ).compute()
    else:
        raise ValueError(
            f"Unsupported cadence: {cadence}. Supported cadences are: 'year_month', 'year_season'"
        )

    print(f"Saved COGs to {tile_dir}")


In [None]:
fc_merged_wo_masked = fc_merged_wo_masked.assign_coords(
    year=fc_merged_wo_masked["time"].dt.year,
    month=fc_merged_wo_masked["time"].dt.month,
    season=fc_merged_wo_masked["time"].dt.season,
)

monthly_medians = fc_merged_wo_masked.groupby(["year", "month"]).median(
    dim="time", keep_attrs=True
)

monthly_counts = fc_merged_wo_masked["BS"].groupby(["year", "month"]).count(dim="time")
monthly_medians["obs_count"] = monthly_counts


seasonal_medians = fc_merged_wo_masked.groupby(["year", "season"]).median(
    dim="time", keep_attrs=True
)

seasonal_counts = (
    fc_merged_wo_masked["BS"].groupby(["year", "season"]).count(dim="time")
)
seasonal_medians["obs_count"] = seasonal_counts


In [None]:
# # Persist all required arrays before the loop
# monthly_red = monthly_medians["nbart_red"].persist()
# monthly_green = monthly_medians["nbart_green"].persist()
# monthly_blue = monthly_medians["nbart_blue"].persist()
# monthly_bs = monthly_medians["BS"].persist()
# monthly_pv = monthly_medians["PV"].persist()
# monthly_npv = monthly_medians["NPV"].persist()

# n_steps = len(monthly_medians["year_month"].values)

# for step in range(n_steps):
#     month_number = monthly_medians["year_month"].values[step]

#     fig, axs = plt.subplots(
#         1,
#         4,
#         figsize=(20, 5),
#         layout="constrained",
#         subplot_kw={"projection": ccrs.epsg(3577)},
#     )

#     rgb_true = xr.concat(
#         [
#             monthly_red.isel(year_month=step),
#             monthly_green.isel(year_month=step),
#             monthly_blue.isel(year_month=step),
#         ],
#         dim="band",
#     )
#     rgb_true.plot.imshow(ax=axs[0], robust=True, transform=ccrs.epsg(3577))
#     axs[0].set_title("RGB True Colour")

#     monthly_bs.isel(year_month=step).plot.imshow(
#         ax=axs[1],
#         cmap="Oranges",
#         vmin=0,
#         vmax=100,
#         add_colorbar=False,
#         transform=ccrs.epsg(3577),
#     )
#     axs[1].set_title("Bare Soil component")

#     monthly_pv.isel(year_month=step).plot.imshow(
#         ax=axs[2],
#         cmap="Greens",
#         vmin=0,
#         vmax=100,
#         add_colorbar=False,
#         transform=ccrs.epsg(3577),
#     )
#     axs[2].set_title("Green Vegetation component")

#     monthly_npv.isel(year_month=step).plot.imshow(
#         ax=axs[3],
#         cmap="Blues",
#         vmin=0,
#         vmax=100,
#         add_colorbar=False,
#         transform=ccrs.epsg(3577),
#     )
#     axs[3].set_title("Non-green Vegetation component")

#     for ax in axs:
#         ratio = 1.0
#         x_left, x_right = ax.get_xlim()
#         y_low, y_high = ax.get_ylim()
#         ax.set_aspect(abs((x_right - x_left) / (y_low - y_high)) * ratio)

#     plt.suptitle(
#         f"Fractional Cover Monthly Median blended - {region_codes[0]} - {month_number}"
#     )

#     output_path = os.path.join(
#         figs_dir,
#         f"fc_blended_median_{region_codes[0]}_{month_number}.png",
#     )

#     if os.path.exists(output_path):
#         os.remove(output_path)

#     plt.savefig(
#         output_path,
#         bbox_inches="tight",
#         dpi=100,
#     )
#     plt.show()


In [34]:
%%time

if FIGURES:
    single_year_monthly_medians = monthly_medians.isel(year=0).persist()

    months = single_year_monthly_medians["month"].values
    n_months = len(months)

    n_steps = n_months

    bounds = np.arange(0, 10, 1)
    norm = mcolors.BoundaryNorm(boundaries=bounds, ncolors=len(bounds) - 1)
    cmap = plt.get_cmap('gnuplot', len(bounds) - 1)

    vmin = 0
    vmax = 10

    fig = plt.figure(figsize=(5 * 4.5, 5 * n_steps))
    gs = gridspec.GridSpec(nrows=n_steps, ncols=5, width_ratios=[1, 1, 1, 1, 0.05])

    axs = np.empty((n_steps, 5), dtype=object)
    for i in range(n_steps):
        for j in range(5):
            axs[i, j] = fig.add_subplot(gs[i, j], projection=ccrs.epsg(3577))

    im_ue = None
    im_obs = None

    for i in range(n_steps):
        month_number = single_year_monthly_medians['month'].values[i]
        month_name = calendar.month_name[month_number]
        
        # True Colour RGB
        rgb_true = xr.concat(
            [
                single_year_monthly_medians["nbart_red"].isel(month=i),
                single_year_monthly_medians["nbart_green"].isel(month=i),
                single_year_monthly_medians["nbart_blue"].isel(month=i),
            ],
            dim="band",
        )
        rgb_true.plot.imshow(ax=axs[i, 0], robust=True, transform=ccrs.epsg(3577))
        axs[i, 0].set_title(f"{month_name} - RGB", fontsize=10)

        # FC RGB
        rgb_fc = xr.concat(
            [
                single_year_monthly_medians["BS"].isel(month=i),
                single_year_monthly_medians["PV"].isel(month=i),
                single_year_monthly_medians["NPV"].isel(month=i),
            ],
            dim="band",
        )
        rgb_fc.plot.imshow(ax=axs[i, 1], robust=True, add_colorbar=False, transform=ccrs.epsg(3577))
        axs[i, 1].set_title(f"FC", fontsize=10)
        
        # median unmixing error
        ue_median = single_year_monthly_medians["UE"].isel(month=i)
        im_ue = ue_median.plot.imshow(
            ax=axs[i, 2], cmap='magma', add_colorbar=False, transform=ccrs.epsg(3577)
        )
        axs[i, 2].set_title(f"median UE", fontsize=10)
        
        # Observation count
        obs_count = single_year_monthly_medians["obs_count"].isel(month=i)
        im_obs = obs_count.plot.imshow(
            ax=axs[i, 3], cmap=cmap, vmin=vmin, vmax=vmax, add_colorbar=False, transform=ccrs.epsg(3577)
        )
        axs[i, 3].set_title(f"obs count", fontsize=10)

    # Add colorbar for UE
    cbar_ax_ue = fig.add_axes([0.92, 0.65, 0.02, 0.2])
    cbar_ue = fig.colorbar(im_ue, cax=cbar_ax_ue)
    cbar_ue.set_label("Unmixing Error (UE)")

    # Add colorbar for Observation Count
    cbar_ax_obs = fig.add_axes([0.92, 0.35, 0.02, 0.2])
    cbar_obs = fig.colorbar(im_obs, cax=cbar_ax_obs, extend='max')
    cbar_obs.set_label("Observation Count")

    fig.suptitle("Monthly blended Landsat-Sentinel-2 FC, observation counts and median unmixing error", fontsize=16)

    tile_dir = os.path.join(output_dir, f"figs/{region_codes[0]}")

    output_path = os.path.join(
        figs_dir,
        f"fc_blended_monthly_median_with_obs_count_{region_codes[0]}.png",
    )

    if os.path.exists(output_path):
        os.remove(output_path)

    plt.savefig(
        output_path,
        bbox_inches="tight",
        dpi=200,
    )

    plt.show()


CPU times: user 5 μs, sys: 0 ns, total: 5 μs
Wall time: 7.15 μs


In [35]:
%%time

if FIGURES:
    single_year_seasonal_medians = seasonal_medians.isel(year=0).persist()

    n_steps = len(single_year_seasonal_medians["season"].values)

    bounds = np.arange(0,20, 1)
    norm = mcolors.BoundaryNorm(boundaries=bounds, ncolors=len(bounds) - 1)
    cmap = plt.get_cmap('gnuplot', len(bounds)- 1)

    vmin = 0
    vmax = 20

    # Use gridspec to add an extra column for the colorbar
    fig = plt.figure(figsize=(5 * 4, 5 * n_steps))
    gs = gridspec.GridSpec(nrows=n_steps, ncols=5, width_ratios=[1, 1, 1, 1, 0.05])

    axs = np.empty((n_steps, 4), dtype=object)
    for i in range(n_steps):
        for j in range(4):
            axs[i, j] = fig.add_subplot(gs[i, j], projection=ccrs.epsg(3577))


    for i in range(n_steps):
        # get calendar month name for plotting
        season_label = single_year_seasonal_medians['season'].values[i]
        
        # True Colour RGB
        rgb_true = xr.concat(
                [
                    single_year_seasonal_medians["nbart_red"].isel(season=i),
                    single_year_seasonal_medians["nbart_green"].isel(season=i),
                    single_year_seasonal_medians["nbart_blue"].isel(season=i),
                ],
                dim="band",
            )
        rgb_true.plot.imshow(ax=axs[i, 0], robust=True, transform=ccrs.epsg(3577))
        axs[i, 0].set_title(f"{season_label} - RGB", fontsize=10)

        # FC RGB
        rgb_fc = xr.concat(
                [
                    single_year_seasonal_medians["BS"].isel(season=i),
                    single_year_seasonal_medians["PV"].isel(season=i),
                    single_year_seasonal_medians["NPV"].isel(season=i),
                ],
                dim="band",
            )
        rgb_fc.plot.imshow(ax=axs[i, 1], robust=True, add_colorbar=False, transform=ccrs.epsg(3577))
        axs[i, 1].set_title(f"FC", fontsize=10)
        
        # median unmixing error
        ue_median = single_year_seasonal_medians["UE"].isel(season=i)
        im_ue = ue_median.plot.imshow(
            ax=axs[i, 3], cmap='magma', add_colorbar=False, transform=ccrs.epsg(3577)
        )
        axs[i, 3].set_title(f"median UE error", fontsize=10)
        
        # Observation count
        obs_count = single_year_seasonal_medians["obs_count"].isel(season=i)
        im_obs = obs_count.plot.imshow(
            ax=axs[i, 2], cmap=cmap, vmin=vmin, vmax=vmax, add_colorbar=False, transform=ccrs.epsg(3577)
        )
        axs[i, 2].set_title(f"obs count", fontsize=10)
        
        
    # add colorbar for UE
    cbar_ax_ue = fig.add_axes([0.92, 0.65, 0.02, 0.2])
    cbar_ue = fig.colorbar(im_ue, cax = cbar_ax_ue)
    cbar_ue.set_label("Unmixing Error (UE)")

    # add colorbar for observation count
    cbar_ax_obs = fig.add_axes([0.92, 0.35, 0.02, 0.2])
    cbar_obs = fig.colorbar(im_obs,  cax=cbar_ax_obs, extend='max')
    cbar_obs.set_label("Observation Count")

    fig.suptitle("Seasonal blended Landsat-Sentinel-2 FC, observation counts and median unmixing error", fontsize=16)

    tile_dir = os.path.join(output_dir, f"figs/{region_codes[0]}")

    output_path = os.path.join(
        figs_dir,
        f"fc_blended_seasonal_median_with_obs_count_{region_codes[0]}.png",
    )

    if os.path.exists(output_path):
        os.remove(output_path)

    plt.savefig(
        output_path,
        bbox_inches="tight",
        dpi=200,
    )

    plt.show()


CPU times: user 4 μs, sys: 0 ns, total: 4 μs
Wall time: 6.91 μs


In [None]:
monthly_medians_to_export = monthly_medians.drop_vars(
    ["nbart_red", "nbart_green", "nbart_blue"]
)

# if save_cogs is true, files will be saved. Otherwise, they will not be saved.
if SAVE_COGS:
    save_as_cogs(
        monthly_medians_to_export,
        "monthly_fc_blended_medians",
        region_codes,
        cadence="year_month",
    )

if SAVE_NETCDF:
    save_as_netcdf(
        monthly_medians_to_export,
        "monthly_fc_blended_medians",
    )


  dest = _reproject(


Saved COGs to /home/jovyan/gdata1/projects/fc-sub-annual/results/tiles/x148y166
Removed existing file: /home/jovyan/gdata1/projects/fc-sub-annual/results/tiles/x148y166/monthly_fc_blended_medians_2024-03-01-2024-05-31_x148y166.nc


In [None]:
seasonal_medians_to_export = seasonal_medians.drop_vars(
    ["nbart_red", "nbart_green", "nbart_blue"]
)

if SAVE_COGS:
    save_as_cogs(
        seasonal_medians_to_export,
        "seasonal_fc_blended_medians",
        region_codes,
        cadence="year_season",
    )

if SAVE_NETCDF:
    save_as_netcdf(
        seasonal_medians_to_export,
        "seasonal_fc_blended_medians",
    )
