# Jupyter hnswlib example

This notebook demonstrates how to use hnswlib with pyspark in a jupyter notebook

In [None]:
from pyspark.ml import Pipeline
from pyspark.ml.feature import VectorAssembler
from pyspark.sql import SparkSession
from pyspark_hnsw.conversion import VectorConverter
from pyspark_hnsw.knn import *
from pyspark_hnsw.linalg import Normalizer
from pyspark.sql import functions as F


import boto3
import io
import requests
import zipfile


spark = SparkSession.builder \
        .master("spark://spark-master:7077") \
        .config("spark.executor.memory", "12g") \
        .config("spark.jars.packages", "com.github.jelmerk:hnswlib-spark_3_5_2.12:2.0.0-alpha.6,io.delta:delta-spark_2.12:3.3.0,org.apache.hadoop:hadoop-aws:3.3.4,com.amazonaws:aws-java-sdk-bundle:1.12.262") \
        .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \
        .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") \
        .config("spark.ui.showConsoleProgress", "false") \
        .config("spark.hadoop.fs.s3a.endpoint", "http://minio:9000") \
        .config("spark.hadoop.fs.s3a.path.style.access", "true") \
        .config("spark.hadoop.fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem") \
        .getOrCreate()

spark.sparkContext.setLogLevel("ERROR")

:: loading settings :: url = jar:file:/Workspace/.venv/lib/python3.12/site-packages/pyspark/jars/ivy-2.5.1.jar!/org/apache/ivy/core/settings/ivysettings.xml


Ivy Default Cache set to: /home/vscode/.ivy2/cache
The jars for the packages stored in: /home/vscode/.ivy2/jars
com.github.jelmerk#hnswlib-spark_3_5_2.12 added as a dependency
io.delta#delta-spark_2.12 added as a dependency
org.apache.hadoop#hadoop-aws added as a dependency
com.amazonaws#aws-java-sdk-bundle added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-14700c9a-6d77-411f-b345-c4c17b72803d;1.0
	confs: [default]
	found com.github.jelmerk#hnswlib-spark_3_5_2.12;2.0.0-alpha.2 in central
	found io.delta#delta-spark_2.12;3.3.0 in central
	found io.delta#delta-storage;3.3.0 in central
	found org.antlr#antlr4-runtime;4.9.3 in central
	found org.apache.hadoop#hadoop-aws;3.3.4 in central
	found com.amazonaws#aws-java-sdk-bundle;1.12.262 in central
	found org.wildfly.openssl#wildfly-openssl;1.0.7.Final in central
:: resolution report :: resolve 132ms :: artifacts dl 6ms
	:: modules in use:
	com.amazonaws#aws-java-sdk-bundle;1.12.262 from central in [defau

Download the word embeddings

In [2]:
s3 = boto3.client("s3")
bucket_name = "spark"

with requests.get(
    "https://huggingface.co/stanfordnlp/glove/resolve/main/glove.42B.300d.zip", 
    stream=True
) as response:
    response.raise_for_status()
    
    with zipfile.ZipFile(io.BytesIO(response.raw.read()), "r") as zip_ref:
        for file_name in zip_ref.namelist():
            with zip_ref.open(file_name) as file:
                s3.upload_fileobj(file, bucket_name, f"input/{file_name}")

Read the data as a spark dataframe

In [2]:
words_df = spark.read \
    .option("delimiter", " ") \
    .option("inferSchema", "true") \
    .option("quote", "\u0000") \
    .csv("s3a://spark/input/glove.42B.300d.txt") \
    .withColumnRenamed("_c0", "id")

Inspect the schema

In [3]:
words_df.printSchema()

root
 |-- id: string (nullable = true)
 |-- _c1: double (nullable = true)
 |-- _c2: double (nullable = true)
 |-- _c3: double (nullable = true)
 |-- _c4: double (nullable = true)
 |-- _c5: double (nullable = true)
 |-- _c6: double (nullable = true)
 |-- _c7: double (nullable = true)
 |-- _c8: double (nullable = true)
 |-- _c9: double (nullable = true)
 |-- _c10: double (nullable = true)
 |-- _c11: double (nullable = true)
 |-- _c12: double (nullable = true)
 |-- _c13: double (nullable = true)
 |-- _c14: double (nullable = true)
 |-- _c15: double (nullable = true)
 |-- _c16: double (nullable = true)
 |-- _c17: double (nullable = true)
 |-- _c18: double (nullable = true)
 |-- _c19: double (nullable = true)
 |-- _c20: double (nullable = true)
 |-- _c21: double (nullable = true)
 |-- _c22: double (nullable = true)
 |-- _c23: double (nullable = true)
 |-- _c24: double (nullable = true)
 |-- _c25: double (nullable = true)
 |-- _c26: double (nullable = true)
 |-- _c27: double (nullable = true

## Fit the model

The cosine distance is obtained with the inner product after normalizing all vectors to unit norm. This is faster than calculating the cosine distance directly

In [4]:
vector_assembler = VectorAssembler(inputCols=words_df.columns[1:], outputCol='features_as_vector')

converter = VectorConverter(inputCol='features_as_vector', outputCol='features')

normalizer = Normalizer(inputCol='features', outputCol='normalized_features')

hnsw = HnswSimilarity(identifierCol='id', featuresCol='normalized_features',
                      distanceFunction='inner-product', m=48, ef=5, k=10, efConstruction=200,
                      numPartitions=1, numThreads=6, predictionCol='approximate')
 
pipeline = Pipeline(stages=[vector_assembler, converter, normalizer, hnsw])

model = pipeline.fit(words_df)

## Query the index

Query the most similar words to 3 sample words

In [10]:
words_df_sample = words_df.filter(F.col("id").isin("king", "family", "car"))

model.transform(words_df_sample) \
    .select("id", F.explode("approximate").alias("n")) \
    .select("id", "n.neighbor") \
    .show(100, False)

+------+----------+
|id    |neighbor  |
+------+----------+
|family|family    |
|family|families  |
|family|parents   |
|family|friends   |
|family|children  |
|family|home      |
|family|mother    |
|family|father    |
|family|lives     |
|family|living    |
|car   |car       |
|car   |cars      |
|car   |vehicle   |
|car   |automobile|
|car   |truck     |
|car   |auto      |
|car   |vehicles  |
|car   |driving   |
|car   |suv       |
|car   |rental    |
|king  |king      |
|king  |queen     |
|king  |prince    |
|king  |kings     |
|king  |henry     |
|king  |kingdom   |
|king  |reign     |
|king  |throne    |
|king  |george    |
|king  |lord      |
+------+----------+



## Save the model

In [11]:
model.save("s3a://spark/model/")