In [0]:
# %pip install databricks-sdk
%pip install --upgrade databricks-vectorsearch pydantic mlflow
%restart_python

In [0]:
%load_ext autoreload
%autoreload 2
# Enables autoreload; learn more at https://docs.databricks.com/en/files/workspace-modules.html#autoreload-for-python-modules
# To disable autoreload; run %autoreload 0

In [0]:
import sys, os
sys.path.append(os.path.abspath('..'))
from configs.project import ProjectConfig

from databricks.vector_search.client import VectorSearchClient
from datetime import timedelta
import time


In [0]:
import yaml

with open("../configs/project.yml", "r") as file:
    data = yaml.safe_load(file)

projectConfig = ProjectConfig(**data)
# TODO: choose the correct index here most likely "id_1"
_config = projectConfig.vector_search_attributes["id_1"]

for k, v in _config.model_dump().items():
  print(k, v)

In [0]:
# globals().update(_config)

In [0]:
vsc = VectorSearchClient(disable_notice=True)

In [0]:
spark.sql(f"ALTER TABLE {_config.source_table_name} ALTER COLUMN {_config.primary_key} SET NOT NULL")
try:
  spark.sql(f"ALTER TABLE {_config.source_table_name} ADD CONSTRAINT {_config.primary_key}_pk PRIMARY KEY( {_config.primary_key} )")
except Exception as e:
  print(f"Constraint {_config.primary_key}_pk already exists.")
spark.sql(f"ALTER TABLE {_config.source_table_name} SET TBLPROPERTIES (delta.enableChangeDataFeed = true) ")


In [0]:
display(spark.table(_config.source_table_name))

In [0]:
try:
    vsc.create_endpoint(name=_config.endpoint_name,
                        endpoint_type="STANDARD")
    
    time.sleep(5)

    vsc.wait_for_endpoint(name=_config.endpoint_name,
                                timeout=timedelta(minutes=60),
                                verbose=True)
    
    print(f"Endpoint named {_config.endpoint_name} is ready.")

    ep = vsc.get_endpoint(name=_config.endpoint_name)

except Exception as e:
    if "already exists" in str(e):
        print(f"Endpoint named {_config.endpoint_name} already exists.")
        ep = vsc.get_endpoint(name=_config.endpoint_name)
    else:
        raise e


In [0]:
from databricks.sdk.service import iam
from databricks.sdk import WorkspaceClient

w = WorkspaceClient()
w.permissions.set(request_object_type="vector-search-endpoints",
                  request_object_id=ep["id"],
                  access_control_list=[
                        iam.AccessControlRequest(group_name="users",
                                                   permission_level=iam.PermissionLevel.CAN_MANAGE)
                      ])

Test Embedding Endpoint

In [0]:
import mlflow
import mlflow.deployments

client = mlflow.deployments.get_deploy_client("databricks")


In [0]:
[ep for ep in client.list_endpoints() if ep["name"]==_config.embedding_model_endpoint_name]


In [0]:
client.predict(endpoint=_config.embedding_model_endpoint_name, inputs={"input": ["What is Apache Spark?"]})


#Create Vector Search Index

In [0]:

try:
  vector_search_index = vsc.create_delta_sync_index_and_wait(
    endpoint_name=_config.endpoint_name,
    index_name=_config.index_name,
    source_table_name=_config.source_table_name,
    primary_key=_config.primary_key,
    embedding_source_column=_config.embedding_source_column,
    embedding_model_endpoint_name=_config.embedding_model_endpoint_name,
    pipeline_type=_config.pipeline_type,
    verbose=True
  )
except Exception as e:
    if "already exists" in str(e):
        print(f"Index named {_config.endpoint_name} already exists.")
        vector_search_index = vsc.get_index(_config.endpoint_name, _config.index_name)
    else:
        raise e

In [0]:
spark.sql(f"GRANT SELECT ON TABLE {_config.index_name} TO `account users` ")
