# Non-Maximum Supression Classification


In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from plantclef.spark import get_spark

spark = get_spark(cores=4)
display(spark)

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/05/04 18:53:28 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/05/04 18:53:28 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).


In [3]:
import os
from pathlib import Path

# Get list of stored filed in cloud bucket
root = Path(os.path.expanduser("~"))
! date

Sun May  4 06:53:31 PM EDT 2025


### NMS detections

In [4]:
# Path and dataset names
data_path = f"{root}/p-dsgt_clef2025-0/shared/plantclef/data"
detect_path = f"{data_path}/detection/test_2025/test_2025_detection_v1"
detection_df = spark.read.parquet(detect_path)
detection_df.printSchema()

root
 |-- image_name: string (nullable = true)
 |-- output: struct (nullable = true)
 |    |-- extracted_bbox: array (nullable = true)
 |    |    |-- element: binary (containsNull = true)
 |    |-- boxes: array (nullable = true)
 |    |    |-- element: array (containsNull = true)
 |    |    |    |-- element: integer (containsNull = true)
 |    |-- scores: array (nullable = true)
 |    |    |-- element: float (containsNull = true)
 |    |-- text_labels: array (nullable = true)
 |    |    |-- element: string (containsNull = true)
 |-- sample_id: integer (nullable = true)



In [5]:
detection_df.count()

                                                                                

2105

In [6]:
from pyspark.sql import functions as F

# explode the extracted_bbox list so each row has one bounding box
exploded_df = (
    detection_df.select(
        "image_name", F.explode("output.extracted_bbox").alias("extracted_bbox")
    )
    .repartition(100, "image_name")
    .persist()
)
display(exploded_df.count())
exploded_df.printSchema()

                                                                                

17053

root
 |-- image_name: string (nullable = true)
 |-- extracted_bbox: binary (nullable = true)



In [7]:
import timm
import torch
from plantclef.serde import deserialize_image
from plantclef.config import get_class_mappings_file
from plantclef.model_setup import setup_fine_tuned_model


def load_class_mapping(class_mapping_file=None):
    with open(class_mapping_file) as f:
        class_index_to_class_name = {i: line.strip() for i, line in enumerate(f)}
    return class_index_to_class_name


num_classes = 7806  # total number of plant species
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = timm.create_model(
    "vit_base_patch14_reg4_dinov2.lvd142m",
    pretrained=False,
    num_classes=num_classes,
    checkpoint_path=setup_fine_tuned_model(),
)
# data transform
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)
# move model to GPU if available
model.to(device)
model.eval()
# path for class_mappings.txt file
class_mapping_file = get_class_mappings_file()


# load class mappings
cid_to_spid = load_class_mapping(class_mapping_file)


def make_predict_fn():
    """Return UDF using a closure over the model"""

    def predict(input_data):
        img = deserialize_image(input_data)  # from bytes to PIL image
        processed_image = transforms(img).unsqueeze(0).to(device)
        with torch.no_grad():
            logits = model(processed_image)
            probabilities = torch.softmax(logits, dim=1)

        return probabilities[0].cpu().numpy().tolist()

    return predict

In [8]:
# create UDF
predict_fn = make_predict_fn()

# get subset of data for testing
subset_pd = exploded_df.limit(20).toPandas()
subset_pd["probabilities"] = subset_pd["extracted_bbox"].apply(predict_fn)

                                                                                

In [9]:
display(subset_pd.head(5))

Unnamed: 0,image_name,extracted_bbox,probabilities
0,CBN-PdlC-D3-20200722.jpg,b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\...,"[2.6421303118695505e-06, 3.864691393573594e-07..."
1,CBN-PdlC-D3-20200722.jpg,b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\...,"[4.0289813796334784e-07, 3.7558137933046964e-0..."
2,CBN-PdlC-D3-20200722.jpg,b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\...,"[4.55143845101702e-06, 5.1167003221053164e-06,..."
3,CBN-PdlC-D3-20200722.jpg,b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\...,"[6.224210210348247e-06, 1.1222463399462868e-05..."
4,CBN-PdlC-D3-20200722.jpg,b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\...,"[3.508395707285672e-07, 6.856546406197594e-07,..."


In [10]:
subset_pd["probabilities"].iloc[0][:10]

[2.6421303118695505e-06,
 3.864691393573594e-07,
 2.1716316496167565e-06,
 1.0423846106277779e-06,
 2.137569936166983e-06,
 5.375565160647966e-05,
 0.00014674248814117163,
 4.8739006160758436e-05,
 0.0016026853118091822,
 0.001182828564196825]

In [11]:
import numpy as np


# get top-K predictions for each row
def get_top_n_predictions(probabilities: list, n=5):
    proba_arr = np.array(probabilities)
    top_n_indices = proba_arr.argsort()[-n:][::-1]  # fastest way to get top n indices
    return [(cid_to_spid[i], probabilities[i]) for i in top_n_indices]


top_k = 10
subset_pd[f"top_{top_k}_predictions"] = subset_pd["probabilities"].apply(
    lambda proba: get_top_n_predictions(proba, n=top_k)
)

In [12]:
subset_pd[f"top_{top_k}_predictions"].iloc[0]

[('1622901', 0.06579609960317612),
 ('1396439', 0.06303632259368896),
 ('1396408', 0.028585851192474365),
 ('1398779', 0.02830541878938675),
 ('1418211', 0.024407442659139633),
 ('1647128', 0.016660762950778008),
 ('1647677', 0.014911099337041378),
 ('1412857', 0.013472514227032661),
 ('1647150', 0.012758580036461353),
 ('1722440', 0.011723535135388374)]

In [13]:
subset_pd.head(5)

Unnamed: 0,image_name,extracted_bbox,probabilities,top_10_predictions
0,CBN-PdlC-D3-20200722.jpg,b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\...,"[2.6421303118695505e-06, 3.864691393573594e-07...","[(1622901, 0.06579609960317612), (1396439, 0.0..."
1,CBN-PdlC-D3-20200722.jpg,b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\...,"[4.0289813796334784e-07, 3.7558137933046964e-0...","[(1396408, 0.7895054817199707), (1396362, 0.01..."
2,CBN-PdlC-D3-20200722.jpg,b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\...,"[4.55143845101702e-06, 5.1167003221053164e-06,...","[(1425722, 0.048355188220739365), (1396408, 0...."
3,CBN-PdlC-D3-20200722.jpg,b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\...,"[6.224210210348247e-06, 1.1222463399462868e-05...","[(1396408, 0.07225219160318375), (1425722, 0.0..."
4,CBN-PdlC-D3-20200722.jpg,b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\...,"[3.508395707285672e-07, 6.856546406197594e-07,...","[(1396408, 0.26717033982276917), (1399082, 0.0..."


### scalable pipeline

In [14]:
# write 500 partitions of the exploded dataframe to parquet
output_dir = (
    f"{root}/p-dsgt_clef2025-0/shared/plantclef/data/detection/batched_extracted_bbox"
)
exploded_df.repartition(200).write.mode("overwrite").parquet(output_dir)

                                                                                

In [19]:
import os
from tqdm import tqdm


def run_batch_inference(spark, input_path, output_path):
    df = spark.read.parquet(input_path).toPandas()
    predict_fn = make_predict_fn()
    df["probabilities"] = df["extracted_bbox"].apply(predict_fn)
    df[["image_name", "probabilities"]].to_parquet(output_path, index=False)


def run_inference_on_all_batches(spark, input_dir, output_dir):
    os.makedirs(output_dir, exist_ok=True)

    parquet_files = sorted([f for f in os.listdir(input_dir) if f.endswith(".parquet")])

    for fname in tqdm(parquet_files, desc="Running inference on batches"):
        input_path = os.path.join(input_dir, fname)
        output_path = os.path.join(
            output_dir, fname.replace(".parquet", "_out.parquet")
        )

        if os.path.exists(output_path):
            continue  # Skip if already processed

        run_batch_inference(spark, input_path, output_path)

In [20]:
input_dir = (
    f"{root}/p-dsgt_clef2025-0/shared/plantclef/data/detection/batched_extracted_bbox"
)
output_dir = (
    f"{root}/p-dsgt_clef2025-0/shared/plantclef/data/detection/inference_outputs"
)

run_inference_on_all_batches(spark, input_dir, output_dir)

Running inference on batches: 100%|██████████| 200/200 [30:44<00:00,  9.22s/it] 
