### todo:
- add water masking using WOs
    - https://github.com/GeoscienceAustralia/dea-notebooks/blob/develop/DEA_products/DEA_Fractional_Cover.ipynb
- refine cloud and shadow masking for landsat sensors
- try adding landsat 7 to see what happens
- test in cloudy area
- test in area pre landsat-8

In [1]:
# %xmode verbose


In [2]:
%pip uninstall datacube -y
%pip install datacube==1.9.8 -q


Found existing installation: datacube 1.9.8
Uninstalling datacube-1.9.8:
  Successfully uninstalled datacube-1.9.8
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [3]:
# %pip uninstall fc -y
# %pip install git+https://github.com/GeoscienceAustralia/fc.git -q
%pip install fractional_cover --find-links="https://packages.dea.ga.gov.au/fc" -q
%pip install dea-tools==0.4.4 -q


Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [4]:
import os
import gc
import yaml
import json
import warnings
import datacube
import numpy as np
import pandas as pd
import xarray as xr
import geopandas as gpd
import calendar
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.gridspec as gridspec

from datacube.drivers.netcdf import write_dataset_to_netcdf
from datacube.utils.cog import write_cog
from fc.fractional_cover import (
    fractional_cover,
)  # import the FC package after installing it above

from odc.geo.xr import assign_crs
from odc.geo.geom import Geometry
from odc.algo import keep_good_only
from datacube.model import Measurement
from datacube.utils import unsqueeze_dataset, masking

from dea_tools.datahandling import load_ard, wofs_fuser
from dea_tools.dask import create_local_dask_cluster
from dea_tools.plotting import rgb

warnings.filterwarnings("ignore")


In [5]:
!pip list


Package                           Version
--------------------------------- ------------------
absl-py                           2.3.1
access                            1.1.9
affine                            2.4.0
aiobotocore                       2.24.0
aiohappyeyeballs                  2.6.1
aiohttp                           3.12.15
aioitertools                      0.12.0
aiosignal                         1.4.0
alabaster                         1.0.0
alembic                           1.16.4
amply                             0.1.6
annotated-types                   0.7.0
antimeridian                      0.4.3
anyio                             4.10.0
argon2-cffi                       25.1.0
argon2-cffi-bindings              25.1.0
arrow                             1.3.0
asciitree                         0.3.3
astropy                           6.1.7
astropy-iers-data                 0.2025.8.11.0.41.9
asttokens                         3.0.0
astunparse                        1.6.3
asyn

In [6]:
client = create_local_dask_cluster(return_client=True)


0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: /user/jenna.guffogg@ga.gov.au/proxy/42611/status,

0,1
Dashboard: /user/jenna.guffogg@ga.gov.au/proxy/42611/status,Workers: 1
Total threads: 62,Total memory: 456.00 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:44303,Workers: 0
Dashboard: /user/jenna.guffogg@ga.gov.au/proxy/42611/status,Total threads: 0
Started: Just now,Total memory: 0 B

0,1
Comm: tcp://127.0.0.1:45975,Total threads: 62
Dashboard: /user/jenna.guffogg@ga.gov.au/proxy/43933/status,Memory: 456.00 GiB
Nanny: tcp://127.0.0.1:44709,
Local directory: /tmp/dask-scratch-space/worker-4bzx_tz7,Local directory: /tmp/dask-scratch-space/worker-4bzx_tz7


In [7]:
SAVE_OUT_COGS = False
SAVE_OUT_NETCDF = True
FIGURES = True

VARIABLES_YAML = (
    "/home/jovyan/git/fc-sub-annual/notebooks/variables_s2-ls-coefficients-current.yaml"
)


For experimenting with different sets of variables (ie, that is where I am currently storing the landsat and sentinel coefficients), you will need to change the first 3 lines of code in the cell below. This could be changed later on to be a dynamic variabls.

Also as a todo, the output file names for cdf files and figures should take the yaml file name as an input, and the variable yaml name should be changed to reflect the coefficients in the files.

In [8]:
# Read YAML file
with open(VARIABLES_YAML, "r") as f:
    config = yaml.safe_load(f)

# Extract variables
MEASUREMENTS = config["measurements"]
LANDSAT_SENSOR_REGRESSION_COEFFICIENTS = config[
    "landsat_sensor_regression_coefficients"
]
SENTINEL_SENSOR_REGRESSION_COEFFICIENTS = config[
    "sentinel_sensor_regression_coefficients"
]
LANDSAT_BAND_MAPPING = config["landsat_band_mapping"]
SENTINEL_BAND_MAPPING = config["sentinel_band_mapping"]

# Convert measurements to Measurement objects
MEASUREMENTS_OBJ = [Measurement(**m) for m in MEASUREMENTS]


In [9]:
def run_fc_multi(nbart: xr.Dataset, measurements, regression_coefficients):
    results = []
    times = nbart.time.values
    for t in times:
        # Select one time slice and remove the time dimension
        input_tile = nbart.sel(time=t)
        # print(input_tile.coords)
        if "time" in input_tile.dims:
            input_tile = input_tile.squeeze("time").drop("time")
        data = fractional_cover(input_tile, measurements, regression_coefficients)
        # Add time back in
        output_tile = unsqueeze_dataset(data, "time", t)
        results.append(output_tile)
    # Concatenate along time
    combined = xr.concat(results, dim="time")
    return combined


In [10]:
def save_as_cogs(ds, ds_name, region_codes, cadence: str):
    variables_dir_name = VARIABLES_YAML.split("_")[-1].replace(".yaml", "")

    region_code = region_codes[0]

    output_path = os.path.join(tile_dir, variables_dir_name)

    if not os.path.exists(output_path):
        os.makedirs(output_path)

    output_fname = os.path.join(
        output_path,
        f"{ds_name}_{start_date}-{end_date}_{region_codes[0]}.nc",
    )

    if os.path.exists(output_fname):
        os.remove(output_fname)
        print(f"Removed existing file: {output_fname}")

    time_dims = ["time", "year_month", "year_season"]
    if cadence not in time_dims:
        raise ValueError(
            f"Unsupported cadence: {cadence}. Supported cadences are: {time_dims}"
        )

    if cadence == "year_month":
        years = ds["year"].values
        months = ds["month"].values
        # loop over all year/ month combos
        for i in range(ds.sizes["year"]):
            for j in range(ds.sizes["month"]):
                year = years[i]
                month = months[j]
                try:
                    singletimestamp_da = ds.isel(year=i, month=j).to_array()
                except Exception:
                    continue
                output_fname = os.path.join(
                    output_path, f"{year:04d}_{month:02d}_{ds_name}_{region_code}.tif"
                )
                write_cog(
                    geo_im=singletimestamp_da,
                    fname=output_fname,
                    overwrite=True,
                ).compute()

    elif cadence == "year_season":
        years = ds["year"].values
        seasons = ds["season"].values
        for i in range(ds.sizes["year"]):
            for j in range(ds.sizes["season"]):
                year = years[i]
                season = seasons[j]
                try:
                    singletimestamp_da = ds.isel(year=i, season=j).to_array()
                except Exception:
                    continue
                output_fname = os.path.join(
                    output_path, f"{year:04d}_{season}_{ds_name}_{region_code}.tif"
                )
                write_cog(
                    geo_im=singletimestamp_da,
                    fname=output_fname,
                    overwrite=True,
                ).compute()
    else:
        raise ValueError(
            f"Unsupported cadence: {cadence}. Supported cadences are: 'year_month', 'year_season'"
        )

    print(f"Saved COGs to {output_path}")


In [11]:
def save_netcdfs(ds, ds_name):
    """
    Save the dataset to a NetCDF file."""

    variables_dir_name = VARIABLES_YAML.split("_")[-1].replace(".yaml", "")

    season_map = {"DJF": 0, "MAM": 1, "JJA": 2, "SON": 3}

    # assign integer coordinate for easons
    if "season" in ds.coords:
        ds = ds.assign_coords(
            season_int=("season", [season_map[s] for s in ds["season"].values])
        )
        ds = ds.swap_dims({"season": "season_int"})
        ds = ds.drop_vars(
            "season"
        )  # have to drop to save to netcdf, as this is an object type

        # save season mapping as attribute so it's preserved in cdf file
        ds.attrs["season_int_mapping"] = str(season_map)

    output_path = os.path.join(tile_dir, variables_dir_name)

    if not os.path.exists(output_path):
        os.makedirs(output_path)

    output_fname = os.path.join(
        output_path,
        f"{ds_name}_{start_date}-{end_date}_{region_codes[0]}.nc",
    )

    if os.path.exists(output_fname):
        os.remove(output_fname)
        print(f"Removed existing file: {output_fname}")

    write_dataset_to_netcdf(ds, output_fname)
    print(f"Saved netCDFs to {output_path}")


In [12]:
ls_measurements = [
    "nbart_green",
    "nbart_red",
    "nbart_blue",
    "nbart_nir",
    "nbart_swir_1",
    "nbart_swir_2",
]

s2_measurements = [
    "nbart_green",
    "nbart_red",
    "nbart_blue",
    "nbart_nir_1",
    "nbart_swir_2",  # closest match to landsat swir1
    "nbart_swir_3",  # closest match to landsat swir2
]


In [13]:
# region_code = ['x176y085'] #marysville
# region_code = ['x168y092'] #hopetoun
# region_code = ['x140y138'] #west macdonnell
# region_codes = ['x148y166'] #limmen NT - cloud heavy area
# region_codes = ['x175y066'] # Tas SW nat park


region_codes = ["x148y166"]

# including DEC previous year to get full season for DJF
year = "2024"
start_date = f"{year}-03-01"
end_date = f"{year}-11-30"
time = (start_date, end_date)


In [14]:
# save outputs
output_dir = "/home/jovyan/gdata1/projects/fc-sub-annual/results/"
tile_dir = os.path.join(output_dir, "tiles", region_codes[0])
figs_dir = os.path.join(output_dir, "figures", region_codes[0])

for d in [tile_dir, figs_dir]:
    os.makedirs(d, exist_ok=True)


In [15]:
# open tiles and select

gdf = gpd.read_file(
    "~/gdata1/projects/fc-sub-annual/data/testing_minitile_suite.geojson"
)

gdf = gdf[gdf["region_code"].isin(region_codes)]
geom = Geometry(geom=gdf.iloc[0].geometry, crs=gdf.crs)


In [16]:
# # get central coords of the tile and use that to select a SMALL area for testing
# gdf["central_lon"] = gdf.geometry.centroid.x
# gdf["central_lat"] = gdf.geometry.centroid.y
# lon_range = (gdf["central_lon"].values[0] - 0.05, gdf["central_lon"].values[0] + 0.05)
# lat_range = (gdf["central_lat"].values[0] - 0.05, gdf["central_lat"].values[0] + 0.05)


In [17]:
dc = datacube.Datacube(app="fc_ls_test")

query = {
    "geopolygon": geom,
    "time": time,
    "resolution": (-30, 30),
    "group_by": "solar_day",
    "output_crs": "EPSG:3577",
}


In [18]:
# gdf.explore(
#     tiles='https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}',
#     attr='Esri',
#     name='Esri satellite'
# )


In [19]:
ls_ds = load_ard(
    dc=dc,
    products=["ga_ls8c_ard_3", "ga_ls9c_ard_3"],
    measurements=ls_measurements,
    cloud_mask="fmask",
    mask_pixel_quality=True,
    fmask_categories=["valid", "snow", "water"],
    mask_filters=[("dilation", 5)],
    mask_contiguity=True,
    skip_broken_datasets=True,
    verbose=True,
    dask_chunks={"time": 10, "x": 1024, "y": 1024},
    **query
)


Finding datasets
    ga_ls8c_ard_3
    ga_ls9c_ard_3
Applying morphological filters to pixel quality mask: [('dilation', 5)]
Applying fmask pixel quality/cloud mask
Applying contiguity mask (oa_nbart_contiguity)
Returning 69 time steps as a dask array


In [20]:
s2_ds = load_ard(
    dc=dc,
    products=["ga_s2am_ard_3", "ga_s2bm_ard_3"],
    measurements=s2_measurements,
    cloud_mask="s2cloudless",
    mask_pixel_quality=True,
    mask_contiguity=True,
    skip_broken_datasets=True,
    verbose=True,
    dask_chunks={"time": 10, "x": 1024, "y": 1024},
    **query,
)


Finding datasets
    ga_s2am_ard_3
    ga_s2bm_ard_3
Applying s2cloudless pixel quality/cloud mask
Applying contiguity mask (oa_nbart_contiguity)
Returning 55 time steps as a dask array


In [21]:
for mapping in LANDSAT_BAND_MAPPING:
    if mapping["rename"]:
        ls_rename_dict = mapping["rename"]
        break

for mapping in SENTINEL_BAND_MAPPING:
    if mapping["rename"]:
        s2_rename_dict = mapping["rename"]
        break

ls_renamed = ls_ds.rename(ls_rename_dict)

s2_renamed = s2_ds.rename(s2_rename_dict)


In [22]:
ls_fc = run_fc_multi(
    ls_renamed, MEASUREMENTS_OBJ, LANDSAT_SENSOR_REGRESSION_COEFFICIENTS
)


In [23]:
s2_fc = run_fc_multi(
    s2_renamed, MEASUREMENTS_OBJ, SENTINEL_SENSOR_REGRESSION_COEFFICIENTS
)


In [24]:
for coord in ls_ds.coords:
    if coord not in ls_fc.coords:
        ls_fc = ls_fc.assign_coords({coord: ls_ds.coords[coord]})
ls_fc.attrs = ls_ds.attrs.copy()

merged_ls = xr.merge([ls_ds, ls_fc])


In [25]:
for coord in s2_ds.coords:
    if coord not in s2_fc.coords:
        s2_fc = s2_fc.assign_coords({coord: s2_ds.coords[coord]})
s2_fc.attrs = s2_ds.attrs.copy()

merged_s2 = xr.merge([s2_ds, s2_fc])


## merge sensors

In [26]:
# both Landsat and Sentinel-2 FC datasets merged here.
# NOTE: This is probably not done correctly and is for demo purposes/ Jenna learning only

merged_all = xr.concat([merged_ls, merged_s2], dim="time")
merged_all = merged_all.drop_vars(
    ["nbart_nir", "nbart_nir_1", "nbart_swir_1", "nbart_swir_2", "nbart_swir_3"]
)  # drop duplicate NIR band from S2

merged_all


Unnamed: 0,Array,Chunk
Bytes,539.54 MiB,40.00 MiB
Shape,"(124, 1068, 1068)","(10, 1024, 1024)"
Dask graph,52 chunks in 28 graph layers,52 chunks in 28 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 539.54 MiB 40.00 MiB Shape (124, 1068, 1068) (10, 1024, 1024) Dask graph 52 chunks in 28 graph layers Data type float32 numpy.ndarray",1068  1068  124,

Unnamed: 0,Array,Chunk
Bytes,539.54 MiB,40.00 MiB
Shape,"(124, 1068, 1068)","(10, 1024, 1024)"
Dask graph,52 chunks in 28 graph layers,52 chunks in 28 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,539.54 MiB,40.00 MiB
Shape,"(124, 1068, 1068)","(10, 1024, 1024)"
Dask graph,52 chunks in 28 graph layers,52 chunks in 28 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 539.54 MiB 40.00 MiB Shape (124, 1068, 1068) (10, 1024, 1024) Dask graph 52 chunks in 28 graph layers Data type float32 numpy.ndarray",1068  1068  124,

Unnamed: 0,Array,Chunk
Bytes,539.54 MiB,40.00 MiB
Shape,"(124, 1068, 1068)","(10, 1024, 1024)"
Dask graph,52 chunks in 28 graph layers,52 chunks in 28 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,539.54 MiB,40.00 MiB
Shape,"(124, 1068, 1068)","(10, 1024, 1024)"
Dask graph,52 chunks in 28 graph layers,52 chunks in 28 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 539.54 MiB 40.00 MiB Shape (124, 1068, 1068) (10, 1024, 1024) Dask graph 52 chunks in 28 graph layers Data type float32 numpy.ndarray",1068  1068  124,

Unnamed: 0,Array,Chunk
Bytes,539.54 MiB,40.00 MiB
Shape,"(124, 1068, 1068)","(10, 1024, 1024)"
Dask graph,52 chunks in 28 graph layers,52 chunks in 28 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,134.89 MiB,1.00 MiB
Shape,"(124, 1068, 1068)","(1, 1024, 1024)"
Dask graph,496 chunks in 4155 graph layers,496 chunks in 4155 graph layers
Data type,int8 numpy.ndarray,int8 numpy.ndarray
"Array Chunk Bytes 134.89 MiB 1.00 MiB Shape (124, 1068, 1068) (1, 1024, 1024) Dask graph 496 chunks in 4155 graph layers Data type int8 numpy.ndarray",1068  1068  124,

Unnamed: 0,Array,Chunk
Bytes,134.89 MiB,1.00 MiB
Shape,"(124, 1068, 1068)","(1, 1024, 1024)"
Dask graph,496 chunks in 4155 graph layers,496 chunks in 4155 graph layers
Data type,int8 numpy.ndarray,int8 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,134.89 MiB,1.00 MiB
Shape,"(124, 1068, 1068)","(1, 1024, 1024)"
Dask graph,496 chunks in 4155 graph layers,496 chunks in 4155 graph layers
Data type,int8 numpy.ndarray,int8 numpy.ndarray
"Array Chunk Bytes 134.89 MiB 1.00 MiB Shape (124, 1068, 1068) (1, 1024, 1024) Dask graph 496 chunks in 4155 graph layers Data type int8 numpy.ndarray",1068  1068  124,

Unnamed: 0,Array,Chunk
Bytes,134.89 MiB,1.00 MiB
Shape,"(124, 1068, 1068)","(1, 1024, 1024)"
Dask graph,496 chunks in 4155 graph layers,496 chunks in 4155 graph layers
Data type,int8 numpy.ndarray,int8 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,134.89 MiB,1.00 MiB
Shape,"(124, 1068, 1068)","(1, 1024, 1024)"
Dask graph,496 chunks in 4155 graph layers,496 chunks in 4155 graph layers
Data type,int8 numpy.ndarray,int8 numpy.ndarray
"Array Chunk Bytes 134.89 MiB 1.00 MiB Shape (124, 1068, 1068) (1, 1024, 1024) Dask graph 496 chunks in 4155 graph layers Data type int8 numpy.ndarray",1068  1068  124,

Unnamed: 0,Array,Chunk
Bytes,134.89 MiB,1.00 MiB
Shape,"(124, 1068, 1068)","(1, 1024, 1024)"
Dask graph,496 chunks in 4155 graph layers,496 chunks in 4155 graph layers
Data type,int8 numpy.ndarray,int8 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,134.89 MiB,1.00 MiB
Shape,"(124, 1068, 1068)","(1, 1024, 1024)"
Dask graph,496 chunks in 4155 graph layers,496 chunks in 4155 graph layers
Data type,int8 numpy.ndarray,int8 numpy.ndarray
"Array Chunk Bytes 134.89 MiB 1.00 MiB Shape (124, 1068, 1068) (1, 1024, 1024) Dask graph 496 chunks in 4155 graph layers Data type int8 numpy.ndarray",1068  1068  124,

Unnamed: 0,Array,Chunk
Bytes,134.89 MiB,1.00 MiB
Shape,"(124, 1068, 1068)","(1, 1024, 1024)"
Dask graph,496 chunks in 4155 graph layers,496 chunks in 4155 graph layers
Data type,int8 numpy.ndarray,int8 numpy.ndarray


## Add water masking

In [27]:
merged_all = masking.mask_invalid_data(merged_all)
ls_fc = masking.mask_invalid_data(ls_fc)
s2_fc = masking.mask_invalid_data(s2_fc)


In [28]:
# mask water

wo_ls = dc.load(
    product="ga_ls_wo_3",
    group_by="solar_day",
    time=(year),
    geopolygon=geom,
    dask_chunks={"time": 1,"x": 1024, "y": 1024},
    fuse_func=wofs_fuser,
    # like=ls_ds,
)

wo_s2 = dc.load(
    product="ga_s2_wo_provisional_3",
    group_by="solar_day",
    time=(year),
    geopolygon=geom,
    dask_chunks={"time":1, "x": 1024, "y": 1024},
    fuse_func=wofs_fuser,
    # like=ls_ds,
)

# wo_s2 = dc.load(
#     product="ga_s2_wo_provisional_3",
#     fuse_func=wofs_fuser,
#     group_by="solar_day",
#     #like=s2_ds,
# )


In [29]:
wo_merged = xr.concat([wo_ls, wo_s2], dim="time")
wo_merged


Unnamed: 0,Array,Chunk
Bytes,6.03 GiB,1.09 MiB
Shape,"(158, 3202, 3202)","(1, 534, 534)"
Dask graph,12798 chunks in 31 graph layers,12798 chunks in 31 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 6.03 GiB 1.09 MiB Shape (158, 3202, 3202) (1, 534, 534) Dask graph 12798 chunks in 31 graph layers Data type float32 numpy.ndarray",3202  3202  158,

Unnamed: 0,Array,Chunk
Bytes,6.03 GiB,1.09 MiB
Shape,"(158, 3202, 3202)","(1, 534, 534)"
Dask graph,12798 chunks in 31 graph layers,12798 chunks in 31 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [30]:
# wo_mask = masking.make_mask(wo_merged.water, dry=True)
wo_mask = wo_merged.water == 0


In [31]:
fc_merged_wo_masked = merged_all.where(wo_mask)
fc_ls_wo_masked = ls_fc.where(wo_mask)
fc_s2_wo_masked = s2_fc.where(wo_mask)


## Means, Medians and rolling medians oh my...

In [32]:
fc_merged_wo_masked = fc_merged_wo_masked.assign_coords(
    year=fc_merged_wo_masked["time"].dt.year,
    month=fc_merged_wo_masked["time"].dt.month,
    season=fc_merged_wo_masked["time"].dt.season,
)

monthly_medians = fc_merged_wo_masked.groupby(["year", "month"]).median(
    dim="time", keep_attrs=True
)

monthly_counts = fc_merged_wo_masked["BS"].groupby(["year", "month"]).count(dim="time")
monthly_medians["obs_count"] = monthly_counts


seasonal_medians = fc_merged_wo_masked.groupby(["year", "season"]).median(
    dim="time", keep_attrs=True
)

seasonal_counts = (
    fc_merged_wo_masked["BS"].groupby(["year", "season"]).count(dim="time")
)
seasonal_medians["obs_count"] = seasonal_counts


In [None]:
monthly_medians_to_export = monthly_medians.drop_vars(
    ["nbart_red", "nbart_green", "nbart_blue"]
)

# if save_cogs is true, files will be saved. Otherwise, they will not be saved.
if SAVE_OUT_COGS:
    save_as_cogs(
        monthly_medians_to_export,
        "monthly_fc_blended_medians",
        region_codes,
        cadence="year_month",
    )

if SAVE_OUT_NETCDF:
    save_netcdfs(
        monthly_medians_to_export,
        "monthly_fc_blended_medians",
    )


Removed existing file: /home/jovyan/gdata1/projects/fc-sub-annual/results/tiles/x148y166/s2-ls-coefficients-current/monthly_fc_blended_medians_2024-03-01-2024-11-30_x148y166.nc


  dest = _reproject(


In [None]:
seasonal_medians_to_export = seasonal_medians.drop_vars(
    ["nbart_red", "nbart_green", "nbart_blue"]
)

if SAVE_OUT_COGS:
    save_as_cogs(
        seasonal_medians_to_export,
        "seasonal_fc_blended_medians",
        region_codes,
        cadence="year_season",
    )

if SAVE_OUT_NETCDF:
    save_netcdfs(
        seasonal_medians_to_export,
        "seasonal_fc_blended_medians",
    )


In [None]:
%%time

if FIGURES:
    single_year_monthly_medians = monthly_medians.isel(year=0).persist()

    months = single_year_monthly_medians["month"].values
    n_months = len(months)

    n_steps = n_months

    bounds = np.arange(0, 10, 1)
    norm = mcolors.BoundaryNorm(boundaries=bounds, ncolors=len(bounds) - 1)
    cmap = plt.get_cmap('gnuplot', len(bounds) - 1)

    vmin = 0
    vmax = 10

    fig = plt.figure(figsize=(5 * 4.5, 5 * n_steps))
    gs = gridspec.GridSpec(nrows=n_steps, ncols=5, width_ratios=[1, 1, 1, 1, 0.05])

    axs = np.empty((n_steps, 5), dtype=object)
    for i in range(n_steps):
        for j in range(5):
            axs[i, j] = fig.add_subplot(gs[i, j], projection=ccrs.epsg(3577))

    im_ue = None
    im_obs = None

    for i in range(n_steps):
        month_number = single_year_monthly_medians['month'].values[i]
        month_name = calendar.month_name[month_number]
        
        # True Colour RGB
        rgb_true = xr.concat(
            [
                single_year_monthly_medians["nbart_red"].isel(month=i),
                single_year_monthly_medians["nbart_green"].isel(month=i),
                single_year_monthly_medians["nbart_blue"].isel(month=i),
            ],
            dim="band",
        )
        rgb_true.plot.imshow(ax=axs[i, 0], robust=True, transform=ccrs.epsg(3577))
        axs[i, 0].set_title(f"{month_name} - RGB", fontsize=10)

        # FC RGB
        rgb_fc = xr.concat(
            [
                single_year_monthly_medians["BS"].isel(month=i),
                single_year_monthly_medians["PV"].isel(month=i),
                single_year_monthly_medians["NPV"].isel(month=i),
            ],
            dim="band",
        )
        rgb_fc.plot.imshow(ax=axs[i, 1], robust=True, add_colorbar=False, transform=ccrs.epsg(3577))
        axs[i, 1].set_title(f"FC", fontsize=10)
        
        # median unmixing error
        ue_median = single_year_monthly_medians["UE"].isel(month=i)
        im_ue = ue_median.plot.imshow(
            ax=axs[i, 2], cmap='magma', add_colorbar=False, transform=ccrs.epsg(3577)
        )
        axs[i, 2].set_title(f"median UE", fontsize=10)
        
        # Observation count
        obs_count = single_year_monthly_medians["obs_count"].isel(month=i)
        im_obs = obs_count.plot.imshow(
            ax=axs[i, 3], cmap=cmap, vmin=vmin, vmax=vmax, add_colorbar=False, transform=ccrs.epsg(3577)
        )
        axs[i, 3].set_title(f"obs count", fontsize=10)

    # Add colorbar for UE
    cbar_ax_ue = fig.add_axes([0.92, 0.65, 0.02, 0.2])
    cbar_ue = fig.colorbar(im_ue, cax=cbar_ax_ue)
    cbar_ue.set_label("Unmixing Error (UE)")

    # Add colorbar for Observation Count
    cbar_ax_obs = fig.add_axes([0.92, 0.35, 0.02, 0.2])
    cbar_obs = fig.colorbar(im_obs, cax=cbar_ax_obs, extend='max')
    cbar_obs.set_label("Observation Count")

    fig.suptitle("Monthly blended Landsat-Sentinel-2 FC, observation counts and median unmixing error", fontsize=16)

    tile_dir = os.path.join(output_dir, f"figs/{region_codes[0]}")

    variables_dir_name = VARIABLES_YAML.split("_")[-1].replace(".yaml", "")

    output_path = os.path.join(
        figs_dir,
        variables_dir_name,
        f"fc_blended_monthly_median_with_obs_count_{region_codes[0]}.png",
    )

    if os.path.exists(output_path):
        os.remove(output_path)

    plt.savefig(
        output_path,
        bbox_inches="tight",
        dpi=200,
    )

    plt.show()


In [None]:
%%time

if FIGURES:
    single_year_seasonal_medians = seasonal_medians.isel(year=0).persist()

    n_steps = len(single_year_seasonal_medians["season"].values)

    bounds = np.arange(0,20, 1)
    norm = mcolors.BoundaryNorm(boundaries=bounds, ncolors=len(bounds) - 1)
    cmap = plt.get_cmap('gnuplot', len(bounds)- 1)

    vmin = 0
    vmax = 20

    # Use gridspec to add an extra column for the colorbar
    fig = plt.figure(figsize=(5 * 4, 5 * n_steps))
    gs = gridspec.GridSpec(nrows=n_steps, ncols=5, width_ratios=[1, 1, 1, 1, 0.05])

    axs = np.empty((n_steps, 4), dtype=object)
    for i in range(n_steps):
        for j in range(4):
            axs[i, j] = fig.add_subplot(gs[i, j], projection=ccrs.epsg(3577))


    for i in range(n_steps):
        # get calendar month name for plotting
        season_label = single_year_seasonal_medians['season'].values[i]
        
        # True Colour RGB
        rgb_true = xr.concat(
                [
                    single_year_seasonal_medians["nbart_red"].isel(season=i),
                    single_year_seasonal_medians["nbart_green"].isel(season=i),
                    single_year_seasonal_medians["nbart_blue"].isel(season=i),
                ],
                dim="band",
            )
        rgb_true.plot.imshow(ax=axs[i, 0], robust=True, transform=ccrs.epsg(3577))
        axs[i, 0].set_title(f"{season_label} - RGB", fontsize=10)

        # FC RGB
        rgb_fc = xr.concat(
                [
                    single_year_seasonal_medians["BS"].isel(season=i),
                    single_year_seasonal_medians["PV"].isel(season=i),
                    single_year_seasonal_medians["NPV"].isel(season=i),
                ],
                dim="band",
            )
        rgb_fc.plot.imshow(ax=axs[i, 1], robust=True, add_colorbar=False, transform=ccrs.epsg(3577))
        axs[i, 1].set_title(f"FC", fontsize=10)
        
        # median unmixing error
        ue_median = single_year_seasonal_medians["UE"].isel(season=i)
        im_ue = ue_median.plot.imshow(
            ax=axs[i, 3], cmap='magma', add_colorbar=False, transform=ccrs.epsg(3577)
        )
        axs[i, 3].set_title(f"median UE error", fontsize=10)
        
        # Observation count
        obs_count = single_year_seasonal_medians["obs_count"].isel(season=i)
        im_obs = obs_count.plot.imshow(
            ax=axs[i, 2], cmap=cmap, vmin=vmin, vmax=vmax, add_colorbar=False, transform=ccrs.epsg(3577)
        )
        axs[i, 2].set_title(f"obs count", fontsize=10)
        
        
    # add colorbar for UE
    cbar_ax_ue = fig.add_axes([0.92, 0.65, 0.02, 0.2])
    cbar_ue = fig.colorbar(im_ue, cax = cbar_ax_ue)
    cbar_ue.set_label("Unmixing Error (UE)")

    # add colorbar for observation count
    cbar_ax_obs = fig.add_axes([0.92, 0.35, 0.02, 0.2])
    cbar_obs = fig.colorbar(im_obs,  cax=cbar_ax_obs, extend='max')
    cbar_obs.set_label("Observation Count")

    fig.suptitle("Seasonal blended Landsat-Sentinel-2 FC, observation counts and median unmixing error", fontsize=16)

    tile_dir = os.path.join(output_dir, f"figs/{region_codes[0]}")

    variables_dir_name = VARIABLES_YAML.split("_")[-1].replace(".yaml", "")

    output_path = os.path.join(
        figs_dir,
        variables_dir_name,
        f"fc_blended_seasonal_median_with_obs_count_{region_codes[0]}.png",
    )

    if os.path.exists(output_path):
        os.remove(output_path)

    plt.savefig(
        output_path,
        bbox_inches="tight",
        dpi=200,
    )

    plt.show()
