# Generating and Storing Embeddings

## Setup Azure API Keys

In [1]:
from notebookutils.mssparkutils.credentials import getSecret

KEYVAULT_ENDPOINT = "https://rag-demo-east-us-kv.vault.azure.net/"
# Azure AI Search
AI_SEARCH_NAME = getSecret(KEYVAULT_ENDPOINT, "AI-SEARCH-NAME")
AI_SEARCH_API_KEY = getSecret(KEYVAULT_ENDPOINT, "AI-SEARCH-API-KEY")
AI_SEARCH_INDEX_NAME = "rag-demo-index"
# Azure AI Services
AI_SERVICES_NAME = getSecret(KEYVAULT_ENDPOINT, "AI-SERVICES-NAME")
AI_SERVICES_API_KEY = getSecret(KEYVAULT_ENDPOINT, "AI-SERVICES-API-KEY")
AI_SERVICES_LOCATION = "eastus"
# Azure Open AI - (if F64 SKU is not used)
OPEN_AI_NAME = getSecret(KEYVAULT_ENDPOINT, "OPEN-AI-NAME")
OPEN_AI_API_KEY = getSecret(KEYVAULT_ENDPOINT, "OPEN-AI-API-KEY")
OPEN_AI_EMBEDDING_DEPLOYMENT_NAME = "text-embedding-ada-002" #1536
OPEN_AI_GPT_DEPLOYMENT_NAME = "gpt-35-turbo-16k" # deploymentName could be one of {gpt-35-turbo, gpt-35-turbo-16k}

StatementMeta(, fa238d40-820f-42e6-9566-09ff21a5597f, 3, Finished, Available)

## Generate Embeddings

In [2]:
from pyspark.sql import SparkSession
exploded_df = spark.sql("SELECT * FROM fabric_demo.document_chunks LIMIT 1000")

StatementMeta(, fa238d40-820f-42e6-9566-09ff21a5597f, 4, Finished, Available)

In [3]:
from synapse.ml.services import OpenAIEmbedding

embedding = (
    OpenAIEmbedding()
    .setDeploymentName(OPEN_AI_EMBEDDING_DEPLOYMENT_NAME)
    .setCustomServiceName(OPEN_AI_NAME)
    .setSubscriptionKey(OPEN_AI_API_KEY)
    .setTextCol("chunk")
    .setErrorCol("error")
    .setOutputCol("embeddings")
)

df_embeddings = embedding.transform(exploded_df)

display(df_embeddings)

StatementMeta(, fa238d40-820f-42e6-9566-09ff21a5597f, 5, Finished, Available)

Failed to fetch cluster details
Traceback (most recent call last):
  File "/home/trusted-service-user/cluster-env/trident_env/lib/python3.10/site-packages/synapse/ml/fabric/token_utils.py", line 188, in _get_openai_mwc_token
    raise Exception(
Exception: get openai mwc token returns 403:b'{"Message":"FT1 SKU Not Supported","Source":"ML","error_code":"PERMISSION_DENIED"}'


SynapseWidget(Synapse.DataFrame, 703f4d87-a9ed-4098-b942-cb23ad79172c)

## Saving Embeddings

In [4]:
import requests
import json

# Length of the embedding vector (OpenAI ada-002 generates embeddings of length 1536)
EMBEDDING_LENGTH = 1536

# Create index for AI Search with fields id, content, and contentVector
# Note the datatypes for each field below
url = f"https://{AI_SEARCH_NAME}.search.windows.net/indexes/{AI_SEARCH_INDEX_NAME}?api-version=2023-11-01"
payload = json.dumps(
    {
        "name": AI_SEARCH_INDEX_NAME,
        "fields": [
            # Unique identifier for each document
            {
                "name": "id",
                "type": "Edm.String",
                "key": True,
                "filterable": True,
            },
            # Text content of the document
            {
                "name": "content",
                "type": "Edm.String",
                "searchable": True,
                "retrievable": True,
            },
            # Vector embedding of the text content
            {
                "name": "contentVector",
                "type": "Collection(Edm.Single)",
                "searchable": True,
                "retrievable": True,
                "dimensions": EMBEDDING_LENGTH,
                "vectorSearchProfile": "vectorConfig",
            },
        ],
        "vectorSearch": {
            "algorithms": [{"name": "hnswConfig", "kind": "hnsw", "hnswParameters": {"metric": "cosine"}}],
            "profiles": [{"name": "vectorConfig", "algorithm": "hnswConfig"}],
        },
    }
)
headers = {"Content-Type": "application/json", "api-key": AI_SEARCH_API_KEY}

response = requests.request("PUT", url, headers=headers, data=payload)
if response.status_code == 201:
    print("Index created!")
elif response.status_code == 204:
    print("Index updated!")
else:
    print(f"HTTP request failed with status code {response.status_code}")
    print(f"HTTP response body: {response.text}")


StatementMeta(, fa238d40-820f-42e6-9566-09ff21a5597f, 6, Finished, Available)

Index updated!


In [5]:
import re

from pyspark.sql.functions import monotonically_increasing_id


def insert_into_index(documents):
    """Uploads a list of 'documents' to Azure AI Search index."""

    url = f"https://{AI_SEARCH_NAME}.search.windows.net/indexes/{AI_SEARCH_INDEX_NAME}/docs/index?api-version=2023-11-01"

    payload = json.dumps({"value": documents})
    headers = {
        "Content-Type": "application/json",
        "api-key": AI_SEARCH_API_KEY,
    }

    response = requests.request("POST", url, headers=headers, data=payload)

    if response.status_code == 200 or response.status_code == 201:
        return "Success"
    else:
        return f"Failure: {response.text}"

def make_safe_id(row_id: str):
    """Strips disallowed characters from row id for use as Azure AI search document ID."""
    return re.sub("[^0-9a-zA-Z_-]", "_", row_id)


def upload_rows(rows):
    """Uploads the rows in a Spark dataframe to Azure AI Search.
    Limits uploads to 1000 rows at a time due to Azure AI Search API limits.
    """
    BATCH_SIZE = 1000
    rows = list(rows)
    for i in range(0, len(rows), BATCH_SIZE):
        row_batch = rows[i : i + BATCH_SIZE]
        documents = []
        for row in rows:
            documents.append(
                {
                    "id": make_safe_id(row["unique_id"]),
                    "content": row["chunk"],
                    "contentVector": row["embeddings"].tolist(),
                    "@search.action": "upload",
                },
            )
        status = insert_into_index(documents)
        yield [row_batch[0]["row_index"], row_batch[-1]["row_index"], status]

# Add ID to help track what rows were successfully uploaded
df_embeddings = df_embeddings.withColumn("row_index", monotonically_increasing_id())

# Run upload_batch on partitions of the dataframe
res = df_embeddings.rdd.mapPartitions(upload_rows)
display(res.toDF(["start_index", "end_index", "insertion_status"]))

StatementMeta(, fa238d40-820f-42e6-9566-09ff21a5597f, 7, Submitted, Running)