# Grid inference pipeline

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import csv
import torch
import numpy as np
import pandas as pd
from plantclef.utils import get_spark
from pyspark import SparkContext
from pyspark.sql import functions as F
from pyspark.sql.functions import pandas_udf
from google.cloud import storage
from plantclef.baseline.model import LinearClassifier


spark = get_spark()
display(spark)

  from .autonotebook import tqdm as notebook_tqdm
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/05/23 16:01:05 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/05/23 16:01:05 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).


In [3]:
# get dataframes
gcs_path = "gs://dsgt-clef-plantclef-2024"
test_path = "data/process/test_v2/grid_dino_pretrained/data"
cls_emb_train = "data/process/training_cropped_resized_v2/dino_cls_token/data"
default_root_dir = "gs://dsgt-clef-plantclef-2024/models/torch-petastorm-v2-linear-nllloss-pretrained-cls"
# paths to dataframe
test_path = f"{gcs_path}/{test_path}"
cls_gcs_path = f"{gcs_path}/{cls_emb_train}"
# read data
test_df = spark.read.parquet(test_path)
train_df = spark.read.parquet(cls_gcs_path)
# show
test_df.show(n=5, truncate=50)
train_df.show(n=5, truncate=50)

                                                                                

+-------------------------------+------------+--------------------------------------------------+---------+
|                     image_name|patch_number|                                     cls_embedding|sample_id|
+-------------------------------+------------+--------------------------------------------------+---------+
|       CBN-PdlC-B6-20190909.jpg|           7|[0.567645, 1.6921916, -0.29813325, 0.15727657, ...|        6|
|       CBN-PdlC-F1-20200722.jpg|           5|[0.25926074, 2.0660655, -0.5255144, -0.58279055...|        6|
|OPTMix-0598-P4-108-20231207.jpg|           8|[-0.1416022, 0.50105137, -0.94805455, 1.6697326...|        6|
|        CBN-Pla-B6-20130904.jpg|           5|[-0.3612894, 1.0147426, -0.5758655, -1.1154238,...|        6|
|       CBN-PdlC-B5-20190722.jpg|           4|[0.57478356, 1.923769, -0.39382437, -0.6269862,...|        6|
+-------------------------------+------------+--------------------------------------------------+---------+
only showing top 5 rows

+--

### prepare data for inference

In [4]:
limit_species = None
species_image_count = 100


def remap_index_to_species_id(df):
    # Aggregate and filter species based on image count
    grouped_df = (
        df.groupBy("species_id")
        .agg(F.count("species_id").alias("n"))
        .filter(F.col("n") >= species_image_count)
        .orderBy(F.col("n").desc(), F.col("species_id"))
        .withColumn("index", F.monotonically_increasing_id())
    ).drop("n")

    # Use broadcast join to optimize smaller DataFrame joining
    filtered_df = df.join(F.broadcast(grouped_df), "species_id", "inner")

    # Optionally limit the number of species
    if limit_species is not None:
        limited_grouped_df = (
            (
                grouped_df.orderBy(F.rand(seed=42))
                .limit(limit_species)
                .withColumn("new_index", F.monotonically_increasing_id())
            )
            .drop("index")
            .withColumnRenamed("new_index", "index")
        )

        filtered_df = filtered_df.drop("index").join(
            F.broadcast(limited_grouped_df), "species_id", "inner"
        )

    return filtered_df

In [5]:
# remap the indexes to species and get dataframe
filtered_df = remap_index_to_species_id(train_df)

# get parameters for the model
feature_col = "cls_embedding"
num_features = int(len(filtered_df.select(feature_col).first()[feature_col]))
num_classes = int(filtered_df.select("species_id").distinct().count())
print(f"num_features: {num_features}")
print(f"num_classes: {num_classes}")



num_features: 768
num_classes: 4797


                                                                                

In [6]:
def load_model_from_gcs(num_features: int, num_classes: int):
    bucket_name = "dsgt-clef-plantclef-2024"
    relative_path = default_root_dir.split(f"{bucket_name}/")[-1]
    path_in_bucket = f"{relative_path}/checkpoints/last.ckpt"
    client = storage.Client()
    bucket = client.bucket(bucket_name)
    blob = bucket.blob(path_in_bucket)
    blob.download_to_filename("last.ckpt")
    checkpoint = torch.load("last.ckpt", map_location=torch.device("cpu"))

    # Instantiate the model first
    model = LinearClassifier(num_features, num_classes)

    # Adjust the state_dict if necessary
    state_dict = checkpoint["state_dict"]

    # Load the state dictionary
    load_result = model.load_state_dict(state_dict, strict=False)

    if load_result.missing_keys or load_result.unexpected_keys:
        print("Warning: There were missing or unexpected keys during model loading")
        print("Missing keys:", load_result.missing_keys)
        print("Unexpected keys:", load_result.unexpected_keys)

    return model

In [7]:
# Get model
model = load_model_from_gcs(
    num_features=num_features,
    num_classes=num_classes,
)
# Broadcast the model to send to all executors
sc = SparkContext.getOrCreate()
broadcast_model = sc.broadcast(model)

In [8]:
@pandas_udf("long")  # Adjust the return type based on model's output
def predict_udf(embedding_series: pd.Series) -> pd.Series:
    local_model = broadcast_model.value  # Access the broadcast variable
    local_model.eval()  # Set the model to evaluation mode
    # Convert the list of numpy arrays to a single numpy array
    embeddings_array = np.array(list(embedding_series))
    # Convert the numpy array to a PyTorch tensor
    embeddings_tensor = torch.tensor(embeddings_array, dtype=torch.float32)
    # Make predictions
    with torch.no_grad():
        outputs = local_model(embeddings_tensor)
        predicted_classes = outputs.argmax(
            dim=1
        ).numpy()  # Get all predicted classes at once
    return pd.Series(predicted_classes)


# UDF using the CLS token
@pandas_udf("long")
def predict_with_cls_udf(embedding_series: pd.Series) -> pd.Series:
    local_model = broadcast_model.value
    local_model.eval()
    embeddings_array = np.array(list(embedding_series))
    embeddings_tensor = torch.tensor(embeddings_array, dtype=torch.float32)
    with torch.no_grad():
        outputs = local_model(embeddings_tensor)
        cls_token = outputs[:, 0, :]
        predicted_classes = cls_token.argmax(dim=1).numpy()
    return pd.Series(predicted_classes)

In [9]:
# Set CLS parameter
use_cls_token = False

if use_cls_token:
    result_df = test_df.withColumn(
        "predictions", predict_with_cls_udf(test_df[feature_col])
    ).cache()
else:
    # get predictions on test_df
    result_df = test_df.withColumn(
        "predictions", predict_udf(test_df[feature_col])
    ).cache()

result_df.orderBy("image_name", "patch_number").show(n=10, truncate=50)

                                                                                

+------------------------+------------+--------------------------------------------------+---------+-----------+
|              image_name|patch_number|                                     cls_embedding|sample_id|predictions|
+------------------------+------------+--------------------------------------------------+---------+-----------+
|CBN-PdlC-A1-20130807.jpg|           0|[-0.0986969, 1.3777568, 0.050541468, -0.2633565...|        4|       3515|
|CBN-PdlC-A1-20130807.jpg|           1|[0.28588018, 1.14535, -1.9985845, -0.28868702, ...|        4|       2700|
|CBN-PdlC-A1-20130807.jpg|           2|[0.009646882, 0.67857593, -1.010919, -0.3333199...|        4|       2700|
|CBN-PdlC-A1-20130807.jpg|           3|[0.53748274, 0.93665814, -0.38700807, -0.979181...|        4|       3272|
|CBN-PdlC-A1-20130807.jpg|           4|[0.5349963, 0.56628793, -0.26641238, -0.0379398...|        4|       3272|
|CBN-PdlC-A1-20130807.jpg|           5|[-0.03167449, 0.50434786, -1.0075089, 0.7306768...|      

In [10]:
def prepare_grid_dataframe_submission(result_df, filtered_df):
    # Join both dataframes to get species_id
    joined_df = (
        result_df.join(
            filtered_df,
            result_df.predictions == filtered_df.index,
            "inner",
        ).select(
            result_df.image_name, "patch_number", "species_id", "index", "predictions"
        )
    ).distinct()
    # Create columns for submission
    final_df = (
        joined_df.withColumn("plot_id", F.regexp_replace("image_name", "\\.jpg$", ""))
        .groupBy("plot_id")
        .agg(F.collect_set("species_id").alias("species_ids"))
    )
    # Convert the set of species_ids to a formatted string enclosed in single square brackets
    final_df = final_df.withColumn(
        "species_ids",
        F.concat(F.lit("["), F.concat_ws(", ", "species_ids"), F.lit("]")),
    )
    return final_df.cache()


def write_cvs_to_gcs(df, file_name: str = "dsgt_run"):
    # Convert Spark DataFrame to Pandas DataFrame
    sorted_final_df = df.orderBy("plot_id")
    final_pandas_df = sorted_final_df.toPandas()

    # Export to CSV with the specified format
    output_dir = f"{default_root_dir}/experiments/{file_name}.csv"
    final_pandas_df.to_csv(output_dir, sep=";", index=False, quoting=csv.QUOTE_NONE)

In [11]:
# prepare dataframe for submission
final_df = prepare_grid_dataframe_submission(
    result_df=result_df,
    filtered_df=filtered_df,
)
final_df.show(n=10, truncate=50)



+--------------------+--------------------------------------------------+
|             plot_id|                                       species_ids|
+--------------------+--------------------------------------------------+
|CBN-PdlC-A1-20160705|[1392608, 1392561, 1394911, 1556096, 1392535, 1...|
|CBN-PdlC-A1-20170906|[1392608, 1392561, 1529289, 1392534, 1412857, 1...|
|CBN-PdlC-A1-20190812|[1392608, 1394438, 1395807, 1743159, 1359835, 1...|
|CBN-PdlC-A1-20200629|[1355967, 1701831, 1373840, 1738847, 1361820, 1...|
|CBN-PdlC-A2-20130903|[1397468, 1394322, 1396144, 1395807, 1391079, 1...|
|CBN-PdlC-A2-20140721|              [1395944, 1395807, 1360910, 1362471]|
|CBN-PdlC-A2-20140901|[1395944, 1390929, 1609716, 1359835, 1361585, 1...|
|CBN-PdlC-A2-20160705|[1395944, 1395807, 1394911, 1355992, 1392669, 1...|
|CBN-PdlC-A2-20160726|     [1393200, 1396144, 1395944, 1395807, 1389597]|
|CBN-PdlC-A2-20180723|[1393200, 1390674, 1395807, 1394366, 1394911, 1...|
+--------------------+----------------

                                                                                

In [12]:
# write CSV file to GCS
# write_cvs_to_gcs(final_df=final_df)

### top k probabilities

In [13]:
# get subset of data
limit_test_df = test_df.limit(10).cache()

In [14]:
import torch
import numpy as np
from pyspark.sql.types import (
    ArrayType,
    StructType,
    StructField,
    LongType,
    FloatType,
)
from pyspark.sql.functions import pandas_udf

In [54]:
species_id_list = filtered_df.select("species_id").distinct().collect()
species_ids = sorted([row["species_id"] for row in species_id_list])
species_id_dict = {idx: species_id for idx, species_id in enumerate(species_ids)}
print(f"species id count: {len(species_id_dict)}")

species id count: 4797


In [55]:
num_features = 768
num_classes = 4797  # total number of plant species
local_directory = "/mnt/data/models/pretrained_models"
class_mapping_file = f"{local_directory}/class_mapping.txt"
# Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Get model
model = load_model_from_gcs(num_features=num_features, num_classes=num_classes)
model.eval()
# Broadcast the model to send to all executors
sc = SparkContext.getOrCreate()
broadcast_model = sc.broadcast(model)
sql_statement = "SELECT image_name, patch_number, predictions FROM __THIS__"


def load_model_from_gcs(num_features: int, num_classes: int):
    bucket_name = "dsgt-clef-plantclef-2024"
    relative_path = default_root_dir.split(f"{bucket_name}/")[-1]
    path_in_bucket = f"{relative_path}/checkpoints/last.ckpt"
    client = storage.Client()
    bucket = client.bucket(bucket_name)
    blob = bucket.blob(path_in_bucket)
    blob.download_to_filename("last.ckpt")
    checkpoint = torch.load("last.ckpt", map_location=torch.device("cpu"))

    # Instantiate the model first
    model = LinearClassifier(num_features, num_classes)

    # Adjust the state_dict if necessary
    state_dict = checkpoint["state_dict"]

    # Load the state dictionary
    load_result = model.load_state_dict(state_dict, strict=False)

    if load_result.missing_keys or load_result.unexpected_keys:
        print("Warning: There were missing or unexpected keys during model loading")
        print("Missing keys:", load_result.missing_keys)
        print("Unexpected keys:", load_result.unexpected_keys)

    return model


def load_class_mapping():
    with open(class_mapping_file) as f:
        class_index_to_class_name = {i: line.strip() for i, line in enumerate(f)}
    return class_index_to_class_name


@pandas_udf(
    ArrayType(
        StructType(
            [
                StructField("species_id", LongType(), True),
                StructField("probability", FloatType(), True),
            ]
        )
    )
)
def make_predict_udf(embedding_series):
    cid_to_spid = load_class_mapping()
    limit_logits = 5
    top_k_proba = 5

    def predict(embedding_series):
        # Convert the list of numpy arrays to a single numpy array
        embeddings_array = np.array(list(embedding_series))
        # Convert the numpy array to a PyTorch tensor
        embeddings_tensor = torch.tensor(
            embeddings_array, dtype=torch.float32
        ).unsqueeze(0)
        # Make predictions
        with torch.no_grad():
            outputs = model(embeddings_tensor)
            probabilities = torch.softmax(outputs, dim=1) * 100
            top_probs, top_indices = torch.topk(probabilities, k=top_k_proba, dim=1)
            top_probs = top_probs.cpu().numpy()
            top_indices = top_indices.cpu().numpy()

            results = []
            for i in range(int(top_probs.shape[0])):
                result = [
                    {
                        "species_id": int(species_id_dict.get(int(index), -1)),
                        "probability": float(round(prob, 6)),
                    }
                    for index, prob in zip(top_indices[i], top_probs[i])
                ]
                results.append(result)

        # Flatten the list of lists
        flattened_results = [item for sublist in results for item in sublist]
        # Sort by score in descending order
        sorted_results = sorted(flattened_results, key=lambda x: -x["probability"])
        return sorted_results[:limit_logits]

    return embedding_series.apply(predict)

In [56]:
# get predictions on test_df
grid_result_df = test_df.withColumn(
    "predictions", make_predict_udf(test_df["cls_embedding"])
).cache()

In [57]:
grid_result_df.orderBy("image_name", "patch_number").show(n=10, truncate=50)

[Stage 111:>                                                      (0 + 4) / 151]



+------------------------+------------+--------------------------------------------------+---------+--------------------------------------------------+
|              image_name|patch_number|                                     cls_embedding|sample_id|                                       predictions|
+------------------------+------------+--------------------------------------------------+---------+--------------------------------------------------+
|CBN-PdlC-A1-20130807.jpg|           0|[-0.0986969, 1.3777568, 0.050541468, -0.2633565...|        4|[{1396736, 51.744686}, {1393635, 43.78267}, {17...|
|CBN-PdlC-A1-20130807.jpg|           1|[0.28588018, 1.14535, -1.9985845, -0.28868702, ...|        4|[{1392569, 99.999}, {1396736, 6.08E-4}, {139363...|
|CBN-PdlC-A1-20130807.jpg|           2|[0.009646882, 0.67857593, -1.010919, -0.3333199...|        4|[{1392569, 100.0}, {1393792, 0.0}, {1396736, 0....|
|CBN-PdlC-A1-20130807.jpg|           3|[0.53748274, 0.93665814, -0.38700807, -0.979181..

                                                                                

In [58]:
# get row from DataFrame with the predictions
example_row = grid_result_df.select("predictions").first()
dino_logits = example_row["predictions"]

# To visualize the content of the first row
import pprint

pp = pprint.PrettyPrinter(indent=2)
pp.pprint(dino_logits)

[ Row(species_id=1392569, probability=100.0),
  Row(species_id=1363259, probability=0.0),
  Row(species_id=1395097, probability=0.0),
  Row(species_id=1649444, probability=0.0),
  Row(species_id=1392023, probability=0.0)]


In [59]:
import itertools

grid_grouped_df = grid_result_df.groupBy("image_name").agg(
    F.collect_list("predictions").alias("all_predictions")
)


def flatten_and_sort(predictions):
    # Flatten the list of lists
    flat_list = list(itertools.chain(*predictions))
    # Sort the list by probability in desc order
    sorted_list = sorted(flat_list, key=lambda x: -x["probability"])
    return sorted_list


# register UDF
flatten_sort_udf = F.udf(
    flatten_and_sort,
    ArrayType(
        StructType(
            [
                StructField("species_idf", LongType(), True),
                StructField("probability", FloatType(), True),
            ]
        )
    ),
)

# apply UDF
grid_sorted_df = grid_grouped_df.withColumn(
    "sorted_predictions", flatten_sort_udf(F.col("all_predictions"))
)
grid_sorted_df = grid_sorted_df.select("image_name", "sorted_predictions")
grid_sorted_df.orderBy("image_name").show(n=10, truncate=50)





+------------------------+--------------------------------------------------+
|              image_name|                                sorted_predictions|
+------------------------+--------------------------------------------------+
|CBN-PdlC-A1-20130807.jpg|[{1392569, 100.0}, {1358641, 100.0}, {1392569, ...|
|CBN-PdlC-A1-20130903.jpg|[{1358641, 100.0}, {1478279, 99.999954}, {13925...|
|CBN-PdlC-A1-20140721.jpg|[{1395287, 99.999985}, {1358641, 99.99691}, {13...|
|CBN-PdlC-A1-20140811.jpg|[{1392569, 100.0}, {1392569, 99.99998}, {139528...|
|CBN-PdlC-A1-20140901.jpg|[{1358641, 100.0}, {1395287, 99.99962}, {141838...|
|CBN-PdlC-A1-20150701.jpg|[{1358641, 100.0}, {1418584, 99.999275}, {14185...|
|CBN-PdlC-A1-20150720.jpg|[{1392569, 100.0}, {1358641, 100.0}, {1395287, ...|
|CBN-PdlC-A1-20150831.jpg|[{1395287, 100.0}, {1358641, 100.0}, {1395287, ...|
|CBN-PdlC-A1-20160705.jpg|[{1395287, 100.0}, {1358641, 100.0}, {1396736, ...|
|CBN-PdlC-A1-20160726.jpg|[{1358641, 100.0}, {1395287, 99.999954

                                                                                

In [62]:
from pyspark.sql.types import ArrayType, IntegerType


# Define the function to extract top k species IDs
def get_top_k_species(predictions, k=5):
    # Get the top k predictions
    top_k = predictions[:k]
    # Extract species_ids and convert to a set to remove duplicates
    species_ids_set = {entry["species_id"] for entry in top_k}
    # Convert set to a list of integers
    return list(species_ids_set)


# Define a closure that captures the value of k
def get_top_k_species_udf(k):
    return F.udf(lambda preds: get_top_k_species(preds, k), ArrayType(IntegerType()))


# Apply the UDF to get top k species IDs
top_k_proba = 5  # Number of top species to extract
grid_final_df = grid_sorted_df.withColumn(
    "species_ids", get_top_k_species_udf(top_k_proba)(F.col("sorted_predictions"))
)

# Convert the set of species_ids to a formatted string
grid_final_df = (
    grid_final_df.withColumn("plot_id", F.regexp_replace("image_name", "\\.jpg$", ""))
    .withColumn(
        "species_ids",
        F.concat(F.lit("["), F.concat_ws(", ", "species_ids"), F.lit("]")),
    )
    .select("plot_id", "species_ids")
)

# Show the final result
grid_final_df.show(truncate=False)



+--------------------+---------------------------------------------+
|plot_id             |species_ids                                  |
+--------------------+---------------------------------------------+
|CBN-PdlC-A1-20130807|[1392569, 1395287, 1358641]                  |
|CBN-PdlC-A1-20140721|[1358641, 1400052, 1395287]                  |
|CBN-PdlC-A1-20140901|[1396225, 1418380, 1412238, 1358641, 1395287]|
|CBN-PdlC-A1-20150720|[1392569, 1393635, 1395287, 1358641]         |
|CBN-PdlC-A1-20160705|[1396736, 1358641, 1412238, 1395287]         |
|CBN-PdlC-A1-20160726|[1411968, 1358641, 1395097, 1395287]         |
|CBN-PdlC-A1-20190812|[1360744, 1358641, 1392569, 1395287]         |
|CBN-PdlC-A2-20130903|[1718504, 1392819, 1400182, 1395287, 1392569]|
|CBN-PdlC-A2-20190909|[1392569, 1743938, 1390691, 1395287]         |
|CBN-PdlC-A3-20130903|[1358641, 1392569, 1395287]                  |
|CBN-PdlC-A3-20140721|[1392569, 1358641]                           |
|CBN-PdlC-A3-20140901|[1743938, 17

                                                                                

In [61]:
# write CSV file to GCS
write_cvs_to_gcs(final_df=grid_final_df, file_name="dsgt_grid_top_5_species_run")

                                                                                

In [64]:
grid_final_df.orderBy("plot_id").show(n=5, truncate=False)
final_df.orderBy("plot_id").show(n=5, truncate=False)



+--------------------+---------------------------------------------+
|plot_id             |species_ids                                  |
+--------------------+---------------------------------------------+
|CBN-PdlC-A1-20130807|[1392569, 1395287, 1358641]                  |
|CBN-PdlC-A1-20130903|[1478279, 1392654, 1358641, 1395287, 1392569]|
|CBN-PdlC-A1-20140721|[1358641, 1400052, 1395287]                  |
|CBN-PdlC-A1-20140811|[1392569, 1358641, 1395287]                  |
|CBN-PdlC-A1-20140901|[1396225, 1418380, 1412238, 1358641, 1395287]|
+--------------------+---------------------------------------------+
only showing top 5 rows

+--------------------+------------------------------------------------------------------------+
|plot_id             |species_ids                                                             |
+--------------------+------------------------------------------------------------------------+
|CBN-PdlC-A1-20130807|[1392608, 1393836, 1395807, 1394911, 1412857

                                                                                

In [25]:
final_df.printSchema()
grid_final_df.printSchema()

root
 |-- plot_id: string (nullable = true)
 |-- species_ids: string (nullable = false)

root
 |-- plot_id: string (nullable = true)
 |-- species_ids: string (nullable = false)



In [35]:
filtered_df.show(n=10, truncate=50)

                                                                                

+----------+--------------------------------------------+--------------------------------------------------+-----+
|species_id|                                  image_name|                                     cls_embedding|index|
+----------+--------------------------------------------+--------------------------------------------------+-----+
|   1390691|5ddad155a99ff9b22355b940c100ee588fd73587.jpg|[1.2776474, 1.4595255, 1.7950346, 2.156743, 2.8...|  893|
|   1360260|ea74d65858ceb7d55981560684234b826a7645e2.jpg|[-1.9118708, 3.0542357, -3.143496, 1.4037858, 0...|   84|
|   1390699|86eb8b2e1a7b6a6d4bdd7b49a1e4bd090b867cf5.jpg|[0.57074785, 1.7904518, -1.0096772, 1.4558841, ...|  904|
|   1396486|4dd8c6a10041b1c028823d667316679da4c457fd.jpg|[0.85863805, -2.9571414, 1.3469055, -0.8254508,...| 1528|
|   1356608|1d5d139f575031bd13940c1ef10bf4e73dac6d41.jpg|[-0.4914153, -1.3314899, 0.33835393, -0.6760831...|   63|
|   1360185|c4f245cd59fa8751cdfa8fe83bcd757e2bd0c44b.jpg|[-0.9461619, 1.1640061,