In [None]:
from pyspark.sql import SparkSession
from pyspark_hnsw.knn import *
import boto3
import pyspark_hnsw
import requests

from pyspark.ml import Pipeline
from pyspark_hnsw.knn import *
from pyspark.ml.feature import HashingTF, IDF, Tokenizer
from pyspark.sql.functions import col, posexplode

spark = SparkSession.builder \
        .master("spark://spark-master:7077") \
        .config("spark.executor.memory", "12g") \
        .config("spark.jars.packages", "com.github.jelmerk:hnswlib-spark_3_5_2.12:2.0.0-alpha.8") \
        .config("spark.ui.showConsoleProgress", "false") \
        .getOrCreate()

spark.sparkContext.setLogLevel("ERROR")

:: loading settings :: url = jar:file:/Workspace/.venv/lib/python3.12/site-packages/pyspark/jars/ivy-2.5.1.jar!/org/apache/ivy/core/settings/ivysettings.xml


Ivy Default Cache set to: /home/vscode/.ivy2/cache
The jars for the packages stored in: /home/vscode/.ivy2/jars
com.github.jelmerk#hnswlib-spark_3_5_2.12 added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-3b46e5aa-622f-4d9b-87b2-fdae476ee09a;1.0
	confs: [default]
	found com.github.jelmerk#hnswlib-spark_3_5_2.12;2.0.0-alpha.8 in central
:: resolution report :: resolve 57ms :: artifacts dl 2ms
	:: modules in use:
	com.github.jelmerk#hnswlib-spark_3_5_2.12;2.0.0-alpha.8 from central in [default]
	---------------------------------------------------------------------
	|                  |            modules            ||   artifacts   |
	|       conf       | number| search|dwnlded|evicted|| number|dwnlded|
	---------------------------------------------------------------------
	|      default     |   1   |   0   |   0   |   0   ||   1   |   0   |
	---------------------------------------------------------------------
:: retrieving :: org.apache.spark#spark

In [2]:
print(spark.version)
print(pyspark_hnsw.version())

3.5.4
2.0.0a8


Load the product data from the [instacart market basket analysis kaggle competition ](https://www.kaggle.com/c/instacart-market-basket-analysis/data?select=products.csv.zip)


In [3]:
url = "https://drive.google.com/uc?export=download&id=1iIF7QpTyuUGYG7lquP7NkplYC9n8Gxuz"
response = requests.get(url, stream=True)

s3 = boto3.client("s3")
s3.upload_fileobj(response.raw,  "spark", "sparse_vector_input/products.csv")

In [4]:
product_df = spark.read.option("header", "true").csv("s3a://spark/sparse_vector_input/products.csv")

In [5]:
product_df.count()

49688

In [6]:
tokenizer = Tokenizer(inputCol="product_name", outputCol="words")
hashingTF = HashingTF(inputCol="words", outputCol="rawFeatures")
idf = IDF(inputCol="rawFeatures", outputCol="features")

Create a simple TF / IDF model that turns product names into sparse word vectors and adds them to an exact knn index. 

An exact or brute force index will give 100% correct, will be quick to index but really slow to query and is only appropriate during development or for doing comparissons against an approximate index


In [7]:
bruteforce = BruteForceSimilarity(
    identifierCol='product_id', 
    featuresCol='features', 
    distanceFunction='cosine', 
    numPartitions=1,
    numThreads=1,
    k = 5
)

In [8]:
exact_pipeline = Pipeline(stages=[tokenizer, hashingTF, idf, bruteforce])

In [9]:
exact_model = exact_pipeline.fit(product_df)

Next create the same model but add the TF / IDF vectors to a HNSW index

In [10]:
hnsw = HnswSimilarity(
    identifierCol='product_id', 
    featuresCol='features',
    distanceFunction='cosine', 
    numPartitions=1, 
    numThreads=2, 
    k = 5
)

In [11]:
hnsw_pipeline = Pipeline(stages=[tokenizer, hashingTF, idf, hnsw])

In [12]:
hnsw_model = hnsw_pipeline.fit(product_df)

Select a record to query

In [13]:
queries = product_df.filter(col("product_id") == 43572)

In [14]:
queries.show(truncate=False)

+----------+-----------------------------+--------+-------------+
|product_id|product_name                 |aisle_id|department_id|
+----------+-----------------------------+--------+-------------+
|43572     |Alcaparrado Manzanilla Olives|110     |13           |
+----------+-----------------------------+--------+-------------+



In [15]:
exact_model.transform(queries) \
  .select(posexplode(col("prediction")).alias("pos", "item")) \
  .select(col("pos"), col("item.neighbor").alias("product_id"), col("item.distance").alias("distance")) \
  .join(product_df, ["product_id"]) \
  .orderBy(["pos"]) \
  .show(truncate=False)

+----------+---+----------------------+---------------------------------+--------+-------------+
|product_id|pos|distance              |product_name                     |aisle_id|department_id|
+----------+---+----------------------+---------------------------------+--------+-------------+
|43572     |0  |-2.220446049250313E-16|Alcaparrado Manzanilla Olives    |110     |13           |
|27806     |1  |0.2961162117528633    |Manzanilla Olives                |110     |13           |
|16721     |2  |0.40715716898722976   |Manzanilla Stuffed Olives        |110     |13           |
|25125     |3  |0.40715716898722976   |Stuffed Manzanilla Olives        |110     |13           |
|39833     |4  |0.49516580877903393   |Pimiento Sliced Manzanilla Olives|110     |13           |
+----------+---+----------------------+---------------------------------+--------+-------------+



Show the results from the hnsw model

In [16]:
hnsw_model.transform(queries) \
  .select(posexplode(col("prediction")).alias("pos", "item")) \
  .select(col("pos"), col("item.neighbor").alias("product_id"), col("item.distance").alias("distance")) \
  .join(product_df, ["product_id"]) \
  .orderBy(["pos"]) \
  .show(truncate=False)

+----------+---+----------------------+---------------------------------+--------+-------------+
|product_id|pos|distance              |product_name                     |aisle_id|department_id|
+----------+---+----------------------+---------------------------------+--------+-------------+
|43572     |0  |-2.220446049250313E-16|Alcaparrado Manzanilla Olives    |110     |13           |
|27806     |1  |0.2961162117528633    |Manzanilla Olives                |110     |13           |
|25125     |2  |0.40715716898722976   |Stuffed Manzanilla Olives        |110     |13           |
|16721     |3  |0.40715716898722976   |Manzanilla Stuffed Olives        |110     |13           |
|39833     |4  |0.49516580877903393   |Pimiento Sliced Manzanilla Olives|110     |13           |
+----------+---+----------------------+---------------------------------+--------+-------------+



Dispose of the resources held on to by the models

In [17]:
[_, _, _, hnsw_stage]= hnsw_model.stages
hnsw_stage.dispose()

In [18]:
[_, _, _, bruteforce_stage]= exact_model.stages
bruteforce_stage.dispose()