In [None]:
import geopandas as gpd
from pystac_client import Client
from odc.stac import load, configure_s3_access
from odc.algo import mask_cleanup
from dask.distributed import Client as DaskClient
from pathlib import Path
from numpy import nanpercentile

In [None]:
areas = gpd.read_file('areas.geojson')

areas.explore()

In [None]:
# Strip newlines
areas["Capital"] = areas["Capital"].str.strip()
areas["Country"] = areas["Country"].str.strip()
areas["Name"] = areas["Name"].str.strip()

# Select Country is Australia
one = areas[areas["Country"] == "Fiji_1"]
one

In [26]:
client = Client.open("https://earth-search.aws.element84.com/v1")
collection = "sentinel-2-c1-l2a"
datetime = "2024"


def mosaic_region(region, overwrite=True, sentinel2=True):
    try:
        if sentinel2:
            sensor = "s2"
        else:
            sensor = "ls"
        output = Path(f"{sensor}_{region.Country.values[0].lower()}.tif")
        geom = region.geometry.values[0]

        if output.exists() and not overwrite:
            print(f"Output file {output} already exists, skipping")
            return None

        dask_client = DaskClient(n_workers=2, threads_per_worker=16)
        print("Searching for items in region")
        if sentinel2:
            items = client.search(
                collections=[collection],
                intersects=geom,
                datetime=datetime,
            ).item_collection()
            print(f"Found {len(items)} items")

            data = load(
                items,
                geopolygon=geom,
                measurements=["red", "green", "blue", "scl"],
                chunks={"x": 2048, "y": 2048},
                groupby="solar_day",
            )
        else:
            configure_s3_access(cloud_defaults=True, requester_pays=True)

            # Search for Landsat items
            items = client.search(
                collections=["landsat-c2-l2"],
                intersects=geom,
                datetime=datetime,
            ).item_collection()

            # Load Landsat with ODC STAC
            data = load(
                items=items,
                geopolygon=geom,
                bands=["red", "green", "blue", "qa_pixel"],
                chunks={"x": 2048, "y": 2048},
                groupby="solar_day",
            )

        print(
            f"Loaded data with dimensions x: {data.x.size}, y: {data.y.size}, time: {data.time.size}"
        )

        if sentinel2:
            mask_flags = [0, 3, 8, 9]
            mask = data.scl.isin(mask_flags)
        else:
            bitflags = 0b00011000
            mask = (data.qa_pixel & bitflags) != 0

        # Clean up mask
        filters = [("opening", 4), ("closing", 12)]
        filtered_mask = mask_cleanup(mask, filters)

        if sentinel2:
            masked = data.where(~filtered_mask).drop_vars("scl")
        else:
            masked = data.where(~filtered_mask).drop_vars("qa_pixel")

        print("Computing median")
        median = masked.median("time").compute()

        print("Working out percentile and making pretty picture")
        percentile_stretch = (1, 99)
        rgb_array = median.to_array().values
        vmin, vmax = nanpercentile(rgb_array, percentile_stretch)
        visualisation = median.odc.to_rgba(vmin=vmin, vmax=vmax).compute()

        visualisation.odc.write_cog(output, overwrite=True)
        print(f"Saved visualisation to {output}")

        return visualisation
    finally:
        dask_client.close()


visualisation = mosaic_region(one, sentinel2=False)

In [None]:
visualisation.odc.explore()