In [82]:
#Requirements for the spark workflow
from sedona.spark import *
from pyspark.sql.functions import col, count, countDistinct
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, FloatType, DateType
from sedona.spark.sql.st_constructors import ST_Point
from sedona.spark.sql.st_functions import GeometryType, ST_XMin, ST_YMin, ST_XMax, ST_YMax
from sedona.spark import SedonaKepler
from pyspark.sql import functions as F
from sedona.spark.geopandas import GeoDataFrame, read_parquet
from sedona.sql import st_predicates

import sys, os
from shapely.geometry import Point
from itertools import product
import sedona
from sedona.spark import SedonaContext

import sedona.db
import numpy as np
import leafmap

import geopandas as gpd
from pyspark.sql.functions import expr
import leafmap.colormaps as cm

In [2]:
# For anonymous access to public S3 buckets
#sd_cont is needed to read all the csv in
sd_cont = (
    SedonaContext.builder()
    .config(
        "spark.hadoop.fs.s3a.bucket.bucket-name.aws.credentials.provider",
        "org.apache.hadoop.fs.s3a.AnonymousAWSCredentialsProvider",
    )
    .getOrCreate()
)


sd = SedonaContext.create(sd_cont)

25/11/10 13:45:29 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.
25/11/10 13:45:31 WARN UDTRegistration: Cannot register UDT for org.geotools.coverage.grid.GridCoverage2D, which is already registered.
25/11/10 13:45:31 WARN SimpleFunctionRegistry: The function rs_union_aggr replaced a previously registered function.
25/11/10 13:45:31 WARN UDTRegistration: Cannot register UDT for org.locationtech.jts.geom.Geometry, which is already registered.
25/11/10 13:45:31 WARN UDTRegistration: Cannot register UDT for org.apache.sedona.common.S2Geography.Geography, which is already registered.
25/11/10 13:45:31 WARN UDTRegistration: Cannot register UDT for org.locationtech.jts.index.SpatialIndex, which is already registered.
25/11/10 13:45:31 WARN SimpleFunctionRegistry: The function st_envelope_aggr replaced a previously registered function.
25/11/10 13:45:31 WARN SimpleFunctionRegistry: The function st_intersection_aggr replaced a previously 

In [None]:

# Path to the directory containing your CSV files
directory_path = "../2024/"

df = sd_cont.read.option("header", True).format("csv").load(directory_path)


In [None]:
#Setting the ST_Point after getting unique values
df.select(ST_Point(col("LONGITUDE"), col("LATITUDE")).alias("GEOMETRY"), "STATION", "DATE", "LATITUDE", "LONGITUDE", "ELEVATION", "NAME", "REPORT_TYPE", 
          "SOURCE", "HourlyDryBulbTemperature", "HourlyPressureChange", "HourlyPressureTendency", "HourlySeaLevelPressure", "HourlyStationPressure", 
          "HourlyWindDirection", "HourlyWindGustSpeed", "HourlyWindSpeed").filter(col("HourlyPressureChange").isNotNull()).write.format("geoparquet").mode("overwrite").save("./p_diff")
df_p_diff_pq = sd.read.format("geoparquet").load("./p_diff")


In [None]:
#I'm not going to use sedonaDB because the df is 8.7gb (from 70gb in text files)
#and I only have 12 gb memmory allocted to docker
df_p_diff_pq.explain('cost')

In [None]:
#This has to be build with unique locations, or at least stations
df_p_diff_pq.select("GEOMETRY").distinct().write.format("geoparquet").mode("overwrite").save("./stations")
df_stations_pq = sd.read.format("geoparquet").load("./stations")
df_stations_pq.show(5)
df_stations_pq.count()

In [None]:
OVERTURE_RELEASE = "2025-09-24.0"
COUNTRY_CODES_OF_INTEREST = ["US"]
SOURCE_DATA_URL = f"s3a://overturemaps-us-west-2/release/{OVERTURE_RELEASE}/theme=divisions/type=division_area"
OUTPUT_FILE = "my_super_cool_data.parquet"

country_overlap_condition = F.arrays_overlap(
    F.col("country"),
    F.array(*[F.lit(x.upper()) for x in COUNTRY_CODES_OF_INTEREST]),
)

source_df = (
    sd.read.format("geoparquet")
    .load(SOURCE_DATA_URL)
    .filter(col("country").isin(COUNTRY_CODES_OF_INTEREST))
    #.filter(col("region")=='US-CA')
    .filter(col("subtype")=='country')
    .withColumn("_overture_release_version", F.lit(OVERTURE_RELEASE))
    .withColumn("_ingest_timestamp", F.current_timestamp())
)

source_df.selectExpr("geometry", "country").write.format("geoparquet").mode("overwrite").save("./USA_geom")
USA_geom_pq = sd.read.format("geoparquet").load("./USA_geom")
USA_geom_pq.show(5)
map = SedonaKepler.create_map(USA_geom_pq, name="USA")
map

In [None]:
# 2. Define a coordinate grid with a 0.5-degree step
longitude_step = .5
latitude_step = .5

longitudes = [i * longitude_step for i in range(int(-180 / longitude_step), int(180 / longitude_step) + 1)]
latitudes = [i * latitude_step for i in range(int(-90 / latitude_step), int(90 / latitude_step) + 1)]

# 3. Generate a list of all coordinate pairs
coordinate_pairs = list(product(longitudes, latitudes))

# 4. Create a Spark DataFrame from the list of coordinates
schema = ["longitude", "latitude"]
df_lat_lon = sd.createDataFrame(coordinate_pairs, schema=schema)

# 5. Create the Sedona geometry points
# ST_Point takes longitude first, then latitude.
spatial_df = df_lat_lon.withColumn(
    "interp_GEOM",
    F.expr(f"ST_Point(longitude, latitude)")
)
# Show the resulting DataFrame clipped to USA boundaries
print("Generated spatial DataFrame:")
df_usa_points = spatial_df.join(
    USA_geom_pq,
    st_predicates.ST_Within(spatial_df["interp_GEOM"], USA_geom_pq["geometry"]),
    "inner"
)
# Show the results
df_usa_points.show(5)
df_usa_points.printSchema()

In [None]:
%%time
#df_interp_pq created below, is a grid of points at 0.5 deg lat lon, clipped to the USA border.
df_usa_points.drop(*["longitude", "latitude", "geometry", "country"]).write.format("geoparquet").mode("overwrite").save("./interp_points")
df_interp_pq = sd.read.format("geoparquet").load("./interp_points")
df_interp_pq.show(5)
df_interp_pq.printSchema()
print(df_interp_pq.count())
print(spatial_df.count())

In [38]:
df_interp_pq = sd.read.format("geoparquet").load("./interp_points")
df_stations_pq = sd.read.format("geoparquet").load("./stations")
df_p_diff_pq = sd.read.format("geoparquet").load("./p_diff")
df_knn_pq = sd.read.format("geoparquet").load("./knn_join")

In [39]:
#df_stations_pq is the list of unique geomtries of stations.  Note, there are duplicates stations with the same geom.  In this case, only one geom is in the df.
df_stations_pq.createOrReplaceTempView("stations")
df_interp_pq.createOrReplaceTempView("interp")

In [24]:
df_stations_pq.show(5)

+--------------------+
|            GEOMETRY|
+--------------------+
|  POINT (14.35 49.2)|
|POINT (136.9 37.3...|
|POINT (-60.983333...|
|   POINT (1.4 44.75)|
|POINT (139.366666...|
+--------------------+
only showing top 5 rows



In [56]:
%%time
#Adapted From https://wherobots.com/blog/introducing-knn-join-for-wherobots-and-apache-sedona/
df_knn_join = sd.sql("""
SELECT
    stations.GEOMETRY AS stations_GEOM,
    interp.interp_GEOM,
    ST_DISTANCESPHERE(stations.GEOMETRY, interp.interp_GEOM) AS DISTANCE
FROM 
    stations
JOIN 
    interp 
ON 
    ST_KNN(interp.interp_GEOM, stations.GEOMETRY, 10, true)
""")
df_knn_join.show(50)
df_knn_join.write.format("geoparquet").mode("overwrite").save("./knn_join")
df_knn_pq = sd.read.format("geoparquet").load("./knn_join")


25/11/10 15:22:12 WARN Executor: Managed memory leak detected; size = 80822548 bytes, task 0.0 in stage 129.0 (TID 336)
                                                                                

+--------------------+-------------------+------------------+
|       stations_GEOM|        interp_GEOM|          DISTANCE|
+--------------------+-------------------+------------------+
|POINT (-94.3077 3...| POINT (-94.5 32.5)|18175.530957096616|
|POINT (-94.71396 ...| POINT (-94.5 32.5)|23455.691075244365|
|POINT (-93.8244 3...| POINT (-94.5 32.5)|63647.204819572355|
|POINT (-93.74616 ...| POINT (-94.5 32.5)| 70818.72531682896|
|POINT (-93.66667 ...| POINT (-94.5 32.5)| 78150.29496084567|
|POINT (-95.404 32...| POINT (-94.5 32.5)| 86286.76365253911|
|POINT (-94.70944 ...| POINT (-94.5 32.5)|104429.19597021754|
|POINT (-93.98777 ...| POINT (-94.5 32.5)|116552.48308789147|
|POINT (-94.75468 ...| POINT (-94.5 32.5)|142600.93265925834|
|POINT (-95.45 33....| POINT (-94.5 32.5)|154004.79493538395|
|POINT (-93.216666...| POINT (-93.5 30.5)| 41612.74451026732|
|POINT (-93.22771 ...| POINT (-93.5 30.5)| 49164.96244722466|
|POINT (-93.18333 ...| POINT (-93.5 30.5)| 68231.10604210442|
|POINT (

[Stage 132:>                                                        (0 + 1) / 1]

CPU times: user 14.1 ms, sys: 2.94 ms, total: 17.1 ms
Wall time: 26.1 s


                                                                                

In [57]:
df_p_diff_pq.createOrReplaceTempView("pressure_diff")
df_knn_pq.createOrReplaceTempView("stations_knn")

In [64]:
%%time
df_iwd_grid = sd.sql("""
SELECT 
    stations_knn.interp_GEOM,
    SUM(ABS(pressure_diff.HourlyPressureChange) / POWER(stations_knn.DISTANCE, 2)) / SUM(1 / POWER(stations_knn.DISTANCE, 2)) AS interpolated_value,
    AVG(stations_knn.DISTANCE) as avg_dist
    from 
      stations_knn
    inner join 
      pressure_diff 
    on 
      stations_knn.stations_GEOM = pressure_diff.GEOMETRY
    group by stations_knn.interp_GEOM
""")
df_iwd_grid.write.format("geoparquet").mode("overwrite").save("./knn_join")
df_iwd_pq = sd.read.format("geoparquet").load("./knn_join")
df_iwd_pq.show(50)




+-------------------+--------------------+------------------+
|        interp_GEOM|  interpolated_value|          avg_dist|
+-------------------+--------------------+------------------+
|  POINT (-111.5 32)| 0.02913730049205619|116643.63984541665|
|  POINT (-112 33.5)|0.032212821789863844| 93923.38095111438|
| POINT (-75.5 44.5)| 0.03257375872371855| 49999.06761207895|
|POINT (-108.5 32.5)| 0.03077595405722897|174162.20661409802|
|    POINT (-110 32)|0.027869790426787798|107285.71865638219|
|  POINT (-111 45.5)| 0.02814154885918009|126412.50629138335|
|  POINT (-117 44.5)| 0.02777864328703037|129429.54084705304|
|     POINT (-89 48)| 0.03341092480090193| 62136.39971433593|
|  POINT (-123 37.5)|0.022024480698810935|65086.522842693776|
|   POINT (-99.5 48)| 0.03154335059684392|128927.16794655411|
|  POINT (-111 39.5)| 0.02544050452125239|151332.11408441883|
|    POINT (-110 38)|0.028262597963739897|178061.54726838128|
|     POINT (-71 42)| 0.03443510010447305|35844.963970935394|
|POINT (

                                                                                

In [65]:
df_iwd_pq.createOrReplaceTempView("iwd")


In [66]:
df_min_max = sd.sql("""
SELECT
    Min(ST_XMin(interp_GEOM)) as x_min,
    Max(ST_XMax(interp_GEOM)) as x_max,
    Min(ST_YMin(interp_GEOM)) as y_min,
    Max(ST_YMax(interp_GEOM)) as y_max
from iwd
WHERE ST_XMax(interp_GEOM) < 0

""")
df_min_max.show()


+------+-----+-----+-----+
| x_min|x_max|y_min|y_max|
+------+-----+-----+-----+
|-179.0|-67.5| 19.0| 71.5|
+------+-----+-----+-----+



In [5]:
# Define parameters  This should be .25 deg off of the points so there is one point per cell
offset_x = df_min_max.first()['x_min'] -0.25
offset_y = df_min_max.first()['y_max'] +0.25
cell_size = 0.5
width = abs(df_min_max.first()['x_max'] - df_min_max.first()['x_min'])/cell_size + 1
height = abs(df_min_max.first()['y_max'] - df_min_max.first()['y_min'])/cell_size + 1
srid = 4326 # Example SRID (WGS84)
num_bands = 1 # Number of bands
pixel_type = 'D' # Pixel type, e.g., 'D' for Double

# Create a raster from the iwd points using SQL expression
df_idw_rs= sd.sql(f"""
SELECT 
    RS_AsRaster(
    iwd.interp_GEOM,
    RS_Clip(
      RS_MakeEmptyRaster({num_bands}, '{pixel_type}', {width}, {height}, {offset_x}, {offset_y}, {cell_size}, {cell_size*-1}, 0, 0, {srid}),
      1, iwd.interp_GEOM, true),
    "D", true, iwd.interpolated_value
    ) rast
from iwd
""")

# Show schema to verify
df_idw_rs.printSchema()
df_idw_rs.show(1)

root
 |-- rast: raster (nullable = true)



[Stage 25:>                                                         (0 + 1) / 1]

+--------------------+
|                rast|
+--------------------+
|GridCoverage2D["g...|
+--------------------+
only showing top 1 row



                                                                                

In [7]:
df_idw_rs.withColumn("raster_binary", expr("RS_AsGeoTiff(rast)"))\
  .write.format("raster").mode("overwrite").save("test_geotiff")

                                                                                

In [67]:
# Define parameters  This should be .25 deg off of the points so there is one point per cell
offset_x = df_min_max.first()['x_min'] -0.25
offset_y = df_min_max.first()['y_max'] +0.25
cell_size = 0.5
width = abs(df_min_max.first()['x_max'] - df_min_max.first()['x_min'])/cell_size + 1
height = abs(df_min_max.first()['y_max'] - df_min_max.first()['y_min'])/cell_size + 1
srid = 4326 # Example SRID (WGS84)
num_bands = 1 # Number of bands
pixel_type = 'D' # Pixel type, e.g., 'D' for Double

# Create a raster from the iwd points using SQL expression
df_idw_poly = sd.sql(f"""
SELECT exploded.geom as geom, exploded.value as p_diff from 
   (Select Explode(RS_PixelAsPolygons(
     RS_AsRaster(
      iwd.interp_GEOM,
      RS_Clip(
        RS_MakeEmptyRaster({num_bands}, '{pixel_type}', {width}, {height}, {offset_x}, {offset_y}, {cell_size}, {cell_size*-1}, 0, 0, {srid}),
        1, iwd.interp_GEOM, true),
    "D", true, iwd.interpolated_value
    ), 1)) as exploded
from iwd)
""")

# Show schema to verify
df_idw_poly.printSchema()
df_idw_poly.show(50)

root
 |-- geom: geometry (nullable = true)
 |-- p_diff: double (nullable = true)

+--------------------+--------------------+
|                geom|              p_diff|
+--------------------+--------------------+
|POLYGON ((-111.75...| 0.02913730049205619|
|POLYGON ((-112.25...|0.032212821789863844|
|POLYGON ((-75.75 ...| 0.03257375872371855|
|POLYGON ((-108.75...| 0.03077595405722897|
|POLYGON ((-110.25...|0.027869790426787798|
|POLYGON ((-111.25...| 0.02814154885918009|
|POLYGON ((-117.25...| 0.02777864328703037|
|POLYGON ((-89.25 ...| 0.03341092480090193|
|POLYGON ((-123.25...|0.022024480698810935|
|POLYGON ((-99.75 ...| 0.03154335059684392|
|POLYGON ((-111.25...| 0.02544050452125239|
|POLYGON ((-110.25...|0.028262597963739897|
|POLYGON ((-71.25 ...| 0.03443510010447305|
|POLYGON ((-119.75...|0.025413040699767284|
|POLYGON ((-100.75...| 0.03170282865539063|
|POLYGON ((-142.25...|0.025700954977418897|
|POLYGON ((-154.25...|0.029233155695884663|
|POLYGON ((-80.25 ...| 0.0305113808696

In [24]:
df_interp_pq = sd.read.format("geoparquet").load("./interp_points")
df_interp_pq.show(5)

+-----------------+
|      interp_geom|
+-----------------+
|  POINT (-108 32)|
|  POINT (-108 32)|
|  POINT (-108 32)|
|  POINT (-108 32)|
|POINT (-108 32.5)|
+-----------------+
only showing top 5 rows



In [68]:
df_temp = df_idw_poly.toPandas()
gdf_idw = gpd.GeoDataFrame(df_temp, geometry="geom")
#df_temp = df_interp_pq.toPandas()
#gdf_points = gpd.GeoDataFrame(df_temp, geometry="interp_GEOM")
#gdf_points.head(5)

                                                                                

In [67]:
gdf_points.memory_usage()

Index           128
interp_GEOM    3032
dtype: int64

In [85]:
#Mapping polys created from raster until I figure out how to map the actual raster.
style = {'opacity': 0, 'weight': 1}
m = leafmap.Map(center=(40, -100), zoom=6)
m.add_data(gdf_idw, column = 'p_diff', cmap='RdYlGn_r', k=20, layer_name="IDW Hourly Pressure Diff", style=style)
m

Map(center=[40, -100], controls=(ZoomControl(options=['position', 'zoom_in_text', 'zoom_in_title', 'zoom_out_t…

In [17]:
#This isn't working due to GDAL compatibility issues.
!pip install --upgrade gdal==3.5
!pip install localtileserver
m = leafmap.Map(center=(40, -124), zoom=8)
m.add_raster("./test_geotiff", layer_name="IDW Hourly Pressure Diff")

Collecting gdal==3.5
  Using cached GDAL-3.5.0.tar.gz (752 kB)
  Preparing metadata (setup.py) ... [?25ldone
[?25hBuilding wheels for collected packages: gdal
[33m  DEPRECATION: Building 'gdal' using the legacy setup.py bdist_wheel mechanism, which will be removed in a future version. pip 25.3 will enforce this behaviour change. A possible replacement is to use the standardized build interface by setting the `--use-pep517` option, (possibly combined with `--no-build-isolation`), or adding a `pyproject.toml` file to the source tree of 'gdal'. Discussion can be found at https://github.com/pypa/pip/issues/6334[0m[33m
[0m  Building wheel for gdal (setup.py) ... [?25lerror
  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mpython setup.py bdist_wheel[0m did not run successfully.
  [31m│[0m exit code: [1;36m1[0m
  [31m╰─>[0m [31m[335 lines of output][0m
  [31m   [0m !!
  [31m   [0m 
  [31m   [0m         ****************************************

ImportError: localtileserver is not installed. Please install it before proceeding. https://github.com/banesullivan/localtileserver

ImportError: localtileserver is not installed. Please install it before proceeding. https://github.com/banesullivan/localtileserver

In [None]:
#Player around.  Use some of this code.
df_pressure_diff = df.select(ST_Point(col("LONGITUDE"), col("LATITUDE")),"STATION", "DATE", "LATITUDE", "LONGITUDE", "ELEVATION", "NAME", "REPORT_TYPE", "SOURCE", "HourlyDryBulbTemperature", "HourlyPressureChange", "HourlyPressureTendency", "HourlySeaLevelPressure", "HourlyStationPressure", "HourlyWindDirection", "HourlyWindGustSpeed", "HourlyWindSpeed").filter(col("HourlyPressureChange").isNotNull()).show()
df.filter(col("HourlySeaLevelPressure").isNotNull() | col("HourlyStationPressure").isNotNull()).count()
#counts of records with different filters
df.count()
#130,112,717 37,352,572 77,319,947 77,210,243
aggregated = df.groupby('STATION', 'DATE').agg({abs('HourlyPressureChange'): 'min', abs('HourlyPressureChange'): 'min'})
df['AbsPressure'] = df['HourlyPressureChange'].abs()
from pyspark.sql import functions as F
df = df.withColumn("absPressure", F.abs(F.col("HourlyPressureChange")))
df.tail(10)
df.agg(F.min('HourlyPressureChange')).show()
df.filter(col('HourlyPressureChange').isNotNull()).withColumn("HourlyPressureChange", col("HourlyPressureChange").cast(FloatType())).groupby('STATION', 'DATE')\
.agg({abs('HourlyPressureChange'): 'min', abs('HourlyPressureChange'): 'min'})

In [None]:
#Using Overtures instread
#import geopandas as gpd

#url = "https://naciscdn.org/naturalearth/110m/cultural/ne_110m_admin_0_countries.zip"


#gdf = gpd.read_file(url)
#gdf
#df_conus = sedona.createDataFrame(gdf[(gdf.SOV_A3=='US1') & (gdf.TYPE=='Country')][['SOVEREIGNT', 'geometry']])
#map = SedonaKepler.create_map(df=df_conus, name="CONUS")
#map


In [None]:
spatial_df.first()['geometry']

In [9]:
OVERTURE_RELEASE = "2025-09-24.0"
COUNTRY_CODES_OF_INTEREST = ["US"]
SOURCE_DATA_URL = f"s3a://overturemaps-us-west-2/release/{OVERTURE_RELEASE}/theme=divisions/type=division_area"
OUTPUT_FILE = "my_super_cool_data.parquet"

country_overlap_condition = F.arrays_overlap(
    F.col("country"),
    F.array(*[F.lit(x.upper()) for x in COUNTRY_CODES_OF_INTEREST]),
)

source_df = (
    sd.read.format("geoparquet")
    .load(SOURCE_DATA_URL)
    .filter(col("country").isin(COUNTRY_CODES_OF_INTEREST))
    .filter(col("region")=='US-CA')
    #.filter(col("subtype")=='country')
    .withColumn("_overture_release_version", F.lit(OVERTURE_RELEASE))
    .withColumn("_ingest_timestamp", F.current_timestamp())
)

source_df.selectExpr("geometry", "region").write.format("geoparquet").mode("overwrite").save("./CA_geom")
CA_geom_pq = sd.read.format("geoparquet").load("./CA_geom")
CA_geom_pq.show(5)
map = SedonaKepler.create_map(CA_geom_pq, name="CA")
map

                                                                                

+--------------------+------+
|            geometry|region|
+--------------------+------+
|POLYGON ((-118.31...| US-CA|
|POLYGON ((-118.49...| US-CA|
|POLYGON ((-118.31...| US-CA|
|POLYGON ((-118.30...| US-CA|
|POLYGON ((-118.28...| US-CA|
+--------------------+------+
only showing top 5 rows



KeplerGl(data={'CA': {'index': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, …

In [None]:
#Used for testing with just CA
#CA_geom = source_df.selectExpr("geometry", "region")
#CA_geom.show(5)
#map1 = SedonaKepler.create_map(CA_geom, name="CA")
#map1
#
#CA_geom = source_df.selectExpr("geometry", "region").filter(GeometryType(col('geometry'))=='MULTIPOLYGON')
#CA_geom.show(5)
#map2 = SedonaKepler.create_map(CA_geom, name="CA")
#map2
#
#CA_geom = source_df.selectExpr("geometry", "region").filter(GeometryType(col('geometry'))=='POLYGON')
#CA_geom.show(5)
#map3 = SedonaKepler.create_map(CA_geom, name="CA")
#map3

In [60]:
# 2. Define a coordinate grid with a 0.5-degree step
longitude_step = .5
latitude_step = .5

longitudes = [i * longitude_step for i in range(int(-180 / longitude_step), int(180 / longitude_step) + 1)]
latitudes = [i * latitude_step for i in range(int(-90 / latitude_step), int(90 / latitude_step) + 1)]

# 3. Generate a list of all coordinate pairs
coordinate_pairs = list(product(longitudes, latitudes))

# 4. Create a Spark DataFrame from the list of coordinates
schema = ["longitude", "latitude"]
df_lat_lon = sd.createDataFrame(coordinate_pairs, schema=schema)

# 5. Create the Sedona geometry points
# ST_Point takes longitude first, then latitude.
spatial_df = df_lat_lon.withColumn(
    "interp_GEOM",
    F.expr(f"ST_Point(longitude, latitude)")
)
# Show the resulting DataFrame clipped to USA boundaries
print("Generated spatial DataFrame:")
df_ca_points = spatial_df.join(
    CA_geom_pq,
    st_predicates.ST_Within(spatial_df["interp_GEOM"], CA_geom_pq["geometry"]),
    "inner"
)
# Show the results
df_ca_points.show(5)
df_ca_points.printSchema()

Generated spatial DataFrame:
+---------+--------+-----------------+--------------------+------+
|longitude|latitude|      interp_GEOM|            geometry|region|
+---------+--------+-----------------+--------------------+------+
|   -124.0|    40.0|  POINT (-124 40)|MULTIPOLYGON (((-...| US-CA|
|   -124.0|    40.0|  POINT (-124 40)|MULTIPOLYGON (((-...| US-CA|
|   -124.0|    40.0|  POINT (-124 40)|POLYGON ((-123.48...| US-CA|
|   -124.0|    40.0|  POINT (-124 40)|POLYGON ((-123.50...| US-CA|
|   -124.0|    40.5|POINT (-124 40.5)|MULTIPOLYGON (((-...| US-CA|
+---------+--------+-----------------+--------------------+------+
only showing top 5 rows

root
 |-- longitude: double (nullable = true)
 |-- latitude: double (nullable = true)
 |-- interp_GEOM: geometry (nullable = true)
 |-- geometry: geometry (nullable = true)
 |-- region: string (nullable = true)



In [61]:
df_iwd_pq = sd.read.format("geoparquet").load("./knn_join")

In [65]:
df_iwd_pq.join(
    CA_geom_pq,
    st_predicates.ST_Within(df_iwd_pq["interp_GEOM"], CA_geom_pq["geometry"]),
    "inner"
).drop(*["interpolated_value", "avg_dist", "geometry", "region"]).write.format("geoparquet").mode("overwrite").save("./ca_interp_points")
df_interp_pq = sd.read.format("geoparquet").load("./ca_interp_points")
df_interp_pq.show(5)
df_interp_pq.printSchema()
print(df_interp_pq.count())

+-------------------+
|        interp_GEOM|
+-------------------+
|    POINT (-117 33)|
|    POINT (-117 33)|
|  POINT (-116.5 33)|
|POINT (-117.5 33.5)|
|    POINT (-117 33)|
+-------------------+
only showing top 5 rows

root
 |-- interp_GEOM: geometry (nullable = true)

379


25/11/10 04:54:53 WARN GeoParquetFileFormat: GeoParquet currently does not support vectorized reader. Falling back to parquet-mr
25/11/10 04:54:53 WARN GeoParquetFileFormat: GeoParquet currently does not support vectorized reader. Falling back to parquet-mr
