# Pushing HuggingFace models to Databricks Unity Catalog

A simple notebook to test loading and storing an arbitrary HuggingFace model into Unity Catalog.

Run on a DBR 16.3 ML LTS cluster.

Test with a random HugginFace model. We use `microsoft/DialoGPT-small`

First, we load the pipeline and make sure it works.

In [0]:
from transformers import pipeline

# Load the conversational pipeline
chatbot = pipeline("text-generation", model="microsoft/DialoGPT-small")

# Start a conversation
response = chatbot("Hello, who are you?")

# Print the response
print("Bot:", response[0]['generated_text'])


Next, let's register the model in Unity Catalog

Fill in the TODO sections with details relevant to your environment

In [0]:
import mlflow

# TODO - add your Unity Catalog and schema here
CATALOG = "..."
SCHEMA = "..."
MODEL_NAME = "..."


fq_model_name = ".".join([CATALOG, SCHEMA, MODEL_NAME])


mlflow.set_registry_uri("databricks-uc")

# Log the pipeline
with mlflow.start_run():
    model_info = mlflow.transformers.log_model(
        transformers_model=chatbot,
        artifact_path="chatbot",
        task="conversational",
        input_example="A clever and witty question",
        registered_model_name=fq_model_name
    )



Finally, let's create a serving endpoint

In [0]:
from mlflow.deployments import get_deploy_client

# Add your model version and endpoint name here
MODEL_VERSION = 1
ENDPOINT_NAME = "..."

model_name_version = "-".join([MODEL_NAME, str(MODEL_VERSION)])

client = get_deploy_client("databricks")

# Define the endpoint configuration
endpoint_config = {
    "served_entities": [
        {
            "entity_name": fq_model_name,  
            "entity_version": MODEL_VERSION,  
            "workload_size": "Small",  # Small, Medium, or Large
            "scale_to_zero_enabled": True  # Set to False for production workloads
        }
    ],
    "traffic_config": {
        "routes": [
            {
                "served_model_name": model_name_version,
                "traffic_percentage": 100  # Route all traffic to this model
            }
        ]
    }
}

# Create the serving endpoint
endpoint_name = ENDPOINT_NAME  
endpoint = client.create_endpoint(name=endpoint_name, config=endpoint_config)

print(f"Serving endpoint created: {endpoint}")