# Register CLIP Image Embedding Model for Databricks Model Serving

This notebook demonstrates how to:
- Create a custom MLflow PyFunc wrapper for CLIP image embeddings
- Follow Databricks embedding model conventions for consistent API usage
- Register the model to Unity Catalog for production use
- Deploy the model to a Databricks Model Serving endpoint

The resulting model endpoint can generate image embeddings using standard Databricks API patterns.

## Import Required Libraries

Import necessary libraries for MLflow model creation and CLIP image embedding handling.

In [0]:
import mlflow
import mlflow.pyfunc
import pandas as pd
import base64
import torch
from PIL import Image
from io import BytesIO
from transformers import CLIPProcessor, CLIPModel

## Dataset Configuration

Configure the model registration parameters for the image embedding model.

In [0]:
# Dataset Configuration - Update these values for your setup
CATALOG_NAME = "autobricks"  # Your Unity Catalog name
SCHEMA_NAME = "agriculture"   # Your schema name
MODEL_NAME = "clip_image_embedding"  # Image embedding model name
ENDPOINT_NAME = "clip-image-embedding"  # Image embedding endpoint name

## Define CLIP Image Embedding Model Class

Create a custom MLflow PyFunc class that wraps CLIP for image embeddings following Databricks conventions.

In [0]:
class CLIPImageEmbedding(mlflow.pyfunc.PythonModel):
    
    def load_context(self, context):
        from transformers import CLIPProcessor, CLIPModel
        
        self.model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
        self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
        
        # Move to GPU if available
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
    
    def _process_image(self, image_input):
        """Process various image input formats to PIL Image."""
        if isinstance(image_input, str):
            # Assume base64 encoded image
            try:
                decoded_bytes = base64.b64decode(image_input)
                image = Image.open(BytesIO(decoded_bytes))
            except Exception:
                raise ValueError("Invalid base64 image string")
        elif isinstance(image_input, bytes):
            # Raw bytes
            image = Image.open(BytesIO(image_input))
        elif hasattr(image_input, 'read'):
            # File-like object
            image = Image.open(image_input)
        else:
            raise ValueError(f"Unsupported image input type: {type(image_input)}")
        
        # Convert to RGB if needed
        if image.mode != 'RGB':
            image = image.convert('RGB')
            
        return image
    
    def predict(self, context, model_input, params=None):
        # Handle different input formats
        if isinstance(model_input, str):
            # Single base64 image string
            images = [model_input]
        elif isinstance(model_input, list):
            # List of base64 image strings
            images = model_input
        elif isinstance(model_input, pd.DataFrame):
            # DataFrame with image data
            if 'input' in model_input.columns:
                images = model_input['input'].tolist()
            else:
                images = model_input.iloc[:, 0].tolist()
        else:
            raise ValueError(f"Unsupported input type: {type(model_input)}")
        
        embeddings = []
        for image_input in images:
            # Process image
            image = self._process_image(image_input)
            
            # Generate embedding
            inputs = self.processor(images=image, return_tensors="pt")
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            
            with torch.no_grad():
                image_features = self.model.get_image_features(**inputs)
            
            embedding = image_features.cpu().numpy().tolist()[0]
            embeddings.append(embedding)
        
        return embeddings

## Register Image Embedding Model

Register the CLIP image embedding model to Unity Catalog with proper MLflow tracking.

In [0]:
# Set experiment
mlflow.set_experiment(f"/Shared/clip_image_embedding_experiment")

with mlflow.start_run() as run:
    # Create model instance
    model = CLIPImageEmbedding()
    
    # Create sample base64 image for signature (1x1 white pixel)
    from PIL import Image
    import io
    
    # Create a simple test image
    test_image = Image.new('RGB', (224, 224), color='white')
    buffer = io.BytesIO()
    test_image.save(buffer, format='JPEG')
    sample_input = base64.b64encode(buffer.getvalue()).decode('utf-8')
    
    # Load model and generate sample output for signature
    model.load_context(None)
    sample_output = model.predict(None, sample_input)
    
    # Create signature
    from mlflow.models import infer_signature
    signature = infer_signature(sample_input, sample_output)
    
    # Define requirements
    requirements = [
        "torch>=1.9.0",
        "transformers>=4.21.0",
        "Pillow>=8.3.0"
    ]
    
    # Log model
    mlflow.pyfunc.log_model(
        artifact_path="clip_image_model",
        python_model=model,
        input_example=sample_input,
        signature=signature,
        pip_requirements=requirements,
        registered_model_name=f"{CATALOG_NAME}.{SCHEMA_NAME}.{MODEL_NAME}"
    )
    
    print(f"Model registered as: {CATALOG_NAME}.{SCHEMA_NAME}.{MODEL_NAME}")

## Create Serving Endpoint

Deploy the registered image embedding model to a Databricks Model Serving endpoint.

In [0]:
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import EndpointCoreConfigInput, ServedEntityInput

w = WorkspaceClient()

# Create serving endpoint
full_model_name = f"{CATALOG_NAME}.{SCHEMA_NAME}.{MODEL_NAME}"

w.serving_endpoints.create(
    name=ENDPOINT_NAME,
    config=EndpointCoreConfigInput(
        served_entities=[
            ServedEntityInput(
                entity_name=full_model_name,
                entity_version="1",
                workload_size="Small",
                scale_to_zero_enabled=True,
            )
        ]
    )
)

print(f"Image embedding endpoint created: {ENDPOINT_NAME}")
print(f"Model: {full_model_name}")

## Test the Model

Test the deployed image embedding model using Databricks standard API patterns.

In [0]:
# Wait for endpoint to be ready, then test
import time

print("Waiting for endpoint to be ready...")
while True:
    try:
        endpoint = w.serving_endpoints.get(ENDPOINT_NAME)
        if endpoint.state and endpoint.state.ready:
            break
        time.sleep(30)
    except:
        time.sleep(30)

print("Endpoint ready, testing...")

# Test with standard Databricks embedding API pattern using sample image
try:
    # Create a test image
    test_image = Image.new('RGB', (224, 224), color='blue')
    buffer = io.BytesIO()
    test_image.save(buffer, format='JPEG')
    test_image_b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
    
    response = w.serving_endpoints.query(
        name=ENDPOINT_NAME,
        input=test_image_b64
    )
    print(f"Image embedding generated successfully: {len(response)} dimensions")
except Exception as e:
    print(f"Test failed: {e}")

print(f"Image embedding model ready for use: {ENDPOINT_NAME}")