# PySpark Sedona Test Notebook

Basic template to verify Apache Sedona is loaded and working with PySpark.

In [None]:
from sedona.spark import *
from pyspark import SparkConf

packages = [
    "org.apache.sedona:sedona-spark-3.5_2.12:1.7.2",
    "org.datasyslab:geotools-wrapper:1.7.2-28.5"
]

config = {
    "spark.driver.memory": "6G",
    "spark.jars.packages": ",".join(packages),
    "spark.jars.repositories": "https://artifacts.unidata.ucar.edu/repository/unidata-all"
}

spark_conf = SparkConf().setAll(config.items())
sedona_conf = SedonaContext.builder().config(conf=spark_conf).getOrCreate()
sedona = SedonaContext.create(sedona_conf)

sedona

In [None]:
# Amazon rainforest region (Mato Grosso, Brazil)
aoi_bounds = [-61.7573, -10.9065, -61.0926, -10.2531] # (minx, miny, maxx, maxy)

aoi_geom_df = sedona.sql(f"""
    SELECT ST_MakeEnvelope({aoi_bounds[0]}, {aoi_bounds[1]}, {aoi_bounds[2]}, {aoi_bounds[3]}, 4326) as aoi_geometry
""")

aoi_geom_df.show()

In [None]:
import math

def get_global_chips(spark, aoi_bounds):
    """Generate global 256px grid chips using Sedona SQL"""
    
    # 256px at 10m resolution = 2560m = ~0.023Â° at equator
    cell_size = 0.023
    
    return spark.sql(f"""
        SELECT 
            concat('chip_', x, '_', y) as chip_id, x, y,
            -180 + (x * {cell_size}) as minx, 
            -90 + (y * {cell_size}) as miny,
            -180 + ((x + 1) * {cell_size}) as maxx, 
            -90 + ((y + 1) * {cell_size}) as maxy,
            ST_MakeEnvelope(
                -180 + (x * {cell_size}), 
                -90 + (y * {cell_size}),
                -180 + ((x + 1) * {cell_size}), 
                -90 + ((y + 1) * {cell_size}),
                4326
            ) as chip_geometry
        FROM (
            SELECT explode(sequence(
                {math.floor((aoi_bounds[0] + 180) / cell_size)}, 
                {math.ceil((aoi_bounds[2] + 180) / cell_size) - 1}
            )) as x
        ) CROSS JOIN (
            SELECT explode(sequence(
                {math.floor((aoi_bounds[1] + 90) / cell_size)}, 
                {math.ceil((aoi_bounds[3] + 90) / cell_size) - 1}
            )) as y
        )
    """)

get_global_chips(sedona, aoi_bounds).show()

In [None]:
from sedona.stac.client import Client

client = Client.open("https://earth-search.aws.element84.com/v1")

sentinel_df = client.search(
    collection_id="sentinel-2-l2a", 
    bbox=aoi_bounds, 
    datetime=["2023-01-01T00:00:00Z", "2024-01-01T00:00:00Z"],
    return_dataframe=True
).select("id", "datetime", "grid:code", "geometry", "assets")

sentinel_df.cache()

sentinel_df.show()

In [None]:
from keplergl import KeplerGl
import pandas as pd
from IPython.display import HTML

# Prepare data
sentinel_pandas = sentinel_df.select("id", "datetime", "geometry").toPandas()
min_date = sentinel_pandas['datetime'].min()

# Create map and add data first
map_viz = KeplerGl(height=800)
map_viz.add_data(aoi_geom_df.toPandas(), name="aoi")
map_viz.add_data(get_global_chips(sedona, aoi_bounds).toPandas(), name="chips")
map_viz.add_data(sentinel_pandas, name="sentinel")

# Then apply config
config = {
    'version': 'v1',
    'config': {
        'mapState': {
            'latitude': -10.88,
            'longitude': -61.13,
            'zoom': 8
        },
        'visState': {
            'layerBlending': 'subtractive',
            'filters': [
                {
                    'dataId': ['sentinel'],
                    'name': ['datetime'],
                    'type': 'timeRange',
                    'fieldType': 'date',
                    'enlarged': True,
                    'speed': 0.05,
                    'value': [min_date.timestamp() * 1000, (min_date + pd.Timedelta(days=3)).timestamp() * 1000]
                }
            ]
        }
    }
}

map_viz.config = config
map_viz.save_to_html(file_name='sentinel_map.html')
HTML('<a href="sentinel_map.html" target="_blank">Open Full Map in New Tab</a>')

In [None]:
map_viz

In [None]:
from pyspark.sql.functions import *

scenes_s3 = sentinel_df.select(
    "id", "datetime", "geometry",
    regexp_replace(col("assets.blue.href"), "https://sentinel-cogs.s3.us-west-2.amazonaws.com", "s3://sentinel-cogs").alias("blue_s3"),
    regexp_replace(col("assets.green.href"), "https://sentinel-cogs.s3.us-west-2.amazonaws.com", "s3://sentinel-cogs").alias("green_s3"),
    regexp_replace(col("assets.red.href"), "https://sentinel-cogs.s3.us-west-2.amazonaws.com", "s3://sentinel-cogs").alias("red_s3"),
    regexp_replace(col("assets.nir.href"), "https://sentinel-cogs.s3.us-west-2.amazonaws.com", "s3://sentinel-cogs").alias("nir_s3"),
    regexp_replace(col("assets.scl.href"), "https://sentinel-cogs.s3.us-west-2.amazonaws.com", "s3://sentinel-cogs").alias("scl_s3"),
    concat(
        lit("EPSG:"), (
            when(expr("ST_Y(ST_Centroid(geometry))") < 0, 32700).otherwise(32600) + 
            regexp_extract(col("`grid:code`"), r"MGRS-(\d+)", 1).cast("int")
        ).cast("string")
    ).alias("epsg_code")
)

# Create and broadcast scene URL lookup
scene_urls = scenes_s3.select("id", "datetime", "blue_s3", "green_s3", "red_s3", "nir_s3", "scl_s3").collect()
url_lookup = {(row['id'], str(row['datetime'])): {
    'blue_s3': row['blue_s3'],
    'green_s3': row['green_s3'], 
    'red_s3': row['red_s3'],
    'nir_s3': row['nir_s3'],
    'scl_s3': row['scl_s3']
} for row in scene_urls}

broadcast_urls = sedona.sparkContext.broadcast(url_lookup)

# Create scene-chip pairs
scenes_s3_limited = scenes_s3.limit(3) #limit for testing

scene_chip_pairs = scenes_s3_limited.crossJoin(get_global_chips(sedona, aoi_bounds)) \
    .where("ST_Intersects(geometry, chip_geometry)") \
    .withColumn("is_complete", expr("ST_Contains(geometry, chip_geometry)")) \
    .select(
        "id", "datetime", "chip_id", "is_complete",
        col("minx").cast("float").alias("minx"),
        col("miny").cast("float").alias("miny"),
        col("maxx").cast("float").alias("maxx"),
        col("maxy").cast("float").alias("maxy")
    )

scene_chip_pairs.show()


In [None]:
import numpy as np
import rasterio
from rasterio.warp import reproject, Resampling
from rasterio.transform import from_bounds
from io import BytesIO
import pandas as pd
from pyspark.sql.types import *

@pandas_udf(returnType=StructType([
    StructField("chip_id", StringType()),
    StructField("datetime", TimestampType()),
    StructField("chip_raster", BinaryType()),
    StructField("is_complete", BooleanType()),
    StructField("cloud_coverage", FloatType())
]), functionType=PandasUDFType.GROUPED_MAP)
def process_scene_chips(df):
    
    def create_multiband_geotiff(bands_dict, transform, crs='EPSG:4326'):
        buffer = BytesIO()
        
        with rasterio.open(buffer, 'w', driver='GTiff', compress='lz4',
                          height=256, width=256, count=len(bands_dict),
                          dtype=list(bands_dict.values())[0].dtype,
                          crs=crs, transform=transform) as dst:
            
            for i, (name, band) in enumerate(bands_dict.items(), 1):
                dst.write(band, i)
                dst.set_band_description(i, name)
        
        return buffer.getvalue()
    
    def extract_reprojected_chip(src, chip_bounds_wgs84):
        target_transform = from_bounds(*chip_bounds_wgs84, 256, 256)
        target_array = np.empty((256, 256), dtype=src.dtypes[0])
        
        reproject(
            source=rasterio.band(src, 1),
            destination=target_array,
            src_transform=src.transform,
            src_crs=src.crs,
            dst_transform=target_transform,
            dst_crs='EPSG:4326',
            resampling=Resampling.bilinear
        )
        
        return target_array
    
    # Get scene info
    scene_id = df.iloc[0]['id']
    datetime_str = str(df.iloc[0]['datetime'])
    urls = broadcast_urls.value.get((scene_id, datetime_str))
    
    if not urls:
        return pd.DataFrame()
    
    results = []
    rasterio_config = {
        'GDAL_CACHEMAX': 512,
        'CPL_VSIL_CURL_CACHE_SIZE': 200000000,
        'GDAL_HTTP_MULTIPLEX': 'YES',
        'GDAL_DISABLE_READDIR_ON_OPEN': 'EMPTY_DIR'
    }
    
    with rasterio.Env(**rasterio_config):
        with rasterio.open(urls['blue_s3']) as blue_src, \
             rasterio.open(urls['green_s3']) as green_src, \
             rasterio.open(urls['red_s3']) as red_src, \
             rasterio.open(urls['nir_s3']) as nir_src, \
             rasterio.open(urls['scl_s3']) as scl_src:

            print(f"Processing {len(df)} chips for urls:\n{"\n".join(urls.values())}")
            
            for _, row in df.iterrows():
                try:
                    chip_bounds = (float(row['minx']), float(row['miny']), 
                                 float(row['maxx']), float(row['maxy']))
                    target_transform = from_bounds(*chip_bounds, 256, 256)
                    
                    # Extract all bands into dict
                    bands = {
                        'blue': extract_reprojected_chip(blue_src, chip_bounds),
                        'green': extract_reprojected_chip(green_src, chip_bounds),
                        'red': extract_reprojected_chip(red_src, chip_bounds),
                        'nir': extract_reprojected_chip(nir_src, chip_bounds),
                        'scl': extract_reprojected_chip(scl_src, chip_bounds)
                    }
                    
                    # Create multiband GeoTIFF with named bands
                    chip_raster = create_multiband_geotiff(bands, target_transform)
                    
                    results.append({
                        'chip_id': row['chip_id'],
                        'datetime': row['datetime'],
                        'chip_raster': chip_raster,
                        'is_complete': row['is_complete'],
                        'cloud_coverage': float(bands['scl'][bands['scl'] > 0].mean())
                    })
                        
                except Exception as e:
                    print(f"Error processing {row['chip_id']}: {e}")
                    continue
    
    return pd.DataFrame(results)


In [None]:
all_chips = scene_chip_pairs.groupBy("id", "datetime").apply(process_scene_chips)

all_chips.cache()
all_chips.show()

In [None]:
from keplergl import KeplerGl
import pandas as pd

# Get chip completeness with geometry
chip_completeness = all_chips.select("chip_id", "is_complete").distinct() \
    .join(get_global_chips(sedona, aoi_bounds), "chip_id") \
    .select("chip_id", "is_complete", "minx", "miny", "maxx", "maxy", "chip_geometry")

# Get scene boundaries for reference
scene_boundaries = scenes_s3_limited.select("id", "geometry")

# Convert to pandas
chips_pandas = chip_completeness.toPandas()
scenes_pandas = scene_boundaries.toPandas()

# Create map
map_viz = KeplerGl(height=800)

# Add chip data colored by completeness
map_viz.add_data(chips_pandas, name="chips")

# Add scene boundaries for reference
map_viz.add_data(scenes_pandas, name="scenes")

# Configure visualization
map_viz = KeplerGl(height=800)
map_viz.add_data(chips_pandas, name="chips")
map_viz.add_data(scenes_pandas, name="scenes")

config = {
    'version': 'v1',
    'config': {
        'visState': {
            'layers': [
                {
                    'id': 'chips',
                    'type': 'geojson',
                    'config': {
                        'dataId': 'chips',
                        'label': 'chips',
                        'columns': {'geojson': 'chip_geometry'},
                        'isVisible': True,
                        'visConfig': {
                            'opacity': 0.4,  # Lower opacity to see boundaries
                            'strokeOpacity': 0.8,
                            'thickness': 0.5,
                            'colorRange': {
                                'name': 'Custom',
                                'type': 'custom',
                                'colors': ['#FF0000', '#00FF00']
                            },
                            'filled': True,
                            'stroked': True,
                            'strokeColor': [255, 255, 255]
                        }
                    },
                    'visualChannels': {
                        'colorField': {'name': 'is_complete', 'type': 'boolean'},
                        'colorScale': 'ordinal'
                    }
                },
                {
                    'id': 'scenes',
                    'type': 'geojson',
                    'config': {
                        'dataId': 'scenes',
                        'label': 'scenes',
                        'columns': {'geojson': 'geometry'},
                        'isVisible': True,
                        'visConfig': {
                            'opacity': 0.1,
                            'strokeOpacity': 0.8,
                            'thickness': 1,
                            'filled': True,
                            'stroked': True
                        }
                    }
                }
            ]
        }
    }
}

map_viz.config = config
map_viz


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import rasterio
from io import BytesIO

# Filter for clear chips in Spark
clear_chips = all_chips.filter("cloud_coverage < 10.0")

# Get complete and incomplete clear chips
complete_clear = clear_chips.filter("is_complete = true").limit(2).collect()
incomplete_clear = clear_chips.filter("is_complete = false").limit(2).collect()
sample_chips = complete_clear + incomplete_clear

print(f"Visualizing {len(sample_chips)} clear chips:")
for chip in sample_chips:
    print(f"Chip: {chip['chip_id']}, Complete: {chip['is_complete']}, Cloud: {chip['cloud_coverage']:.2f}")

# Visualization
fig, axes = plt.subplots(2, 2, figsize=(15, 15))
axes = axes.flatten()

for i, chip in enumerate(sample_chips):
    if i >= len(sample_chips):
        axes[i].set_visible(False)
        continue
    
    # Read multiband GeoTIFF
    with rasterio.open(BytesIO(chip['chip_raster'])) as src:
        # Read RGB bands (blue=1, green=2, red=3)
        blue_data = src.read(1)
        green_data = src.read(2)  
        red_data = src.read(3)
    
    # Create RGB composite with proper scaling
    rgb = np.stack([red_data, green_data, blue_data], axis=-1)
    p2, p98 = np.percentile(rgb[rgb > 0], [2, 98])
    rgb_norm = np.clip((rgb - p2) / (p98 - p2) * 255, 0, 255).astype(np.uint8)
    
    axes[i].imshow(rgb_norm)
    axes[i].set_title(f"{chip['chip_id']}\nComplete: {chip['is_complete']}\nCloud: {chip['cloud_coverage']:.2f}")
    axes[i].axis('off')

plt.suptitle('Clear RGB Chips (Cloud < 10.0): Complete vs Incomplete', fontsize=16)
plt.tight_layout()
plt.show()


In [None]:
from pyspark.sql.functions import expr, col

def create_rgb_visualization(df, filter_condition="cloud_coverage < 10.0", limit=None, include_bounds=False):
    """
    Create RGB visualizations from chip rasters using pure Sedona functions
    """
    def scale_band(band_index, scale_factor=2000.0):
        jiffle_script = f"scaled = rast[{band_index}] / {scale_factor} * 255.0; out = scaled > 255.0 ? 255.0 : scaled;"
        return f"RS_MapAlgebra(RS_FromGeoTiff(chip_raster), 'B', '{jiffle_script}')"
    
    result = df.filter(filter_condition)
    
    if limit:
        result = result.limit(limit)
    
    result = result \
        .withColumn("red_scaled", expr(scale_band(2))) \
        .withColumn("green_scaled", expr(scale_band(1))) \
        .withColumn("blue_scaled", expr(scale_band(0))) \
        .withColumn("rg_raster", expr("RS_AddBandFromArray(red_scaled, RS_BandAsArray(green_scaled, 1))")) \
        .withColumn("rgb_raster", expr("RS_AddBandFromArray(rg_raster, RS_BandAsArray(blue_scaled, 1))")) \
        .withColumn("rgb_image", expr("RS_AsBase64(rgb_raster)"))
    
    if include_bounds:
        result = result.join(get_global_chips(sedona, aoi_bounds), "chip_id") \
            .select("chip_id", "is_complete", 
                    col("cloud_coverage").cast("float").alias("cloud_coverage"),
                    "rgb_image", 
                    col("minx").cast("float").alias("minx"),
                    col("miny").cast("float").alias("miny"), 
                    col("maxx").cast("float").alias("maxx"),
                    col("maxy").cast("float").alias("maxy"))
    else:
        result = result.select("chip_id", "is_complete", "cloud_coverage", "rgb_image")
    
    return result

def create_folium_map():
    import folium
    from shapely.geometry import mapping
    
    # Create base map
    m = folium.Map(location=[-10.58, -61.42], zoom_start=12, tiles='OpenStreetMap')
    
    # Add scene boundaries - single efficient query
    scenes = sentinel_df.select("id", "geometry").limit(3).collect()
    scene_features = {
        "type": "FeatureCollection",
        "features": [
            {
                "type": "Feature",
                "geometry": mapping(scene['geometry']),  # Convert Shapely to GeoJSON
                "properties": {"id": scene['id']}
            }
            for scene in scenes
        ]
    }
    
    folium.GeoJson(
        scene_features,
        style_function=lambda x: {'fillColor': 'blue', 'color': 'blue', 'weight': 2, 'fillOpacity': 0.3},
        name="Scene Boundaries"
    ).add_to(m)
    
    # Add chip rasters - single efficient query
    chips = create_rgb_visualization(all_chips, limit=500, include_bounds=True).collect()
    
    for chip in chips:
        bounds = [[chip['miny'], chip['minx']], [chip['maxy'], chip['maxx']]]
        
        folium.raster_layers.ImageOverlay(
            image=f"data:image/png;base64,{chip['rgb_image']}",
            bounds=bounds,
            opacity=0.8,
            name=f"Chip {chip['chip_id']}"
        ).add_to(m)
        
        folium.Rectangle(
            bounds=bounds,
            popup=f"Chip: {chip['chip_id']}<br>Complete: {chip['is_complete']}<br>Cloud: {chip['cloud_coverage']:.1f}%",
            color='red' if not chip['is_complete'] else 'green',
            weight=2,
            fill=False
        ).add_to(m)
    
    folium.LayerControl().add_to(m)
    return m

map = create_folium_map()
map

In [None]:
# UDF for simplified temporal merging
@udf(returnType=BinaryType())
def simple_temporal_merge(current_chip_bytes, current_height, current_width, 
                         previous_chip_bytes, previous_height, previous_width):
    import numpy as np
    from io import BytesIO
    
    def deserialize_raster(bytes_data):
        return np.load(BytesIO(bytes_data))
    
    def serialize_raster(array):
        buffer = BytesIO()
        np.save(buffer, array)
        return buffer.getvalue()
    
    # Create 256x256 target array
    target_array = np.zeros((256, 256), dtype=np.float32)
    
    # Load and place current data
    current_data = deserialize_raster(current_chip_bytes)
    if len(current_data.shape) == 3:
        current_data = current_data[0]  # Take first band
    
    h, w = min(current_height, 256), min(current_width, 256)
    target_array[:h, :w] = current_data[:h, :w]
    
    # Fill gaps with previous data if available
    if previous_chip_bytes is not None:
        previous_data = deserialize_raster(previous_chip_bytes)
        if len(previous_data.shape) == 3:
            previous_data = previous_data[0]
        
        # Fill remaining zeros with previous data
        mask = target_array == 0
        prev_h, prev_w = min(previous_height, 256), min(previous_width, 256)
        target_array[mask[:prev_h, :prev_w]] = previous_data[mask[:prev_h, :prev_w]]
    
    return serialize_raster(target_array)

# Apply temporal merging
window_spec = Window.partitionBy("chip_id").orderBy("datetime")

merged_chips = all_chips.select(
    "chip_id", "datetime",
    simple_temporal_merge(
        "b02_chip", "chip_height", "chip_width",
        lag("b02_chip").over(window_spec),
        lag("chip_height").over(window_spec),
        lag("chip_width").over(window_spec)
    ).alias("b02_chip_filled"),
    simple_temporal_merge(
        "b03_chip", "chip_height", "chip_width",
        lag("b03_chip").over(window_spec),
        lag("chip_height").over(window_spec),
        lag("chip_width").over(window_spec)
    ).alias("b03_chip_filled"),
    simple_temporal_merge(
        "b04_chip", "chip_height", "chip_width",
        lag("b04_chip").over(window_spec),
        lag("chip_height").over(window_spec),
        lag("chip_width").over(window_spec)
    ).alias("b04_chip_filled"),
    simple_temporal_merge(
        "b08_chip", "chip_height", "chip_width",
        lag("b08_chip").over(window_spec),
        lag("chip_height").over(window_spec),
        lag("chip_width").over(window_spec)
    ).alias("b08_chip_filled")
)