## Predict from SAFE

In [1]:
from pathlib import Path
from omnicloudmask import predict_from_load_func, load_s2

In [2]:
# Paths to scenes (L1C and or L2A)
scene_paths = [Path("/home/data/S2B_MSIL2A_20240414T090559_N0510_R050_T34SFG_20240414T104805.SAFE")]

# Predict masks for scenes
pred_paths = predict_from_load_func(scene_paths, load_s2) # This saves the prediction as GeoTIFF automatically

(…)nety_004.pycls_in1k_PT_state.safetensors:   0%|          | 0.00/72.3M [00:00<?, ?B/s]

(…)next_small.usi_in1k_PT_state.safetensors:   0%|          | 0.00/37.2M [00:00<?, ?B/s]

Running inference using cpu float32:   0%|          | 0/1 [00:00<?, ?it/s]

## Predict from arrays

In [1]:
from pathlib import Path
import numpy as np
import rasterio

from omnicloudmask import predict_from_array 

In [10]:
def load_s2_rgbnir_10m(safe_path: str | Path) -> np.ndarray:
    """
    Return a 3xH×W float32 array with Red, Green, NIR (B04, B03, B08) at 10 m
    resolution from a Sentinel-2 .SAFE product.
    
    Parameters
    ----------
    safe_path : pathlib.Path
        This is the path of the .SAFE directory.
        
    
    Returns
    -------
    numpy.ndarray
        A ``(3, H, W)`` float32 array containing, in order,
        ``(Red, Green, NIR)`` reflectance at 10 m resolution.

    Raises
    ------
    FileNotFoundError
        If any of the required 10 m JP2 band files cannot be located inside
        ``safe_path``.
    """

    # Collect the three band files – works for L1C and L2A
    band_files = {  # key → JP2 pattern
        "B04": list(safe_path.glob("**/*B04_10m.jp2")),
        "B03": list(safe_path.glob("**/*B03_10m.jp2")),
        "B08": list(safe_path.glob("**/*B08_10m.jp2")),
    }

    # Convert lists to single Path objects & sanity-check
    bands = {}
    for b, files in band_files.items():
        if not files:
            raise FileNotFoundError(f"{b} not found in {safe_path}")
        bands[b] = files[0]

    # Read, rescale and stack
    arrays = []
    for b in ("B04", "B03", "B08"): # R, G, NIR order
        with rasterio.open(bands[b]) as src:
            img = src.read(1).astype("float32")
            # L1C/L2A reflectances are stored as integers ×10 000
            if img.max() > 1.0:  # quick test for unscaled data
                img /= 10_000.0
            arrays.append(img)

    return np.stack(arrays, axis=-1).transpose(2, 0, 1) # H × W × 3

def save_mask_as_geotiff(mask: np.ndarray,
                         ref_band_path: str | Path,
                         out_path: str | Path,
                         nodata: int | None = 0):
    """
    Save a (1, H, W) or (H, W) mask array as GeoTIFF, copying geotransform & CRS
    from a reference Sentinel-2 band (e.g. B04_10m.jp2).

    Parameters
    ----------
    mask : np.ndarray
        Cloud/-shadow mask, either (1, H, W) or (H, W). dtype will be promoted
        to uint8 on disk.
    ref_band_path : str | Path
        Path to any 10 m Sentinel-2 JP2 band in the same tile (B04, B03, B08…).
    out_path : str | Path
        Where to write the GeoTIFF (use .tif or .tiff).
    nodata : int or None, optional
        Value to mark NoData pixels. Set to None to omit the tag.
    """
    mask = np.squeeze(mask)            # (H, W) if it was (1, H, W)

    with rasterio.open(ref_band_path) as src:
        meta = src.meta.copy()

    meta.update(
        driver="GTiff",
        count=1,
        dtype=rasterio.uint8,
        compress="lzw",                # loss-less & widely supported
        nodata=nodata if nodata is not None else None,
    )

    # write
    with rasterio.open(out_path, "w", **meta) as dst:
        dst.write(mask.astype("uint8"), 1)  # band index = 1

In [5]:
# Run OmniCloudMask
safe = Path("/home/data/S2B_MSIL2A_20240414T090559_N0510_R050_T34SFG_20240414T104805.SAFE")
input_array = load_s2_rgbnir_10m(safe) # shape (3, height, width)
# predict
pred_mask = predict_from_array(input_array)

In [9]:
pred_mask

array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], shape=(1, 10980, 10980), dtype=uint8)

In [12]:
# Save as GeoTIFF
ref_band = next(safe.glob("**/*B04_10m.jp2"))  # or pick B08_10m.jp2, etc.

save_mask_as_geotiff(
    mask=pred_mask,
    ref_band_path=ref_band,
    out_path="/home/data/cloud_shadow_mask_10m.tif",
)