Run on Serverless notebook, or on DBR 16.3 or higher

In [0]:
%pip install databricks-vectorsearch
%restart_python

In [0]:
catalog = "main" 
schema = "default" 
source_table = f"{catalog}.{schema}.pilot_notes_supervised_classification"
endpoint_name = "pilot_notes_endpoint"
target_table = f"{catalog}.{schema}.pilot_notes_final_classification"
spark.sql(f"USE CATALOG {catalog}")
spark.sql(f"USE SCHEMA {schema}")

In [0]:
from databricks.vector_search.client import VectorSearchClient

# The following line automatically generates a PAT Token for authentication when run in a Databricks notebook
client = VectorSearchClient()

# The following line uses the service principal token for authentication
# client = VectorSearchClient(service_principal_client_id=<CLIENT_ID>,service_principal_client_secret=<CLIENT_SECRET>)
client.create_endpoint_and_wait(
    name=endpoint_name,
    endpoint_type="STANDARD"
)

In [0]:
# Enable change data feed for the source table
spark.sql(f"ALTER TABLE {source_table} SET TBLPROPERTIES (delta.enableChangeDataFeed = true)")

# Create the delta sync index
index = client.create_delta_sync_index_and_wait(
  endpoint_name=endpoint_name,
  source_table_name=source_table,
  index_name=source_table + "_index",
  pipeline_type="TRIGGERED",
  primary_key="id",
  columns_to_sync=["id", "pilot_notes", "unsupervised_prediction", "supervised_prediction"],
  embedding_source_column="pilot_notes",
  embedding_model_endpoint_name="databricks-gte-large-en"
)

In [0]:
spark.sql(f"""
CREATE OR REPLACE FUNCTION get_similar_category(
    query STRING COMMENT "The string to search for similar notes to"
) RETURNS STRUCT<category: STRING, doc_count: STRING>
COMMENT "Returns the most common category related to the query"
RETURN (
    SELECT STRUCT(category, doc_count)
    FROM (
        SELECT unsupervised_prediction as category, COUNT(*) as doc_count
        FROM VECTOR_SEARCH(
                index => '{source_table + '_index'}', 
                query => get_similar_category.query, 
                num_results => 10)
        GROUP BY unsupervised_prediction
        ORDER BY doc_count DESC
        LIMIT 1
    )
)
""")
spark.sql(f"SELECT get_similar_category('getting ready for landing')").display()

In [0]:
final_df = spark.sql(f"""
    SELECT
        pilot_notes, id,
        get_similar_category(pilot_notes).category as similar_category, 
        get_similar_category(pilot_notes).doc_count as similar_doc_count
    FROM {source_table}
""")
final_df.write.mode("overwrite").saveAsTable(target_table)
spark.read.table(target_table).display()