In [None]:
%cd /home/afischer/snakeclef-2024/
from snakeclef.utils import get_spark

# https://knowledge.informatica.com/s/article/000196886?language=en_US
# The vectorized reader will run out of memory (8gb) with the default batch size, so
# this is one way of handling the issue. This is likely due to the fact that the data
# column is so damn big, and treated as binary data instead of something like a string.
# We might also be able to avoid this if we don't cache the fields into memory, but this
# this needs to be validated by hand. 
spark = get_spark(**{
    # "spark.sql.parquet.columnarReaderBatchSize": 512,
    "spark.sql.parquet.enableVectorizedReader": False, 
})

size = 'small' # small, medium, large
gcs_parquet_path = "gs://dsgt-clef-snakeclef-2024/data/parquet_files/"
input_folder = f"SnakeCLEF2023-train-{size}_size/"

df = spark.read.parquet(gcs_parquet_path+input_folder)
df.printSchema()
df.show(1, vertical=True, truncate=False)
df.count()

In [None]:
from transformers import AutoImageProcessor, AutoModel

processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
model = AutoModel.from_pretrained('facebook/dinov2-base')

In [None]:
from pyspark.sql.functions import udf
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.types import BinaryType, ArrayType, FloatType
import io
from PIL import Image
import torch
import pickle

# import torchvision.transforms as transforms

# Define a function to convert binary data to an image, apply DINO, and return binary
def apply_dino(binary_data):
    # Convert binary data to image
    image = Image.open(io.BytesIO(binary_data))
    
    # Transform the image as required by DINO
    inputs = processor(images=image, return_tensors="pt")
    
    # Apply DINO model (assuming model and preprocessing are defined)
    with torch.no_grad():
        # output = model(image.unsqueeze(0))  # Add batch dimension
        outputs = model(**inputs)
        last_hidden_states = outputs['last_hidden_state']

    print('last_hidden_states:', last_hidden_states.shape)
    
    # Convert output to binary, if necessary
    # This step depends on what you want to do with the DINO output
    # output_binary = ...  # Convert the output to binary format if needed

    # buffer = io.BytesIO()
    # torch.save(last_hidden_states, buffer)
    # output_binary = buffer.getvalue()
    # output_binary = bytearray(output_binary)

    numpy_array = last_hidden_states.numpy()

    # Serialize the numpy array into bytes
    output_binary = pickle.dumps(numpy_array)

    print('output_binary:')
    print('output_binary:', type(output_binary))
    print('output_binary:', output_binary)

    return output_binary

# Register the UDF
apply_dino_udf = udf(apply_dino, BinaryType())

# batch prediction UDF
## TODO: Update function; Should recieve data as an np array or dict of numpy arrays
apply_dino_pbudf = predict_batch_udf(
    apply_dino,
    input_tensor_shapes=[[1,257,768]],
    return_type=ArrayType(FloatType()),
    batch_size=16)
)

In [None]:
value = df.select("data").first()[0]
print(value)
output = apply_dino(value)

In [None]:
len(output)
output[:10]

In [None]:
# Apply the UDF to transform images
df_transformed = df.limit(3).withColumn("transformed_image", apply_dino_udf(df["data"]))

In [None]:
df_transformed.printSchema()
df_transformed.show(1, vertical=True, truncate=False)


In [None]:
output_folder = f'DINOv2-embeddings-{size}_size/'
df_transformed.write.mode("overwrite").parquet(gcs_parquet_path+output_folder)

In [None]:
#Check outputs

output_df = spark.read.parquet(gcs_parquet_path+output_folder)
output_df.printSchema()
# output_df.show(1, vertical=True, truncate=True)
output_df.count()

In [None]:
from pyspark.sql.functions import col
output_df.filter(col("transformed_image").isNotNull()).count()