# Find similar question

Using a Python library for NLP compute sentence embeddings for each question and then using cosine similarity find the most similar question to a reference question.

See the [docs](https://pypi.org/project/sentence-transformers/) for sentence_transformers library.

In [None]:
import os
import numpy as np

from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import col, regexp_replace, trim, desc, lit, concat, expr, udf
from pyspark.sql.types import ArrayType, DoubleType

from sentence_transformers import SentenceTransformer

In [None]:
spark = (
    SparkSession
    .builder
    .appName('Text similarity with UDF')
    .getOrCreate()
)

In [None]:
base_path = os.getcwd()

project_path = ('/').join(base_path.split('/')[0:-3]) 

questions_json_input_path = os.path.join(project_path, 'data/questions-json')

In [None]:
questionsDF = (
    spark
    .read
    .format('json')
    .option('path', questions_json_input_path)
    .load()
)

In [None]:
# This function will be used to clean the text and remove html tags with other symbols

def clean_text(df: DataFrame) -> DataFrame:
    return (
        df.withColumn('body', regexp_replace('body', '<[^>]*>', ''))  # Remove HTML tags
        .withColumn('body', regexp_replace('body', '\\\\n|\\\\r|\\\\t|\\n|\\r|\\t', ' '))  # Remove escape characters
        .withColumn('body', regexp_replace('body', '\\s+', ' '))  # Collapse multiple spaces
        .withColumn('body', trim('body'))  # Trim leading/trailing spaces
    )

1) Apply the clean_text function on the questions data.
2) Next create a new column `title` in which you concat `title` with the `body` of the question to have more context for the embedding.

In [None]:
questionsDF = (
    questionsDF
    .transform(clean_text)
    .withColumn('title', concat('title', lit(': '), 'body'))
)

In [None]:
# We will use the all-MiniLM-L6-v2 model:

model = SentenceTransformer('all-MiniLM-L6-v2')

In [None]:
# This is the reference question which you need to compare with all other questions in the questions DataFrame:

reference_question = 'How can I get the first and last row of each partition in PySpark after using repartition and sortWithinPartitions?'

In [None]:
# compute the embedding for the reference question and convert it to list:

reference_embedding = model.encode(reference_question).tolist()

In [None]:
# Implement the udf to compute embedding for the questions in the DataFrame:
# The UDF should return array of doubles

@udf(ArrayType(DoubleType()))
def get_embeddings_udf(text):
    return model.encode(text).tolist()

1) First filter the questions DataFrame to questions where the tags contain the expression `spark`. This will speed up the calculation as computing the embeddings for the whole DataFrame takes about 0.5h. After the filter it should be around 0.5min.

2) Next, compute the embedding for each question using the UDF.

3) Add the embedding for the reference question as a new column to the DataFrame. Then compute the similarity between the reference question and all other questions.

This SQL expression with higher order functions can calculate the cosine similarity for two normalized vectors. The model is returning normalized vectors, so no additinoal normalization is required. Make sure `embedding` and `ref_embedding` are columns in the DataFrame and contain the arrays of doubles for the embeddings of the two questions for which you want to compute the similarity.
```
aggregate(
    zip_with(embedding, ref_embedding, (x, y) -> x * y),
    0D,
    (acc, x) -> acc + x
)
```
4) Finaly sort the result in desc order by the computed similarity and find the questions that is more similar to the reference question.


In [None]:
# your code here:

(
    questionsDF
    .filter(col('tags').like('%spark%'))
    .withColumn('embedding', get_embeddings_udf('title'))
    .withColumn('ref_embedding', lit(reference_embedding))
    .withColumn(
        'similarity', 
        expr("""
          aggregate(
            zip_with(embedding, ref_embedding, (x, y) -> x * y),
            0D,
            (acc, x) -> acc + x
          )
        """)
    )
    .orderBy(desc('similarity'))
    .select('title', 'similarity')
).show(truncate=100)

In [None]:
spark.stop()