# Inference pipeline

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import csv
import torch
import numpy as np
import pandas as pd
from pathlib import Path
from plantclef.utils import get_spark
from pyspark import SparkContext
from pyspark.sql import functions as F
from pyspark.sql.functions import pandas_udf
from google.cloud import storage
from plantclef.baseline.model import LinearClassifier


spark = get_spark()
display(spark)

  from .autonotebook import tqdm as notebook_tqdm
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/05/04 23:23:29 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/05/04 23:23:29 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).


In [3]:
# get dataframes
gcs_path = "gs://dsgt-clef-plantclef-2024"
test_path = "data/process/test_v1/dino_dct/data"
dct_emb_train = "data/process/training_cropped_resized_v2/dino_dct/data"
default_root_dir = (
    "gs://dsgt-clef-plantclef-2024/models/torch-petastorm-v1-limit-species-5"
)
# paths to dataframe
test_path = f"{gcs_path}/{test_path}"
dct_gcs_path = f"{gcs_path}/{dct_emb_train}"
# read data
test_df = spark.read.parquet(test_path)
dct_df = spark.read.parquet(dct_gcs_path)
# show
test_df.show(n=5, truncate=50)
dct_df.show(n=5, truncate=50)

                                                                                

+------------------------+--------------------------------------------------+
|              image_name|                                     dct_embedding|
+------------------------+--------------------------------------------------+
| CBN-Pla-C3-20190723.jpg|[-20098.941, -21163.824, 10775.059, -5062.299, ...|
|   RNNB-3-9-20240117.jpg|[-30963.371, 1738.1318, -22631.05, 7658.0986, 1...|
|CBN-PdlC-C4-20160705.jpg|[-21734.926, 7483.978, 856.27124, -969.2971, 19...|
|CBN-PdlC-F6-20200722.jpg|[-25506.555, 19362.922, -22429.754, -5536.284, ...|
|CBN-PdlC-B1-20190812.jpg|[-30066.883, 16372.629, 13698.087, 13579.195, -...|
+------------------------+--------------------------------------------------+
only showing top 5 rows

+--------------------------------------------+----------+--------------------------------------------------+
|                                  image_name|species_id|                                     dct_embedding|
+--------------------------------------------+---------

### prepare data for inference

In [4]:
limit_species = 5
species_image_count = 100


def remap_index_to_species_id(df):
    # Aggregate and filter species based on image count
    grouped_df = (
        df.groupBy("species_id")
        .agg(F.count("species_id").alias("n"))
        .filter(F.col("n") >= species_image_count)
        .orderBy(F.col("n").desc(), F.col("species_id"))
        .withColumn("index", F.monotonically_increasing_id())
    ).drop("n")

    # Use broadcast join to optimize smaller DataFrame joining
    filtered_df = df.join(F.broadcast(grouped_df), "species_id", "inner")

    # Optionally limit the number of species
    if limit_species is not None:
        limited_grouped_df = (
            (
                grouped_df.orderBy(F.rand(seed=42))
                .limit(limit_species)
                .withColumn("new_index", F.monotonically_increasing_id())
            )
            .drop("index")
            .withColumnRenamed("new_index", "index")
        )

        filtered_df = filtered_df.drop("index").join(
            F.broadcast(limited_grouped_df), "species_id", "inner"
        )

    return filtered_df

In [6]:
# remap the indexes to species and get dataframe
filtered_df = remap_index_to_species_id(dct_df)

# get parameters for the model
feature_col = "dct_embedding"
num_features = int(len(filtered_df.select(feature_col).first()[feature_col]))
num_classes = int(filtered_df.select("species_id").distinct().count())

                                                                                

In [7]:
def load_model_from_gcs(num_features: int, num_classes: int):
    bucket_name = "dsgt-clef-plantclef-2024"
    relative_path = default_root_dir.split(f"{bucket_name}/")[-1]
    path_in_bucket = f"{relative_path}/checkpoints/last.ckpt"
    client = storage.Client()
    bucket = client.bucket(bucket_name)
    blob = bucket.blob(path_in_bucket)
    blob.download_to_filename("last.ckpt")
    checkpoint = torch.load("last.ckpt", map_location=torch.device("cpu"))

    # Instantiate the model first
    model = LinearClassifier(num_features, num_classes)

    # Adjust the state_dict if necessary
    state_dict = checkpoint["state_dict"]

    # Load the state dictionary
    load_result = model.load_state_dict(state_dict, strict=False)

    if load_result.missing_keys or load_result.unexpected_keys:
        print("Warning: There were missing or unexpected keys during model loading")
        print("Missing keys:", load_result.missing_keys)
        print("Unexpected keys:", load_result.unexpected_keys)

    return model

In [8]:
# Get model
model = load_model_from_gcs(
    num_features=num_features,
    num_classes=num_classes,
)
# Broadcast the model to send to all executors
sc = SparkContext.getOrCreate()
broadcast_model = sc.broadcast(model)

In [9]:
@pandas_udf("long")  # Adjust the return type based on model's output
def predict_udf(dct_embedding_series: pd.Series) -> pd.Series:
    local_model = broadcast_model.value  # Access the broadcast variable
    local_model.eval()  # Set the model to evaluation mode
    # Convert the list of numpy arrays to a single numpy array
    embeddings_array = np.array(list(dct_embedding_series))
    # Convert the numpy array to a PyTorch tensor
    embeddings_tensor = torch.tensor(embeddings_array, dtype=torch.float32)
    # Make predictions
    with torch.no_grad():
        outputs = local_model(embeddings_tensor)
        predicted_classes = outputs.argmax(
            dim=1
        ).numpy()  # Get all predicted classes at once
    return pd.Series(predicted_classes)


# UDF using the CLS token
@pandas_udf("long")
def predict_with_cls_udf(dct_embedding_series: pd.Series) -> pd.Series:
    local_model = broadcast_model.value
    local_model.eval()
    embeddings_array = np.array(list(dct_embedding_series))
    embeddings_tensor = torch.tensor(embeddings_array, dtype=torch.float32)
    with torch.no_grad():
        outputs = local_model(embeddings_tensor)
        cls_token = outputs[:, 0, :]
        predicted_classes = cls_token.argmax(dim=1).numpy()
    return pd.Series(predicted_classes)

In [10]:
def prepare_dataframe_submission(result_df, filtered_df):
    # Join both dataframes to get species_id
    joined_df = result_df.join(
        F.broadcast(filtered_df), result_df.predictions == filtered_df.index, "inner"
    ).select(result_df.image_name, "species_id", "index", "predictions")
    # Create columns for submission
    final_df = (
        joined_df.withColumn("plot_id", F.regexp_replace("image_name", "\\.jpg$", ""))
        .groupBy("plot_id")
        .agg(F.collect_set("species_id").alias("species_ids"))
    )
    # Convert the set of species_ids to a formatted string enclosed in single square brackets
    final_df = final_df.withColumn(
        "species_ids",
        F.concat(F.lit("["), F.concat_ws(", ", "species_ids"), F.lit("]")),
    )
    return final_df.cache()


def write_cvs_to_gcs(final_df):
    # Convert Spark DataFrame to Pandas DataFrame
    final_pandas_df = final_df.toPandas()

    # Export to CSV with the specified format
    output_dir = f"{default_root_dir}/experiments/dsgt_run.csv"
    final_pandas_df.to_csv(output_dir, sep=";", index=False, quoting=csv.QUOTE_NONE)

In [11]:
# Set CLS parameter
use_cls_token = False

if use_cls_token:
    result_df = test_df.withColumn(
        "predictions", predict_with_cls_udf(test_df[feature_col])
    ).cache()
else:
    # get predictions on test_df
    result_df = test_df.withColumn(
        "predictions", predict_udf(test_df[feature_col])
    ).cache()

# prepare dataframe for submission
final_df = prepare_dataframe_submission(
    result_df=result_df,
    filtered_df=filtered_df,
)
final_df.show(n=10, truncate=50)

# write CSV file to GCS
# write_cvs_to_gcs(final_df=final_df)



+--------------------+-----------+
|             plot_id|species_ids|
+--------------------+-----------+
|CBN-PdlC-A1-20130807|  [1363875]|
|CBN-PdlC-A1-20130903|  [1363875]|
|CBN-PdlC-A1-20140721|  [1363875]|
|CBN-PdlC-A1-20140811|  [1363875]|
|CBN-PdlC-A1-20140901|  [1363875]|
|CBN-PdlC-A1-20150701|  [1363875]|
|CBN-PdlC-A1-20150720|  [1363875]|
|CBN-PdlC-A1-20150831|  [1363875]|
|CBN-PdlC-A1-20160705|  [1363875]|
|CBN-PdlC-A1-20160726|  [1363875]|
+--------------------+-----------+
only showing top 10 rows



                                                                                

In [12]:
# Convert Spark DataFrame to Pandas DataFrame
final_pandas_df = final_df.toPandas()

# Export to CSV with the specified format
base_dir = Path(os.getcwd()).parents[1]
output_dir = f"{base_dir}/experiments/dsgt_run_v1.csv"
final_pandas_df.to_csv(output_dir, sep=";", index=False, quoting=csv.QUOTE_NONE)

## Extract CLS Token

In [21]:
# get dataframes
gcs_path = "gs://dsgt-clef-plantclef-2024"
dino_emb_train = "data/process/training_cropped_resized_v2/dino/data"
cls_emb_path = "data/process/training_cropped_resized_v2/dino_cls_token/data"
# paths to dataframe
dino_gcs_path = f"{gcs_path}/{dino_emb_train}"
cls_gcs_path = f"{gcs_path}/{cls_emb_path}"
# read data
dino_df = spark.read.parquet(dino_gcs_path)
cls_df = spark.read.parquet(cls_gcs_path)
# show
dino_df.show(n=5, truncate=50)
cls_df.show(n=5, truncate=50)

                                                                                

+--------------------------------------------+----------+--------------------------------------------------+---------+
|                                  image_name|species_id|                                    dino_embedding|sample_id|
+--------------------------------------------+----------+--------------------------------------------------+---------+
|8384311a03a9cff67a54a2825dbeb4d3e8a891a3.jpg|   1397608|[0.75137013, 0.3275455, 1.6707572, 0.45285824, ...|        9|
|b38e87b2a2bcfeefcbc6adbeb4aad0437b9e1839.jpg|   1397608|[0.35812917, 1.4896353, 2.4680657, 0.7175607, 0...|        9|
|b56d8dc9553c1014cb6aecffa93c734aaa997ccf.jpg|   1363992|[1.5729619, -0.0512933, 0.5113419, -1.4510978, ...|        9|
|ff815358961d1c0dbd1a95e1ac5f9dff0e5e13fc.jpg|   1358357|[1.8277985, 0.011137103, -0.058480255, -0.32846...|        9|
|e2b977b6461d35a266c28806cb75e11930614866.jpg|   1363814|[0.37129048, -1.4583586, -3.1828911, 1.3645148,...|        9|
+--------------------------------------------+--