# Apache Sedona Tutorial - Spatial Join

This tutorial shows you how to create & analyse geospatial dataframes in Spark using Apache Sedona; as well as visualise the results

---

### Process
1. Initialise a Spark session with Sedona enabled
2. Load boundary & point dataframes from parquet files
3. Convert them to geospatial dataframes
4. Perform a spatial join
    - Fix the performance
5. Export the result to a Geopandas dataframe
6. Map the points, coloured by boundary type

---

### Data Used

- Boundary data is the ABS 2016 Census Remoteness Areas
- Point data is a randomised set of points, based on ABS 2016 Census Meshblock centroids

© Australian Bureau of Statistics (ABS), Commonwealth of Australia



---

### Import packages and set parameters

In [None]:
# import Python packages

import os

from multiprocessing import cpu_count

from pyspark.sql import SparkSession
from pyspark.sql import functions as f

from sedona.register import SedonaRegistrator
from sedona.utils import SedonaKryoRegistrator, KryoSerializer

In [None]:
# set input path for parquet files
input_path = os.path.join(os.getcwd(), "data")
print(input_path)

# set max number of processes (defaults to number of physical CPUs)
num_processors = cpu_count()

### Create the Spark session

In [None]:
spark = (SparkSession
         .builder
         .master("local[*]")
         .appName("Spatial Join Tutorial")
         .config("spark.serializer", KryoSerializer.getName)
         .config("spark.kryo.registrator", SedonaKryoRegistrator.getName)
         .config("spark.cores.max", num_processors)
         .getOrCreate()
         )

print("Spark {} session initialised (ignore any 'illegal reflective' errors - it's a minor Java 11 issue)".format(spark.version))

In [None]:
# Register Sedona's User Defined Types (UDTs) and Functions (UDFs) with the Spark session
SedonaRegistrator.registerAll(spark)

### Load dataframes

#### 1. Load boundary data from gzipped parquet files

Boundary geometries are polygons stored as OGC Well Known Text (WKT) strings.

In [None]:
bdy_wkt_df = spark.read.parquet(os.path.join(input_path, "boundaries"))
bdy_wkt_df.printSchema()
bdy_wkt_df.show(5)

print("Loaded {} records".format(bdy_wkt_df.count()))

In [None]:
# add bdy number (last character of bdy ID) to bdy type - to enable display ordering in map
bdy_wkt_df2 = bdy_wkt_df \
    .withColumn("bdy_type", f.concat(f.substring(bdy_wkt_df["bdy_id"], -1, 1), f.lit(" - "), bdy_wkt_df["bdy_type"]))

# show 5 rows ordered randomly
bdy_wkt_df2.orderBy(f.rand()).show(5)

In [None]:
# Create a view of the DataFrame to enable SQL queries
bdy_wkt_df2.createOrReplaceTempView("bdy_wkt")

#### 2. Load point data

Spatial data is stored in latitude & longitude (double precision) fields

In [None]:
point_wkt_df = spark.read.parquet(os.path.join(input_path, "points"))
point_wkt_df.printSchema()
point_wkt_df.show(5, False)

print("Loaded {} records".format(point_wkt_df.count()))

# create view to enable SQL queries
point_wkt_df.createOrReplaceTempView("point_wkt")

### Create geospatial dataframes

#### 1. Create boundary geometries from WKT strings

In [None]:
bdy_df = spark.sql("select bdy_id, bdy_type, state, ST_GeomFromWKT(wkt_geom) as geom from bdy_wkt") \
    .repartition(96, "state")
bdy_df.printSchema()
bdy_df.show(5)

print("{} partitions".format(bdy_df.rdd.getNumPartitions()))

# create view to enable SQL queries
bdy_df.createOrReplaceTempView("bdy")

#### 2. Create point geometries from lat/long fields

In [None]:
point_df = spark.sql("select point_id, state, ST_Point(longitude, latitude) as geom from point_wkt") \
    .repartition(96, "state")
point_df.printSchema()
point_df.show(5, False)

print("{} partitions".format(point_df.rdd.getNumPartitions()))

# create view to enable SQL queries
point_df.createOrReplaceTempView("pnt")

### Run a spatial join to boundary tag the points

##### Note:
1. One of the dataframes will be spatially indexed automatically to speed up the query
2. It's an inner join; point records could be lost in coastal areas or where there are gaps in the boundaries

In [None]:
from datetime import datetime

start_time = datetime.now()

sql = """SELECT pnt.point_id,
                bdy.bdy_id,
                bdy.bdy_type,
                bdy.state,
                pnt.geom
         FROM pnt
         INNER JOIN bdy ON ST_Intersects(pnt.geom, bdy.geom)"""
join_df = spark.sql(sql) \
    .cache()  # cache can save processing time when calling the same dataframe more than once
join_df.printSchema()
join_df.show(5, False)

join_count = join_df.count()

print("Boundary tagged {} points".format(join_count))
print("Query took {}".format(datetime.now() - start_time))

### Export result to Geopandas for visualisation

*Note: doesn't scale to big data*

In [None]:
import geopandas

# convert to Pandas dataframe first (ordered by bdy_type for better visualisation)
pandas_df = join_df.orderBy(f.desc("bdy_type")).toPandas()

# then convert to Geopandas dataframe
geopandas_df = geopandas.GeoDataFrame(pandas_df, geometry="geom")

geopandas_df

### Map the result

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, figsize=(20, 20))

# set background colour
ax.set_facecolor('#EEEEEE')

# create map of points by bdy type
geopandas_df.plot(
    column="bdy_type",
    legend=True,
    cmap='YlOrRd',
    ax=ax
)

In [None]:
# map NSW only
fig2, ax2 = plt.subplots(1, figsize=(20, 20))
ax2.set_facecolor('#EEEEEE')

geopandas_df.loc[geopandas_df["state"] == "New South Wales"].plot(
    column="bdy_type",
    legend=True,
    cmap='YlOrRd',
    ax=ax2
)

### Close the Spark session

In [None]:
spark.stop()