In [None]:

%cd /home/teresakim/fungiclef-2024/
from fungiclef.utils import get_spark

# https://knowledge.informatica.com/s/article/000196886?language=en_US
# The vectorized reader will run out of memory (8gb) with the default batch size, so
# this is one way of handling the issue. This is likely due to the fact that the data
# column is so damn big, and treated as binary data instead of something like a string.
# We might also be able to avoid this if we don't cache the fields into memory, but this
# this needs to be validated by hand. 
spark = get_spark(**{
    # "spark.sql.parquet.columnarReaderBatchSize": 512,
    "spark.sql.parquet.enableVectorizedReader": False, 
})

gcs_parquet_path = "gs://dsgt-clef-fungiclef-2024/data/parquet/DF20_300px/"

df = spark.read.parquet(gcs_parquet_path)
df.printSchema()
df.show(1, vertical=True, truncate=False)
df.count() # 295938

In [None]:
from torchvision import transforms
from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, FloatType
from PIL import Image
import io

# define a UDF to convert binary to PIL image, apply transformations, and return a flattened image tensor
def preprocess_image(data):
    image = Image.open(io.BytesIO(data)).convert('RGB')
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    image_tensor = transform(image)
    # flatten the image tensor to return a 1D array
    return image_tensor.numpy().flatten().tolist()

preprocess_udf = udf(preprocess_image, ArrayType(FloatType()))

# apply UDF
df = df.withColumn("processed_image", preprocess_udf(df["data"]))
