# Image to Embedding using DINOv2

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from plantclef.utils import get_spark

spark = get_spark(cores=4, memory="8g")
display(spark)

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/03/31 19:29:11 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/03/31 19:29:11 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).


In [3]:
# Get list of stored filed in cloud bucket
root = "gs://dsgt-clef-plantclef-2024"
! date
! gcloud storage ls {root}/data/parquet_files

Sun Mar 31 19:29:14 UTC 2024
gs://dsgt-clef-plantclef-2024/data/parquet_files/
gs://dsgt-clef-plantclef-2024/data/parquet_files/PlantCLEF2022_web_training_images_1/
gs://dsgt-clef-plantclef-2024/data/parquet_files/PlantCLEF2022_web_training_images_4/
gs://dsgt-clef-plantclef-2024/data/parquet_files/PlantCLEF2024_test/
gs://dsgt-clef-plantclef-2024/data/parquet_files/PlantCLEF2024_training/
gs://dsgt-clef-plantclef-2024/data/parquet_files/PlantCLEF2024_training_cropped_resized/
gs://dsgt-clef-plantclef-2024/data/parquet_files/PlantCLEF2024_training_cropped_resized_v2/


In [4]:
# Path and dataset names
gcs_path = "gs://dsgt-clef-plantclef-2024/data/parquet_files"
train = "PlantCLEF2024_training_cropped_resized_v2"

# Define the GCS path to the Train parquet file
train_gcs_path = f"{gcs_path}/{train}"

# Read the Parquet file into a DataFrame
train_df = spark.read.parquet(train_gcs_path)

# Show the data (for example, first few rows)
train_df.show(n=3)

                                                                                

+--------------------+--------------------+------+----------+----------+--------------------+-------+--------------------+--------+-------------+-------------+---------------+--------------------+----------+-------------+--------+-----------+--------------------+--------------------+---------+--------------------+--------------------+
|          image_name|                path| organ|species_id|    obs_id|             license|partner|              author|altitude|     latitude|    longitude|gbif_species_id|             species|     genus|       family| dataset|  publisher|          references|                 url|learn_tag|    image_backup_url|                data|
+--------------------+--------------------+------+----------+----------+--------------------+-------+--------------------+--------+-------------+-------------+---------------+--------------------+----------+-------------+--------+-----------+--------------------+--------------------+---------+--------------------+-----------

### pipeline

In [5]:
from pyspark.sql import functions as F
from pyspark.ml import Pipeline
from plantclef.transforms import WrappedDinoV2, DCTN

# Get subset of images to test pipeline
train100_df = (
    train_df.where(F.col("species_id").isin([1361703, 1355927]))
    .orderBy(F.rand(1000))
    .limit(200)
    .cache()
)

# Init DINOv2 wrapper
dino = WrappedDinoV2(input_col="data", output_col="transformed_data")

# Init Descrite Cosine Transform wrapper
dctn = DCTN(input_col="transformed_data", output_col="dctn_data")

# Create Pipeline
pipeline = Pipeline(stages=[dino, dctn])

# Fit pipeline to DF
model = pipeline.fit(train100_df)

# Apply the model to transform the DF
transformed_df = model.transform(train100_df).cache()

# Show results
transformed_df.select(["image_name", "species_id", "dctn_data"]).show(n=10)

[Stage 5:>                                                          (0 + 1) / 1]

+--------------------+----------+--------------------+
|          image_name|species_id|           dctn_data|
+--------------------+----------+--------------------+
|b4660ccafa8567718...|   1361703|[-15065.38, -446....|
|d61f7d5ba5a3554cb...|   1361703|[-20283.373, 3034...|
|24d80097e70d5f914...|   1361703|[-17341.75, -1007...|
|d4dc1b782195687c2...|   1355927|[-26710.031, -457...|
|9c51869fcb57794f9...|   1361703|[-28332.748, -205...|
|e6c09450ef071b82b...|   1355927|[-30168.426, 1062...|
|bddb8be8e7927aa58...|   1361703|[-29858.95, 31081...|
|5937bba593427705f...|   1355927|[-24212.04, 17660...|
|915978130f13a5fc2...|   1355927|[-25134.674, -847...|
|7e91e5a4c7780887e...|   1355927|[-29754.826, 5821...|
+--------------------+----------+--------------------+
only showing top 10 rows



                                                                                

In [6]:
transformed_df.groupBy("species_id").count().show()

+----------+-----+
|species_id|count|
+----------+-----+
|   1361703|  108|
|   1355927|   92|
+----------+-----+



### build a classifier

In [None]:
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.functions import array_to_vector

# Load training data
training = transformed_df.select(
    array_to_vector("dctn_data").alias("dctn_data"), "image_name", "species_id"
)

# Create pipeline
lr_pipe = Pipeline(
    stages=[
        LogisticRegression(
            featuresCol="dctn_data",
            labelCol="species_id",
            maxIter=10,
            regParam=0.3,
            elasticNetParam=0.8,
        ),
    ]
)

# Fit the model
lrModel = lr_pipe.fit(training)

# Print the coefficients and intercept for multinomial logistic regression
print("Coefficients: \n" + str(lrModel.coefficientMatrix))
print("Intercept: " + str(lrModel.interceptVector))

trainingSummary = lrModel.summary

# Obtain the objective per iteration
objectiveHistory = trainingSummary.objectiveHistory
print("objectiveHistory:")
for objective in objectiveHistory:
    print(objective)

# for multiclass, we can inspect metrics on a per-label basis
print("False positive rate by label:")
for i, rate in enumerate(trainingSummary.falsePositiveRateByLabel):
    print("label %d: %s" % (i, rate))

print("True positive rate by label:")
for i, rate in enumerate(trainingSummary.truePositiveRateByLabel):
    print("label %d: %s" % (i, rate))

print("Precision by label:")
for i, prec in enumerate(trainingSummary.precisionByLabel):
    print("label %d: %s" % (i, prec))

print("Recall by label:")
for i, rec in enumerate(trainingSummary.recallByLabel):
    print("label %d: %s" % (i, rec))

print("F-measure by label:")
for i, f in enumerate(trainingSummary.fMeasureByLabel()):
    print("label %d: %s" % (i, f))

accuracy = trainingSummary.accuracy
falsePositiveRate = trainingSummary.weightedFalsePositiveRate
truePositiveRate = trainingSummary.weightedTruePositiveRate
fMeasure = trainingSummary.weightedFMeasure()
precision = trainingSummary.weightedPrecision
recall = trainingSummary.weightedRecall
print(
    "Accuracy: %s\nFPR: %s\nTPR: %s\nF-measure: %s\nPrecision: %s\nRecall: %s"
    % (accuracy, falsePositiveRate, truePositiveRate, fMeasure, precision, recall)
)