# Inference pipeline

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path
from plantclef.utils import get_spark
from pyspark.sql import functions as F


spark = get_spark()
display(spark)

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/04/27 15:28:31 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/04/27 15:28:32 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).


In [3]:
# paths
gcs_path = "gs://dsgt-clef-plantclef-2024"
test_path = "data/process/test_v1/dino_dct/data"
# data module parameters
input_path = f"{gcs_path}/{test_path}"
# read data
test_df = spark.read.parquet(input_path)
# show data
test_df.show(n=5, truncate=50)

                                                                                

+------------------------+--------------------------------------------------+
|              image_name|                                     dct_embedding|
+------------------------+--------------------------------------------------+
| CBN-Pla-C3-20190723.jpg|[-20098.941, -21163.824, 10775.059, -5062.299, ...|
|   RNNB-3-9-20240117.jpg|[-30963.371, 1738.1318, -22631.05, 7658.0986, 1...|
|CBN-PdlC-C4-20160705.jpg|[-21734.926, 7483.978, 856.27124, -969.2971, 19...|
|CBN-PdlC-F6-20200722.jpg|[-25506.555, 19362.922, -22429.754, -5536.284, ...|
|CBN-PdlC-B1-20190812.jpg|[-30066.883, 16372.629, 13698.087, 13579.195, -...|
+------------------------+--------------------------------------------------+
only showing top 5 rows



In [4]:
import torch
from google.cloud import storage
from plantclef.baseline.model import LinearClassifier


def load_model_from_gcs(
    bucket_name: str, path_in_bucket: str, num_features: int, num_classes: int
):
    client = storage.Client()
    bucket = client.bucket(bucket_name)
    blob = bucket.blob(path_in_bucket)
    blob.download_to_filename("last.ckpt")
    checkpoint = torch.load("last.ckpt", map_location=torch.device("cpu"))

    # Instantiate the model first
    model = LinearClassifier(num_features, num_classes)

    # Adjust the state_dict if necessary
    state_dict = checkpoint["state_dict"]
    adjusted_state_dict = {
        k.replace("layer.", "model."): v for k, v in state_dict.items()
    }

    # Load the state dictionary
    load_result = model.load_state_dict(adjusted_state_dict, strict=False)

    if load_result.missing_keys or load_result.unexpected_keys:
        print("Warning: There were missing or unexpected keys during model loading")
        print("Missing keys:", load_result.missing_keys)
        print("Unexpected keys:", load_result.unexpected_keys)

    return model

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
# model path
bucket_path = gcs_path.split("gs://")[1]
model_path = "models/torch-petastorm-v1/plantclef-2024/u1p43hb2/checkpoints/last.ckpt"
# get model
model = load_model_from_gcs(bucket_path, model_path, 64, 4797)

In [6]:
# Load the checkpoint file
checkpoint = torch.load("last.ckpt", map_location=torch.device("cpu"))

# You can inspect keys to see what's inside the checkpoint
print(checkpoint.keys())

# Check if 'hyper_parameters' exists in the checkpoint
if "hyper_parameters" in checkpoint:
    hyperparameters = checkpoint["hyper_parameters"]
    print("Hyperparameters:", hyperparameters)
else:
    print("No hyperparameters found in this checkpoint.")

dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'hparams_name', 'hyper_parameters'])
Hyperparameters: {'num_features': 64, 'num_classes': 4797}


In [7]:
import torch
import numpy as np
import pandas as pd
from pyspark import SparkContext
from pyspark.sql.functions import pandas_udf, PandasUDFType

sc = SparkContext.getOrCreate()
broadcast_model = sc.broadcast(model)

In [8]:
@pandas_udf("long")  # Adjust the return type based on your model's output
def predict_udf(dct_embedding_series: pd.Series) -> pd.Series:
    local_model = broadcast_model.value  # Access the broadcast variable
    local_model.eval()  # Set the model to evaluation mode

    # Convert the list of numpy arrays to a single numpy array
    embeddings_array = np.array(list(dct_embedding_series))

    # Convert the numpy array to a PyTorch tensor
    embeddings_tensor = torch.tensor(embeddings_array, dtype=torch.float32)

    # Make predictions
    with torch.no_grad():
        outputs = local_model(embeddings_tensor)
        predicted_classes = outputs.argmax(
            dim=1
        ).numpy()  # Get all predicted classes at once

    return pd.Series(predicted_classes)

In [9]:
# get predictions on test_df
result_df = test_df.withColumn("predictions", predict_udf(test_df["dct_embedding"]))

In [10]:
result_df.show()

[Stage 2:>                                                          (0 + 1) / 1]

+--------------------+--------------------+-----------+
|          image_name|       dct_embedding|predictions|
+--------------------+--------------------+-----------+
|CBN-Pla-C3-201907...|[-20098.941, -211...|       4217|
|RNNB-3-9-20240117...|[-30963.371, 1738...|        700|
|CBN-PdlC-C4-20160...|[-21734.926, 7483...|        111|
|CBN-PdlC-F6-20200...|[-25506.555, 1936...|        688|
|CBN-PdlC-B1-20190...|[-30066.883, 1637...|        331|
|CBN-Pla-C4-201509...|[-18970.492, -138...|       2594|
|CBN-Pla-E4-201508...|[-21910.695, -651...|       2252|
|CBN-Pla-A4-201507...|[-20680.252, -102...|        730|
|CBN-Pla-C2-201607...|[-21060.98, -7502...|       1162|
|RNNB-2-9-20230512...|[-21258.816, -157...|        747|
|CBN-PdlC-F1-20180...|[-20736.156, -569...|       2700|
|RNNB-3-7-20230512...|[-26146.121, -410...|        616|
|CBN-PdlC-E4-20140...|[-20004.02, 8810....|       2133|
|CBN-PdlC-F4-20160...|[-24047.156, 1470...|       4066|
|CBN-PdlC-B2-20150...|[-26217.541, 1342...|     

                                                                                

### prepare data for inference

In [11]:
# Path and dataset names
dct_emb_train = "data/process/training_cropped_resized_v2/dino_dct/data"
dct_gcs_path = f"{gcs_path}/{dct_emb_train}"
dct_df = spark.read.parquet(dct_gcs_path)


def remap_index_to_species_id(df, species_image_count: int = 100):
    # Aggregate and filter species based on image count
    grouped_df = (
        df.groupBy("species_id")
        .agg(F.count("species_id").alias("n"))
        .filter(F.col("n") >= species_image_count)
        .orderBy(F.col("n").desc(), F.col("species_id"))
        .withColumn("index", F.monotonically_increasing_id())
    ).drop("n")

    # Use broadcast join to optimize smaller DataFrame joining
    filtered_df = df.join(F.broadcast(grouped_df), "species_id", "inner")
    return filtered_df

In [12]:
filtered_df = remap_index_to_species_id(dct_df)
filtered_df.show(n=5, truncate=50)

                                                                                

+----------+--------------------------------------------+--------------------------------------------------+-----+
|species_id|                                  image_name|                                     dct_embedding|index|
+----------+--------------------------------------------+--------------------------------------------------+-----+
|   1742956|170e88ca9af457daa1038092479b251c61c64f7d.jpg|[-20648.51, 2133.689, -2555.3125, 14820.57, 685...| 3810|
|   1367432|b0433cd6968b57d52e5c25dc45a28e674a25e61e.jpg|[-23662.764, -6773.8213, -8283.518, 3769.6064, ...|  623|
|   1360260|2b0301802884c465aa5d835598c4982a7be23a91.jpg|[-26661.809, 14558.946, -22445.26, 8736.277, 18...|   84|
|   1378563|e60616fddca9656d3d6bf3a0433a9652ca1b343a.jpg|[-31823.447, 393.7155, 20305.72, -4179.255, 168...| 3195|
|   1358543|24c155bbe8f0e77ae54ce302fa11384b4aab1341.jpg|[-21757.35, 2965.9988, 13036.157, -7744.346, 41...|   34|
+----------+--------------------------------------------+-----------------------

                                                                                

In [13]:
# Join both dataframes to get species_id
joined_df = result_df.join(
    filtered_df, result_df.predictions == filtered_df.index, "inner"
).select(result_df.image_name, "species_id", "index", "predictions")
joined_df.show(n=5, truncate=50)

                                                                                

+------------------------+----------+-----+-----------+
|              image_name|species_id|index|predictions|
+------------------------+----------+-----+-----------+
|CBN-PdlC-F5-20160726.jpg|   1390691|  893|        893|
|CBN-PdlC-F5-20170906.jpg|   1396797|   99|         99|
|CBN-PdlC-F5-20160726.jpg|   1390691|  893|        893|
|   RNNB-4-2-20240117.jpg|   1396253| 2548|       2548|
| CBN-Pla-C3-20130808.jpg|   1396253| 2548|       2548|
+------------------------+----------+-----+-----------+
only showing top 5 rows



In [14]:
final_df = (
    joined_df.withColumn("plot_id", F.regexp_replace("image_name", "\\.jpg$", ""))
    .groupBy("plot_id")
    .agg(F.collect_set("species_id").alias("species_ids"))
)
# Convert the set of species_ids to a formatted string enclosed in single square brackets
final_df = final_df.withColumn(
    "species_ids", F.concat(F.lit("["), F.concat_ws(", ", "species_ids"), F.lit("]"))
)
final_df.show(n=5, truncate=50)



+--------------------+-----------+
|             plot_id|species_ids|
+--------------------+-----------+
|CBN-PdlC-A1-20130807|  [1356217]|
|CBN-PdlC-A1-20130903|  [1646238]|
|CBN-PdlC-A1-20140721|  [1361520]|
|CBN-PdlC-A1-20140811|  [1646238]|
|CBN-PdlC-A1-20140901|  [1361863]|
+--------------------+-----------+
only showing top 5 rows



                                                                                

In [15]:
import csv

# Convert Spark DataFrame to Pandas DataFrame
final_pandas_df = final_df.toPandas()

# Export to CSV with the specified format
base_dir = Path(os.getcwd()).parents[1]
output_dir = f"{base_dir}/experiments/dsgt_run_v1.csv"
final_pandas_df.to_csv(output_dir, sep=";", index=False, quoting=csv.QUOTE_NONE)

                                                                                

In [16]:
default_root_dir = "gs://dsgt-clef-plantclef-2024/models/torch-petastorm-v1"
output_dir = f"{default_root_dir}/experiments/dsgt_run_v1.csv"
# write to GCS
final_pandas_df.to_csv(output_dir, sep=";", index=False, quoting=csv.QUOTE_NONE)