In [0]:
# Simple CLIP Batch Processing for Delta Tables

import mlflow
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, udf
from pyspark.sql.types import ArrayType, FloatType

In [0]:
# Configuration - Update these values
MODEL_NAME = "autobricks.agriculture.clip_embedding-356"
SOURCE_TABLE = "autobricks.agriculture.crop_images_directory"
INPUT_COLUMN = "image_base64"
INPUT_TYPE = "image"  # or "image"
OUTPUT_TABLE = "autobricks.agriculture.crop_images_directory_embeddings"

In [0]:
# Load model from Unity Catalog
model_uri = f"models:/{MODEL_NAME}/1"
model = mlflow.pyfunc.load_model(model_uri)
print(f"Loaded model: {MODEL_NAME}")

# Create UDF
def get_embedding(input_text):
    if input_text is None:
        return None
    
    input_data = pd.DataFrame({"input_data": [input_text]})
    params = {"input_type": INPUT_TYPE}
    
    try:
        result = model.predict(input_data, params=params)
        return result[0]
    except:
        return None

embedding_udf = udf(get_embedding, ArrayType(FloatType()))

In [0]:
# Process table
df = spark.table(SOURCE_TABLE).limit(10)
print(f"Processing {df.count()} rows")

result_df = df.withColumn("embeddings", embedding_udf(col(INPUT_COLUMN)))

display(result_df)

In [0]:
# Save to Delta table
result_df.write \
         .format("delta") \
         .mode("overwrite") \
         .saveAsTable(OUTPUT_TABLE)

print(f"Results saved to {OUTPUT_TABLE}")