# HnswLib Quick Start


[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jelmerk/hnswlib/blob/master/hnswlib-examples/hnswlib-examples-pyspark-google-colab/quick_start_google_colab.ipynb)

We will first set up the runtime environment and give it a quick test

In [None]:
!wget https://raw.githubusercontent.com/jelmerk/hnswlib/master/scripts/colab_setup.sh -O - | bash

--2022-01-08 02:32:40--  https://raw.githubusercontent.com/jelmerk/hnswlib/master/scripts/colab_setup.sh
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1269 (1.2K) [text/plain]
Saving to: ‘STDOUT’


2022-01-08 02:32:41 (73.4 MB/s) - written to stdout [1269/1269]

setup Colab for PySpark 3.0.3 and Hnswlib 1.0.0
Installing PySpark 3.0.3 and Hnswlib 1.0.0
[K     |████████████████████████████████| 209.1 MB 73 kB/s 
[K     |████████████████████████████████| 198 kB 80.2 MB/s 
[?25h  Building wheel for pyspark (setup.py) ... [?25l[?25hdone


In [None]:
import pyspark_hnsw

from pyspark.ml import Pipeline
from pyspark_hnsw.knn import *
from pyspark.ml.feature import HashingTF, IDF, Tokenizer
from pyspark.sql.functions import col, posexplode

In [None]:
spark = pyspark_hnsw.start()

In [None]:
print("Hnswlib version: {}".format(pyspark_hnsw.version()))
print("Apache Spark version: {}".format(spark.version))

Hnswlib version: 1.0.0
Apache Spark version: 3.0.3


Load the product data from the [instacart market basket analysis kaggle competition ](https://www.kaggle.com/c/instacart-market-basket-analysis/data?select=products.csv.zip)

In [None]:
!wget -O /tmp/products.csv "https://drive.google.com/uc?export=download&id=1iIF7QpTyuUGYG7lquP7NkplYC9n8Gxuz"

--2022-01-08 03:58:45--  https://drive.google.com/uc?export=download&id=1iIF7QpTyuUGYG7lquP7NkplYC9n8Gxuz
Resolving drive.google.com (drive.google.com)... 173.194.79.100, 173.194.79.102, 173.194.79.101, ...
Connecting to drive.google.com (drive.google.com)|173.194.79.100|:443... connected.
HTTP request sent, awaiting response... 302 Moved Temporarily
Location: https://doc-10-b4-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/4nf11kob2m4ai6bvlueodufo0oocm0t2/1641614325000/16131524327083715076/*/1iIF7QpTyuUGYG7lquP7NkplYC9n8Gxuz?e=download [following]
--2022-01-08 03:58:45--  https://doc-10-b4-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/4nf11kob2m4ai6bvlueodufo0oocm0t2/1641614325000/16131524327083715076/*/1iIF7QpTyuUGYG7lquP7NkplYC9n8Gxuz?e=download
Resolving doc-10-b4-docs.googleusercontent.com (doc-10-b4-docs.googleusercontent.com)... 108.177.127.132, 2a00:1450:4013:c07::84
Connecting to doc-10-b4-docs.googleusercontent.com (doc-1

In [None]:
productData = spark.read.option("header", "true").csv("/tmp/products.csv")

In [None]:
productData.count()

49688

In [None]:
tokenizer = Tokenizer(inputCol="product_name", outputCol="words")
hashingTF = HashingTF(inputCol="words", outputCol="rawFeatures")
idf = IDF(inputCol="rawFeatures", outputCol="features")

Create a simple TF / IDF model that turns product names into sparse word vectors and adds them to an exact knn index. 

An exact or brute force index will give 100% correct, will be quick to index but really slow to query and is only appropriate during development or for doing comparissons against an approximate index

In [None]:
bruteforce = BruteForceSimilarity(identifierCol='product_id', queryIdentifierCol='product_id', k = 5, featuresCol='features', distanceFunction='cosine', excludeSelf=True, numPartitions=10)

In [None]:
exact_pipeline = Pipeline(stages=[tokenizer, hashingTF, idf, bruteforce])

In [None]:
exact_model = exact_pipeline.fit(productData)

Next create the same model but add the TF / IDF vectors to a HNSW index

In [None]:
hnsw = HnswSimilarity(identifierCol='product_id', queryIdentifierCol='product_id', featuresCol='features',
                      distanceFunction='cosine', numPartitions=10, excludeSelf=True, k = 5)

In [None]:
hnsw_pipeline = Pipeline(stages=[tokenizer, hashingTF, idf, hnsw])

In [None]:
hnsw_model = hnsw_pipeline.fit(productData)

Select a record to query

In [None]:
queries = productData.filter(col("product_id") == 43572)

In [None]:
queries.show(truncate = False)

+----------+-----------------------------+--------+-------------+
|product_id|product_name                 |aisle_id|department_id|
+----------+-----------------------------+--------+-------------+
|43572     |Alcaparrado Manzanilla Olives|110     |13           |
+----------+-----------------------------+--------+-------------+



Show the results from the exact model

In [None]:
exact_model.transform(queries) \
  .select(posexplode(col("prediction")).alias("pos", "item")) \
  .select(col("pos"), col("item.neighbor").alias("product_id"), col("item.distance").alias("distance")) \
  .join(productData, ["product_id"]) \
  .show(truncate=False)

+----------+---+-------------------+----------------------------------+--------+-------------+
|product_id|pos|distance           |product_name                      |aisle_id|department_id|
+----------+---+-------------------+----------------------------------+--------+-------------+
|27806     |0  |0.2961162117528633 |Manzanilla Olives                 |110     |13           |
|25125     |1  |0.40715716898722976|Stuffed Manzanilla Olives         |110     |13           |
|16721     |2  |0.40715716898722976|Manzanilla Stuffed Olives         |110     |13           |
|39833     |3  |0.49516580877903393|Pimiento Sliced Manzanilla Olives |110     |13           |
|33495     |4  |0.514201828085252  |Manzanilla Pimiento Stuffed Olives|110     |13           |
+----------+---+-------------------+----------------------------------+--------+-------------+



Show the results from the hnsw model

In [None]:
hnsw_model.transform(queries) \
  .select(posexplode(col("prediction")).alias("pos", "item")) \
  .select(col("pos"), col("item.neighbor").alias("product_id"), col("item.distance").alias("distance")) \
  .join(productData, ["product_id"]) \
  .show(truncate=False)

+----------+---+-------------------+----------------------------------+--------+-------------+
|product_id|pos|distance           |product_name                      |aisle_id|department_id|
+----------+---+-------------------+----------------------------------+--------+-------------+
|27806     |0  |0.2961162117528633 |Manzanilla Olives                 |110     |13           |
|25125     |1  |0.40715716898722976|Stuffed Manzanilla Olives         |110     |13           |
|16721     |2  |0.40715716898722976|Manzanilla Stuffed Olives         |110     |13           |
|33495     |3  |0.514201828085252  |Manzanilla Pimiento Stuffed Olives|110     |13           |
|41472     |4  |0.514201828085252  |Pimiento Stuffed Manzanilla Olives|110     |13           |
+----------+---+-------------------+----------------------------------+--------+-------------+

