In [1]:
%load_ext autoreload
%autoreload 2
from geolifeclef.utils import get_spark
from pyspark.sql import functions as F

spark = get_spark()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/04/02 08:40:59 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/04/02 08:40:59 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).
24/04/02 08:41:00 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [2]:
metadata = spark.read.parquet(
    "gs://dsgt-clef-geolifeclef-2024/data/processed/metadata_clean/v1"
)
metadata.printSchema()
metadata.show()

                                                                                

root
 |-- dataset: string (nullable = true)
 |-- surveyId: integer (nullable = true)
 |-- lat_proj: double (nullable = true)
 |-- lon_proj: double (nullable = true)
 |-- lat: double (nullable = true)
 |-- lon: double (nullable = true)
 |-- year: integer (nullable = true)
 |-- geoUncertaintyInM: double (nullable = true)
 |-- speciesId: double (nullable = true)



                                                                                

+--------+--------+-------------------+--------------------+---------+---------+----+-----------------+---------+
| dataset|surveyId|           lat_proj|            lon_proj|      lat|      lon|year|geoUncertaintyInM|speciesId|
+--------+--------+-------------------+--------------------+---------+---------+----+-----------------+---------+
|      po|  599428|-3546710.1309880773|1.5869697752944473E7| 41.35689|-3.323538|2020|              2.0|   3932.0|
|pa_train| 3707965|-1614858.9153893404| 1.690952987363603E7| 57.04466|  9.05137|2017|             10.0|   7739.0|
|pa_train|  331966|-1882413.1908692368|1.5621125512619067E7| 46.32355|   13.909|2020|              0.0|   2322.0|
|pa_train| 2118745| -1581919.826367132|1.6785045224298324E7| 56.27395| 10.46924|2019|             10.0|   2885.0|
|      po|  388457| -2949950.789550384|1.6390396871864012E7| 47.28818|-1.629925|2019|             45.6|   4617.0|
|pa_train|  226994|-1604956.8148379833| 1.690090014821878E7| 57.02554|  9.25348|2018|   

In [3]:
metadata.select("geoUncertaintyInM").describe().show()

[Stage 2:>                                                          (0 + 4) / 4]

+-------+------------------+
|summary| geoUncertaintyInM|
+-------+------------------+
|  count|           6555615|
|   mean|16.874068541243098|
| stddev| 19.06872761715543|
|    min|               0.0|
|    max|             100.0|
+-------+------------------+



                                                                                

In [4]:
# let's find nearest neighbors via lsh
from pyspark.ml import Pipeline
from pyspark.ml.feature import VectorAssembler, BucketedRandomProjectionLSH

pipeline = Pipeline(
    stages=[
        VectorAssembler(inputCols=["lat_proj", "lon_proj"], outputCol="features"),
        BucketedRandomProjectionLSH(
            inputCol="features", outputCol="hashes", bucketLength=20, numHashTables=5
        ),
    ]
)

train = metadata.where(F.col("speciesId").isNotNull())
model = pipeline.fit(train)
transformed = (
    model.transform(metadata)
    .select("speciesId", "features", "hashes")
    .limit(1000)
    .cache()
)

In [6]:
joined = model.stages[-1].approxSimilarityJoin(transformed, transformed, 20)
joined.printSchema()
joined.show()

root
 |-- datasetA: struct (nullable = false)
 |    |-- speciesId: double (nullable = true)
 |    |-- features: vector (nullable = true)
 |    |-- hashes: array (nullable = true)
 |    |    |-- element: vector (containsNull = true)
 |-- datasetB: struct (nullable = false)
 |    |-- speciesId: double (nullable = true)
 |    |-- hashes: array (nullable = true)
 |    |    |-- element: vector (containsNull = true)
 |    |-- features: vector (nullable = true)
 |-- distCol: double (nullable = false)

+--------------------+--------------------+-------+
|            datasetA|            datasetB|distCol|
+--------------------+--------------------+-------+
|{3932.0, [-354671...|{3932.0, [[363540...|    0.0|
|{7739.0, [-161485...|{7739.0, [[282345...|    0.0|
|{2322.0, [-188241...|{2322.0, [[279785...|    0.0|
|{2885.0, [-158191...|{2885.0, [[279245...|    0.0|
|{4617.0, [-294995...|{4617.0, [[340865...|    0.0|
|{581.0, [-1604956...|{581.0, [[281760....|    0.0|
|{9816.0, [-155215...|{9816.0, [