# Pretrained DinoV2 model

In [1]:
%load_ext autoreload
%autoreload 2

Make directory to store the models

In [2]:
# ! mkdir -p /mnt/data/models
# %cd /mnt/data/models

/mnt/data/models


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


copy the file from GCS to `/mnt/data/models` directory

In [None]:
# ! gsutil cp gs://dsgt-clef-plantclef-2024/data/models/PlantNet_PlantCLEF2024_pretrained_models_on_the_flora_of_south-western_europe.tar /mnt/data/models/

Unzip the `.tar` file

In [None]:
# ! tar -xvf /mnt/data/models/PlantNet_PlantCLEF2024_pretrained_models_on_the_flora_of_south-western_europe.tar -C /mnt/data/models

use pretrained model for inference

In [None]:
from argparse import ArgumentParser
import pandas as pd
from urllib.request import urlopen
from PIL import Image
import timm
import torch


def load_class_mapping(class_list_file):
    with open(class_list_file) as f:
        class_index_to_class_name = {i: line.strip() for i, line in enumerate(f)}
    return class_index_to_class_name


def load_species_mapping(species_map_file):
    df = pd.read_csv(species_map_file, sep=";", quoting=1, dtype={"species_id": str})
    df = df.set_index("species_id")
    return df["species"].to_dict()


def main(image_url, class_mapping, species_mapping, pretrained_path):
    cid_to_spid = load_class_mapping(class_mapping)
    spid_to_sp = load_species_mapping(species_mapping)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = timm.create_model(
        "vit_base_patch14_reg4_dinov2.lvd142m",
        pretrained=False,
        num_classes=len(cid_to_spid),
        checkpoint_path=pretrained_path,
    )
    model = model.to(device)
    model = model.eval()

    # get model specific transforms (normalization, resize)
    data_config = timm.data.resolve_model_data_config(model)
    transforms = timm.data.create_transform(**data_config, is_training=False)

    img = None
    if "https://" in image_url or "http://" in image_url:
        img = Image.open(urlopen(image_url))
    elif image_url != None:
        img = Image.open(image_url)

    if img != None:
        img = transforms(img).unsqueeze(0)
        img = img.to(device)
        output = model(img)  # unsqueeze single image into batch of 1
        top5_probabilities, top5_class_indices = torch.topk(
            output.softmax(dim=1) * 100, k=5
        )
        top5_probabilities = top5_probabilities.cpu().detach().numpy()
        top5_class_indices = top5_class_indices.cpu().detach().numpy()

        for proba, cid in zip(top5_probabilities[0], top5_class_indices[0]):
            species_id = cid_to_spid[cid]
            species = spid_to_sp[species_id]
            print(species_id, species, proba)

In [None]:
# Define your paths and image URL directly
path = "/mnt/data/models/pretrained_models"
image_url = "https://lab.plantnet.org/LifeCLEF/PlantCLEF2024/single_plant_training_data/PlantCLEF2024singleplanttrainingdata/test/1361687/bd2d3830ac3270218ba82fd24e2290becd01317c.jpg"
class_mapping_file = f"{path}/class_mapping.txt"
species_mapping_file = f"{path}/species_id_to_name.txt"
model_path = "vit_base_patch14_reg4_dinov2_lvd142m_pc24_onlyclassifier_then_all"
pretrained_path = f"{path}/{model_path}/model_best.pth.tar"

main(image_url, class_mapping_file, species_mapping_file, pretrained_path)

In [None]:
from plantclef.utils import get_spark
from pyspark.sql import functions as F

spark = get_spark()
display(spark)

In [None]:
# get dataframes
gcs_path = "gs://dsgt-clef-plantclef-2024"
test_data_path = "data/parquet_files/PlantCLEF2024_test"

# paths to dataframe
test_path = f"{gcs_path}/{test_data_path}"
# read data
test_df = spark.read.parquet(test_path)
# show
test_df.show(n=5, truncate=50)

In [None]:
import io

import numpy as np
import timm
import torch
from PIL import Image
from pyspark.sql import functions as F
from pyspark.sql.types import ArrayType, FloatType, MapType, StringType
from pyspark.ml import Transformer
from pyspark.ml.param.shared import HasInputCol, HasOutputCol
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
from plantclef.model_setup import setup_pretrained_model
from pyspark.sql import DataFrame
from pyspark.ml.functions import vector_to_array
from pyspark.ml import Pipeline
from pyspark.ml.feature import SQLTransformer

In [None]:
class PretrainedDinoV2(
    Transformer,
    HasInputCol,
    HasOutputCol,
    DefaultParamsReadable,
    DefaultParamsWritable,
):
    def __init__(
        self,
        pretrained_path: str,
        input_col: str = "input",
        output_col: str = "output",
        model_name: str = "vit_base_patch14_reg4_dinov2.lvd142m",
        batch_size: int = 8,
    ):
        super().__init__()
        self._setDefault(inputCol=input_col, outputCol=output_col)
        self.model_name = model_name
        self.batch_size = batch_size
        self.pretrained_path = pretrained_path
        self.num_classes = 7806  # total number of plant species
        self.local_directory = "/mnt/data/models/pretrained_models"
        self.class_mapping_file = f"{self.local_directory}/class_mapping.txt"
        # Model
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = timm.create_model(
            self.model_name,
            pretrained=False,
            num_classes=self.num_classes,
            checkpoint_path=self.pretrained_path,
        )
        self.model.to(self.device)
        self.model.eval()
        # Data transform
        self.data_config = timm.data.resolve_model_data_config(self.model)
        self.transforms = timm.data.create_transform(
            **self.data_config, is_training=False
        )
        self.sql_statement = "SELECT image_name, dino_logits FROM __THIS__"

    def _load_class_mapping(self):
        with open(self.class_mapping_file) as f:
            class_index_to_class_name = {i: line.strip() for i, line in enumerate(f)}
        return class_index_to_class_name

    def _make_predict_fn(self):
        """Return PredictBatchFunction using a closure over the model"""
        self.cid_to_spid = self._load_class_mapping()

        def predict(inputs: np.ndarray) -> np.ndarray:
            batch_results = []
            for i, input in enumerate(inputs):
                print(f"Item {i} type: {type(input)}")  # Check the type of the input
                if not isinstance(input, bytes):
                    print("Error: Input is not bytes.")
                    batch_results.append({})
                    continue

                try:
                    image = Image.open(io.BytesIO(input))
                    processed_image = self.transforms(image).unsqueeze(0)
                    batch_input = torch.cat([processed_image]).to(self.device)

                    with torch.no_grad():
                        outputs = self.model(batch_input)
                        probabilities = torch.softmax(outputs, dim=1) * 100
                        top_probs, top_indices = torch.topk(probabilities, k=20)

                    top_probs = top_probs.cpu().numpy()
                    top_indices = top_indices.cpu().numpy()

                    # Convert top indices and probabilities to a dictionary
                    result = {
                        self.cid_to_spid.get(index, "Unknown"): float(prob)
                        for index, prob in zip(
                            top_indices.flatten(), top_probs.flatten()
                        )
                    }
                    batch_results.append(result)

                except Exception as e:
                    print(f"Failed to process input due to: {str(e)}")
                    batch_results.append({})

            return pd.Series(batch_results)

        return predict

    def _transform(self, df):
        print(f"df schema: {df.schema}")
        predict_udf = F.udf(
            self._make_predict_fn(), ArrayType(MapType(StringType(), FloatType()))
        )
        return df.withColumn(self.getOutputCol(), predict_udf(df[self.getInputCol()]))

    def transform(self, df) -> DataFrame:
        transformed = self._transform(df)

        for c in self.feature_columns:
            # check if the feature is a vector and convert it to an array
            if "array" in transformed.schema[c].simpleString():
                continue
            transformed = transformed.withColumn(c, vector_to_array(F.col(c)))
        return transformed

    @property
    def feature_columns(self) -> list:
        return ["dino_logits"]

    def pipeline(self):
        return Pipeline(stages=[self, SQLTransformer(statement=self.sql_statement)])

    def run(self, df: DataFrame) -> DataFrame:
        model = self.pipeline().fit(df)
        transformed = model.transform(df)

        return transformed

In [None]:
pretrained_path = setup_pretrained_model()
pretrained_dino = PretrainedDinoV2(
    pretrained_path=pretrained_path,
    input_col="data",
    output_col="dino_logits",
)

In [None]:
transformed_df = pretrained_dino.run(df=test_df)