In [3]:
import numpy as np
import rasterio
from pystac_client import Client
import stackstac
from rasterio.windows import Window
from omnicloudmask import predict_from_array

In [5]:
tile = "26XPF"
date = "2024-08-26"
chunk_size = 1024  

In [None]:
process_tile(tile, date, output_path_base=tile, chunk_size=chunk_size)

In [7]:
def process_tile(tile, date, output_path_base, chunk_size):
    """
    Process a Sentinel-2 tile for cloud masking and save the output to a GeoTIFF.

    Args:
        tile (str): Sentinel-2 MGRS tile ID.
        date (str): Date of the image (YYYY-MM-DD).
        output_path (str): Path to save the output GeoTIFF.
        chunk_size (int): Size of chunks for processing.
    """
    # Connect to STAC API
    catalog_url = "https://earth-search.aws.element84.com/v1"
    catalog = Client.open(catalog_url)

    # Search for the tile
    query={
        "mgrs:utm_zone": {"eq": tile[:2]},
        "mgrs:latitude_band": {"eq":  tile[2:3]},
        "mgrs:grid_square": {"eq":  tile[3:5]}
    }
    search = catalog.search(
        collections=["sentinel-2-l2a"],
        query=query,
        datetime=f"{date}/{date}",
    )
    items = list(search.items())
    if not items:
        raise ValueError("No items found for the specified tile and date.")
    item = items[0] 
    
    assets = item.assets
    red_url = assets["red"].href
    green_url = assets["green"].href
    blue_url = assets["blue"].href
    nir_url = assets["nir"].href
    scl_url = assets["scl"].href
    # Open the bands using rasterio
    with rasterio.open(red_url) as red_src, \
         rasterio.open(green_url) as green_src, \
         rasterio.open(nir_url) as nir_src, \
         rasterio.open(blue_url) as blue_src, \
         rasterio.open(scl_url) as scl_src:

        # Prepare output GeoTIFF
        profile = red_src.profile
        profile.update(dtype="float32", count=3, compress="lzw")

        # Paths for the three outputs
        original_output_path = f"{output_path_base}_original.tif"
        omni_masked_output_path = f"{output_path_base}_omni_masked.tif"
        scl_masked_output_path = f"{output_path_base}_scl_masked.tif"
        with rasterio.open(original_output_path, "w", **profile) as dst_original, \
             rasterio.open(omni_masked_output_path, "w", **profile) as dst_omni, \
             rasterio.open(scl_masked_output_path, "w", **profile) as dst_scl:
            # Process the image in chunks
            for row in range(0, int(np.ceil(red_src.height/chunk_size)*chunk_size), chunk_size):
                for col in range(0, int(np.ceil(red_src.width/chunk_size)*chunk_size), chunk_size):
                    # Define the window
                    window = Window(col, row, chunk_size, chunk_size)

                    # Read chunks
                    red_chunk = red_src.read(1, window=window)
                    green_chunk = green_src.read(1, window=window)
                    blue_chunk = blue_src.read(1, window=window)
                    nir_chunk = nir_src.read(1, window=window)
                    scl_chunk = scl_src.read(1, window=window)

                    # Normalize the chunk for OmniCloudMask
                    input_chunk = np.stack([
                        red_chunk,
                        green_chunk,
                        nir_chunk,
                    ])

                    # Predict cloud mask for the chunk
                    pred_chunk = predict_from_array(input_chunk, patch_size=int(chunk_size), patch_overlap=int(np.floor(chunk_size/2)))

                    # Apply cloud masking (OmniCloudMask-based)
                    rgb_omni_masked_chunk = np.stack([
                        np.where(pred_chunk == 0, red_chunk, np.nan)[0],
                        np.where(pred_chunk == 0, green_chunk, np.nan)[0],
                        np.where(pred_chunk == 0, blue_chunk, np.nan)[0],
                    ])
                    # SCL-based masking (assuming cloud class value is 3 or 8, adjust if needed)
                    scl_mask = np.isin(scl_chunk, [3, 8,9,10])  # Example cloud values, adjust as needed
                    try:
                        rgb_scl_masked_chunk = np.stack([
                            np.where(scl_mask, np.nan, red_chunk),
                            np.where(scl_mask, np.nan, green_chunk),
                            np.where(scl_mask, np.nan, blue_chunk),
                        ])
                    except ValueError as e:
                        print(f"ValueError at row={row}, col={col}, scl_mask shape={scl_chunk.shape}, red_chunk shape={red_chunk.shape}")
                        raise e
                    # Write original RGB
                    dst_original.write(np.stack([red_chunk, green_chunk, blue_chunk]).astype("float32"), window=window)

                    # Write OmniCloudMask-masked RGB
                    dst_omni.write(rgb_omni_masked_chunk.astype("float32"), window=window)

                    # Write SCL-masked RGB
                    dst_scl.write(rgb_scl_masked_chunk.astype("float32"), window=window)

    print(f"Processing complete. Outputs saved to:")
    print(f" - Original RGB: {original_output_path}")
    print(f" - OmniCloudMask Masked RGB: {omni_masked_output_path}")
    print(f" - SCL Masked RGB: {scl_masked_output_path}")
