# Inference pipeline

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path
from pyspark.sql import DataFrame
from plantclef.utils import get_spark
from pyspark.sql import functions as F


spark = get_spark()
display(spark)

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/04/22 00:02:34 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/04/22 00:02:35 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).


In [4]:
# paths
gcs_path = "gs://dsgt-clef-plantclef-2024"
test_path = "data/parquet_files/PlantCLEF2024_test"
# data module parameters
input_path = f"{gcs_path}/{test_path}"
# read data
test_df = spark.read.parquet(input_path)
# show data
test_df.show(n=5, truncate=50)

                                                                                

+-------------------------------------------+------------------------+--------------------------------------------------+
|                                       path|              image_name|                                              data|
+-------------------------------------------+------------------------+--------------------------------------------------+
| /PlantCLEF2024test/CBN-Pla-B4-20160728.jpg| CBN-Pla-B4-20160728.jpg|[FF D8 FF E0 00 10 4A 46 49 46 00 01 01 01 00 4...|
| /PlantCLEF2024test/CBN-Pla-D3-20130808.jpg| CBN-Pla-D3-20130808.jpg|[FF D8 FF E0 00 10 4A 46 49 46 00 01 01 01 00 4...|
|/PlantCLEF2024test/CBN-PdlC-E4-20150701.jpg|CBN-PdlC-E4-20150701.jpg|[FF D8 FF E0 00 10 4A 46 49 46 00 01 01 01 00 4...|
| /PlantCLEF2024test/CBN-Pla-F5-20150901.jpg| CBN-Pla-F5-20150901.jpg|[FF D8 FF E0 00 10 4A 46 49 46 00 01 01 01 00 4...|
| /PlantCLEF2024test/CBN-Pla-D1-20180724.jpg| CBN-Pla-D1-20180724.jpg|[FF D8 FF E0 00 10 4A 46 49 46 00 01 01 01 00 4...|
+-----------------------

In [30]:
import torch
from google.cloud import storage
from plantclef.baseline.model import LinearClassifier
from pytorch_lightning import Trainer


def load_model_from_gcs(
    bucket_name: str, path_in_bucket: str, num_features: int, num_classes: int
):
    client = storage.Client()
    bucket = client.bucket(bucket_name)
    blob = bucket.blob(path_in_bucket)
    # Temporarily save the file locally
    blob.download_to_filename("last.ckpt")
    # Load checkpoint and manually adjust the keys
    checkpoint = torch.load("last.ckpt", map_location=lambda storage, loc: storage)
    state_dict = checkpoint["state_dict"]
    adjusted_state_dict = {
        k.replace("layer.", "model."): v for k, v in state_dict.items()
    }

    # Load the model with the adjusted state dict
    model = LinearClassifier(num_features, num_classes).load_state_dict(
        adjusted_state_dict
    )
    return model


# model path
bucket_path = gcs_path.split("gs://")[1]
model_path = "models/torch-petastorm-v1/plantclef-2024/u1p43hb2/checkpoints/last.ckpt"
# get model
model = load_model_from_gcs(bucket_path, model_path, 64, 4797)

In [28]:
# Load the checkpoint file
checkpoint = torch.load("last.ckpt", map_location=torch.device("cpu"))

# You can inspect keys to see what's inside the checkpoint
print(checkpoint.keys())

# Check if 'hyper_parameters' exists in the checkpoint
if "hyper_parameters" in checkpoint:
    hyperparameters = checkpoint["hyper_parameters"]
    print("Hyperparameters:", hyperparameters)
else:
    print("No hyperparameters found in this checkpoint.")

dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'hparams_name', 'hyper_parameters'])
Hyperparameters: {'num_features': 64, 'num_classes': 4797}


### prepare data for inference

In [None]:
# Path and dataset names
dct_emb_train = "data/process/training_cropped_resized_v2/dino_dct/data"
dct_gcs_path = f"{gcs_path}/{dct_emb_train}"
dct_df = spark.read.parquet(dct_gcs_path)


def remap_index_to_species_id(df, species_image_count: int = 100):
    # Aggregate and filter species based on image count
    grouped_df = (
        df.groupBy("species_id")
        .agg(F.count("species_id").alias("n"))
        .filter(F.col("n") >= species_image_count)
        .orderBy(F.col("n").desc(), F.col("species_id"))
        .withColumn("index", F.monotonically_increasing_id())
    ).drop("n")

    # Use broadcast join to optimize smaller DataFrame joining
    filtered_df = df.join(F.broadcast(grouped_df), "species_id", "inner")
    return filtered_df