In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

sys.path.append(os.path.expanduser("~/clef/plantclef-2025/"))

In [3]:
%env PYSPARK_DRIVER_MEMORY=32g
%env PYSPARK_EXECUTOR_MEMORY=16g
%env SPARK_LOCAL_DIR=/tmp/spark-tmp

env: PYSPARK_DRIVER_MEMORY=32g
env: PYSPARK_EXECUTOR_MEMORY=16g
env: SPARK_LOCAL_DIR=/tmp/spark-tmp


In [4]:
from plantclef.spark import get_spark

spark = get_spark()
display(spark)

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/03/06 03:17:40 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/03/06 03:17:40 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).


In [5]:
from pathlib import Path

home = Path.home()
train_embeddings_path = home / "scratch/plantclef/data/embeddings/train/data"

train_df = spark.read.parquet(str(train_embeddings_path))
train_df.printSchema()
train_df.show(5)

                                                                                

root
 |-- image_name: string (nullable = true)
 |-- species_id: integer (nullable = true)
 |-- cls_embedding: array (nullable = true)
 |    |-- element: float (containsNull = true)
 |-- sample_id: integer (nullable = true)



                                                                                

+--------------------+----------+--------------------+---------+
|          image_name|species_id|       cls_embedding|sample_id|
+--------------------+----------+--------------------+---------+
|3a2c58a78ee93b471...|   1363472|[0.9020945, 0.016...|       15|
|0a0bf86d70307e8db...|   1361957|[-0.26025677, -0....|       15|
|7990901729be71186...|   1363472|[-0.1633016, -0.0...|       15|
|80257a4818f5955f9...|   1392612|[-0.7547744, 0.39...|       15|
|e13e476d0dc36ed7b...|   1360562|[0.14017674, 0.05...|       15|
+--------------------+----------+--------------------+---------+
only showing top 5 rows



In [6]:
from pyspark.sql import functions as F

grouped_df = train_df.groupBy("species_id") \
    .agg(F.collect_list("cls_embedding").alias("embeddings_list"))
    
len(train_df.select("cls_embedding").first()[0])

768

In [None]:
from pyspark.sql.types import StructType, StructField, IntegerType, ArrayType, FloatType
import pandas as pd
import numpy as np
import faiss


num_centroids = 50

schema = StructType([
    StructField("centroid_id", IntegerType()),
    StructField("species_id", IntegerType()),
    StructField("embedding", ArrayType(FloatType()))
])

def compute_centroids(pdf):
    """Compute centroids for each species group using FAISS and return as rows"""
    results = []
    
    embedding_dim = 768
    
    # res = faiss.StandardGpuResources()
    
    for species_id, group_pdf in pdf.groupby('species_id'):
        all_embeddings = []
        for row in group_pdf['embeddings_list']:
            all_embeddings.extend(row)
        
        embeddings = np.vstack(all_embeddings).astype(np.float32)
        
        actual_centroids = min(num_centroids, len(embeddings))
        
        cpu_index = faiss.IndexFlatL2(embedding_dim)
        # gpu_index = faiss.index_cpu_to_gpu(res, 0, cpu_index)
        
        kmeans = faiss.Clustering(embedding_dim, actual_centroids)
        kmeans.niter = 20
        kmeans.min_points_per_centroid = 1
        kmeans.train(embeddings, cpu_index)
        # kmeans.train(embeddings, gpu_index)
        
        centroids_array = faiss.vector_float_to_array(kmeans.centroids)
        centroids_reshaped = centroids_array.reshape(actual_centroids, embedding_dim)
        
        for centroid_id in range(actual_centroids):
            centroid_vector = centroids_reshaped[centroid_id].tolist()
            results.append((centroid_id, species_id, centroid_vector))
    
    return pd.DataFrame(results, columns=["centroid_id", "species_id", "embedding"])

num_partitions = grouped_df.rdd.getNumPartitions()
centroids_df = grouped_df.groupBy().applyInPandas(compute_centroids, schema).repartition(num_partitions)



In [9]:
output_path = home / f"scratch/plantclef/data/parquet/train_centroids/num_centroids={num_centroids}"
centroids_df.write.mode("overwrite").parquet(str(output_path))

                                                                                

In [8]:
centroids_df.show(5)

[Stage 9:>                                                          (0 + 1) / 1]

+-----------+----------+--------------------+
|centroid_id|species_id|           embedding|
+-----------+----------+--------------------+
|          6|   1391964|[1.0990983, 0.259...|
|          9|   1360169|[-0.5898131, 0.29...|
|          5|   1389806|[0.5803415, -0.23...|
|          3|   1393591|[-0.24192038, -0....|
|          7|   1361603|[0.030800506, -0....|
+-----------+----------+--------------------+
only showing top 5 rows



                                                                                