In [1]:
import numpy as np
import rasterio
from pystac_client import Client
import stackstac
from rasterio.windows import Window
from omnicloudmask import predict_from_array

In [3]:
def process_tile(tile, date, output_path_base, chunk_size):
    """
    Process a Sentinel-2 tile for cloud masking and save the output to GeoTIFFs.

    Args:
        tile (str): Sentinel-2 MGRS tile ID.
        date (str): Date of the image (YYYY-MM-DD).
        output_path_base (str): Base path to save output GeoTIFFs.
        chunk_size (int): Size of chunks for processing.
    """
    # Connect to STAC API
    catalog_url = "https://earth-search.aws.element84.com/v1"
    catalog = Client.open(catalog_url)

    # Search for the tile
    query = {
        "mgrs:utm_zone": {"eq": tile[:2]},
        "mgrs:latitude_band": {"eq": tile[2:3]},
        "mgrs:grid_square": {"eq": tile[3:5]},
    }
    search = catalog.search(
        collections=["sentinel-2-l2a"],
        query=query,
        datetime=f"{date}/{date}",
    )
    items = list(search.items())
    if not items:
        raise ValueError("No items found for the specified tile and date.")
    item = items[0]

    # Get asset URLs
    assets = item.assets
    red_url = assets["red"].href
    green_url = assets["green"].href
    blue_url = assets["blue"].href
    nir_url = assets["nir"].href
    scl_url = assets["scl"].href

    with rasterio.open(red_url) as red_src, \
         rasterio.open(green_url) as green_src, \
         rasterio.open(blue_url) as blue_src, \
         rasterio.open(nir_url) as nir_src:

        profile = red_src.profile
        profile.update(dtype="float32", count=3, compress="lzw")

        omni_masked_output_path = f"{output_path_base}_omni_masked.tif"
        with rasterio.open(omni_masked_output_path, "w", **profile) as dst_omni:
            for row in range(0, red_src.height, chunk_size):
                for col in range(0, red_src.width, chunk_size):
                    window = Window.from_slices(
                        slice(row, min(row + chunk_size, red_src.height)),
                        slice(col, min(col + chunk_size, red_src.width)),
                    )

                    red_chunk = red_src.read(1, window=window)
                    green_chunk = green_src.read(1, window=window)
                    blue_chunk = blue_src.read(1, window=window)
                    nir_chunk = nir_src.read(1, window=window)

                    input_chunk = np.stack([
                        red_chunk,
                        green_chunk,
                        nir_chunk,
                    ])

                    pred_chunk = predict_from_array(input_chunk).squeeze()

                    rgb_omni_masked_chunk = np.stack([
                        np.where(pred_chunk == 0, red_chunk, np.nan),
                        np.where(pred_chunk == 0, green_chunk, np.nan),
                        np.where(pred_chunk == 0, blue_chunk, np.nan),
                    ])

                    try:
                        dst_omni.write(rgb_omni_masked_chunk.astype("float32"), window=window)
                    except:
                        pass

    print(f"Processing complete. Outputs saved to:")
    print(f" - OmniCloudMask Masked RGB: {omni_masked_output_path}")

In [2]:
tile = "26XPF"
date = "2024-08-26"
chunk_size = 1024  
process_tile(tile, date, output_path_base=tile, chunk_size=chunk_size)