# Grid inference pipeline

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import csv
import torch
import numpy as np
import pandas as pd
from pathlib import Path
from plantclef.utils import get_spark
from pyspark import SparkContext
from pyspark.sql import functions as F
from pyspark.sql.functions import pandas_udf
from google.cloud import storage
from plantclef.baseline.model import LinearClassifier


spark = get_spark()
display(spark)

  from .autonotebook import tqdm as notebook_tqdm
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/05/22 15:18:21 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/05/22 15:18:21 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).


In [32]:
# get dataframes
gcs_path = "gs://dsgt-clef-plantclef-2024"
test_path = "data/process/test_v2/grid_dino_pretrained/data"
cls_emb_train = "data/process/training_cropped_resized_v2/dino_cls_token/data"
default_root_dir = "gs://dsgt-clef-plantclef-2024/models/torch-petastorm-v2-linear-nllloss-pretrained-cls"
# paths to dataframe
test_path = f"{gcs_path}/{test_path}"
cls_gcs_path = f"{gcs_path}/{cls_emb_train}"
# read data
test_df = spark.read.parquet(test_path)
train_df = spark.read.parquet(cls_gcs_path)
# show
test_df.show(n=5, truncate=50)
train_df.show(n=5, truncate=50)

+-------------------------------+------------+--------------------------------------------------+---------+
|                     image_name|patch_number|                                     cls_embedding|sample_id|
+-------------------------------+------------+--------------------------------------------------+---------+
|       CBN-PdlC-B6-20190909.jpg|           7|[0.567645, 1.6921916, -0.29813325, 0.15727657, ...|        6|
|       CBN-PdlC-F1-20200722.jpg|           5|[0.25926074, 2.0660655, -0.5255144, -0.58279055...|        6|
|OPTMix-0598-P4-108-20231207.jpg|           8|[-0.1416022, 0.50105137, -0.94805455, 1.6697326...|        6|
|        CBN-Pla-B6-20130904.jpg|           5|[-0.3612894, 1.0147426, -0.5758655, -1.1154238,...|        6|
|       CBN-PdlC-B5-20190722.jpg|           4|[0.57478356, 1.923769, -0.39382437, -0.6269862,...|        6|
+-------------------------------+------------+--------------------------------------------------+---------+
only showing top 5 rows

+--

### prepare data for inference

In [33]:
limit_species = None
species_image_count = 100


def remap_index_to_species_id(df):
    # Aggregate and filter species based on image count
    grouped_df = (
        df.groupBy("species_id")
        .agg(F.count("species_id").alias("n"))
        .filter(F.col("n") >= species_image_count)
        .orderBy(F.col("n").desc(), F.col("species_id"))
        .withColumn("index", F.monotonically_increasing_id())
    ).drop("n")

    # Use broadcast join to optimize smaller DataFrame joining
    filtered_df = df.join(F.broadcast(grouped_df), "species_id", "inner")

    # Optionally limit the number of species
    if limit_species is not None:
        limited_grouped_df = (
            (
                grouped_df.orderBy(F.rand(seed=42))
                .limit(limit_species)
                .withColumn("new_index", F.monotonically_increasing_id())
            )
            .drop("index")
            .withColumnRenamed("new_index", "index")
        )

        filtered_df = filtered_df.drop("index").join(
            F.broadcast(limited_grouped_df), "species_id", "inner"
        )

    return filtered_df

In [35]:
# remap the indexes to species and get dataframe
filtered_df = remap_index_to_species_id(train_df)

# get parameters for the model
feature_col = "cls_embedding"
num_features = int(len(filtered_df.select(feature_col).first()[feature_col]))
num_classes = int(filtered_df.select("species_id").distinct().count())
print(f"num_features: {num_features}")
print(f"num_classes: {num_classes}")



num_features: 768
num_classes: 4797


                                                                                

In [37]:
def load_model_from_gcs(num_features: int, num_classes: int):
    bucket_name = "dsgt-clef-plantclef-2024"
    relative_path = default_root_dir.split(f"{bucket_name}/")[-1]
    path_in_bucket = f"{relative_path}/checkpoints/last.ckpt"
    client = storage.Client()
    bucket = client.bucket(bucket_name)
    blob = bucket.blob(path_in_bucket)
    blob.download_to_filename("last.ckpt")
    checkpoint = torch.load("last.ckpt", map_location=torch.device("cpu"))

    # Instantiate the model first
    model = LinearClassifier(num_features, num_classes)

    # Adjust the state_dict if necessary
    state_dict = checkpoint["state_dict"]

    # Load the state dictionary
    load_result = model.load_state_dict(state_dict, strict=False)

    if load_result.missing_keys or load_result.unexpected_keys:
        print("Warning: There were missing or unexpected keys during model loading")
        print("Missing keys:", load_result.missing_keys)
        print("Unexpected keys:", load_result.unexpected_keys)

    return model

In [38]:
# Get model
model = load_model_from_gcs(
    num_features=num_features,
    num_classes=num_classes,
)
# Broadcast the model to send to all executors
sc = SparkContext.getOrCreate()
broadcast_model = sc.broadcast(model)

In [39]:
@pandas_udf("long")  # Adjust the return type based on model's output
def predict_udf(dct_embedding_series: pd.Series) -> pd.Series:
    local_model = broadcast_model.value  # Access the broadcast variable
    local_model.eval()  # Set the model to evaluation mode
    # Convert the list of numpy arrays to a single numpy array
    embeddings_array = np.array(list(dct_embedding_series))
    # Convert the numpy array to a PyTorch tensor
    embeddings_tensor = torch.tensor(embeddings_array, dtype=torch.float32)
    # Make predictions
    with torch.no_grad():
        outputs = local_model(embeddings_tensor)
        predicted_classes = outputs.argmax(
            dim=1
        ).numpy()  # Get all predicted classes at once
    return pd.Series(predicted_classes)


# UDF using the CLS token
@pandas_udf("long")
def predict_with_cls_udf(dct_embedding_series: pd.Series) -> pd.Series:
    local_model = broadcast_model.value
    local_model.eval()
    embeddings_array = np.array(list(dct_embedding_series))
    embeddings_tensor = torch.tensor(embeddings_array, dtype=torch.float32)
    with torch.no_grad():
        outputs = local_model(embeddings_tensor)
        cls_token = outputs[:, 0, :]
        predicted_classes = cls_token.argmax(dim=1).numpy()
    return pd.Series(predicted_classes)

In [41]:
# Set CLS parameter
use_cls_token = False

if use_cls_token:
    result_df = test_df.withColumn(
        "predictions", predict_with_cls_udf(test_df[feature_col])
    ).cache()
else:
    # get predictions on test_df
    result_df = test_df.withColumn(
        "predictions", predict_udf(test_df[feature_col])
    ).cache()

result_df.orderBy("image_name", "patch_number").show(n=10, truncate=50)

24/05/22 15:52:27 WARN CacheManager: Asked to cache already cached data.


+------------------------+------------+--------------------------------------------------+---------+-----------+
|              image_name|patch_number|                                     cls_embedding|sample_id|predictions|
+------------------------+------------+--------------------------------------------------+---------+-----------+
|CBN-PdlC-A1-20130807.jpg|           0|[-0.0986969, 1.3777568, 0.050541468, -0.2633565...|        4|       3515|
|CBN-PdlC-A1-20130807.jpg|           1|[0.28588018, 1.14535, -1.9985845, -0.28868702, ...|        4|       2700|
|CBN-PdlC-A1-20130807.jpg|           2|[0.009646882, 0.67857593, -1.010919, -0.3333199...|        4|       2700|
|CBN-PdlC-A1-20130807.jpg|           3|[0.53748274, 0.93665814, -0.38700807, -0.979181...|        4|       3272|
|CBN-PdlC-A1-20130807.jpg|           4|[0.5349963, 0.56628793, -0.26641238, -0.0379398...|        4|       3272|
|CBN-PdlC-A1-20130807.jpg|           5|[-0.03167449, 0.50434786, -1.0075089, 0.7306768...|      

In [42]:
def prepare_dataframe_submission(result_df, filtered_df):
    # Join both dataframes to get species_id
    joined_df = result_df.join(
        F.broadcast(filtered_df),
        result_df.predictions == filtered_df.index,
        "inner",
    ).select(result_df.image_name, "species_id", "index", "predictions")
    # Create columns for submission
    final_df = (
        joined_df.withColumn("plot_id", F.regexp_replace("image_name", "\\.jpg$", ""))
        .groupBy("plot_id")
        .agg(F.collect_set("species_id").alias("species_ids"))
    )
    # Convert the set of species_ids to a formatted string enclosed in single square brackets
    final_df = final_df.withColumn(
        "species_ids",
        F.concat(F.lit("["), F.concat_ws(", ", "species_ids"), F.lit("]")),
    )
    return final_df.cache()


def write_cvs_to_gcs(final_df):
    # Convert Spark DataFrame to Pandas DataFrame
    final_pandas_df = final_df.toPandas()

    # Export to CSV with the specified format
    output_dir = f"{default_root_dir}/experiments/dsgt_run.csv"
    final_pandas_df.to_csv(output_dir, sep=";", index=False, quoting=csv.QUOTE_NONE)

In [45]:
# Join both dataframes to get species_id
joined_df = result_df.join(
    F.broadcast(filtered_df),
    result_df.predictions == filtered_df.index,
    "inner",
).select(result_df.image_name, "patch_number", "species_id", "index", "predictions")

# show
joined_df.orderBy("image_name", "patch_number").show(n=10, truncate=50)

# # Create columns for submission
# final_df = (
#     joined_df.withColumn("plot_id", F.regexp_replace("image_name", "\\.jpg$", ""))
#     .groupBy("plot_id")
#     .agg(F.collect_set("species_id").alias("species_ids"))
# )
# # Convert the set of species_ids to a formatted string enclosed in single square brackets
# final_df = final_df.withColumn(
#     "species_ids",
#     F.concat(F.lit("["), F.concat_ws(", ", "species_ids"), F.lit("]")),
# )
# final_df.cache()

[Stage 206:=====>                                               (16 + 12) / 151]

+------------------------+------------+----------+-----+-----------+
|              image_name|patch_number|species_id|index|predictions|
+------------------------+------------+----------+-----+-----------+
|CBN-PdlC-A1-20130807.jpg|           0|   1394911| 3515|       3515|
|CBN-PdlC-A1-20130807.jpg|           0|   1394911| 3515|       3515|
|CBN-PdlC-A1-20130807.jpg|           0|   1394911| 3515|       3515|
|CBN-PdlC-A1-20130807.jpg|           0|   1394911| 3515|       3515|
|CBN-PdlC-A1-20130807.jpg|           0|   1394911| 3515|       3515|
|CBN-PdlC-A1-20130807.jpg|           0|   1394911| 3515|       3515|
|CBN-PdlC-A1-20130807.jpg|           0|   1394911| 3515|       3515|
|CBN-PdlC-A1-20130807.jpg|           0|   1394911| 3515|       3515|
|CBN-PdlC-A1-20130807.jpg|           0|   1394911| 3515|       3515|
|CBN-PdlC-A1-20130807.jpg|           0|   1394911| 3515|       3515|
+------------------------+------------+----------+-----+-----------+
only showing top 10 rows



                                                                                

In [None]:
# prepare dataframe for submission
final_df = prepare_dataframe_submission(
    result_df=result_df,
    filtered_df=filtered_df,
)
final_df.show(n=10, truncate=50)

# write CSV file to GCS
# write_cvs_to_gcs(final_df=final_df)

In [29]:
# write CSV file to GCS
write_cvs_to_gcs(final_df=final_df)