k-Nearest Neighbors (kNN) joins in WherobotsDB

A geospatial k-Nearest Neighbors (kNN) join is a specialized form of the kNN join that specifically deals with geospatial data. This method involves identifying the k-nearest neighbors for a given spatial point or region based on geographic proximity, typically using spatial coordinates and a suitable distance metric like Euclidean or great-circle distance.

Approximate kNN Join

The approximate k-Nearest Neighbor (kNN) algorithm utilizes an approximation method to map and distribute multidimensional data into a single dimension while maintaining some degree of locality. This approach allows for the efficient generation of approximate k-Nearest Neighbors for each geometry in the query dataset.

Exact kNN Join

The exact k-Nearest Neighbor (kNN) algorithm begins by partitioning the dataset to maintain spatial locality. It then constructs an efficient structure over another dataset to quickly find accurate kNN matches. By combining local results, the algorithm delivers the complete kNN join for both datasets.

In [1]:
from sedona.spark import *

config = SedonaContext.builder().getOrCreate()
sedona = SedonaContext.create(config)


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
                                                                                

In this use case, the k-Nearest Neighbor (kNN) join is applied to match places with nearby flights. The Queries table contains the locations of Overture Maps places, such as schools, businesses, hospitals, religious organizations, landmarks, and mountain peak, while the Objects table holds the locations of flights. The goal is to find which flights are closest to each place, which can be crucial for making real-time decisions in air traffic management and ensuring flight safety.

In [2]:
# Wyoming state boundary
spatial_filter = "POLYGON((-104.0556 41.0037,-104.0584 44.9949,-111.0539 44.9998,-111.0457 40.9986,-104.0556 41.0006,-104.0556 41.0037))"

Queries Table: Places
This table contains the objects for which I want to find the nearest neighbors.

In [3]:
from pyspark.sql.functions import col
from pyspark.sql.functions import monotonically_increasing_id, col

In [39]:
# load data
df_queries = sedona.table("wherobots_open_data.overture.places_place")

In [40]:
df_queries.printSchema()

root
 |-- id: string (nullable = true)
 |-- updatetime: string (nullable = true)
 |-- version: integer (nullable = true)
 |-- names: map (nullable = true)
 |    |-- key: string
 |    |-- value: array (valueContainsNull = true)
 |    |    |-- element: map (containsNull = true)
 |    |    |    |-- key: string
 |    |    |    |-- value: string (valueContainsNull = true)
 |-- categories: struct (nullable = true)
 |    |-- main: string (nullable = true)
 |    |-- alternate: array (nullable = true)
 |    |    |-- element: string (containsNull = true)
 |-- confidence: double (nullable = true)
 |-- websites: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- socials: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- emails: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- phones: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- brand: struct (nullable = true)
 |    |-- names: map (nulla

In [42]:
# Tworzymy zapytanie SQL
sedona.sql("""
    CREATE OR REPLACE TEMP VIEW selected_columns AS
    SELECT
        id,
        names
    FROM wherobots_open_data.overture.places_place
""")

# Odczytujemy wybrane kolumny z tabeli
df_selected = sedona.table("selected_columns")

# Możesz teraz pracować z df_selected
df_selected.show()

+--------------------+--------------------+
|                  id|               names|
+--------------------+--------------------+
|tmp_E0A4EB35FCFB2...|{common -> [{valu...|
|tmp_1F0F886124540...|{common -> [{valu...|
|tmp_0B84DA7C47022...|{common -> [{valu...|
|tmp_8CA74E2D99550...|{common -> [{valu...|
|tmp_609ABFA9C7DBC...|{common -> [{valu...|
|tmp_B69F92A75A9D3...|{common -> [{valu...|
|tmp_89F47CE8D2182...|{common -> [{valu...|
|tmp_2CDA20AD7CBAC...|{common -> [{valu...|
|tmp_D48FFE2496D33...|{common -> [{valu...|
|tmp_000D5178CD984...|{common -> [{valu...|
|tmp_BF199FDEDADEB...|{common -> [{valu...|
|tmp_3614165B0CB29...|{common -> [{valu...|
|tmp_655C61CF86477...|{common -> [{valu...|
|tmp_638F36CFFD826...|{common -> [{valu...|
|tmp_01E78355F2B29...|{common -> [{valu...|
|tmp_270DA60D293EC...|{common -> [{valu...|
|tmp_9A1FB9BEB5147...|{common -> [{valu...|
|tmp_7ED9BE955EFE3...|{common -> [{valu...|
|tmp_53F9A21CD7586...|{common -> [{valu...|
|tmp_44A2704E4276A...|{common ->

                                                                                

In [43]:

df_queries = df_queries.withColumn("id", monotonically_increasing_id())
df_queries = df_queries.filter(f"ST_Contains(ST_GeomFromWKT('{spatial_filter}'), geometry) = true")

In [44]:
df_queries.printSchema()

root
 |-- id: long (nullable = false)
 |-- updatetime: string (nullable = true)
 |-- version: integer (nullable = true)
 |-- names: map (nullable = true)
 |    |-- key: string
 |    |-- value: array (valueContainsNull = true)
 |    |    |-- element: map (containsNull = true)
 |    |    |    |-- key: string
 |    |    |    |-- value: string (valueContainsNull = true)
 |-- categories: struct (nullable = true)
 |    |-- main: string (nullable = true)
 |    |-- alternate: array (nullable = true)
 |    |    |-- element: string (containsNull = true)
 |-- confidence: double (nullable = true)
 |-- websites: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- socials: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- emails: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- phones: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- brand: struct (nullable = true)
 |    |-- names: map (nullab

In [46]:
# Rejestrujemy df_queries jako tymczasowy widok
df_queries.createOrReplaceTempView("df_queries")

# Teraz możemy wykonać zapytanie SQL
sedona.sql("""
    CREATE OR REPLACE TEMP VIEW selected_columns AS
    SELECT
        id,
        names
    FROM df_queries
""")

# Odczytujemy wybrane kolumny z tabeli
df_selected1 = sedona.table("selected_columns")

# Możesz teraz pracować z df_selected
df_selected1.show()



+------------+--------------------+
|          id|               names|
+------------+--------------------+
|661424963604|{common -> [{valu...|
|661424963605|{common -> [{valu...|
|661424963607|{common -> [{valu...|
|661424963608|{common -> [{valu...|
|661424963622|{common -> [{valu...|
|661424963624|{common -> [{valu...|
|661424963636|{common -> [{valu...|
|661424963653|{common -> [{valu...|
|661424963655|{common -> [{valu...|
|661424963676|{common -> [{valu...|
|661424963682|{common -> [{valu...|
|661424963686|{common -> [{valu...|
|661424963689|{common -> [{valu...|
|661424963692|{common -> [{valu...|
|661424963694|{common -> [{valu...|
|661424963698|{common -> [{valu...|
|661424963714|{common -> [{valu...|
|661424963718|{common -> [{valu...|
|661424963732|{common -> [{valu...|
|661424963736|{common -> [{valu...|
+------------+--------------------+
only showing top 20 rows



                                                                                

In [47]:
from pyspark.sql.functions import countDistinct

# Liczymy wszystkie wiersze
total_rows = df_queries.count()

# Liczymy unikalne wartości w kolumnie "id"
distinct_ids = df_queries.select(countDistinct("id")).collect()[0][0]

# Sprawdzamy, czy liczba wierszy jest równa liczbie unikalnych "id"
if total_rows != distinct_ids:
    print("Są duplikaty w kolumnie id.")
else:
    print("Brak duplikatów w kolumnie id.")



Brak duplikatów w kolumnie id.


                                                                                

In [48]:
df_queries = df_queries.repartition(100)

df_queries.cache()

df_queries.createOrReplaceTempView("queries")

print(df_queries.rdd.getNumPartitions())
print(df_queries.count())

100
29010


25/03/22 19:46:46 WARN CacheManager: Asked to cache already cached data.


In [49]:
df_queries.printSchema()

root
 |-- id: long (nullable = false)
 |-- updatetime: string (nullable = true)
 |-- version: integer (nullable = true)
 |-- names: map (nullable = true)
 |    |-- key: string
 |    |-- value: array (valueContainsNull = true)
 |    |    |-- element: map (containsNull = true)
 |    |    |    |-- key: string
 |    |    |    |-- value: string (valueContainsNull = true)
 |-- categories: struct (nullable = true)
 |    |-- main: string (nullable = true)
 |    |-- alternate: array (nullable = true)
 |    |    |-- element: string (containsNull = true)
 |-- confidence: double (nullable = true)
 |-- websites: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- socials: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- emails: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- phones: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- brand: struct (nullable = true)
 |    |-- names: map (nullab

In [50]:
type(df_queries)

pyspark.sql.dataframe.DataFrame

Objects Table: Flights
This table contains the objects that are potential neighbors to the objects in the Queries table.

In [51]:
# Load objects table
df_objects = sedona.read.format("geoparquet").load("s3a://wherobots-examples/data/examples/flights/2024_s2.parquet")
df_objects = df_objects.filter(f"ST_Contains(ST_GeomFromWKT('{spatial_filter}'), geometry) = true")
df_objects = df_objects.repartition(800)

df_objects.cache()

df_objects.createOrReplaceTempView("objects")

print(df_objects.rdd.getNumPartitions())
print(df_objects.count())

25/03/22 19:48:19 WARN CacheManager: Asked to cache already cached data.


800
1431091


In [52]:
df_objects.printSchema()

root
 |-- dbFlags: long (nullable = true)
 |-- desc: string (nullable = true)
 |-- icao: string (nullable = true)
 |-- ownOp: string (nullable = true)
 |-- r: string (nullable = true)
 |-- reg_details: struct (nullable = true)
 |    |-- description: string (nullable = true)
 |    |-- iso2: string (nullable = true)
 |    |-- iso3: string (nullable = true)
 |    |-- nation: string (nullable = true)
 |-- t: string (nullable = true)
 |-- timestamp: string (nullable = true)
 |-- trace: struct (nullable = true)
 |    |-- aircraft: struct (nullable = true)
 |    |    |-- alert: long (nullable = true)
 |    |    |-- alt_geom: long (nullable = true)
 |    |    |-- baro_rate: long (nullable = true)
 |    |    |-- category: string (nullable = true)
 |    |    |-- emergency: string (nullable = true)
 |    |    |-- flight: string (nullable = true)
 |    |    |-- geom_rate: long (nullable = true)
 |    |    |-- gva: long (nullable = true)
 |    |    |-- ias: long (nullable = true)
 |    |    |-- mac

The code executes a spatial query in Apache Sedona (GeoSpark) and computes k-nearest neighbors (k-NN) between two datasets.
The query performs spatial analysis using Sedona (GeoSpark).
QUERIES and OBJECTS are tables (or DataFrames) that contain spatial data.
ST_KNN(QUERIES.GEOMETRY, OBJECTS.GEOMETRY, 4, FALSE): This function finds the 4 nearest neighbors for each record in the QUERIES table relative to the objects in the OBJECTS table. The parameter FALSE means that we will not consider results that are exactly the same objects (filtering of independent neighbors).
ST_DISTANCESPHERE(QUERIES.GEOMETRY, OBJECTS.GEOMETRY): Calculates the spherical distance (on the surface of the globe) between two geometries.
ST_MAKELINE(QUERIES.GEOMETRY, OBJECTS.GEOMETRY): Creates a geometry line connecting points QUERIES.GEOMETRY and OBJECTS.GEOMETRY.
3. df_knn_join.cache()
The cache() function stores the query result in a cache, which makes subsequent operations or queries on this DataFrame faster because the data does not need to be recalculated or reloaded.
The code executes a query that finds the 4 nearest neighbors for each record in the QUERIES table for the objects in the OBJECTS table. This query also calculates the spherical distance and creates lines connecting pairs of geometries. Finally, the result is stored in memory and the number of rows in the result is counted and displayed. Using the cache() function speeds up subsequent operations on the data if it will be used multiple times.

In [53]:
%%time

df_knn_join = sedona.sql("""
SELECT
    QUERIES.GEOMETRY AS QUERIES_GEOM,
    QUERIES.ID AS QID,
    OBJECTS.GEOMETRY AS OBJECTS_GEOM,
    ST_DISTANCESPHERE(QUERIES.GEOMETRY, OBJECTS.GEOMETRY) AS DISTANCE,
    ST_MAKELINE(QUERIES.GEOMETRY, OBJECTS.GEOMETRY) AS LINE
FROM QUERIES
JOIN OBJECTS ON ST_KNN(QUERIES.GEOMETRY, OBJECTS.GEOMETRY, 4, FALSE)
""")

# cache for further queries and visualization
df_knn_join.cache()

total_count = df_knn_join.count()
print(total_count)

25/03/22 19:48:40 WARN CacheManager: Asked to cache already cached data.

116040
CPU times: user 8.95 ms, sys: 4.94 ms, total: 13.9 ms
Wall time: 5.2 s


                                                                                

In [38]:
df_knn_join.show(3)

+--------------------+------------+--------------------+------------------+--------------------+
|        QUERIES_GEOM|         QID|        OBJECTS_GEOM|          DISTANCE|                LINE|
+--------------------+------------+--------------------+------------------+--------------------+
|POINT (-110.885 4...|661425021290|POINT (-110.88769...|243.42429137166317|LINESTRING (-110....|
|POINT (-110.885 4...|661425021290|POINT (-110.88651...|485.20524474148715|LINESTRING (-110....|
|POINT (-110.885 4...|661425021290|POINT (-110.88520...| 643.5787783557646|LINESTRING (-110....|
+--------------------+------------+--------------------+------------------+--------------------+
only showing top 3 rows



In [55]:
df_knn_join.count()

                                                                                

116040

In [54]:
from pyspark.sql import functions as F

# Sprawdzamy, czy QID ma duplikaty
duplicate_qids = df_knn_join.groupBy("QID").count().filter(F.col("count") > 1)

# Wyświetlamy wyniki
duplicate_qids.show()



+------------+-----+
|         QID|count|
+------------+-----+
|661425182760|    4|
|661425228876|    4|
|661425021290|    4|
|661425252349|    4|
|661425141725|    4|
|661425113566|    4|
|661425140390|    4|
|661425204631|    4|
|661424981447|    4|
|661424977924|    4|
|661425215417|    4|
|661425014918|    4|
|661425010515|    4|
|661425188311|    4|
|661425220523|    4|
|661425033155|    4|
|661425224658|    4|
|661425230126|    4|
|661424975096|    4|
|661425003025|    4|
+------------+-----+
only showing top 20 rows



                                                                                

In [15]:
# Select N unique QID rows
df_unique_qid = df_knn_join.dropDuplicates(["QUERIES_GEOM"])

In [17]:
df_unique_qid.count()

                                                                                

26557

In [18]:
df_unique_qid.show(3)



+--------------------+------------+--------------------+------------------+--------------------+
|        QUERIES_GEOM|         QID|        OBJECTS_GEOM|          DISTANCE|                LINE|
+--------------------+------------+--------------------+------------------+--------------------+
|POINT (-104.0625 ...|661425201683|POINT (-104.06160...| 87.44867615586293|LINESTRING (-104....|
|POINT (-108.28125...|661425132764|POINT (-108.28196...| 970.7136646090928|LINESTRING (-108....|
|POINT (-108.28125...|661425177751|POINT (-108.10712...|15132.365068235968|LINESTRING (-108....|
+--------------------+------------+--------------------+------------------+--------------------+
only showing top 3 rows



                                                                                

In [19]:
# Perform an inner join to get all rows from join_df that have QIDs in unique_qid_df
df_related_rows = df_knn_join.join(df_unique_qid, on="QID", how="inner").select(df_knn_join["*"])

In [20]:
df_related_rows.count()

                                                                                

106228

In [21]:
df_related_rows.show(3)



+--------------------+------------+--------------------+------------------+--------------------+
|        QUERIES_GEOM|         QID|        OBJECTS_GEOM|          DISTANCE|                LINE|
+--------------------+------------+--------------------+------------------+--------------------+
|POINT (-110.885 4...|661425021290|POINT (-110.88769...|243.42429137166317|LINESTRING (-110....|
|POINT (-110.885 4...|661425021290|POINT (-110.88651...|485.20524474148715|LINESTRING (-110....|
|POINT (-110.885 4...|661425021290|POINT (-110.88520...| 643.5787783557646|LINESTRING (-110....|
+--------------------+------------+--------------------+------------------+--------------------+
only showing top 3 rows



                                                                                

In [22]:
df_knn_join.count()

                                                                                

116040

In [23]:
# create map for the results
map_view = SedonaKepler.create_map(df_unique_qid.select('QUERIES_GEOM'), name="PLACES")
SedonaKepler.add_df(map_view, df=df_related_rows.select('OBJECTS_GEOM', 'DISTANCE').withColumnRenamed("OBJECTS_GEOM", "geometry"), name="FLIGHTS")
SedonaKepler.add_df(map_view, df=df_related_rows.select('LINE', 'DISTANCE').withColumnRenamed("LINE", "geometry"), name="KNN LINES")

# show the map
map_view

User Guide: https://docs.kepler.gl/docs/keplergl-jupyter


                                                                                

KeplerGl(data={'PLACES': {'index': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, …