# Single-label image classification

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from plantclef.spark import get_spark

spark = get_spark(cores=4)
display(spark)

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/04/16 13:40:09 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/04/16 13:40:09 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).


### embeddings

In [3]:
import os
from pathlib import Path

# Get list of stored filed in cloud bucket
root = Path(os.path.expanduser("~"))
! date

Wed Apr 16 01:40:11 PM EDT 2025


### test embeddings

In [None]:
# Path and dataset names
data_path = f"{root}/p-dsgt_clef2025-0/shared/plantclef/data/parquet"

# Define the path to the parquet files
test_path = f"{data_path}/test_2025"

# Read the parquet files into a spark DataFrame
test_df = spark.read.parquet(test_path)

# Show the data
test_df.printSchema()
test_df.show(n=5)

                                                                                

root
 |-- image_name: string (nullable = true)
 |-- path: string (nullable = true)
 |-- data: binary (nullable = true)



[Stage 1:>                                                          (0 + 1) / 1]

In [None]:
import timm
import torch
from plantclef.model_setup import setup_fine_tuned_model


num_classes = 7806  # total number of plant species
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = timm.create_model(
    "vit_base_patch14_reg4_dinov2.lvd142m",
    pretrained=False,
    num_classes=num_classes,
    checkpoint_path=setup_fine_tuned_model(),
)
# Data transform
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)
# Move model to GPU if available
model.to(device)
model.eval()

In [None]:
import pandas as pd


# load the cluster probabilities dataframes
def get_cluster_probability_dfs():
    clustering_path = "~/p-dsgt_clef2025-0/shared/plantclef/data/clustering"
    test_cluster_csv = f"{clustering_path}/test_2025_dominant_clusters.csv"
    test_cluster_probabilities = (
        f"{clustering_path}/test_2025_embed_probabilities_clustered"
    )
    cluster_df = pd.read_csv(test_cluster_csv)
    probabilities_df = pd.read_parquet(test_cluster_probabilities)
    return cluster_df, probabilities_df


def get_prior_for_image(image_name, cluster_df, probabilities_df) -> dict:
    row = cluster_df[cluster_df["image_name"] == image_name]
    cluster_id = row.iloc[0]["kmeans_cluster"]
    prior_row = probabilities_df[probabilities_df["dominant_cluster"] == cluster_id]
    return prior_row.iloc[0]["renormalized_probabilities"]


cluster_df, probabilities_df = get_cluster_probability_dfs()
display(cluster_df.head(5))
display(probabilities_df.head(5))

In [None]:
from plantclef.config import get_class_mappings_file

use_grid = True
use_prior = True


def load_class_mapping(class_mapping_file):
    with open(class_mapping_file) as f:
        class_index_to_class_name = {i: line.strip() for i, line in enumerate(f)}
    return class_index_to_class_name


def split_into_grid(image, grid_size=4):
    w, h = image.size
    grid_w, grid_h = w // grid_size, h // grid_size
    images = []
    for i in range(grid_size):
        for j in range(grid_size):
            left = i * grid_w
            upper = j * grid_h
            right = left + grid_w
            lower = upper + grid_h
            crop_image = image.crop((left, upper, right, lower))
            images.append(crop_image)
    return images


class_mapping_file = get_class_mappings_file()
cid_to_spid = load_class_mapping(class_mapping_file)

In [None]:
from plantclef.serde import deserialize_image


# predict single image
def predict(input_data, image_name):
    img = deserialize_image(input_data)  # from bytes to PIL image
    top_k_proba = 10
    limit_logits = 10
    images = [img]
    # use grid to get logits
    if use_grid:
        images = split_into_grid(img)
    results = []
    for tile in images:
        processed_image = transforms(tile).unsqueeze(0).to(device)
        with torch.no_grad():
            outputs = model(processed_image)
            probabilities = torch.softmax(outputs, dim=1) * 100
            if use_prior:
                prior = get_prior_for_image(image_name)
                probabilities = probabilities * torch.tensor(prior).to(device)
            top_probs, top_indices = torch.topk(probabilities, k=top_k_proba)
        top_probs = top_probs.cpu().numpy()[0]
        top_indices = top_indices.cpu().numpy()[0]

        result = [
            {cid_to_spid[index]: float(prob)}
            for index, prob in zip(top_indices, top_probs)
        ]
        results.append(result)
    # flatten the results from all grids, get top probabilities
    flattened_results = [item for grid in results for item in grid[:limit_logits]]
    # sort by score in descending order
    sorted_logits = sorted(flattened_results, key=lambda x: -list(x.values())[0])
    return sorted_logits

In [None]:
# select first image
image_name = test_df.select("image_name").first()[0]
image_data = test_df.filter(test_df.image_name == image_name).first()
type(image_data)