In [1]:
%cd /home/afischer/snakeclef-2024/
from snakeclef.utils import get_spark

# https://knowledge.informatica.com/s/article/000196886?language=en_US
# The vectorized reader will run out of memory (8gb) with the default batch size, so
# this is one way of handling the issue. This is likely due to the fact that the data
# column is so damn big, and treated as binary data instead of something like a string.
# We might also be able to avoid this if we don't cache the fields into memory, but this
# this needs to be validated by hand. 
spark = get_spark(**{
    # "spark.sql.parquet.columnarReaderBatchSize": 512,
    "spark.sql.parquet.enableVectorizedReader": False, 
})

size = 'small' # small, medium, large
gcs_parquet_path = "gs://dsgt-clef-snakeclef-2024/data/parquet_files/"
input_folder = f"SnakeCLEF2023-train-{size}_size/"

df = spark.read.parquet(gcs_parquet_path+input_folder)
df.printSchema()
# df.show(1, vertical=True, truncate=True)
df.count()

  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


/home/afischer/snakeclef-2024


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/02/21 00:00:15 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/02/21 00:00:16 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).
                                                                                

root
 |-- image_path: string (nullable = true)
 |-- path: string (nullable = true)
 |-- folder_name: string (nullable = true)
 |-- year: string (nullable = true)
 |-- binomial_name: string (nullable = true)
 |-- file_name: string (nullable = true)
 |-- data: binary (nullable = true)
 |-- observation_id: integer (nullable = true)
 |-- endemic: boolean (nullable = true)
 |-- code: string (nullable = true)
 |-- class_id: integer (nullable = true)
 |-- subset: string (nullable = true)



                                                                                

68495

## Batch

In [2]:
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.types import BinaryType, ArrayType, FloatType
import io
from PIL import Image
import torch
import numpy as np

def make_predict_fn():
    """Return PredictBatchFunction"""
    from transformers import AutoImageProcessor, AutoModel
    processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
    model = AutoModel.from_pretrained('facebook/dinov2-base')

    def predict(inputs: np.ndarray) -> np.ndarray:
        # print('inputs:')
        # print(type(inputs))
        # print(inputs.shape)

        images = [Image.open(io.BytesIO(input)) for input in inputs]
        # print('images:')
        # print(type(images))
        # print(images)

        model_inputs = processor(images=images, return_tensors="pt")
        # print('model_inputs:')
        # print(type(model_inputs))

        with torch.no_grad():
            # print('start modeling')
            outputs = model(**model_inputs)
            # print('outputs')
            # print(outputs)
            last_hidden_states = outputs.last_hidden_state
        
        # print('last_hidden_states:')
        # print(type(last_hidden_states))
        # print(last_hidden_states.shape)

        numpy_array = last_hidden_states.numpy()
        # Reshape the array
        new_shape = numpy_array.shape[:-2] + (-1,)
        numpy_array = numpy_array.reshape(new_shape)

        # print('numpy_array:')
        # print(type(numpy_array))
        # print(numpy_array.shape)

        return numpy_array

    return predict
    
# batch prediction UDF
apply_dino_pbudf = predict_batch_udf(
    make_predict_fn = make_predict_fn,
    return_type=ArrayType(FloatType()),
    batch_size=8
)

In [3]:
# Apply the UDF to transform images
df_transformed = df.limit(24).withColumn("transformed_image", apply_dino_pbudf(df["data"]))

df_transformed.printSchema()
# df_transformed.show(1, vertical=True)

root
 |-- image_path: string (nullable = true)
 |-- path: string (nullable = true)
 |-- folder_name: string (nullable = true)
 |-- year: string (nullable = true)
 |-- binomial_name: string (nullable = true)
 |-- file_name: string (nullable = true)
 |-- data: binary (nullable = true)
 |-- observation_id: integer (nullable = true)
 |-- endemic: boolean (nullable = true)
 |-- code: string (nullable = true)
 |-- class_id: integer (nullable = true)
 |-- subset: string (nullable = true)
 |-- transformed_image: array (nullable = true)
 |    |-- element: float (containsNull = true)



In [4]:
output_folder = f'DINOv2-embeddings-{size}_size/'
df_transformed.write.mode("overwrite").parquet(gcs_parquet_path+output_folder)

                                                                                

## Check Outputs

In [5]:
#Check outputs

output_df = spark.read.parquet(gcs_parquet_path+output_folder)
output_df.printSchema()
# output_df.show(1, vertical=True, truncate=True)
output_df.count()

root
 |-- image_path: string (nullable = true)
 |-- path: string (nullable = true)
 |-- folder_name: string (nullable = true)
 |-- year: string (nullable = true)
 |-- binomial_name: string (nullable = true)
 |-- file_name: string (nullable = true)
 |-- data: binary (nullable = true)
 |-- observation_id: integer (nullable = true)
 |-- endemic: boolean (nullable = true)
 |-- code: string (nullable = true)
 |-- class_id: integer (nullable = true)
 |-- subset: string (nullable = true)
 |-- transformed_image: array (nullable = true)
 |    |-- element: float (containsNull = true)



24

In [6]:
from pyspark.sql.functions import col
output_df.filter(col("transformed_image").isNotNull()).count()

                                                                                

24

In [7]:
output_df.show(1, vertical=True)

                                                                                

-RECORD 0---------------------------------
 image_path        | 1993/Phrynonax_po... 
 path              | /SnakeCLEF2023-sm... 
 folder_name       | SnakeCLEF2023-sma... 
 year              | 1993                 
 binomial_name     | Phrynonax_polylepis  
 file_name         | 102870166.jpg        
 data              | [FF D8 FF E0 00 1... 
 observation_id    | 64030606             
 endemic           | false                
 code              | EC                   
 class_id          | 1287                 
 subset            | train                
 transformed_image | [2.94488, -1.4730... 
only showing top 1 row

