In [0]:
%sql
use catalog lucasbruand_catalog;

create schema if not exists vector_search;
use schema vector_search;

In [0]:
DATABRICKS_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
workspaceUrl = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get()

In [0]:
print(DATABRICKS_TOKEN)
print(workspaceUrl)

In [0]:
import requests

# Vector Search index name
index_name = "robert_mosley.customer_service.product_docs_ind"
# API endpoint using workspaceUrl variable
url = f"{workspaceUrl}/api/2.0/vector-search/indexes/{index_name}/query"

# Headers
headers = {
    "Authorization": f"Bearer {DATABRICKS_TOKEN}",
    "Content-Type": "application/json;charset=UTF-8",
    "Accept": "application/json, text/plain, */*"
}

# Request payload
payload = {
    "num_results": 1024,
    "columns": ['product_id', 'indexed_doc'],
    "query_text": "example_query"
}

# Make the POST request
response = requests.post(url, headers=headers, json=payload)

# Display the response
print(f"Status Code: {response.status_code}")
print(f"Response: {response.json()}")

In [0]:
%sql
CREATE OR REPLACE FUNCTION vector_search_rest_multi_index_or(
  index_names ARRAY<STRING>,
  query_vector ARRAY<DOUBLE>,
  columns ARRAY<STRING>,
  num_results INT,
  workspace_url STRING,
  token STRING
)
RETURNS STRING
LANGUAGE PYTHON
AS $$
import requests
import json

def get_score_index(result_data):
    """Extract score column index and manifest from API response."""
    if 'manifest' not in result_data:
        return None, None
    
    manifest = result_data['manifest']
    score_index = None
    
    for idx, col in enumerate(manifest.get('columns', [])):
        if col.get('name') == 'score':
            score_index = idx
            break
    
    return score_index, manifest

try:
    all_results = []
    score_index = None
    manifest = None
    
    # Loop through each index name
    for index_name in index_names:
        # Build API endpoint
        url = f"{workspace_url}/api/2.0/vector-search/indexes/{index_name}/query"
        
        # Headers
        headers = {
            "Authorization": f"Bearer {token}",
            "Content-Type": "application/json;charset=UTF-8",
            "Accept": "application/json, text/plain, */*"
        }
        
        # Request payload
        payload = {
            "num_results": num_results,
            "columns": columns,
            "query_vector": query_vector
        }
        
        # Make the POST request
        response = requests.post(url, headers=headers, json=payload)
        
        if response.status_code == 200:
            result_data = response.json()
            
            # Find the score column index from manifest (only need to do this once)
            if score_index is None:
                score_index, manifest = get_score_index(result_data)
            
            # Add results with source index name
            if 'result' in result_data and 'data_array' in result_data['result']:
                for row in result_data['result']['data_array']:
                    row.append(index_name)  # Add source index name
                    all_results.append(row)
        else:
            # Include error information for this index
            error_entry = [None] * (len(columns) + 1)  # Create array with nulls
            error_entry.append(f"ERROR: Index {index_name} failed with status {response.status_code}")
            all_results.append(error_entry)
    
    # Sort by score (descending) if score_index was found
    if score_index is not None and all_results:
        all_results.sort(key=lambda x: float(x[score_index]) if isinstance(x, list) and len(x) > score_index and x[score_index] is not None else float('-inf'), reverse=True)
    
    # Add source_index to manifest if we have one
    if manifest:
        manifest['columns'].append({'name': 'source_index', 'type': 'string'})
    
    # Return combined and sorted results
    return json.dumps({
        "manifest": manifest,
        "result": {            
            "data_array": all_results,
            "num_results": len(all_results)
        }
    })
except Exception as e:
    return json.dumps({"error": str(e)})
$$

In [0]:
workspace_url = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get()
token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
print(workspace_url)
print(token)    

In [0]:
emb = [0.1]*1024
print(len(emb))
print(",".join(str(x) for x in emb))

In [0]:
df = spark.sql(f"""
SELECT vector_search_rest_multi_index_or(
  array('{index_name}', '{index_name}'),
  array({','.join(str(x) for x in emb)}),
  array('product_id', 'indexed_doc'),
  103,
  workspace_url => '{workspace_url}',
  token => '{token}'
)
""")

In [0]:
display(df)