# Register CLIP Text Embedding Model for Databricks Model Serving

This notebook demonstrates how to:
- Create a custom MLflow PyFunc wrapper for CLIP text embeddings
- Follow Databricks embedding model conventions for consistent API usage
- Register the model to Unity Catalog for production use
- Deploy the model to a Databricks Model Serving endpoint

The resulting model endpoint can generate text embeddings using the same API pattern as `databricks-gte-large-en`.

## Import Required Libraries

Import necessary libraries for MLflow model creation and CLIP text embedding handling.

In [0]:
import mlflow
import mlflow.pyfunc
import pandas as pd
import torch
from transformers import CLIPProcessor, CLIPModel

## Dataset Configuration

Configure the model registration parameters for the text embedding model.

In [0]:
# Dataset Configuration - Update these values for your setup
CATALOG_NAME = "autobricks"  # Your Unity Catalog name
SCHEMA_NAME = "agriculture"   # Your schema name
MODEL_NAME = "clip_text_embedding"  # Text embedding model name
ENDPOINT_NAME = "clip-text-embedding"  # Text embedding endpoint name

## Define CLIP Text Embedding Model Class

Create a custom MLflow PyFunc class that wraps CLIP for text embeddings following Databricks conventions.

In [0]:
class CLIPTextEmbedding(mlflow.pyfunc.PythonModel):
    
    def load_context(self, context):
        from transformers import CLIPProcessor, CLIPModel
        
        self.model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
        self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
        
        # Move to GPU if available
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
    
    def predict(self, context, model_input, params=None):
        # Handle both string input and DataFrame input
        if isinstance(model_input, str):
            texts = [model_input]
        elif isinstance(model_input, list):
            texts = model_input
        elif isinstance(model_input, pd.DataFrame):
            # Handle DataFrame input - assume single column or 'input' column
            if 'input' in model_input.columns:
                texts = model_input['input'].tolist()
            else:
                texts = model_input.iloc[:, 0].tolist()
        else:
            raise ValueError(f"Unsupported input type: {type(model_input)}")
        
        embeddings = []
        for text in texts:
            inputs = self.processor(text=[text], return_tensors="pt", padding=True)
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            
            with torch.no_grad():
                text_features = self.model.get_text_features(**inputs)
            
            embedding = text_features.cpu().numpy().tolist()[0]
            embeddings.append(embedding)
        
        return embeddings

## Register Text Embedding Model

Register the CLIP text embedding model to Unity Catalog with proper MLflow tracking.

In [0]:
# Set experiment
mlflow.set_experiment(f"/Shared/clip_text_embedding_experiment")

with mlflow.start_run() as run:
    # Create model instance
    model = CLIPTextEmbedding()
    
    # Create sample input for signature
    sample_input = "This is a sample text for embedding"
    
    # Load model and generate sample output for signature
    model.load_context(None)
    sample_output = model.predict(None, sample_input)
    
    # Create signature
    from mlflow.models import infer_signature
    signature = infer_signature(sample_input, sample_output)
    
    # Define requirements
    requirements = [
        "torch>=1.9.0",
        "transformers>=4.21.0"
    ]
    
    # Log model
    mlflow.pyfunc.log_model(
        artifact_path="clip_text_model",
        python_model=model,
        input_example=sample_input,
        signature=signature,
        pip_requirements=requirements,
        registered_model_name=f"{CATALOG_NAME}.{SCHEMA_NAME}.{MODEL_NAME}"
    )
    
    print(f"Model registered as: {CATALOG_NAME}.{SCHEMA_NAME}.{MODEL_NAME}")

## Create Serving Endpoint

Deploy the registered text embedding model to a Databricks Model Serving endpoint.

In [0]:
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import EndpointCoreConfigInput, ServedEntityInput

w = WorkspaceClient()

# Create serving endpoint
full_model_name = f"{CATALOG_NAME}.{SCHEMA_NAME}.{MODEL_NAME}"

w.serving_endpoints.create(
    name=ENDPOINT_NAME,
    config=EndpointCoreConfigInput(
        served_entities=[
            ServedEntityInput(
                entity_name=full_model_name,
                entity_version="1",
                workload_size="Small",
                scale_to_zero_enabled=True,
            )
        ]
    )
)

print(f"Text embedding endpoint created: {ENDPOINT_NAME}")
print(f"Model: {full_model_name}")

## Test the Model

Test the deployed text embedding model using Databricks standard API patterns.

In [0]:
# Wait for endpoint to be ready, then test
import time

print("Waiting for endpoint to be ready...")
while True:
    try:
        endpoint = w.serving_endpoints.get(ENDPOINT_NAME)
        if endpoint.state and endpoint.state.ready:
            break
        time.sleep(30)
    except:
        time.sleep(30)

print("Endpoint ready, testing...")

# Test with standard Databricks embedding API pattern
try:
    response = w.serving_endpoints.query(
        name=ENDPOINT_NAME,
        input="This is a test sentence for text embedding"
    )
    print(f"Text embedding generated successfully: {len(response)} dimensions")
except Exception as e:
    print(f"Test failed: {e}")

print(f"Text embedding model ready for use: {ENDPOINT_NAME}")