# Centroid predictions

Classify test data using centroid probabilities.
We're using probabilities for the entire test image.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from plantclef.spark import get_spark

spark = get_spark(cores=4)
display(spark)

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/04/19 20:01:21 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/04/19 20:01:21 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).
25/04/19 20:01:22 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [3]:
import os
from pathlib import Path

# Get list of stored filed in cloud bucket
root = Path(os.path.expanduser("~"))
! date

Sat Apr 19 08:01:23 PM EDT 2025


### Faiss centroid probabilities 

In [4]:
import numpy as np
import pandas as pd
from plantclef.config import get_class_mappings_file

# path and dataset names
data_path = f"{root}/p-dsgt_clef2025-0/shared/plantclef/data"

# read the parquet files into a spark DataFrame
centroid_path = f"{data_path}/clustering/test_2025_centroid_probabilities"
metadata_path = f"{data_path}/species_metadata.csv"

# read data
cent_df = pd.read_parquet(centroid_path)
species_meta_df = pd.read_csv(metadata_path)
display(cent_df.head(5))
display(species_meta_df.head(5))

# load species_ids in the correct order
class_mappings_file = get_class_mappings_file()
with open(class_mappings_file) as f:
    sorted_species_ids = np.array([int(line.strip()) for line in f])

Unnamed: 0,image_name,probabilities
0,CBN-Pla-A1-20190814.jpg,"[0.00013234721, 0.00012035529, 0.00012586203, ..."
1,CBN-Pla-D6-20190814.jpg,"[0.00013238179, 0.00012068612, 0.00012430409, ..."
2,CBN-PdlC-C5-20140901.jpg,"[0.00013084458, 0.00012031096, 0.000125613, 0...."
3,LISAH-BOU-0-37-20230512.jpg,"[0.00011587113, 0.00012632042, 0.00013190121, ..."
4,CBN-Pla-E4-20130808.jpg,"[0.00013049232, 0.00011688124, 0.00012663532, ..."


Unnamed: 0,species_id,species,genus,family
0,1355868,Lactuca virosa L.,Lactuca,Asteraceae
1,1355869,Crepis capillaris (L.) Wallr.,Crepis,Asteraceae
2,1355870,Crepis foetida L.,Crepis,Asteraceae
3,1355871,Hypochaeris glabra L.,Hypochaeris,Asteraceae
4,1355872,Hypochaeris radicata L.,Hypochaeris,Asteraceae


In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def top_k_species(probabilities, top_k: int = 9):
    probs_tensor = torch.tensor(probabilities).to(device)
    top_probs, top_indices = torch.topk(probs_tensor, k=top_k)
    top_probs = top_probs.cpu().numpy()
    top_indices = top_indices.cpu().numpy()
    return [int(sorted_species_ids[i]) for i in top_indices]


# apply top-k per row
top_k_proba = 9
cent_df["species_ids"] = cent_df["probabilities"].apply(
    lambda probs: top_k_species(probs, top_k=top_k_proba)
)
cent_df.head(5)

Unnamed: 0,image_name,probabilities,species_ids
0,CBN-Pla-A1-20190814.jpg,"[0.00013234721, 0.00012035529, 0.00012586203, ...","[1610792, 1360655, 1356268, 1359837, 1391100, ..."
1,CBN-Pla-D6-20190814.jpg,"[0.00013238179, 0.00012068612, 0.00012430409, ...","[1393563, 1362192, 1743219, 1744525, 1452625, ..."
2,CBN-PdlC-C5-20140901.jpg,"[0.00013084458, 0.00012031096, 0.000125613, 0....","[1393563, 1391841, 1359837, 1743383, 1610792, ..."
3,LISAH-BOU-0-37-20230512.jpg,"[0.00011587113, 0.00012632042, 0.00013190121, ...","[1356064, 1360223, 1398244, 1357072, 1357432, ..."
4,CBN-Pla-E4-20130808.jpg,"[0.00013049232, 0.00011688124, 0.00012663532, ...","[1610792, 1360655, 1397083, 1356268, 1359690, ..."


In [6]:
cent_df["species_ids"].iloc[0]

[1610792,
 1360655,
 1356268,
 1359837,
 1391100,
 1397083,
 1732737,
 1396785,
 1357003]

In [7]:
preds_df = cent_df[["image_name", "species_ids"]]
# rename image_name to quadrat_id
preds_df = preds_df.rename(columns={"image_name": "quadrat_id"})
preds_df.head()

Unnamed: 0,quadrat_id,species_ids
0,CBN-Pla-A1-20190814.jpg,"[1610792, 1360655, 1356268, 1359837, 1391100, ..."
1,CBN-Pla-D6-20190814.jpg,"[1393563, 1362192, 1743219, 1744525, 1452625, ..."
2,CBN-PdlC-C5-20140901.jpg,"[1393563, 1391841, 1359837, 1743383, 1610792, ..."
3,LISAH-BOU-0-37-20230512.jpg,"[1356064, 1360223, 1398244, 1357072, 1357432, ..."
4,CBN-Pla-E4-20130808.jpg,"[1610792, 1360655, 1397083, 1356268, 1359690, ..."


In [8]:
def format_species_ids(species_ids: list) -> str:
    """Formats the species IDs in single square brackets, separated by commas."""
    formatted_ids = ", ".join(str(id) for id in species_ids)
    return f"[{formatted_ids}]"


def prepare_and_write_submission(
    pandas_df: pd.DataFrame, col: str = "image_name"
) -> pd.DataFrame:
    """Formats the Pandas DataFrame, and writes to PACE."""
    records = []
    for _, row in pandas_df.iterrows():
        logits = row["species_ids"]
        formatted_species = format_species_ids(logits)
        records.append({"quadrat_id": row[col], "species_ids": formatted_species})

    pandas_df = pd.DataFrame(records)
    # remove .jpg from quadrat_id in final_df
    pandas_df["quadrat_id"] = pandas_df["quadrat_id"].str.replace(
        ".jpg", "", regex=False
    )

    return pandas_df


final_df = prepare_and_write_submission(preds_df, col="quadrat_id")
final_df.head(10)

Unnamed: 0,quadrat_id,species_ids
0,CBN-Pla-A1-20190814,"[1610792, 1360655, 1356268, 1359837, 1391100, ..."
1,CBN-Pla-D6-20190814,"[1393563, 1362192, 1743219, 1744525, 1452625, ..."
2,CBN-PdlC-C5-20140901,"[1393563, 1391841, 1359837, 1743383, 1610792, ..."
3,LISAH-BOU-0-37-20230512,"[1356064, 1360223, 1398244, 1357072, 1357432, ..."
4,CBN-Pla-E4-20130808,"[1610792, 1360655, 1397083, 1356268, 1359690, ..."
5,CBN-PdlC-D6-20150701,"[1722589, 1357093, 1576091, 1397370, 1391044, ..."
6,CBN-PdlC-F2-20170906,"[1576091, 1397698, 1397083, 1395669, 1453535, ..."
7,CBN-PdlC-A6-20180905,"[1360187, 1396343, 1563707, 1389817, 1392601, ..."
8,RNNB-3-12-20230512,"[1358255, 1358426, 1395021, 1399781, 1361769, ..."
9,CBN-PdlC-F4-20150810,"[1389603, 1395194, 1628936, 1362498, 1394319, ..."


In [9]:
len(final_df), final_df.shape, final_df.columns

(2105, (2105, 2), Index(['quadrat_id', 'species_ids'], dtype='object'))

In [10]:
import csv


def get_plantclef_dir() -> str:
    home_dir = Path(os.path.expanduser("~"))
    return f"{home_dir}/p-dsgt_clef2025-0/shared/plantclef"


def write_csv_to_pace(df, file_name: str, col: str = "quadrat_id") -> None:
    """Writes the Pandas DataFrame to a CSV file on PACE."""

    # prepare and write the submission
    submission_df = prepare_and_write_submission(df, col)
    project_dir = get_plantclef_dir()
    submission_path = f"{project_dir}/submissions/centroids"
    output_path = f"{submission_path}/{file_name}"
    # ensure directory exists before saving
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    # write to CSV
    submission_df.to_csv(output_path, sep=",", index=False, quoting=csv.QUOTE_ALL)
    print(f"Submission file saved to: {output_path}")


file_name = f"dsgt_centroid_topk{top_k_proba}.csv"
write_csv_to_pace(preds_df, file_name)

Submission file saved to: /storage/home/hcoda1/9/mgustineli3/p-dsgt_clef2025-0/shared/plantclef/submissions/centroids/dsgt_centroid_topk9.csv


In [11]:
submission_path = (
    f"~/p-dsgt_clef2025-0/shared/plantclef/submissions/centroids/{file_name}"
)
df = pd.read_csv(submission_path)
df.head(10)

Unnamed: 0,quadrat_id,species_ids
0,CBN-Pla-A1-20190814,"[1610792, 1360655, 1356268, 1359837, 1391100, ..."
1,CBN-Pla-D6-20190814,"[1393563, 1362192, 1743219, 1744525, 1452625, ..."
2,CBN-PdlC-C5-20140901,"[1393563, 1391841, 1359837, 1743383, 1610792, ..."
3,LISAH-BOU-0-37-20230512,"[1356064, 1360223, 1398244, 1357072, 1357432, ..."
4,CBN-Pla-E4-20130808,"[1610792, 1360655, 1397083, 1356268, 1359690, ..."
5,CBN-PdlC-D6-20150701,"[1722589, 1357093, 1576091, 1397370, 1391044, ..."
6,CBN-PdlC-F2-20170906,"[1576091, 1397698, 1397083, 1395669, 1453535, ..."
7,CBN-PdlC-A6-20180905,"[1360187, 1396343, 1563707, 1389817, 1392601, ..."
8,RNNB-3-12-20230512,"[1358255, 1358426, 1395021, 1399781, 1361769, ..."
9,CBN-PdlC-F4-20150810,"[1389603, 1395194, 1628936, 1362498, 1394319, ..."
