In [0]:
%pip install sentence_transformers
%restart_python

In [0]:
dbutils.widgets.text("emebdding_model","")
dbutils.widgets.text("catalog_name","")
embedding_model = dbutils.widgets.get("emebdding_model") #sentence-transformers/all-MiniLM-L6-v2
catalog_name = dbutils.widgets.get("catalog_name") 

In [0]:
import mlflow
from sentence_transformers import SentenceTransformer

mlflow.set_registry_uri("databricks-uc")

model_path = embedding_model.replace("/", "-")
artifact_path = model_path
mlflow_model_name=f"{catalog_name}.gold.{model_path}"

model = SentenceTransformer(embedding_model)
sample_input = "Isn't embeddings awesome bridge between natural and machine language!"
signature = mlflow.models.infer_signature(
    model_input=sample_input,
    model_output=model.encode(sample_input),
)

with mlflow.start_run():
    model_info = mlflow.sentence_transformers.log_model(
        model=model,
        artifact_path=artifact_path,
        signature=signature,
        input_example=sample_input,
        registered_model_name=mlflow_model_name,
        pip_requirements=["mlflow","torch", "transformers","accelerate", "torchvision", "sentence_transformers"],
    )


In [0]:
from mlflow.models import validate_serving_input

model_uri = model_info.model_uri

serving_payload = """{
  "dataframe_split": {
    "data": [
      [
        "Hello, Geeks!"
      ]
    ]
  }
}"""

# Validate the serving payload works on the model
validate_serving_input(model_uri, serving_payload)


In [0]:
from mlflow.deployments import get_deploy_client


def create_serving_endpoint(endpoint_name, entity_name, entity_version=1, workload_size="Small", workload_type="CPU", scale_to_zero_enabled=True):

    client = get_deploy_client("databricks")
    endpoint = client.list_endpoints()
    for i in endpoint:
        if(i['name']==endpoint_name):
            print("Endpoint Already Exists.")
            return
    endpoint = client.create_endpoint(
        name=endpoint_name,
        config={
            "served_entities": [
                {
                    "entity_name": entity_name,
                    "entity_version": entity_version,
                    "workload_size": workload_size,
                    "workload_type": workload_type,
                    "scale_to_zero_enabled": scale_to_zero_enabled
                }
            ]
        }
    )
    return endpoint

In [0]:
endpoint = create_serving_endpoint(
        model_path,
        mlflow_model_name,
        entity_version=1,
    )

In [0]:
try:
    endpoint = create_serving_endpoint(
        model_path,
        mlflow_model_name,
        entity_version=model_info.registered_model_version,
    )
except:
    print("Endpoint Already Exists.")