In [1]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import BucketedRandomProjectionLSH
from pyspark.sql.functions import col
from pyspark.ml.functions import array_to_vector
from pyspark.sql.functions import desc

In [2]:
import pyspark

print(pyspark.__version__)  # VERSION MUST MATCH THE SPARK CONTAINER VERSION

3.5.3


In [3]:
NEO4J_URI = "bolt://neo4j:password@neo4j:7687"
graph_name = "AnalysisGraph"

In [4]:
spark = (
    SparkSession.builder.appName("KeywordSimilarity")
    .master("spark://spark:7077")
    .config("spark.jars.packages", "neo4j-contrib:neo4j-spark-connector:5.3.1-s_2.12")
    .config("neo4j.url", NEO4J_URI)
    .config("neo4j.authentication.basic.username", "neo4j")
    .config("neo4j.authentication.basic.password", "password")
    .config("neo4j.database", "neo4j")
    .getOrCreate()
)
spark

In [5]:
# Drop Graph if exists
graph_exists = (
    spark.read.format("org.neo4j.spark.DataSource")
    .option("gds", "gds.graph.exists")
    .option("gds.graphName", graph_name)
    .load()
)

In [6]:
graph_exists = graph_exists.first()["exists"]
graph_exists

False

In [7]:
if not graph_exists:
    (
        spark.read.format("org.neo4j.spark.DataSource")
        .option("gds", "gds.graph.project")
        .option("gds.graphName", graph_name)
        .option("gds.nodeProjection", ["Keyword", "Paper", "Volume"])
        .option(
            "gds.relationshipProjection",
            """
            {
            "KEYWORD": {"orientation": "UNDIRECTED"},
            "CONTAINS": {"orientation": "UNDIRECTED"}
            }
            """,
        )
        .load()
        .show(truncate=False)
    )

+------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-------------+---------+-----------------+-------------+
|nodeProjection                                                                                                                            |relationshipProjection                                                                                                                                                                                                                          |graphName    |nodeCount|relationshipCount|projectMillis|
+---------------------------------------------------------------------------------------------------------------------------

In [9]:
fastRP_df = (
    spark.read.format("org.neo4j.spark.DataSource")
    .option("gds", "gds.fastRP.stream")
    .option("gds.graphName", graph_name)
    .option("gds.configuration.embeddingDimension", "64")
    .option("gds.configuration.randomSeed", "42")
    .load()
)
fastRP_df.show()

+------+--------------------+
|nodeId|           embedding|
+------+--------------------+
| 63067|[0.20194181799888...|
| 63068|[-0.0544314756989...|
| 63069|[0.03589031472802...|
| 63070|[-0.0724608153104...|
| 63071|[-0.3429939150810...|
| 63072|[0.10177309811115...|
| 63073|[-0.0202710330486...|
| 63074|[0.12917548418045...|
| 63075|[-0.1356983631849...|
| 63076|[-0.0529525578022...|
| 63077|[0.03949951007962...|
| 63078|[0.17643597722053...|
| 63079|[0.18553991615772...|
| 63080|[0.20973035693168...|
| 63081|[0.17643597722053...|
| 63082|[0.12219828367233...|
| 63083|[0.11186875402927...|
| 63084|[0.16258206963539...|
| 63085|[0.12238671630620...|
| 63086|[0.10621956735849...|
+------+--------------------+
only showing top 20 rows



In [17]:
# Convert list of floats to Spark Vectors
df = fastRP_df.withColumn("features", array_to_vector(col("embedding")))

lsh = BucketedRandomProjectionLSH(
    inputCol="features", outputCol="hashes", bucketLength=1.0, numHashTables=3
)
model = lsh.fit(df)

In [18]:
similar_items = model.approxSimilarityJoin(df, df, threshold=0.05).select(
    col("datasetA.nodeId").alias("node1"),
    col("datasetB.nodeId").alias("node2"),
    "distCol",
)

In [19]:
similar_items.orderBy(desc("distCol")).show()

+-----+-----+-------------------+
|node1|node2|            distCol|
+-----+-----+-------------------+
|63072|63992|0.03234734930197196|
|64055|63072|0.03234734930197196|
|64022|63072|0.03234734930197196|
|63072|64037|0.03234734930197196|
|64025|63072|0.03234734930197196|
|63999|63072|0.03234734930197196|
|64036|63072|0.03234734930197196|
|63072|64028|0.03234734930197196|
|63072|64032|0.03234734930197196|
|64048|63072|0.03234734930197196|
|63072|64053|0.03234734930197196|
|63072|63995|0.03234734930197196|
|64071|63072|0.03234734930197196|
|63072|64065|0.03234734930197196|
|63072|63990|0.03234734930197196|
|64031|63072|0.03234734930197196|
|63072|64036|0.03234734930197196|
|64001|63072|0.03234734930197196|
|63072|64000|0.03234734930197196|
|63072|64047|0.03234734930197196|
+-----+-----+-------------------+
only showing top 20 rows



In [20]:
keywords = (
    spark.read.format("org.neo4j.spark.DataSource")
    .option("labels", ":Keyword")
    .load()
    .select(col("<id>").alias("nodeId"))
)

In [21]:
# Alias your DataFrames for clarity
keywords1 = keywords.alias("k1")
keywords2 = keywords.alias("k2")
sim_items = similar_items.alias("s")

# Join once for node1 and once for node2 with clear aliases
result = (
    sim_items.join(keywords1, col("s.node1") == col("k1.nodeId"))
    .join(keywords2, col("s.node2") == col("k2.nodeId"))
    .select(
        col("k1.nodeId").alias("node1"),
        col("k2.nodeId").alias("node2"),
        col("s.distCol"),
    )
)
result.show()

+-----+-----+-------+
|node1|node2|distCol|
+-----+-----+-------+
|68587|68585|    0.0|
|69157|69161|    0.0|
|69227|69230|    0.0|
|69236|69236|    0.0|
|69389|69389|    0.0|
|69542|69542|    0.0|
|69583|69583|    0.0|
|69840|69839|    0.0|
|70166|70164|    0.0|
|70537|70537|    0.0|
|70551|70551|    0.0|
|70653|70654|    0.0|
|70673|70676|    0.0|
|70704|70706|    0.0|
|70727|70727|    0.0|
|70728|70725|    0.0|
|70914|70913|    0.0|
|71068|71069|    0.0|
|71151|71151|    0.0|
|71203|71203|    0.0|
+-----+-----+-------+
only showing top 20 rows



In [22]:
spark.stop()