# Geospark Tutorial - Spatial Joins

---

### Process
1. Initialise a Spark session with Geospark enabled
2. Load boundary and point datasets
3. Convert them to geospatial DataFrames
3. Perform a point in polygon spatial join

---

Import Python packages

In [1]:
import os

from multiprocessing import cpu_count
from pyspark.sql import SparkSession
from geospark.register import upload_jars, GeoSparkRegistrator
from geospark.utils import KryoSerializer, GeoSparkKryoRegistrator

Set parameters

In [2]:
# input path for parquet files
input_path = os.path.join(os.getcwd(), "data")

# number of processes to use (defaults to 2x physical CPUs)
num_processors = cpu_count() * 2

Copy Geospark's Java libraries to Spark

In [3]:
upload_jars()

True

Create the Spark session

In [4]:
spark = SparkSession \
    .builder \
    .master("local[*]") \
    .appName("query") \
    .config("spark.sql.session.timeZone", "UTC") \
    .config("spark.sql.debug.maxToStringFields", 100) \
    .config("spark.serializer", KryoSerializer.getName) \
    .config("spark.kryo.registrator", GeoSparkKryoRegistrator.getName) \
    .config("spark.sql.adaptive.enabled", "true") \
    .config("spark.executor.cores", 1) \
    .config("spark.cores.max", num_processors) \
    .config("spark.driver.memory", "8g") \
    .config("spark.driver.maxResultSize", "1g") \
    .getOrCreate()

print("Spark {} session initialised".format(spark.version))

Spark 2.4.6 session initialised


Register Geospark's User Defined Types (UDTs) and Functions (UDFs) with the Spark session

In [5]:
GeoSparkRegistrator.registerAll(spark)

True

Load boundary data from gzipped parquet files.

Boundary geometries are polygons stored as OGC Well Known Text (WKT) strings.

In [6]:
bdy_wkt_df = spark.read.parquet(os.path.join(input_path, "boundaries"))
bdy_wkt_df.printSchema()
bdy_wkt_df.show(5)

print("Loaded {} records".format(bdy_wkt_df.count()))

root
 |-- bdy_id: string (nullable = true)
 |-- wkt_geom: string (nullable = true)

+------+--------------------+
|bdy_id|            wkt_geom|
+------+--------------------+
|  RA10|POLYGON((149.1082...|
|  RA10|POLYGON((149.1914...|
|  RA10|POLYGON((149.1914...|
|  RA10|POLYGON((149.1914...|
|  RA10|POLYGON((150.6666...|
+------+--------------------+
only showing top 5 rows

Loaded 17540 records


Create a view of the DataFrame to enable SQL queries

In [7]:
bdy_wkt_df.createOrReplaceTempView("bdy_wkt")

Load the point records from parquet files. The spatial data is stored in latitude & longitude (double precision) fields

In [8]:
point_wkt_df = spark.read.parquet(os.path.join(input_path, "points"))
point_wkt_df.printSchema()
point_wkt_df.show(5, False)

print("Loaded {} records".format(point_wkt_df.count()))

# create view to enable SQL queries
point_wkt_df.createOrReplaceTempView("point_wkt")

root
 |-- point_id: string (nullable = true)
 |-- latitude: double (nullable = true)
 |-- longitude: double (nullable = true)

+-----------+-------------------+------------------+
|point_id   |latitude           |longitude         |
+-----------+-------------------+------------------+
|60000010000|-41.397873825618845|148.30298463001887|
|60000020000|-40.91276600258191 |148.3205616825289 |
|60000030000|-40.92217845040206 |148.32129834871355|
|60000040000|-40.914773511737074|148.32338432564558|
|60000050000|-40.91553126097137 |148.32378038710883|
+-----------+-------------------+------------------+
only showing top 5 rows

Loaded 358009 records


Create a DataFrame of boundary IDs and geometries (i.e geospatial objects).

They'll be spatially indexed automatically, enabling fast querying.

In [9]:
bdy_df = spark.sql("select bdy_id, st_geomFromWKT(wkt_geom) as geom from bdy_wkt")
bdy_df.printSchema()
bdy_df.show(5)

# create view to enable SQL queries
bdy_df.createOrReplaceTempView("bdy")

root
 |-- bdy_id: string (nullable = true)
 |-- geom: geometry (nullable = false)

+------+--------------------+
|bdy_id|                geom|
+------+--------------------+
|  RA10|POLYGON ((149.108...|
|  RA10|POLYGON ((149.191...|
|  RA10|POLYGON ((149.191...|
|  RA10|POLYGON ((149.191...|
|  RA10|POLYGON ((150.666...|
+------+--------------------+
only showing top 5 rows



Create a DataFrame of point IDs and geometries. Note the current limitation requiring Decimal lat/long fields as input.

In [10]:
sql = """select point_id,
                st_point(cast(longitude as decimal(9, 6)), cast(latitude as decimal(8, 6))) as geom
         from point_wkt"""
point_df = spark.sql(sql)
point_df.printSchema()
point_df.show(5, False)

# create view to enable SQL queries
point_df.createOrReplaceTempView("pnt")

root
 |-- point_id: string (nullable = true)
 |-- geom: geometry (nullable = false)

+-----------+-----------------------------+
|point_id   |geom                         |
+-----------+-----------------------------+
|60000010000|POINT (148.302985 -41.397874)|
|60000020000|POINT (148.320562 -40.912766)|
|60000030000|POINT (148.321298 -40.922178)|
|60000040000|POINT (148.323384 -40.914774)|
|60000050000|POINT (148.32378 -40.915531) |
+-----------+-----------------------------+
only showing top 5 rows



Run a spatial join to boundary tag the points. Note it's an inner join, so point records could be lost

In [11]:
sql = """SELECT pnt.point_id,
                bdy.bdy_id,
                pnt.geom
         FROM pnt
         INNER JOIN bdy ON ST_Intersects(pnt.geom, bdy.geom)"""
join_df = spark.sql(sql)

join_count = join_df.count()

join_df.printSchema()
join_df.show(5, False)

print("Boundary tagged {} points".format(join_count))

root
 |-- point_id: string (nullable = true)
 |-- bdy_id: string (nullable = true)
 |-- geom: geometry (nullable = false)

+-----------+------+-----------------------------+
|point_id   |bdy_id|geom                         |
+-----------+------+-----------------------------+
|90000200000|RA94  |POINT (96.888063 -12.197968) |
|90000260000|RA94  |POINT (96.907247 -12.124957) |
|90000101000|RA94  |POINT (105.634587 -10.488663)|
|90000113000|RA94  |POINT (105.679505 -10.420269)|
|90000102000|RA94  |POINT (105.672533 -10.437645)|
+-----------+------+-----------------------------+
only showing top 5 rows

Boundary tagged 357720 points


Close the Spark session and release its resources.

In [12]:
spark.stop()

All done!