# Species Prior
PaCMAP repo: https://github.com/YingfanWang/PaCMAP

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from plantclef.spark import get_spark

spark = get_spark()
display(spark)

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/04/14 21:48:29 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/04/14 21:48:29 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).
25/04/14 21:48:30 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


### embeddings

In [3]:
import os
from pathlib import Path

# Get list of stored filed in cloud bucket
root = Path(os.path.expanduser("~"))
! date

Mon Apr 14 09:48:35 PM EDT 2025


### test embeddings

In [4]:
# Path and dataset names
data_path = f"{root}/p-dsgt_clef2025-0/shared/plantclef/data/embeddings"

# Define the path to the train and test parquet files
test_path = f"{data_path}/test_2025/test_2025_embed_logits"

# Read the parquet files into a spark DataFrame
test_df = spark.read.parquet(test_path)

# Show the data
test_df.printSchema()
test_df.show(n=5)

                                                                                

root
 |-- image_name: string (nullable = true)
 |-- output: struct (nullable = true)
 |    |-- cls_token: array (nullable = true)
 |    |    |-- element: float (containsNull = true)
 |    |-- logits: array (nullable = true)
 |    |    |-- element: float (containsNull = true)
 |-- sample_id: integer (nullable = true)



                                                                                

+--------------------+--------------------+---------+
|          image_name|              output|sample_id|
+--------------------+--------------------+---------+
|CBN-Pla-A1-201908...|{[0.47354543, 1.5...|        0|
|CBN-Pla-D6-201908...|{[-0.39621377, 1....|        0|
|CBN-PdlC-C5-20140...|{[-0.5331654, 0.2...|        0|
|LISAH-BOU-0-37-20...|{[1.2480925, 0.47...|        0|
|CBN-Pla-E4-201308...|{[0.7065191, 1.70...|        0|
+--------------------+--------------------+---------+
only showing top 5 rows



In [5]:
# count number of rows: should be 2105 for grid=1x1
print(f"Number of rows: {test_df.count()}")

Number of rows: 2105


In [6]:
regions = [
    "2024-CEV3",
    "CBN-can",
    "CBN-PdlC",
    "CBN-Pla",
    "CBN-Pyr",
    "GUARDEN-AMB",
    "GUARDEN-CBNMed",
    "LISAH-BOU",
    "LISAH-BVD",
    "LISAH-JAS",
    "LISAH-PEC",
    "OPTMix",
    "RNNB",
]

len(regions)

13

In [7]:
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, regexp_extract


def prepare_emb_df(df: DataFrame, embed_col: str = "output.cls_token") -> DataFrame:
    """
    Prepare the DataFrame by renaming the embedding column and selecting relevant columns.
    """
    regex_pattern = "|".join([f"^{region}" for region in regions])
    # add the grouped_regions column using regexp_extract
    test_df_with_regions = df.withColumn(
        "grouped_regions", regexp_extract(col("image_name"), f"({regex_pattern})", 1)
    )

    # check if there are any unmatched regions (empty strings)
    unmatched_count = test_df_with_regions.filter(col("grouped_regions") == "").count()
    print(f"Number of rows with unmatched regions: {unmatched_count}")

    # show the result with the new column
    test_df_with_regions.select("image_name", "grouped_regions", embed_col).show(
        10, truncate=True
    )

    # Count occurrences of each region
    region_counts = (
        test_df_with_regions.groupBy("grouped_regions")
        .count()
        .orderBy(col("count").desc())
    )
    region_counts.show(20)
    return test_df_with_regions


test_df_with_regions = prepare_emb_df(test_df, "output.cls_token")

Number of rows with unmatched regions: 0
+--------------------+---------------+--------------------+
|          image_name|grouped_regions|           cls_token|
+--------------------+---------------+--------------------+
|CBN-Pla-A1-201908...|        CBN-Pla|[0.47354543, 1.55...|
|CBN-Pla-D6-201908...|        CBN-Pla|[-0.39621377, 1.2...|
|CBN-PdlC-C5-20140...|       CBN-PdlC|[-0.5331654, 0.21...|
|LISAH-BOU-0-37-20...|      LISAH-BOU|[1.2480925, 0.478...|
|CBN-Pla-E4-201308...|        CBN-Pla|[0.7065191, 1.709...|
|CBN-PdlC-D6-20150...|       CBN-PdlC|[-0.32394692, 0.4...|
|CBN-PdlC-F2-20170...|       CBN-PdlC|[1.4019761, 1.783...|
|CBN-PdlC-A6-20180...|       CBN-PdlC|[-0.49399343, 1.1...|
|RNNB-3-12-2023051...|           RNNB|[-0.37940657, 0.1...|
|CBN-PdlC-F4-20150...|       CBN-PdlC|[-0.26687536, 1.2...|
+--------------------+---------------+--------------------+
only showing top 10 rows



[Stage 9:==>                                                      (1 + 19) / 20]

+---------------+-----+
|grouped_regions|count|
+---------------+-----+
|       CBN-PdlC|  816|
|        CBN-Pla|  628|
| GUARDEN-CBNMed|  165|
|           RNNB|  141|
|      LISAH-BOU|   82|
|         OPTMix|   78|
|      LISAH-BVD|   76|
|    GUARDEN-AMB|   36|
|      LISAH-PEC|   35|
|        CBN-can|   30|
|      LISAH-JAS|   15|
|        CBN-Pyr|    2|
|      2024-CEV3|    1|
+---------------+-----+



                                                                                

In [8]:
test_df_with_regions.count()

2105

In [12]:
import pacmap
from sklearn.cluster import KMeans
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.base import BaseEstimator, TransformerMixin


class PaCMAPTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, n_components=2, random_state=42):
        self.n_components = n_components
        self.random_state = random_state
        self.reducer = None
        self.embedding_ = None

    def fit(self, X, y=None):
        self.reducer = pacmap.PaCMAP(
            n_components=self.n_components, random_state=self.random_state
        )
        self.embedding_ = self.reducer.fit_transform(X)
        return self

    def transform(self, X):
        # just return the already-computed embedding
        return self.embedding_


clustering_pipeline = Pipeline(
    [
        ("scaler", StandardScaler()),
        ("pacmap", PaCMAPTransformer(n_components=2, random_state=42)),
        ("cluster", KMeans(n_clusters=5, random_state=42)),
    ]
)

In [13]:
import pickle

# model directory
model_dir = (
    Path(os.path.expanduser("~")) / "p-dsgt_clef2025-0/shared/plantclef/models/pacmap"
)
pipeline_filename = model_dir / "plant_clustering_pipeline.pkl"


# load the clustering model
with open(pipeline_filename, "rb") as file:
    clustering_pipeline = pickle.load(file)

In [14]:
import numpy as np

# Convert to Pandas DF
col_name = "output.cls_token"
df = test_df_with_regions.select([col_name, "grouped_regions"])
pandas_df = df.select(["cls_token", "grouped_regions"]).toPandas()

# Fit and predict clusters
embeddings = np.stack(pandas_df["cls_token"].values)
labels = clustering_pipeline.fit_predict(embeddings)

# Attach results to the DataFrame
pandas_df["cluster"] = labels
pandas_df["pacmap_1"] = clustering_pipeline.named_steps["pacmap"].embedding_[:, 0]
pandas_df["pacmap_2"] = clustering_pipeline.named_steps["pacmap"].embedding_[:, 1]



In [15]:
labels[:10]

array([2, 2, 1, 4, 2, 3, 1, 3, 0, 1], dtype=int32)

### TODO

**Question:** what species belong to each cluster?

1. get the entire classification logits for each test image
2. average the probabilities (softmax of the logits) for each cluster
3. get the most probable species for each cluster --> get the most probable species for each location
4. group them by genus and family and do the same for steps 3 & 4

Can we narrow down genus ~100 per cluster?

**Main goal:** select a subset of candidate species to choose from when making predictions

In [25]:
from pyspark.sql import functions as F
from pyspark.sql.types import IntegerType

dominant_clusters = {
    "GUARDEN-CBNMed": 0,
    "RNNB": 0,
    "LISAH-BOU": 0,
    "OPTMix": 0,
    "LISAH-BVD": 0,
    "GUARDEN-AMB": 0,
    "LISAH-PEC": 0,
    "LISAH-JAS": 0,
    "CBN-Pyr": 0,
    "2024-CEV3": 0,
    "CBN-PdlC": 1,
    "CBN-can": 1,
    "CBN-Pla": 2,
}


# add dominant cluster to the test_df_with_regions based on grouped_regions
def get_dominant_cluster(region):
    return dominant_clusters.get(region, 0)  # Default to 0 if region not found


get_dominant_cluster_udf = F.udf(get_dominant_cluster, IntegerType())
cluster_df = test_df_with_regions.withColumn(
    "dominant_cluster", get_dominant_cluster_udf("grouped_regions")
)
# Show the updated DataFrame
cluster_df.show()

                                                                                

+--------------------+--------------------+---------+---------------+----------------+
|          image_name|              output|sample_id|grouped_regions|dominant_cluster|
+--------------------+--------------------+---------+---------------+----------------+
|CBN-Pla-A1-201908...|{[0.47354543, 1.5...|        0|        CBN-Pla|               2|
|CBN-Pla-D6-201908...|{[-0.39621377, 1....|        0|        CBN-Pla|               2|
|CBN-PdlC-C5-20140...|{[-0.5331654, 0.2...|        0|       CBN-PdlC|               1|
|LISAH-BOU-0-37-20...|{[1.2480925, 0.47...|        0|      LISAH-BOU|               0|
|CBN-Pla-E4-201308...|{[0.7065191, 1.70...|        0|        CBN-Pla|               2|
|CBN-PdlC-D6-20150...|{[-0.32394692, 0....|        0|       CBN-PdlC|               1|
|CBN-PdlC-F2-20170...|{[1.4019761, 1.78...|        0|       CBN-PdlC|               1|
|CBN-PdlC-A6-20180...|{[-0.49399343, 1....|        0|       CBN-PdlC|               1|
|RNNB-3-12-2023051...|{[-0.37940657, 0....|

In [26]:
import torch
from pyspark.sql.types import ArrayType, FloatType


# get the probabilities from the output.logits column using softmax
def get_probabilities(logits):
    logits_tensor = torch.tensor(logits)
    probabilities = torch.softmax(logits_tensor, dim=0)
    return probabilities.tolist()


get_probabilities_udf = F.udf(get_probabilities, ArrayType(FloatType()))
cluster_df = cluster_df.withColumn(
    "probabilities", get_probabilities_udf("output.logits")
)
cluster_df.show(truncate=False)

                                                                                

+---------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [29]:
def avg_probabilities(probabilities):
    probabilities_tensor = torch.tensor(probabilities)
    avg_probabilities = torch.mean(probabilities_tensor, dim=0)
    return avg_probabilities.tolist()


# apply the UDF to calculate average probabilities for each row
cluster_df = cluster_df.withColumn(
    "avg_probabilities",
    F.udf(avg_probabilities, ArrayType(FloatType()))("probabilities"),
)

# group by grouped_regions and average probabilities
avg_probabilities_df = cluster_df.groupBy("grouped_regions").agg(
    F.avg("avg_probabilities").alias("avg_probabilities")
)
avg_probabilities_df.show(truncate=False)

AnalysisException: [DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "avg(avg_probabilities)" due to data type mismatch: Parameter 1 requires the "NUMERIC" or "ANSI INTERVAL" type, however "avg_probabilities" has the type "ARRAY<FLOAT>".;
'Aggregate [grouped_regions#28], [grouped_regions#28, avg(avg_probabilities#246) AS avg_probabilities#262]
+- Project [image_name#0, output#1, sample_id#2, grouped_regions#28, dominant_cluster#131, probabilities#160, avg_probabilities(probabilities#160)#245 AS avg_probabilities#246]
   +- Project [image_name#0, output#1, sample_id#2, grouped_regions#28, dominant_cluster#131, probabilities#160, get_avg_probabilities(probabilities#160)#195 AS avg_probabilities#196]
      +- Project [image_name#0, output#1, sample_id#2, grouped_regions#28, dominant_cluster#131, get_probabilities(output#1.logits)#159 AS probabilities#160]
         +- Project [image_name#0, output#1, sample_id#2, grouped_regions#28, get_dominant_cluster(grouped_regions#28)#130 AS dominant_cluster#131]
            +- Project [image_name#0, output#1, sample_id#2, regexp_extract(image_name#0, (^2024-CEV3|^CBN-can|^CBN-PdlC|^CBN-Pla|^CBN-Pyr|^GUARDEN-AMB|^GUARDEN-CBNMed|^LISAH-BOU|^LISAH-BVD|^LISAH-JAS|^LISAH-PEC|^OPTMix|^RNNB), 1) AS grouped_regions#28]
               +- Relation [image_name#0,output#1,sample_id#2] parquet
