# Similarity Pipeline Demo

This notebook demonstrates how to assemble the similarity pipeline components shipped with `spark-fuse` to cluster semantically similar rows and select per-cluster representatives.

## 1. Start a local Spark session

For demonstration purposes the session runs in local mode. In production you would rely on your existing cluster settings.

In [8]:
import os
print('PYSPARK_PYTHON:', os.environ.get('PYSPARK_PYTHON'))
print('PYSPARK_DRIVER_PYTHON:', os.environ.get('PYSPARK_DRIVER_PYTHON'))


PYSPARK_PYTHON: /Users/kevin/Github/spark-fuse/.venv/bin/python
PYSPARK_DRIVER_PYTHON: /Users/kevin/Github/spark-fuse/.venv/bin/python


In [9]:
from spark_fuse.spark import create_session

spark = create_session(app_name="spark-fuse-similarity-demo", master="local[2]")

## 2. Create a sample dataset

Each row carries a simple three-dimensional embedding and a quality score that we can later use to pick representatives.

In [10]:
data = [
    (1, "Crunchy Red Apple", 4.7),
    (2, "Sweet Gala Apple", 4.9),
    (3, "Fresh Cavendish Banana", 4.6),
    (4, "Ripe Plantain", 4.5),
    (5, "Classic Spiral Notebook", 4.4),
]
columns = ["product_id", "description", "score"]
df = spark.createDataFrame(data, columns)

df.show(truncate=False)


+----------+-----------------------+-----+
|product_id|description            |score|
+----------+-----------------------+-----+
|1         |Crunchy Red Apple      |4.7  |
|2         |Sweet Gala Apple       |4.9  |
|3         |Fresh Cavendish Banana |4.6  |
|4         |Ripe Plantain          |4.5  |
|5         |Classic Spiral Notebook|4.4  |
+----------+-----------------------+-----+



## 3. Configure and run the similarity pipeline

The pipeline embeds the product descriptions with Hugging Face `sentence-transformers`, normalizes vectors for cosine similarity, clusters with K-Means, and keeps the highest scoring item per cluster. If the dependency is missing in your environment, the generator falls back to a deterministic stub so the demo continues to run.


In [11]:
from spark_fuse.similarity import (
    CosineSimilarity,
    KMeansPartitioner,
    MaxColumnChoice,
    SentenceEmbeddingGenerator,
    SimilarityPipeline,
)

embedding_generator = SentenceEmbeddingGenerator(
    input_col="description",
    model_name="sentence-transformers/all-MiniLM-L6-v2",
    normalize=True,
    use_vectorized=False,
    device="cpu",
)
metric = CosineSimilarity(embedding_col="embedding")
partitioner = KMeansPartitioner(k=3, seed=7)
choice = MaxColumnChoice(column="score")

pipeline = SimilarityPipeline(
    embedding_generator=embedding_generator,
    partitioner=partitioner,
    similarity_metric=metric,
    choice_function=choice,
)

clustered_df = pipeline.run(df)
clustered_df.select("product_id", "cluster_id", "description").orderBy("cluster_id").show(truncate=False)


                                                                                

+----------+----------+-----------------------+
|product_id|cluster_id|description            |
+----------+----------+-----------------------+
|3         |0         |Fresh Cavendish Banana |
|4         |0         |Ripe Plantain          |
|1         |1         |Crunchy Red Apple      |
|2         |1         |Sweet Gala Apple       |
|5         |2         |Classic Spiral Notebook|
+----------+----------+-----------------------+



## 4. Retrieve representatives

With `MaxColumnChoice` the representative is simply the row with the largest score inside each cluster.

In [12]:
representatives = pipeline.select_representatives(clustered_df)
representatives.select("cluster_id", "product_id", "description", "score").orderBy("cluster_id").show(truncate=False)

+----------+----------+-----------------------+-----+
|cluster_id|product_id|description            |score|
+----------+----------+-----------------------+-----+
|0         |3         |Fresh Cavendish Banana |4.6  |
|1         |2         |Sweet Gala Apple       |4.9  |
|2         |5         |Classic Spiral Notebook|4.4  |
+----------+----------+-----------------------+-----+



## 5. Shut down Spark

In [13]:
spark.stop()