# Centroid grid predictions

Classify test data using centroid probabilities for grid of tiles.
We're using probabilities for the entire test image.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from plantclef.spark import get_spark

spark = get_spark(cores=4)
display(spark)

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/04/20 17:12:54 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/04/20 17:12:54 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).
25/04/20 17:12:54 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [3]:
import os
from pathlib import Path

# Get list of stored filed in cloud bucket
root = Path(os.path.expanduser("~"))
! date

Sun Apr 20 05:12:56 PM EDT 2025


### Faiss centroid probabilities 

In [4]:
# Path and dataset names
data_path = f"{root}/p-dsgt_clef2025-0/shared/plantclef/data/embeddings"


# Define the path to the train and test parquet files
def get_faiss_embed_path(num_centroids: int = 10):
    return f"{data_path}/train_centroids/num_centroids={num_centroids}"


# Read the parquet files into a spark DataFrame
faiss10_df = spark.read.parquet(get_faiss_embed_path(10))
faiss20_df = spark.read.parquet(get_faiss_embed_path(20))
faiss50_df = spark.read.parquet(get_faiss_embed_path(50))

# Show the data
faiss10_df.printSchema()
faiss10_df.show(n=5)

root
 |-- centroid_id: integer (nullable = true)
 |-- species_id: integer (nullable = true)
 |-- embedding: array (nullable = true)
 |    |-- element: float (containsNull = true)



                                                                                

+-----------+----------+--------------------+
|centroid_id|species_id|           embedding|
+-----------+----------+--------------------+
|          6|   1398243|[1.0948665, -0.35...|
|          7|   1647175|[-0.30202305, 0.7...|
|          3|   1360020|[0.049416766, 0.9...|
|          3|   1361527|[0.3046188, 0.885...|
|          8|   1359277|[-0.09973065, 0.6...|
+-----------+----------+--------------------+
only showing top 5 rows



In [5]:
test_path = f"{data_path}/test_2025/test_2025_embed_logits"
test_df = spark.read.parquet(test_path)
test_df.printSchema()

root
 |-- image_name: string (nullable = true)
 |-- output: struct (nullable = true)
 |    |-- cls_token: array (nullable = true)
 |    |    |-- element: float (containsNull = true)
 |    |-- logits: array (nullable = true)
 |    |    |-- element: float (containsNull = true)
 |-- sample_id: integer (nullable = true)



In [6]:
# Path and dataset names
data_path = f"{root}/p-dsgt_clef2025-0/shared/plantclef/data"
test_parquet_path = f"{data_path}/parquet/test_2025"
df = spark.read.parquet(test_parquet_path)
df.printSchema()

root
 |-- image_name: string (nullable = true)
 |-- path: string (nullable = true)
 |-- data: binary (nullable = true)



In [None]:
import pandas as pd
from PIL import Image
from plantclef.serde import deserialize_image, serialize_image
from pyspark.sql import functions as F
from pyspark.sql.types import (
    BinaryType,
    ArrayType,
    StructType,
    StructField,
    IntegerType,
)


def split_into_grid(image: Image.Image, grid_size: int = 4):
    w, h = image.size
    grid_w, grid_h = w // grid_size, h // grid_size
    tiles = []
    for i in range(grid_size):
        for j in range(grid_size):
            left = i * grid_w
            upper = j * grid_h
            right = left + grid_w
            lower = upper + grid_h
            tile = image.crop((left, upper, right, lower))
            tiles.append(tile)
    return tiles


@F.pandas_udf(
    ArrayType(
        StructType(
            [
                StructField("tile_index", IntegerType()),
                StructField("tile", BinaryType()),
            ]
        )
    )
)
def get_image_tiles(data_series: pd.Series, grid_size: int = 4) -> pd.Series:
    all_tiles = []
    for img_byte in data_series:
        img = deserialize_image(img_byte)
        tiles = split_into_grid(img, grid_size=grid_size)
        tile_structs = [
            {"tile_index": i, "tile": serialize_image(tile)}
            for i, tile in enumerate(tiles)
        ]
        all_tiles.append(tile_structs)
    return pd.Series(all_tiles)


df_with_tiles = df.withColumn("tiles", get_image_tiles("data", grid_size=4))
df_tiles = df_with_tiles.select("image_name", F.explode("tiles").alias("tile_struct"))
df_tiles = df_tiles.select(
    "image_name",
    F.col("tile_struct.tile_index").alias("tile_index"),
    F.col("tile_struct.tile").alias("tile"),
)
df_tiles.printSchema()
df_tiles.show(n=16)

root
 |-- image_name: string (nullable = true)
 |-- tile_index: integer (nullable = true)
 |-- tile: binary (nullable = true)



                                                                                

+--------------------+----------+--------------------+
|          image_name|tile_index|                tile|
+--------------------+----------+--------------------+
|CBN-Pla-B3-201907...|         0|[89 50 4E 47 0D 0...|
|CBN-Pla-B3-201907...|         1|[89 50 4E 47 0D 0...|
|CBN-Pla-B3-201907...|         2|[89 50 4E 47 0D 0...|
|CBN-Pla-B3-201907...|         3|[89 50 4E 47 0D 0...|
|CBN-Pla-B3-201907...|         4|[89 50 4E 47 0D 0...|
|CBN-Pla-B3-201907...|         5|[89 50 4E 47 0D 0...|
|CBN-Pla-B3-201907...|         6|[89 50 4E 47 0D 0...|
|CBN-Pla-B3-201907...|         7|[89 50 4E 47 0D 0...|
|CBN-Pla-B3-201907...|         8|[89 50 4E 47 0D 0...|
|CBN-Pla-B3-201907...|         9|[89 50 4E 47 0D 0...|
|CBN-Pla-B3-201907...|        10|[89 50 4E 47 0D 0...|
|CBN-Pla-B3-201907...|        11|[89 50 4E 47 0D 0...|
|CBN-Pla-B3-201907...|        12|[89 50 4E 47 0D 0...|
|CBN-Pla-B3-201907...|        13|[89 50 4E 47 0D 0...|
|CBN-Pla-B3-201907...|        14|[89 50 4E 47 0D 0...|
|CBN-Pla-B

In [None]:
import numpy as np
from pyspark.sql import functions as F
from pyspark.sql.types import ArrayType, FloatType


def avg_embeddings_udf(embeddings):
    array = np.array(embeddings)
    mean_array = np.mean(array, axis=0)
    return mean_array.tolist()


average_embeddings = F.udf(avg_embeddings_udf, ArrayType(FloatType()))

# group and apply the UDF
avg_embeddings_df = (
    faiss10_df.groupBy("species_id")
    .agg(F.collect_list("embedding").alias("embedding_list"))
    .withColumn("avg_embeddings", average_embeddings(F.col("embedding_list")))
)
avg_embeddings_df.printSchema()
avg_embeddings_df.show(n=10, truncate=50)

### classifier-based probabilities

Calculate probabilities based on embedding distances

In [None]:
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
from scipy.special import softmax
from plantclef.config import get_class_mappings_file

# load class mappings
class_mappings_file = get_class_mappings_file()
with open(class_mappings_file) as f:
    sorted_species_ids = [int(line.strip()) for line in f]

# get (species_id, avg_embeddings) from Spark
centroids_pd = avg_embeddings_df.select("species_id", "avg_embeddings").toPandas()

# filter + reorder centroids to match sorted_species_ids
centroids_dict = dict(zip(centroids_pd["species_id"], centroids_pd["avg_embeddings"]))
filtered_embeddings = [
    centroids_dict[species_id]
    if species_id in centroids_dict
    else np.zeros_like(next(iter(centroids_dict.values())))
    for species_id in sorted_species_ids
]

# shape: (num_species, embedding_dim)
train_embeddings = np.stack(filtered_embeddings)
train_embeddings.shape

In [None]:
# get test embeddings and image names
test_pd = test_df.select("image_name", "output.cls_token").toPandas()
test_embeddings = np.stack(test_pd["cls_token"].values)
image_names = test_pd["image_name"].values

# compute cosine similarity and softmax
cos_similarities = cosine_similarity(test_embeddings, train_embeddings)
eucliden_dist = euclidean_distances(test_embeddings, train_embeddings)
euclidean_score = (1 / (eucliden_dist**2)) / np.sum(
    1 / (eucliden_dist**2), axis=1, keepdims=True
)  # normalize to sum to 1
cos_probabilities = softmax(cos_similarities, axis=1)  # shape: (num_test, num_species)

# create final DataFrame with aligned probabilities
final_df = pd.DataFrame(
    {
        "image_name": image_names,
        "cos_probabilities": list(cos_probabilities),
        "euclidean_score": list(euclidean_score),
    }
)
final_df.head()