# WMS/WMTS Segmentation with SAM3

[![image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/opengeos/geoai/blob/main/docs/examples/wms_wmts_segmentation.ipynb)

This notebook demonstrates how to perform segmentation on imagery from WMS (Web Map Service) and WMTS (Web Map Tile Service) sources using SAM3, producing georeferenced polygon outputs.

## Install package
To use the `geoai-py` and `segment-geospatial` packages, ensure they are installed in your environment.

In [None]:
# %pip install "geoai-py[all]" "segment-geospatial[samgeo3]"

## Import libraries

In [None]:
import leafmap
from samgeo import SamGeo3
import geopandas as gpd
import pandas as pd
from pathlib import Path

## Request access to SAM3

To use SAM3, you need to request access by filling out this form on Hugging Face: https://huggingface.co/facebook/sam3

Once you have access, uncomment and run the following code block:

In [None]:
# from huggingface_hub import login
# login()

## Example 1: Download imagery from XYZ tile service

First, let's download imagery from a basemap tile service for a specific area of interest. We'll use leafmap's built-in functionality to download tiles and create a GeoTIFF.

In [None]:
# Define area of interest (bbox format: [west, south, east, north])
# Example: UC Berkeley campus area
bbox = [-122.2625, 37.8685, -122.2535, 37.8755]

# Download tiles from ESRI World Imagery
output_path = "wms_imagery.tif"
leafmap.map_tiles_to_geotiff(
    output=output_path,
    bbox=bbox,
    zoom=18,
    source="Esri.WorldImagery",
    overwrite=True,
)

## Visualize downloaded imagery

In [None]:
m = leafmap.Map()
m.add_raster(output_path, layer_name="Downloaded Imagery")
m.center_object(output_path, zoom=17)
m

## Initialize SAM3

Initialize the SAM3 model with the transformers backend.

In [None]:
sam3 = SamGeo3(
    backend="transformers",
    device=None,  # Auto-detect GPU/CPU
    checkpoint_path=None,
    load_from_HF=True,
)

## Set the image

In [None]:
sam3.set_image(output_path)

## Perform segmentation with text prompts

Segment specific features using text prompts. For example, let's detect buildings.

In [None]:
# Segment buildings
sam3.generate_masks(prompt="building")

## Save results as georeferenced vector

Export the segmentation results as georeferenced polygons in GeoPackage format.

In [None]:
# Save as GeoPackage with georeferencing
vector_output = "buildings.gpkg"
sam3.save_masks(
    vector_output,
    output_format="vector",
    simplify_tolerance=1.0,
    min_area=10,
)

## Visualize results on map

In [None]:
sam3.show_anns()

In [None]:
# Display results on interactive map
m2 = leafmap.Map()
m2.add_raster(output_path, layer_name="Imagery")
m2.add_vector(
    vector_output,
    layer_name="Segmented Buildings",
    style={"color": "red", "fillOpacity": 0.3},
)
m2.center_object(output_path, zoom=17)
m2

## Example 2: Using WMS service directly

For working with WMS services, we can first preview the WMS layer on the map, then download a specific area.

In [None]:
# Example: USGS NAIP imagery WMS
wms_url = "https://imagery.nationalmap.gov/arcgis/services/USGSNAIPImagery/ImageServer/WMSServer"

# Display WMS layer on map
m3 = leafmap.Map(center=[40.7, -100], zoom=4)
m3.add_wms_layer(
    url=wms_url,
    layers="USGSNAIPImagery:USGSNAIPImagery",
    name="NAIP Imagery",
    format="image/png",
    transparent=True,
    attribution="USGS",
)
m3

## Example 3: Sliding window segmentation for large areas

For large areas, we can implement a sliding window approach by downloading tiles and processing them in batches.

In [None]:
import rasterio
from rasterio.windows import Window
import numpy as np
from pathlib import Path


def sliding_window_segmentation(
    image_path,
    sam_model,
    prompt,
    window_size=1024,
    overlap=128,
    output_dir="sliding_results",
):
    """
    Perform sliding window segmentation on a large image.

    Args:
        image_path: Path to the input GeoTIFF
        sam_model: Initialized SAM3 model
        prompt: Text prompt for segmentation
        window_size: Size of each window in pixels
        overlap: Overlap between windows in pixels
        output_dir: Directory to save individual window results
    """
    output_dir = Path(output_dir)
    output_dir.mkdir(exist_ok=True)

    with rasterio.open(image_path) as src:
        width, height = src.width, src.height
        transform = src.transform
        crs = src.crs

        # Calculate stride
        stride = window_size - overlap

        # Iterate over windows
        window_idx = 0
        for row in range(0, height, stride):
            for col in range(0, width, stride):
                # Calculate window dimensions
                win_width = min(window_size, width - col)
                win_height = min(window_size, height - row)

                # Skip if window is too small
                if win_width < 256 or win_height < 256:
                    continue

                # Create window
                window = Window(col, row, win_width, win_height)

                # Read window data
                window_data = src.read(window=window)

                # Calculate window transform
                window_transform = rasterio.windows.transform(window, transform)

                # Save window as temporary GeoTIFF
                temp_window_path = output_dir / f"window_{window_idx}.tif"
                with rasterio.open(
                    temp_window_path,
                    "w",
                    driver="GTiff",
                    height=win_height,
                    width=win_width,
                    count=src.count,
                    dtype=window_data.dtype,
                    crs=crs,
                    transform=window_transform,
                ) as dst:
                    dst.write(window_data)

                # Process window with SAM3
                print(f"Processing window {window_idx} at row={row}, col={col}")
                sam_model.set_image(str(temp_window_path))
                sam_model.generate_masks(prompt=prompt)

                # Save window results
                window_output = output_dir / f"window_{window_idx}_masks.gpkg"
                sam_model.save_masks(
                    str(window_output),
                    output_format="vector",
                    simplify_tolerance=1.0,
                    min_area=10,
                )

                window_idx += 1

    print(f"\nProcessed {window_idx} windows. Results saved to {output_dir}")
    return output_dir

## Run sliding window segmentation (optional)

Uncomment and run the following code to process a large image using sliding windows:

In [None]:
# # Download a larger area for sliding window demo
# large_bbox = [-122.27, 37.86, -122.24, 37.88]  # Larger area
# large_image_path = "large_area.tif"
#
# leafmap.map_tiles_to_geotiff(
#     output=large_image_path,
#     bbox=large_bbox,
#     zoom=17,
#     source="Esri.WorldImagery",
#     overwrite=True,
# )
#
# # Run sliding window segmentation
# results_dir = sliding_window_segmentation(
#     image_path=large_image_path,
#     sam_model=sam3,
#     prompt="building",
#     window_size=1024,
#     overlap=128,
#     output_dir="sliding_windows_results",
# )

## Merge results from multiple windows

After processing multiple windows, we can merge the results into a single GeoPackage file.

In [None]:
def merge_window_results(results_dir, output_file="merged_results.gpkg"):
    """
    Merge multiple GeoPackage files from sliding window results.

    Args:
        results_dir: Directory containing window result GeoPackage files
        output_file: Path to save merged results
    """
    results_dir = Path(results_dir)
    gpkg_files = list(results_dir.glob("*_masks.gpkg"))

    if not gpkg_files:
        print("No GeoPackage files found to merge.")
        return None

    # Read all GeoDataFrames
    gdfs = []
    for gpkg_file in gpkg_files:
        try:
            gdf = gpd.read_file(gpkg_file)
            if not gdf.empty:
                gdfs.append(gdf)
        except Exception as e:
            print(f"Error reading {gpkg_file}: {e}")

    if not gdfs:
        print("No valid geometries found to merge.")
        return None

    # Merge all GeoDataFrames
    merged_gdf = gpd.GeoDataFrame(pd.concat(gdfs, ignore_index=True))

    # Remove duplicate geometries (from overlapping windows)
    merged_gdf = merged_gdf.drop_duplicates(subset="geometry")

    # Save merged results
    merged_gdf.to_file(output_file, driver="GPKG")
    print(f"Merged {len(gdfs)} files into {output_file}")
    print(f"Total features: {len(merged_gdf)}")

    return merged_gdf

In [None]:
# # Example: Merge sliding window results
# merged_gdf = merge_window_results(
#     results_dir="sliding_windows_results",
#     output_file="merged_buildings.gpkg"
# )
#
# # Visualize merged results
# if merged_gdf is not None:
#     m4 = leafmap.Map()
#     m4.add_raster(large_image_path, layer_name="Imagery")
#     m4.add_gdf(merged_gdf, layer_name="Merged Buildings", style={"color": "red", "fillOpacity": 0.3})
#     m4.center_object(large_image_path)
#     display(m4)

## Summary

This notebook demonstrated:

1. **Downloading imagery from WMS/WMTS services** using leafmap's tile downloading functionality
2. **Performing SAM3 segmentation** on the downloaded imagery with text prompts
3. **Exporting georeferenced polygons** as GeoPackage files
4. **Implementing sliding window segmentation** for processing large areas with configurable overlap
5. **Merging results** from multiple windows to create a unified output

### Key Features:

- **Georeferenced outputs**: All results maintain proper coordinate reference systems
- **Flexible input**: Works with various tile services (Esri, OpenStreetMap, etc.) and WMS endpoints
- **Scalable**: Sliding window approach allows processing of arbitrarily large areas
- **Configurable overlap**: Prevents edge artifacts by processing overlapping tiles

### Tips:

- Use higher zoom levels (18-19) for detailed feature extraction
- Adjust `simplify_tolerance` and `min_area` to control output polygon complexity
- For very large areas, consider processing in batches to manage memory usage
- The overlap parameter should be at least 10-20% of window size to ensure edge features are captured