# Species Prior
PaCMAP repo: https://github.com/YingfanWang/PaCMAP

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from plantclef.spark import get_spark

spark = get_spark(cores=4)
display(spark)

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/04/16 11:16:45 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/04/16 11:16:45 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).


### embeddings

In [3]:
import os
from pathlib import Path

# Get list of stored filed in cloud bucket
root = Path(os.path.expanduser("~"))
! date

Wed Apr 16 11:16:48 AM EDT 2025


### test embeddings

In [4]:
# Path and dataset names
data_path = f"{root}/p-dsgt_clef2025-0/shared/plantclef/data/embeddings"

# Define the path to the train and test parquet files
test_path = f"{data_path}/test_2025/test_2025_embed_logits"

# Read the parquet files into a spark DataFrame
test_df = spark.read.parquet(test_path)

# Show the data
test_df.printSchema()
test_df.show(n=5)

                                                                                

root
 |-- image_name: string (nullable = true)
 |-- output: struct (nullable = true)
 |    |-- cls_token: array (nullable = true)
 |    |    |-- element: float (containsNull = true)
 |    |-- logits: array (nullable = true)
 |    |    |-- element: float (containsNull = true)
 |-- sample_id: integer (nullable = true)



                                                                                

+--------------------+--------------------+---------+
|          image_name|              output|sample_id|
+--------------------+--------------------+---------+
|CBN-Pla-A1-201908...|{[0.47354543, 1.5...|        0|
|CBN-Pla-D6-201908...|{[-0.39621377, 1....|        0|
|CBN-PdlC-C5-20140...|{[-0.5331654, 0.2...|        0|
|LISAH-BOU-0-37-20...|{[1.2480925, 0.47...|        0|
|CBN-Pla-E4-201308...|{[0.7065191, 1.70...|        0|
+--------------------+--------------------+---------+
only showing top 5 rows



In [5]:
# count number of rows: should be 2105 for grid=1x1
print(f"Number of rows: {test_df.count()}")

Number of rows: 2105


In [6]:
regions = [
    "2024-CEV3",
    "CBN-can",
    "CBN-PdlC",
    "CBN-Pla",
    "CBN-Pyr",
    "GUARDEN-AMB",
    "GUARDEN-CBNMed",
    "LISAH-BOU",
    "LISAH-BVD",
    "LISAH-JAS",
    "LISAH-PEC",
    "OPTMix",
    "RNNB",
]

len(regions)

13

In [7]:
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, regexp_extract


def prepare_emb_df(df: DataFrame, embed_col: str = "output.cls_token") -> DataFrame:
    """
    Prepare the DataFrame by renaming the embedding column and selecting relevant columns.
    """
    regex_pattern = "|".join([f"^{region}" for region in regions])
    # add the grouped_regions column using regexp_extract
    test_df_with_regions = df.withColumn(
        "grouped_regions", regexp_extract(col("image_name"), f"({regex_pattern})", 1)
    )

    # check if there are any unmatched regions (empty strings)
    unmatched_count = test_df_with_regions.filter(col("grouped_regions") == "").count()
    print(f"Number of rows with unmatched regions: {unmatched_count}")

    # show the result with the new column
    test_df_with_regions.select("image_name", "grouped_regions", embed_col).show(
        10, truncate=True
    )

    # Count occurrences of each region
    region_counts = (
        test_df_with_regions.groupBy("grouped_regions")
        .count()
        .orderBy(col("count").desc())
    )
    region_counts.show(20)
    return test_df_with_regions


test_df_with_regions = prepare_emb_df(test_df, "output.cls_token")

Number of rows with unmatched regions: 0
+--------------------+---------------+--------------------+
|          image_name|grouped_regions|           cls_token|
+--------------------+---------------+--------------------+
|CBN-Pla-A1-201908...|        CBN-Pla|[0.47354543, 1.55...|
|CBN-Pla-D6-201908...|        CBN-Pla|[-0.39621377, 1.2...|
|CBN-PdlC-C5-20140...|       CBN-PdlC|[-0.5331654, 0.21...|
|LISAH-BOU-0-37-20...|      LISAH-BOU|[1.2480925, 0.478...|
|CBN-Pla-E4-201308...|        CBN-Pla|[0.7065191, 1.709...|
|CBN-PdlC-D6-20150...|       CBN-PdlC|[-0.32394692, 0.4...|
|CBN-PdlC-F2-20170...|       CBN-PdlC|[1.4019761, 1.783...|
|CBN-PdlC-A6-20180...|       CBN-PdlC|[-0.49399343, 1.1...|
|RNNB-3-12-2023051...|           RNNB|[-0.37940657, 0.1...|
|CBN-PdlC-F4-20150...|       CBN-PdlC|[-0.26687536, 1.2...|
+--------------------+---------------+--------------------+
only showing top 10 rows

+---------------+-----+
|grouped_regions|count|
+---------------+-----+
|       CBN-PdlC|  81

In [8]:
test_df_with_regions.count()

2105

In [9]:
import pacmap
from sklearn.cluster import KMeans
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.base import BaseEstimator, TransformerMixin


class PaCMAPTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, n_components=2, random_state=42):
        self.n_components = n_components
        self.random_state = random_state
        self.reducer = None
        self.embedding_ = None

    def fit(self, X, y=None):
        self.reducer = pacmap.PaCMAP(
            n_components=self.n_components, random_state=self.random_state
        )
        self.embedding_ = self.reducer.fit_transform(X)
        return self

    def transform(self, X):
        # just return the already-computed embedding
        return self.embedding_


clustering_pipeline = Pipeline(
    [
        ("scaler", StandardScaler()),
        ("pacmap", PaCMAPTransformer(n_components=2, random_state=42)),
        ("cluster", KMeans(n_clusters=5, random_state=42)),
    ]
)

In [10]:
import pickle

# model directory
model_dir = (
    Path(os.path.expanduser("~")) / "p-dsgt_clef2025-0/shared/plantclef/models/pacmap"
)
pipeline_filename = model_dir / "plant_clustering_pipeline.pkl"


# load the clustering model
with open(pipeline_filename, "rb") as file:
    clustering_pipeline = pickle.load(file)

In [11]:
import numpy as np

# Convert to Pandas DF
col_name = "output.cls_token"
df = test_df_with_regions.select([col_name, "grouped_regions"])
pandas_df = df.select(["cls_token", "grouped_regions"]).toPandas()

# Fit and predict clusters
embeddings = np.stack(pandas_df["cls_token"].values)
labels = clustering_pipeline.fit_predict(embeddings)

# Attach results to the DataFrame
pandas_df["cluster"] = labels
pandas_df["pacmap_1"] = clustering_pipeline.named_steps["pacmap"].embedding_[:, 0]
pandas_df["pacmap_2"] = clustering_pipeline.named_steps["pacmap"].embedding_[:, 1]



In [12]:
labels[:10]

array([2, 2, 1, 4, 2, 3, 1, 3, 0, 1], dtype=int32)

### TODO

**Question:** what species belong to each cluster?

1. get the entire classification logits for each test image
2. average the probabilities (softmax the logits) for each cluster
3. get the most probable species for each cluster --> get the most probable species for each location
4. group them by genus and family and do the same for steps 3 & 4

Can we narrow down genus ~100 per cluster?

**Main goal:** select a subset of candidate species to choose from when making predictions

In [13]:
from pyspark.sql import functions as F
from pyspark.sql.types import IntegerType

# retrieved from clustering notebook
dominant_clusters = {
    "GUARDEN-CBNMed": 0,
    "RNNB": 0,
    "LISAH-BOU": 0,
    "OPTMix": 0,
    "LISAH-BVD": 0,
    "GUARDEN-AMB": 0,
    "LISAH-PEC": 0,
    "LISAH-JAS": 0,
    "CBN-Pyr": 0,
    "2024-CEV3": 0,
    "CBN-PdlC": 1,
    "CBN-can": 1,
    "CBN-Pla": 2,
}


# add dominant cluster to the test_df_with_regions based on grouped_regions
def get_dominant_cluster(region):
    return dominant_clusters.get(region, 0)  # Default to 0 if region not found


get_dominant_cluster_udf = F.udf(get_dominant_cluster, IntegerType())
cluster_df = test_df_with_regions.withColumn(
    "dominant_cluster", get_dominant_cluster_udf("grouped_regions")
)
# Show the updated DataFrame
cluster_df.show(n=10, truncate=50)

[Stage 16:>                                                         (0 + 1) / 1]

+---------------------------+--------------------------------------------------+---------+---------------+----------------+
|                 image_name|                                            output|sample_id|grouped_regions|dominant_cluster|
+---------------------------+--------------------------------------------------+---------+---------------+----------------+
|    CBN-Pla-A1-20190814.jpg|{[0.47354543, 1.5568701, -1.6330245, -1.3648611...|        0|        CBN-Pla|               2|
|    CBN-Pla-D6-20190814.jpg|{[-0.39621377, 1.2026826, 0.27647698, -0.661421...|        0|        CBN-Pla|               2|
|   CBN-PdlC-C5-20140901.jpg|{[-0.5331654, 0.21328913, -1.2809799, 0.1238243...|        0|       CBN-PdlC|               1|
|LISAH-BOU-0-37-20230512.jpg|{[1.2480925, 0.4781976, 0.69301766, 0.4653994, ...|        0|      LISAH-BOU|               0|
|    CBN-Pla-E4-20130808.jpg|{[0.7065191, 1.7097996, -1.2477401, 1.3419615, ...|        0|        CBN-Pla|               2|
|   CBN-

                                                                                

In [14]:
import torch
from pyspark.sql.types import ArrayType, FloatType


# get the probabilities from the output.logits column using softmax
def get_probabilities(logits):
    logits_tensor = torch.tensor(logits)
    probabilities = torch.softmax(logits_tensor, dim=0)
    return probabilities.tolist()


get_probabilities_udf = F.udf(get_probabilities, ArrayType(FloatType()))
cluster_df = cluster_df.withColumn(
    "probabilities", get_probabilities_udf("output.logits")
)
cluster_df.show(n=10, truncate=50)

[Stage 17:>                                                         (0 + 1) / 1]

+---------------------------+--------------------------------------------------+---------+---------------+----------------+--------------------------------------------------+
|                 image_name|                                            output|sample_id|grouped_regions|dominant_cluster|                                     probabilities|
+---------------------------+--------------------------------------------------+---------+---------------+----------------+--------------------------------------------------+
|    CBN-Pla-A1-20190814.jpg|{[0.47354543, 1.5568701, -1.6330245, -1.3648611...|        0|        CBN-Pla|               2|[4.5976332E-5, 1.9088793E-5, 2.4028224E-5, 8.92...|
|    CBN-Pla-D6-20190814.jpg|{[-0.39621377, 1.2026826, 0.27647698, -0.661421...|        0|        CBN-Pla|               2|[4.5638313E-5, 7.7455275E-5, 6.034234E-5, 1.287...|
|   CBN-PdlC-C5-20140901.jpg|{[-0.5331654, 0.21328913, -1.2809799, 0.1238243...|        0|       CBN-PdlC|               1|[4

                                                                                

In [15]:
def avg_probabilities_udf(probabilities_list):
    tensor = torch.tensor(probabilities_list)
    mean_tensor = torch.mean(tensor, dim=0)
    return mean_tensor.tolist()


average_probabilities = F.udf(avg_probabilities_udf, ArrayType(FloatType()))

# group and apply the UDF
avg_probabilities_df = (
    cluster_df.groupBy("dominant_cluster")
    .agg(F.collect_list("probabilities").alias("proba_list"))
    .withColumn("avg_probabilities", average_probabilities(col("proba_list")))
)
avg_probabilities_df.show(n=10, truncate=50)

                                                                                

+----------------+--------------------------------------------------+--------------------------------------------------+
|dominant_cluster|                                        proba_list|                                 avg_probabilities|
+----------------+--------------------------------------------------+--------------------------------------------------+
|               1|[[4.9514707E-5, 3.681495E-5, 2.595474E-5, 1.529...|[2.362997E-5, 4.499179E-5, 4.3323707E-5, 1.6415...|
|               2|[[4.5976332E-5, 1.9088793E-5, 2.4028224E-5, 8.9...|[2.4171377E-5, 2.2901053E-5, 1.4540791E-5, 7.44...|
|               0|[[1.5335014E-5, 1.0890696E-5, 1.909136E-5, 6.78...|[5.1414485E-5, 1.0827277E-4, 1.0225503E-4, 6.50...|
+----------------+--------------------------------------------------+--------------------------------------------------+



In [16]:
# renormarlize the probabilities
def renormalize_probabilities(probabilities):
    probabilities_tensor = torch.tensor(probabilities)
    probabilities_tensor /= torch.sum(probabilities_tensor)
    return probabilities_tensor.tolist()


renormalize_probabilities_udf = F.udf(renormalize_probabilities, ArrayType(FloatType()))
avg_probabilities_df = avg_probabilities_df.withColumn(
    "renormalized_probabilities", renormalize_probabilities_udf("avg_probabilities")
)
avg_probabilities_df.select(
    "dominant_cluster",
    "renormalized_probabilities",
).show(truncate=50)

[Stage 27:>                                                         (0 + 2) / 2]

+----------------+--------------------------------------------------+
|dominant_cluster|                        renormalized_probabilities|
+----------------+--------------------------------------------------+
|               1|[2.3629971E-5, 4.4991793E-5, 4.332371E-5, 1.641...|
|               2|[2.4171377E-5, 2.2901053E-5, 1.4540791E-5, 7.44...|
|               0|[5.1414485E-5, 1.0827277E-4, 1.0225503E-4, 6.50...|
+----------------+--------------------------------------------------+



                                                                                

In [21]:
# write the DataFrame to parquet
output_path = f"{data_path}/test_2025/test_2025_embed_probabilities_clustered"
probabilities_df = avg_probabilities_df.select(
    "dominant_cluster",
    "renormalized_probabilities",
)
probabilities_df.write.mode("overwrite").parquet(output_path)
print(f"Probabilities saved to {output_path}")



Probabilities saved to /storage/home/hcoda1/9/mgustineli3/p-dsgt_clef2025-0/shared/plantclef/data/embeddings/test_2025/test_2025_embed_probabilities_clustered


                                                                                

In [17]:
from pyspark.sql.types import StringType, MapType
from plantclef.config import get_class_mappings_file

class_mappings_file = get_class_mappings_file()
with open(class_mappings_file) as f:
    class_index_to_name = {i: line.strip() for i, line in enumerate(f)}

num_classes = len(class_index_to_name)
print(f"Number of classes: {num_classes}")


def map_species_probabilities(probabilities, k=5):
    probabilities_tensor = torch.tensor(probabilities)
    top_probs, top_indices = torch.topk(probabilities_tensor, k=k)
    top_probs = top_probs.cpu().numpy()
    top_indices = top_indices.cpu().numpy()
    result = {
        class_index_to_name.get(index, "Unknown"): float(prob)
        for index, prob in zip(top_indices, top_probs)
    }
    return result


species_probabilities_udf = F.udf(
    lambda probs: map_species_probabilities(probs, k=num_classes),
    MapType(StringType(), FloatType()),
)
avg_probabilities_df = avg_probabilities_df.withColumn(
    "species_probabilities", species_probabilities_udf("renormalized_probabilities")
)
avg_probabilities_df.select("dominant_cluster", "species_probabilities").show(
    truncate=50
)

Number of classes: 7806


[Stage 32:>                                                         (0 + 2) / 2]

+----------------+--------------------------------------------------+
|dominant_cluster|                             species_probabilities|
+----------------+--------------------------------------------------+
|               1|{1482309 -> 2.7289409E-5, 1356646 -> 3.14462E-5...|
|               2|{1482309 -> 2.0991134E-5, 1356646 -> 3.1725456E...|
|               0|{1482309 -> 3.854579E-5, 1356646 -> 7.4256466E-...|
+----------------+--------------------------------------------------+



                                                                                

In [18]:
# conver to pandas
pandas_df = avg_probabilities_df.select(
    "dominant_cluster",
    "species_probabilities",
).toPandas()
pandas_df.head(10)

                                                                                

Unnamed: 0,dominant_cluster,species_probabilities
0,1,"{'1482309': 2.728940853558015e-05, '1356646': ..."
1,2,"{'1482309': 2.0991134078940377e-05, '1356646':..."
2,0,"{'1482309': 3.8545789720956236e-05, '1356646':..."


### write to csv for Bayesian prior inference

In [19]:
probabilities_df = pandas_df[
    [
        "dominant_cluster",
        "species_probabilities",
    ]
]


# write pandas to PACE as CSV
def write_to_pace(df, path):
    df.to_csv(path, index=False)
    print(f"DataFrame written to {path}")


# write the DataFrame to csv
file_name = "test_2025_embed_probabilities_clustered.csv"
output_path = f"{root}/p-dsgt_clef2025-0/shared/plantclef/data/clustering/{file_name}"
write_to_pace(probabilities_df, output_path)

DataFrame written to /storage/home/hcoda1/9/mgustineli3/p-dsgt_clef2025-0/shared/plantclef/data/clustering/test_2025_embed_probabilities_clustered.csv


In [21]:
import pandas as pd

# define thresholds to analyze
thresholds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95]


# function to count species needed to reach each threshold
def cumsum_species_counts(prob_dict, thresholds):
    probs = np.array(sorted(prob_dict.values(), reverse=True))
    cumsum = np.cumsum(probs)
    return {
        f"cutoff_{int(t * 100)}": np.searchsorted(cumsum, t) + 1  # +1 for species count
        for t in thresholds
    }


# Apply the function across rows
cumsum_df = pandas_df["species_probabilities"].apply(
    lambda row: pd.Series(cumsum_species_counts(row, thresholds))
)

# Combine with original DataFrame if needed
result_df = pd.concat([pandas_df, cumsum_df], axis=1)

# Example: show first few rows
print(result_df.head())

# Or compute the mean number of species needed at each threshold across the dataset
mean_cutoffs = cumsum_df.mean().astype(int)
print("Mean species count per cumulative threshold:")
print(mean_cutoffs)

   dominant_cluster                              species_probabilities  \
0                 1  {'1482309': 2.728940853558015e-05, '1356646': ...   
1                 2  {'1482309': 2.0991134078940377e-05, '1356646':...   
2                 0  {'1482309': 3.8545789720956236e-05, '1356646':...   

   cutoff_10  cutoff_20  cutoff_30  cutoff_40  cutoff_50  cutoff_60  \
0          4         14         39         97        219        453   
1          2         14         43        112        245        486   
2         11         40         96        197        386        729   

   cutoff_70  cutoff_80  cutoff_90  cutoff_95  
0        917       1806       3447       4825  
1        921       1707       3164       4458  
2       1362       2484       4349       5713  
Mean species count per cumulative threshold:
cutoff_10       5
cutoff_20      22
cutoff_30      59
cutoff_40     135
cutoff_50     283
cutoff_60     556
cutoff_70    1066
cutoff_80    1999
cutoff_90    3653
cutoff_95    4998
d

In [22]:
row = pandas_df["species_probabilities"].iloc[0]
max_val = max(row.values())
max_val

0.044477567076683044

In [23]:
species_metadata = "~/p-dsgt_clef2025-0/shared/plantclef/data/species_metadata.csv"
species_df = pd.read_csv(species_metadata)


# print top 10 species and probabilities for each dominant cluster
def print_top_species_by_cluster(df, species_df, top_n=5):
    for cluster in df["dominant_cluster"].unique():
        print(f"CLUSTER: {cluster}")
        top_species = df[df["dominant_cluster"] == cluster][
            "species_probabilities"
        ].iloc[0]
        sorted_species = sorted(top_species.items(), key=lambda x: x[1], reverse=True)[
            :top_n
        ]
        # filter species_df based on species_id
        species_ids = [int(s[0]) for s in sorted_species]
        species_df_merged = species_df[species_df["species_id"].isin(species_ids)]
        for species_id, prob in sorted_species:
            species_name = species_df_merged[
                species_df_merged["species_id"] == int(species_id)
            ]["species"].values[0]
            print(f"  {species_name}: {prob * 100:.1f}%")
        print()


print_top_species_by_cluster(result_df, species_df, top_n=10)

CLUSTER: 1
  Salix herbacea L.: 4.4%
  Geum montanum L.: 2.4%
  Festuca ovina L.: 2.3%
  Carex curvula All.: 2.0%
  Omalotheca supina (L.) DC.: 1.8%
  Festuca nigrescens Lam.: 1.7%
  Carex atrata L.: 0.9%
  Salix serpillifolia Scop.: 0.9%
  Veronica repens Clarion ex DC.: 0.7%
  Scorzoneroides helvetica (Mérat) Holub: 0.7%

CLUSTER: 2
  Festuca ovina L.: 7.7%
  Salix herbacea L.: 2.5%
  Festuca quadriflora Honck.: 1.8%
  Salix serpillifolia Scop.: 1.6%
  Tephroseris integrifolia (L.) Holub: 1.0%
  Asplenium cuneifolium Viv.: 0.8%
  Festuca nigrescens Lam.: 0.8%
  Veronica repens Clarion ex DC.: 0.7%
  Oreochloa elegans (Sennen) A.W.Hill: 0.7%
  Botrychium simplex E.Hitchc.: 0.6%

CLUSTER: 0
  Salicornia fruticosa (L.) L.: 2.0%
  Thinopyrum junceum (L.) Á.Löve: 1.5%
  Taeniatherum caput-medusae (L.) Nevski: 1.1%
  Calamagrostis arenaria (L.) Roth: 0.9%
  Medicago marina L.: 0.9%
  Lotus creticus L.: 0.7%
  Galium spurium L.: 0.7%
  Bromus madritensis L.: 0.7%
  Agrostis gigantea Roth: 0

### use prior with the test data

In [25]:
test_cluster_csv = f"{root}/p-dsgt_clef2025-0/shared/plantclef/data/clustering/test_2025_dominant_clusters.csv"
test_cluster_logits = f"{root}/p-dsgt_clef2025-0/shared/plantclef/data/clustering/test_2025_embed_probabilities_clustered.csv"

test_cluster_df = pd.read_csv(test_cluster_csv)
display(test_cluster_df.head())

probabilities_df = pd.read_csv(test_cluster_logits)
display(probabilities_df.head())

Unnamed: 0,image_name,kmeans_cluster
0,CBN-Pla-A1-20190814.jpg,2
1,CBN-Pla-D6-20190814.jpg,2
2,CBN-PdlC-C5-20140901.jpg,1
3,LISAH-BOU-0-37-20230512.jpg,0
4,CBN-Pla-E4-20130808.jpg,2


Unnamed: 0,dominant_cluster,species_probabilities
0,1,"{'1482309': 2.728940853558015e-05, '1356646': ..."
1,2,"{'1482309': 2.0991134078940377e-05, '1356646':..."
2,0,"{'1482309': 3.8545789720956236e-05, '1356646':..."


In [26]:
import ast

probabilities_df = probabilities_df["species_probabilities"].apply(ast.literal_eval)
probabilities_df.head()

0    {'1482309': 2.728940853558015e-05, '1356646': ...
1    {'1482309': 2.0991134078940377e-05, '1356646':...
2    {'1482309': 3.8545789720956236e-05, '1356646':...
Name: species_probabilities, dtype: object