# Embedding Text with local (per node) NVIDIA TensorRT accelerator and GPU based Aproximate Nearest Neighbor (ANN)

The demo extending existing [Azure OpenAI based demo](https://github.com/microsoft/SynapseML/blob/master/docs/Explore%20Algorithms/OpenAI/Quickstart%20-%20OpenAI%20Embedding%20and%20GPU%20based%20KNN.ipynb) when encoding is processed by OpenAI requests and KNN was using GPU based brute force search. This tutorial shows how to perform fast local embeddings using [multilingual E5 text embeddings](https://arxiv.org/abs/2402.05672) and fast aproximate Nearest Neighbor search using IVFFlat alcorithm. All tutorial stages accelerated by NVIDIA GPU using [NVIDIA TensorRT](https://developer.nvidia.com/tensorrt) and [Spark Rapids ML](https://github.com/NVIDIA/spark-rapids-ml). The tutorial folder contains two benchmark notebooks to demonstrate advantages of the presented GPU based approach compare to [previos CPU based demo](https://github.com/microsoft/SynapseML/blob/master/docs/Explore%20Algorithms/OpenAI/Quickstart%20-%20OpenAI%20Embedding.ipynb)

The key prerequisites for this quickstart include a working Azure OpenAI resource, and an Apache Spark cluster with SynapseML installed. We suggest creating a Synapse workspace, but currently the notebook was running on Databricks GPU based cluster using Standard_NC24ads_A100_v4 with 6 workers. Databricks Runtime was 13.3 LTS ML (includes Apache Spark 3.4.1, GPU, Scala 2.12) with related [init_script](https://github.com/microsoft/SynapseML/tree/master/tools/init_scripts) to install all required packages.


## Step 1: Prepare Environment

It will imports required libraries and get initial settings

In [0]:

import warnings
warnings.filterwarnings("ignore", category=UserWarning, module='tritonclient.grpc')
import logging
logging.getLogger('py4j').setLevel(logging.ERROR)
import mlflow
import datetime
import pytz
from spark_rapids_ml.knn import ApproximateNearestNeighbors, ApproximateNearestNeighborsModel
from sentence_embedding_transformer import EmbeddingTransformer

logging.getLogger('sentence_transformers.SentenceTransformer').setLevel(logging.ERROR)
mlflow.autolog(disable=True)

# Define the PST timezone
pst_timezone = pytz.timezone('US/Pacific')

# Get the current time in UTC and convert it to PST
current_start_time_utc = datetime.datetime.now(pytz.utc)
current_time_pst = current_start_time_utc.astimezone(pst_timezone)

print("Current time in PST:", current_time_pst.strftime('%Y-%m-%d %H:%M:%S %Z%z'))


Start demo run with 1000 input rows
Current time in PST: 2024-06-11 14:51:24 PDT-0700


## Step 2: Load Data

In this demo we will explore a dataset of fine food reviews

In [0]:
dataTransformer = EmbeddingTransformer(inputCol="combined", outputCol="embeddings", useTRTFlag=True, batchSize=16)

# Load food revies with limiting number of rows until 1000000
df = dataTransformer.load_data_food_reviews(spark=spark, limit=1000).repartition(10).cache()

## Step 3: Generate Embeddings

We will first generate embeddings using NVIDIA TensorRT optimized SentenceTransformer

In [0]:
all_embeddings = dataTransformer.transform(df)

## Step 4: Build the query against embeddings

Get query embeddings running standard SentenceTransformer just on the driver. Convert embedding results to a data frame

In [0]:
# Sample queries
queries = ["desserts", "disgusting"]

# Create an instance of the EmbeddingTransformer to encode embeddings on drive only
# to speed it up processing small amout of queries
#embedding_transformer = EmbeddingTransformer(driverOnly=True, spark=spark)
embedding_transformer = EmbeddingTransformer(driverOnly=True)
query_embeddings = embedding_transformer.transform(queries, spark=spark)

Downloading .gitattributes:   0%|          | 0.00/1.48k [00:00<?, ?B/s]

Downloading 1_Pooling/config.json:   0%|          | 0.00/201 [00:00<?, ?B/s]

Downloading README.md:   0%|          | 0.00/67.8k [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/616 [00:00<?, ?B/s]

Downloading handler.py:   0%|          | 0.00/1.12k [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/1.34G [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.34G [00:00<?, ?B/s]

Downloading (…)nce_bert_config.json:   0%|          | 0.00/57.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

Downloading tokenizer_config.json:   0%|          | 0.00/314 [00:00<?, ?B/s]

Downloading vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading modules.json:   0%|          | 0.00/387 [00:00<?, ?B/s]

## Step 5: Build a fast vector index to over review embeddings

We will use fast NVIDIA Rapids indexer

In [0]:
rapids_knn = ApproximateNearestNeighbors(k=5)
rapids_knn.setInputCol("embeddings").setIdCol("id")

rapids_knn_model = rapids_knn.fit(all_embeddings.select("id", "embeddings"))

## Step 6: Find top k Nearest Neighbors

We will use fast ANN IVFFlat algorithm from Rapids

In [0]:
(_, _, knn_df) = rapids_knn_model.kneighbors(query_embeddings.select("id", "embeddings"))

## Step 7: Collect and display results

In [0]:
display(knn_df)

print(f"Demo finished")

# Get the current time in UTC and convert it to PST
current_end_time_utc = datetime.datetime.now(pytz.utc)
current_time_pst = current_end_time_utc.astimezone(pst_timezone)

print("Current time in PST:", current_time_pst.strftime('%Y-%m-%d %H:%M:%S %Z%z'))

dif = current_end_time_utc - current_start_time_utc

# Extract hours, minutes, and seconds from the difference
total_seconds = int(dif.total_seconds())
hours, remainder = divmod(total_seconds, 3600)
minutes, seconds = divmod(remainder, 60)

# Print the difference in the desired format
print(f"Difference: h: {hours}, min: {minutes}, sec: {seconds}")

query_id,indices,distances
1,"List(737, 595, 308, 106, 591)","List(0.6997133, 0.70115274, 0.7020032, 0.7047766, 0.7060983)"
2,"List(58, 860, 194, 827, 614)","List(0.6791535, 0.6887464, 0.702577, 0.70618147, 0.7084421)"


Demo finished
Current time in PST: 2024-06-11 14:56:00 PDT-0700
Difference: h: 0, min: 4, sec: 36
