In [9]:
import numpy as np
import matplotlib.pyplot as plt
import rasterio
from rasterio.transform import from_bounds
from pystac_client import Client
import stackstac
import rioxarray
from omnicloudmask import predict_from_array
import torch
torch.cuda.is_available()

In [4]:
def save_pred_mask(pred_mask, xda_data, output_path):
    """
    Save a predicted mask as a GeoTIFF with one band.

    Args:
        pred_mask (numpy.ndarray): The predicted mask to save.
        xda_data (xarray.DataArray): The stacked data array for metadata.
        output_path (str): Path to save the output GeoTIFF.
    """
    # Extract spatial metadata
    bounds = xda_data.rio.bounds()
    transform = from_bounds(
        *bounds,
        width=pred_mask.shape[1],  # Columns in the mask
        height=pred_mask.shape[0],  # Rows in the mask
    )
    crs = xda_data.rio.crs
    # Save to GeoTIFF
    with rasterio.open(
        output_path,
        "w",
        driver="GTiff",
        height=pred_mask.shape[0],
        width=pred_mask.shape[1],
        count=1,
        dtype=pred_mask.dtype,
        crs=crs,
        transform=transform,
    ) as dst:
        dst.write(pred_mask, 1)
        
def process_tile(tile, date, output_path_base):
    """
    Process a Sentinel-2 tile for cloud masking and save the output to GeoTIFFs.

    Args:
        tile (str): Sentinel-2 MGRS tile ID.
        date (str): Date of the image (YYYY-MM-DD).
        output_path_base (str): Base path to save output GeoTIFFs.
    """
    # Connect to STAC API
    catalog_url = "https://earth-search.aws.element84.com/v1"
    catalog = Client.open(catalog_url)

    # Search for the tile
    query = {
        "mgrs:utm_zone": {"eq": tile[:2]},
        "mgrs:latitude_band": {"eq": tile[2:3]},
        "mgrs:grid_square": {"eq": tile[3:5]},
    }
    search = catalog.search(
        collections=["sentinel-2-l2a"],
        query=query,
        datetime=f"{date}/{date}",
    )
    items = list(search.items())
    if not items:
        raise ValueError("No items found for the specified tile and date.")
    item = items[0]

    xda_data = stackstac.stack(
        item,
        epsg=item.properties["proj:epsg"]
    )
    # Ensure the data is in EPSG:4326
    xda_data = xda_data.rio.write_crs(item.properties["proj:epsg"], inplace=True)

    # Convert to numpy array for processing
    red = xda_data.sel(band="red").values
    green = xda_data.sel(band="green").values
    blue = xda_data.sel(band="blue").values
    nir = xda_data.sel(band="nir").values

    # Predict cloud and cloud shadow masks using omnicloudmask
    pred_mask = predict_from_array(np.stack([red[0], green[0], nir[0]]), batch_size=1)
    # Save the predicted mask as a GeoTIFF
    output_path = f"{output_path_base}_{tile}_pred_mask.tif"
    save_pred_mask(pred_mask[0], xda_data, output_path)

    print(f"Processing complete. Outputs saved to:")
    print(f" - OmniCloudMask Masked RGB: {output_path}")


In [8]:
# Define parameters
catalog_url = "https://earth-search.aws.element84.com/v1"
collection = "sentinel-2-l2a"
tile = "26XPF"
date = "2024-08-26" # 26XPF: 20240826 og 20230823, 21WXS: 20230708 og 20230807, 20QLF: 20230228 34WET: 20240527 og 20240714

tile_date_pair = [
    ('26XPF','2024-08-26'),
    ('26XPF','2023-08-23'),
    ('21WXS','2023-07-08'),
    ('21WXS','2023-08-07'),
    ('20QLF','2023-02-28'),
    ('34WET','2024-07-14'),
    ('34WET','2024-05-27'),
]

In [9]:
for pair in tile_date_pair:
    process_tile(pair[0], pair[1], pair[1])

Processing complete. Outputs saved to:
 - OmniCloudMask Masked RGB: 2024-08-26_26XPF_pred_mask.tif
Processing complete. Outputs saved to:
 - OmniCloudMask Masked RGB: 2023-08-23_26XPF_pred_mask.tif
Processing complete. Outputs saved to:
 - OmniCloudMask Masked RGB: 2023-07-08_21WXS_pred_mask.tif
Processing complete. Outputs saved to:
 - OmniCloudMask Masked RGB: 2023-08-07_21WXS_pred_mask.tif
Processing complete. Outputs saved to:
 - OmniCloudMask Masked RGB: 2023-02-28_20QLF_pred_mask.tif
Processing complete. Outputs saved to:
 - OmniCloudMask Masked RGB: 2024-07-14_34WET_pred_mask.tif
Processing complete. Outputs saved to:
 - OmniCloudMask Masked RGB: 2024-05-27_34WET_pred_mask.tif
