In [None]:
import geopandas as gpd
from pystac_client import Client
from odc.stac import load, configure_s3_access
from odc.algo import mask_cleanup
from dask.distributed import Client as DaskClient
from pathlib import Path
from numpy import nanpercentile

In [None]:
areas = gpd.read_file('areas.geojson')
# Strip newlines
areas["Capital"] = areas["Capital"].str.strip()
areas["Country"] = areas["Country"].str.strip()
areas["Name"] = areas["Name"].str.strip()

In [None]:
def mosaic_region(
    area,
    overwrite=True,
    sentinel2=True,
    catalog="https://earth-search.aws.element84.com/v1",
    year="2024",
    write_data=False,
):
    client = Client.open(catalog)
    country = area.Country.lower().replace(" ", "_")
    if sentinel2:
        sensor = "s2"
    else:
        sensor = "ls"
    
    if write_data:
        output = Path(f"{sensor}_data_{country}.tif")
    else:
        output = Path(f"{sensor}_{country}.tif")
    geom = area.geometry

    if output.exists() and not overwrite:
        print(f"Skipping: {country}, {output} already exists")
        return None
    else:
        print(f"Working on: {country}")

    with DaskClient(n_workers=4, threads_per_worker=24, memory_limit="250GB"):
        print("Searching for items in area")
        if sentinel2:
            items = client.search(
                collections=["sentinel-2-c1-l2a"],
                intersects=geom,
                datetime=year,
            ).item_collection()
            print(f"Found {len(items)} items")

            data = load(
                items,
                geopolygon=geom,
                measurements=["red", "green", "blue", "scl"],
                chunks={"x": 4096, "y": 4096},
                groupby="solar_day",
            )
        else:
            configure_s3_access(cloud_defaults=True, requester_pays=True)

            # Search for Landsat items
            items = client.search(
                collections=["landsat-c2-l2"],
                intersects=geom,
                datetime=year,
            ).item_collection()

            # Load Landsat with ODC STAC
            data = load(
                items=items,
                bbox=area.geometry.bounds,
                bands=["red", "green", "blue", "qa_pixel"],
                chunks={"x": 4096, "y": 4096},
                groupby="solar_day",
            )

        print(
            f"Loaded data with dimensions x: {data.x.size}, y: {data.y.size}, time: {data.time.size}"
        )

        if sentinel2:
            mask_flags = [0, 3, 8, 9]
            mask = data.scl.isin(mask_flags)
        else:
            bitflags = 0b00011000
            qa_mask = (data.qa_pixel & bitflags) != 0
            nodata_mask = data.red == data.red.odc.nodata

            mask = qa_mask | nodata_mask

        # Clean up mask
        filters = [("opening", 4), ("closing", 12)]
        filtered_mask = mask_cleanup(mask, filters)

        if sentinel2:
            masked = data.where(~filtered_mask).drop_vars("scl")
        else:
            masked = data.where(~filtered_mask).drop_vars("qa_pixel")

        print("Computing median")
        median = masked.median("time").compute()

        if not write_data:
            print("Working out percentile and making pretty picture")
            percentile_stretch = (1, 99)
            rgb_array = median.to_array().values
            vmin, vmax = nanpercentile(rgb_array, percentile_stretch)
            visualisation = median.odc.to_rgba(vmin=vmin, vmax=vmax).compute()

            visualisation.odc.write_cog(output, overwrite=True)
            print(f"Saved visualisation to {output}")

            return visualisation
        else:
            output = Path(f"{sensor}_data_{country}.tif")
            median.to_array().odc.write_cog(output, overwrite=True)

            print(f"Saved data to {output}")

            return median

In [None]:
for area in areas.itertuples():
    visualisation = mosaic_region(area, sentinel2=True, overwrite=False, write_data=True)