In [None]:
import geopandas as gpd
from pystac_client import Client
from odc.stac import load, configure_s3_access
from odc.algo import mask_cleanup
from dask.distributed import Client as DaskClient
from pathlib import Path

In [None]:
areas = gpd.read_file('areas.geojson')
# Strip newlines
areas["Capital"] = areas["Capital"].str.strip()
areas["Country"] = areas["Country"].str.strip()
areas["Name"] = areas["Name"].str.strip()

In [None]:
def mosaic_region(
    area,
    overwrite=True,
    sentinel2=True,
    catalog="https://earth-search.aws.element84.com/v1",
    year="2024",
    write_data=True,
    return_data=False
):
    client = Client.open(catalog)
    country = area.Country.lower().replace(" ", "_")
    if sentinel2:
        sensor = "s2"
    else:
        sensor = "ls"
    
    if write_data:
        output = Path(f"{sensor}_data_{country}.tif")
    else:
        output = Path(f"{sensor}_{country}.tif")
    geom = area.geometry

    if output.exists() and not overwrite:
        print(f"Skipping: {country}, {output} already exists")
        return None
    else:
        print(f"Working on: {country}")

    with DaskClient(n_workers=4, threads_per_worker=24, memory_limit="250GB"):
        print("Searching for items in area")
        if sentinel2:
            items = client.search(
                collections=["sentinel-2-c1-l2a"],
                intersects=geom,
                datetime=year,
            ).item_collection()
            print(f"Found {len(items)} items")

            data = load(
                items,
                geopolygon=geom,
                measurements=["red", "green", "blue", "scl"],
                chunks={"x": 4096, "y": 4096},
                groupby="solar_day",
            )
        else:
            configure_s3_access(cloud_defaults=True, requester_pays=True)

            # Search for Landsat items
            items = client.search(
                collections=["landsat-c2-l2"],
                intersects=geom,
                datetime=year,
            ).item_collection()

            # Load Landsat with ODC STAC
            data = load(
                items=items,
                bbox=area.geometry.bounds,
                bands=["red", "green", "blue", "qa_pixel"],
                chunks={"x": 4096, "y": 4096},
                groupby="solar_day",
            )

        print(
            f"Loaded data with dimensions x: {data.x.size}, y: {data.y.size}, time: {data.time.size}"
        )

        if return_data:
            return data

        if sentinel2:
            mask_flags = [0, 3, 8, 9]
            mask = data.scl.isin(mask_flags)
        else:
            bitflags = 0b00011000
            qa_mask = (data.qa_pixel & bitflags) != 0
            nodata_mask = data.red == data.red.odc.nodata

            mask = qa_mask | nodata_mask

        # Clean up mask
        filters = [("opening", 4), ("closing", 12)]
        filtered_mask = mask_cleanup(mask, filters)

        if sentinel2:
            masked = data.where(~filtered_mask).drop_vars("scl")
        else:
            masked = data.where(~filtered_mask).drop_vars("qa_pixel")

        print("Computing median")
        median = masked.median("time").compute()

        if write_data:
            output = Path(f"{sensor}_data_{country}.tif")
            median.to_array().odc.write_cog(output, overwrite=True)

            print(f"Saved data to {output}")

        return median

In [None]:
# Visualise one region.

for area in areas.itertuples():
    if area.Country == "Melsisi":
        mosaic = mosaic_region(area, sentinel2=True, write_data=False)

mosaic

In [None]:
# View the mosaic interactively

mosaic.odc.explore(vmin=1000, vmax=3000)

In [None]:
# # Write a single region

# for area in areas.itertuples():
#     if area.Country == "Melsisi":
#         visualisation = mosaic_region(area, sentinel2=True, overwrite=False, write_data=True)

In [None]:
# # Write all regions

# for area in areas.itertuples():
#     visualisation = mosaic_region(area, sentinel2=True, overwrite=False, write_data=True)

In [None]:
# Get all data, but not as a mosaic (so still daily)

# for area in areas.itertuples():
#     if area.Country == "Baravet":
#         break

# data = mosaic_region(area, sentinel2=True, overwrite=False, write_data=True, return_data=True)
# data = data.drop_vars("scl")

# data

In [None]:
# Write a single day

# one = data.sel(time="2024-05-07", method="nearest")

# one.to_array().odc.write_cog("s2_data_baravet_20240507.tif", overwrite=True)