In [None]:

%cd /home/teresakim/fungiclef-2024/
from fungiclef.utils import get_spark

# https://knowledge.informatica.com/s/article/000196886?language=en_US
# The vectorized reader will run out of memory (8gb) with the default batch size, so
# this is one way of handling the issue. This is likely due to the fact that the data
# column is so damn big, and treated as binary data instead of something like a string.
# We might also be able to avoid this if we don't cache the fields into memory, but this
# this needs to be validated by hand. 
spark = get_spark(**{
    # "spark.sql.parquet.columnarReaderBatchSize": 512,
    "spark.sql.parquet.enableVectorizedReader": False, 
})

gcs_parquet_path = "gs://dsgt-clef-fungiclef-2024/data/parquet/DF20_300px/"

df = spark.read.parquet(gcs_parquet_path)
# df.printSchema()
# df.show(1, vertical=True, truncate=False)
# df.count() # 295938

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, FloatType
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import torch
import io

# load pretrained model
model = models.resnet18(pretrained=True).eval()

In [None]:
# transformation for image preprocessing
transform = transforms.Compose([
    transforms.Resize(256), # 256x256 px
    transforms.CenterCrop(224), # going with common crop size, haven't messed with it
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # normalization formula: (input[channel] - mean[channel]) / std[channel]
])

In [None]:
# UDF to convert binary image data to embedding
def binary_image_to_embedding(image_data):
    img = Image.open(io.BytesIO(image_data))
    img_t = transform(img)
    batch_t = torch.unsqueeze(img_t, 0)
    with torch.no_grad():
        embedding = model(batch_t)
   
    return embedding.cpu().numpy().flatten().tolist()

In [None]:
# register UDF with the appropriate return type
binary_to_embedding_udf = udf(binary_image_to_embedding, ArrayType(FloatType()))

In [None]:
# apply UDF
df_with_embeddings = df.withColumn("embeddings", binary_to_embedding_udf(df["data"]))


In [None]:
df_with_embeddings.show()

gcs_embedding_path = "gs://dsgt-clef-fungiclef-2024/data/parquet/DF20_300px_resnet18/"

df_with_embeddings.write.mode("overwrite").parquet(gcs_embedding_path)