In [0]:
# CLIP Image Embeddings Batch Processing
# This notebook processes the crop_images_directory table to generate CLIP embeddings

import mlflow
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, udf
from pyspark.sql.types import ArrayType, FloatType
import base64
from io import BytesIO
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel
import numpy as np

In [0]:
# Load the registered CLIP model
model_name = "autobricks.agriculture.clip_embedding-356"
model_version = 1  # or specify a specific alias

print(f"Loading model: {model_name}")
loaded_model = mlflow.pyfunc.load_model(f"models:/{model_name}/{model_version}")

In [0]:
# Define UDF for generating embeddings
def generate_clip_embedding(base64_image_str):
    """
    Generate CLIP embedding for a base64 encoded image
    """
    try:
        if base64_image_str is None:
            return None
        
        # Use the loaded model to predict
        model_input = {"input_data": [base64_image_str]}
        params = {"input_type": "image"}
        
        embedding = loaded_model.predict(context=None, model_input=model_input, params=params)
        return embedding[0]  # Return the embedding list
        
    except Exception as e:
        print(f"Error processing image: {str(e)}")
        return None

# Register the UDF
embedding_udf = udf(generate_clip_embedding, ArrayType(FloatType()))

In [0]:
# Read the source table
source_table = "autobricks.agriculture.crop_images_directory"
print(f"Reading source table: {source_table}")

df = spark.table(source_table)

# Show schema and sample data
print("Source table schema:")
df.printSchema()
print("\nSample data:")
display(df.limit(5))

In [0]:
# Add CLIP embeddings column
print("Generating CLIP embeddings...")
df_with_embeddings = df.withColumn("CLIP_embedding", embedding_udf(col("image_base64")))

# Show results
print("Schema with embeddings:")
df_with_embeddings.printSchema()

# Count successful embeddings
total_records = df_with_embeddings.count()
successful_embeddings = df_with_embeddings.filter(col("CLIP_embedding").isNotNull()).count()
print(f"Total records: {total_records}")
print(f"Successful embeddings: {successful_embeddings}")
print(f"Failed embeddings: {total_records - successful_embeddings}")

In [0]:
display(df_with_embeddings.limit(5))

In [0]:
# Define output table name
output_table = "autobricks.agriculture.crop_images_with_embeddings"

# Write to Delta table
print(f"Saving results to Delta table: {output_table}")
df_with_embeddings.write \
    .format("delta") \
    .mode("overwrite") \
    .option("overwriteSchema", "true") \
    .saveAsTable(output_table)

print(f"✅ Successfully created Delta table: {output_table}")

In [0]:
# Verify the results
print("\nVerifying results...")
result_df = spark.table(output_table)
result_df.select("file_name", "folder", "size_bytes", "CLIP_embedding").show(5, truncate=False)

# Show embedding dimensions
sample_embedding = result_df.filter(col("CLIP_embedding").isNotNull()).select("CLIP_embedding").first()
if sample_embedding and sample_embedding["CLIP_embedding"]:
    embedding_dim = len(sample_embedding["CLIP_embedding"])
    print(f"Embedding dimensions: {embedding_dim}")

print("✅ Batch processing complete!")