In [78]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import Word2Vec
from pyspark.sql.functions import col
from pyspark.sql import functions as F
from pyspark.ml.feature import OneHotEncoder, StringIndexer, MinMaxScaler, VectorAssembler
from pyspark.ml import Pipeline
from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, FloatType
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from pyspark.ml.linalg import Vectors
from pyspark.ml.linalg import DenseVector


In [2]:
# Initialize Spark session
spark = SparkSession.builder \
        .appName("SpotifyHybridEmbedding")\
        .config("spark.executor.memory", "8g") \
        .config("spark.driver.memory", "8g") \
        .config("spark.executor.cores", "4") \
        .config("spark.dynamicAllocation.enabled", "true") \
        .config("spark.dynamicAllocation.initialExecutors", "2") \
        .config("spark.dynamicAllocation.maxExecutors", "10") \
        .config("spark.memory.fraction", "0.6") \
        .config("spark.memory.storageFraction", "0.2") \
        .config("spark.sql.shuffle.partitions", "200") \
        .config("spark.executor.extraJavaOptions", "-XX:+UseG1GC -XX:InitiatingHeapOccupancyPercent=35") \
        .getOrCreate()

24/12/14 01:49:31 WARN Utils: Your hostname, Muhammads-MacBook-Pro.local resolves to a loopback address: 127.0.0.1; using 10.50.15.243 instead (on interface en0)
24/12/14 01:49:31 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/12/14 01:49:32 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


### Generate Contextual Embeddings

In [143]:
# Load playlist data
playlist_df = spark.read.parquet("../data/processed/df_playlists.parquet", header=True, inferSchema=True)
playlist_df

DataFrame[pid: bigint, tid: bigint, pos: bigint]

In [144]:
playlist_df.count()

66346428

In [145]:
playlist_df = playlist_df.limit(10000)

In [146]:
playlist_df.count()

10000

In [147]:
# # Convert all columns to string type
# for column in playlist_df.columns:
#     playlist_df = playlist_df.withColumn(column, col(column).cast("string"))

In [244]:
playlist_df.show()

+------+---+---+
|   pid|tid|pos|
+------+---+---+
|549000|  0|  0|
|549000|  1|  1|
|549000|  2|  2|
|549000|  3|  3|
|549000|  4|  4|
|549000|  5|  5|
|549000|  6|  6|
|549000|  7|  7|
|549000|  8|  8|
|549000|  9|  9|
|549000| 10| 10|
|549000| 11| 11|
|549000| 12| 12|
|549000| 13| 13|
|549000| 14| 14|
|549000| 15| 15|
|549000| 16| 16|
|549000| 17| 17|
|549000| 18| 18|
|549000| 19| 19|
+------+---+---+
only showing top 20 rows



In [148]:
playlist_df.printSchema()

root
 |-- pid: long (nullable = true)
 |-- tid: long (nullable = true)
 |-- pos: long (nullable = true)



In [149]:
t_list = playlist_df.select("tid").rdd.flatMap(lambda x: x).collect()

In [150]:
t_list

[0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38,
 39,
 40,
 41,
 42,
 43,
 44,
 45,
 46,
 47,
 48,
 49,
 50,
 51,
 52,
 53,
 54,
 55,
 56,
 57,
 58,
 59,
 60,
 61,
 62,
 63,
 64,
 65,
 66,
 67,
 68,
 69,
 70,
 71,
 72,
 73,
 74,
 75,
 76,
 77,
 78,
 79,
 11,
 80,
 81,
 82,
 83,
 84,
 85,
 86,
 87,
 88,
 89,
 90,
 91,
 92,
 93,
 94,
 95,
 96,
 97,
 98,
 99,
 100,
 101,
 102,
 103,
 104,
 105,
 106,
 107,
 108,
 109,
 110,
 111,
 112,
 113,
 114,
 115,
 116,
 117,
 118,
 119,
 120,
 121,
 122,
 123,
 124,
 125,
 126,
 127,
 128,
 129,
 130,
 131,
 132,
 133,
 134,
 135,
 136,
 137,
 138,
 139,
 140,
 141,
 142,
 143,
 144,
 145,
 146,
 147,
 148,
 149,
 150,
 151,
 152,
 153,
 154,
 155,
 156,
 157,
 158,
 159,
 160,
 161,
 162,
 163,
 164,
 165,
 166,
 167,
 168,
 169,
 170,
 171,
 172,
 173,
 174,
 175,
 176,
 177,
 178,
 179,
 180,
 181,
 182,
 183,
 

In [151]:
# Group track IDs by playlist
playlist_sequences = playlist_df.groupBy("pid").agg(
    F.collect_list("tid").alias("track_sequence")
)


In [247]:
# Group by 'pid' and collect 'tid' into a list
playlists = playlist_df.groupBy("pid").agg(F.collect_list("tid").alias("tid_list"))

# Convert to a list of tuples or a Pandas DataFrame if needed
playlists_list = playlists.rdd.map(lambda row: (row["pid"], row["tid_list"])).collect()

In [246]:
playlist_sequences.show()



+------+--------------------+
|   pid|      track_sequence|
+------+--------------------+
|549000|[0, 1, 2, 3, 4, 5...|
|549001|[75, 76, 77, 78, ...|
|549002|[135, 136, 137, 1...|
|549003|[279, 280, 281, 2...|
|549004|[317, 318, 93, 31...|
|549005|[331, 332, 333, 3...|
|549006|[359, 360, 361, 3...|
|549007|[386, 387, 388, 3...|
|549008|[465, 466, 467, 4...|
|549009|[581, 582, 583, 5...|
|549010|[604, 605, 606, 6...|
|549011|[672, 673, 674, 6...|
|549012|[787, 788, 789, 7...|
|549013|[928, 929, 930, 9...|
|549014|[966, 967, 968, 9...|
|549015|[364, 1010, 1011,...|
|549016|[1033, 1034, 1035...|
|549017|[1069, 1070, 1071...|
|549018|[1141, 482, 1142,...|
|549019|[1152, 1153, 1154...|
+------+--------------------+
only showing top 20 rows



                                                                                

In [152]:
# Convert track IDs to strings in track_sequence
playlist_sequences = playlist_sequences.withColumn(
    "track_sequence", col("track_sequence").cast("array<string>")
)

In [153]:
# Train Word2Vec model
word2vec = Word2Vec(vectorSize=25, windowSize=3, minCount=5, inputCol="track_sequence", outputCol="contextual_embedding")
model = word2vec.fit(playlist_sequences)

In [154]:
# Transform to get contextual embeddings
playlist_with_embeddings = model.transform(playlist_sequences)
contextual_embeddings = playlist_with_embeddings.select("pid", "contextual_embedding")

In [216]:
contextual_embeddings.show()

+------+--------------------+
|   pid|contextual_embedding|
+------+--------------------+
|549000|[1.71278392275174...|
|549001|[0.0,0.0,0.0,0.0,...|
|549002|[5.59382088896301...|
|549003|[0.0,0.0,0.0,0.0,...|
|549004|[-5.0520233344286...|
|549005|[0.0,0.0,0.0,0.0,...|
|549006|[0.0,0.0,0.0,0.0,...|
|549007|[0.0,0.0,0.0,0.0,...|
|549008|[2.72288608695726...|
|549009|[0.0,0.0,0.0,0.0,...|
|549010|[0.0,0.0,0.0,0.0,...|
|549011|[-1.5225472009700...|
|549012|[-1.4411177908186...|
|549013|[0.0,0.0,0.0,0.0,...|
|549014|[0.0,0.0,0.0,0.0,...|
|549015|[0.0,0.0,0.0,0.0,...|
|549016|[3.75169902466810...|
|549017|[2.08801960191094...|
|549018|[0.00264966914740...|
|549019|[0.0,0.0,0.0,0.0,...|
+------+--------------------+
only showing top 20 rows



### Generate Feature-Based Embeddings

In [156]:
df_tracks = spark.read.parquet('../data/processed/df_tracks.parquet', header=True, inferSchema=True)

In [157]:
df_tracks.printSchema()

root
 |-- album_name: string (nullable = true)
 |-- album_uri: string (nullable = true)
 |-- artist_name: string (nullable = true)
 |-- artist_uri: string (nullable = true)
 |-- duration_ms: long (nullable = true)
 |-- track_name: string (nullable = true)
 |-- track_uri: string (nullable = true)
 |-- tid: long (nullable = true)



In [158]:
# Convert a single column to string type
# df_tracks = df_tracks.withColumn("tid", col("tid").cast("string"))

In [159]:
df_tracks.printSchema()

root
 |-- album_name: string (nullable = true)
 |-- album_uri: string (nullable = true)
 |-- artist_name: string (nullable = true)
 |-- artist_uri: string (nullable = true)
 |-- duration_ms: long (nullable = true)
 |-- track_name: string (nullable = true)
 |-- track_uri: string (nullable = true)
 |-- tid: long (nullable = true)



In [160]:
# filtered_df_tracks = df_tracks.filter(df_tracks.tid.isin(t_list))
df_tracks = df_tracks.filter(col("tid").isin(t_list))

In [161]:
df_tracks.count()

7911

In [162]:
# Index and encode categorical features
artist_indexer = StringIndexer(inputCol="artist_name", outputCol="artist_index")
album_indexer = StringIndexer(inputCol="album_name", outputCol="album_index")
artist_encoder = OneHotEncoder(inputCol="artist_index", outputCol="artist_vec")
album_encoder = OneHotEncoder(inputCol="album_index", outputCol="album_vec")


In [163]:
# VectorAssembler to convert duration_ms into a vector
duration_vector_assembler = VectorAssembler(inputCols=["duration_ms"], outputCol="duration_vector")

# MinMaxScaler for duration
scaler = MinMaxScaler(inputCol="duration_vector", outputCol="scaled_duration")


In [164]:
# Combine features
assembler = VectorAssembler(inputCols=["artist_vec", "album_vec", "scaled_duration"], outputCol="metadata_embedding")

In [165]:
# Build pipeline
# Pipeline
pipeline = Pipeline(stages=[
    artist_indexer,
    album_indexer,
    artist_encoder,
    album_encoder,
    duration_vector_assembler,  # Converts duration_ms to vector
    scaler,                     # Scales the vector
    assembler                   # Combines all features
])
metadata_model = pipeline.fit(df_tracks)


PipelineModel_218772bf1a61

In [166]:
track_with_metadata = metadata_model.transform(df_tracks)

In [167]:
# Select metadata embeddings
metadata_embeddings = track_with_metadata.select("tid", "metadata_embedding")

In [217]:
metadata_embeddings.show()

+---+--------------------+
|tid|  metadata_embedding|
+---+--------------------+
|  0|(7796,[67,7299,77...|
|  1|(7796,[67,3252,77...|
|  2|(7796,[2104,7086,...|
|  3|(7796,[67,4065,77...|
|  4|(7796,[67,4065,77...|
|  5|(7796,[67,5736,77...|
|  6|(7796,[67,4598,77...|
|  7|(7796,[67,3252,77...|
|  8|(7796,[67,3252,77...|
|  9|(7796,[248,3086,7...|
| 10|(7796,[67,6218,77...|
| 11|(7796,[6,7097,779...|
| 12|(7796,[666,4262,7...|
| 13|(7796,[2365,7664,...|
| 14|(7796,[248,3086,7...|
| 15|(7796,[1384,3751,...|
| 16|(7796,[492,4705,7...|
| 17|(7796,[1832,4657,...|
| 18|(7796,[273,3484,7...|
| 19|(7796,[449,7151,7...|
+---+--------------------+
only showing top 20 rows



### Combine Contextual and Metadata Embeddings

In [226]:
contextual_embeddings.printSchema()

root
 |-- pid: long (nullable = true)
 |-- contextual_embedding: vector (nullable = true)



In [227]:
metadata_embeddings.printSchema()

root
 |-- tid: long (nullable = true)
 |-- metadata_embedding: vector (nullable = true)



In [228]:
# Define UDF to concatenate embeddings
def combine_embeddings(contextual, metadata):
    return list(contextual) + list(metadata)

In [229]:
combine_udf = udf(combine_embeddings, ArrayType(FloatType()))

In [234]:
# Perform the join
hybrid_df = contextual_embeddings.join(
    metadata_embeddings,
    contextual_embeddings.pid == metadata_embeddings.tid,
    "outer"
).select(
    metadata_embeddings.tid.alias("tid"),  # Keep the track ID
    contextual_embeddings.contextual_embedding,
    metadata_embeddings.metadata_embedding
)

AssertionError: on should be Column or list of Column

In [232]:
hybrid_df.count()

8063

In [233]:
hybrid_df.show()

24/12/14 02:34:54 WARN DAGScheduler: Broadcasting large task binary with size 1072.2 KiB

+---+--------------------+--------------------+
|tid|contextual_embedding|  metadata_embedding|
+---+--------------------+--------------------+
|  0|                NULL|(7796,[67,7299,77...|
|  1|                NULL|(7796,[67,3252,77...|
|  2|                NULL|(7796,[2104,7086,...|
|  3|                NULL|(7796,[67,4065,77...|
|  4|                NULL|(7796,[67,4065,77...|
|  5|                NULL|(7796,[67,5736,77...|
|  6|                NULL|(7796,[67,4598,77...|
|  7|                NULL|(7796,[67,3252,77...|
|  8|                NULL|(7796,[67,3252,77...|
|  9|                NULL|(7796,[248,3086,7...|
| 10|                NULL|(7796,[67,6218,77...|
| 11|                NULL|(7796,[6,7097,779...|
| 12|                NULL|(7796,[666,4262,7...|
| 13|                NULL|(7796,[2365,7664,...|
| 14|                NULL|(7796,[248,3086,7...|
| 15|                NULL|(7796,[1384,3751,...|
| 16|                NULL|(7796,[492,4705,7...|
| 17|                NULL|(7796,[1832,46

24/12/14 02:35:28 WARN DAGScheduler: Broadcasting large task binary with size 1519.5 KiB
                                                                                

In [199]:
# Join datasets
hybrid_df = contextual_embeddings.join(metadata_embeddings, contextual_embeddings.pid == metadata_embeddings.tid, "outer")

In [200]:
hybrid_df.count()

8063

In [201]:
# Add hybrid embeddings
hybrid_df = hybrid_df.withColumn(
    "hybrid_embedding", combine_udf(col("contextual_embedding"), col("metadata_embedding"))
)

In [207]:
hybrid_df.head()

24/12/14 02:28:27 WARN DAGScheduler: Broadcasting large task binary with size 1072.2 KiB
24/12/14 02:28:27 WARN DAGScheduler: Broadcasting large task binary with size 1540.7 KiB
24/12/14 02:28:28 ERROR Executor: Exception in task 0.0 in stage 283.0 (TID 602)
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/var/folders/wj/fxlzyhvj6vbcslcq88wj0vf40000gn/T/ipykernel_9918/2405206578.py", line 3, in combine_embeddings
TypeError: 'NoneType' object is not iterable

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:572)
	at org.apache.spark.sql.execution.python.BasePythonUDFRunner$$anon$1.read(PythonUDFRunner.scala:94)
	at org.apache.spark.sql.execution.python.BasePythonUDFRunner$$anon$1.read(PythonUDFRunner.scala:75)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:525)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at sc

PythonException: 
  An exception was thrown from the Python worker. Please see the stack trace below.
Traceback (most recent call last):
  File "/var/folders/wj/fxlzyhvj6vbcslcq88wj0vf40000gn/T/ipykernel_9918/2405206578.py", line 3, in combine_embeddings
TypeError: 'NoneType' object is not iterable


### Compute Similarity Using Hybrid Embeddings

It will take long time to run because of the dataset,and it will crash if you haven't allocated memory

In [204]:
# Collect embeddings
# hybrid_data = hybrid_df.select("tid", "hybrid_embedding").toPandas()

# Write hybrid embeddings to disk
hybrid_df.select("tid", "hybrid_embedding").write.parquet("../data/processed/embeddings/hybrid_embeddings.parquet" , mode="overwrite")

# Later, read the file back in smaller chunks for processing
# hybrid_df_chunk = spark.read.parquet("hybrid_embeddings.parquet")

24/12/14 02:27:55 WARN DAGScheduler: Broadcasting large task binary with size 1072.2 KiB
24/12/14 02:27:56 WARN DAGScheduler: Broadcasting large task binary with size 1734.3 KiB
24/12/14 02:27:56 ERROR Utils: Aborting task
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/var/folders/wj/fxlzyhvj6vbcslcq88wj0vf40000gn/T/ipykernel_9918/2405206578.py", line 3, in combine_embeddings
TypeError: 'NoneType' object is not iterable

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:572)
	at org.apache.spark.sql.execution.python.BasePythonUDFRunner$$anon$1.read(PythonUDFRunner.scala:94)
	at org.apache.spark.sql.execution.python.BasePythonUDFRunner$$anon$1.read(PythonUDFRunner.scala:75)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:525)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$11.has

Py4JJavaError: An error occurred while calling o31729.parquet.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 262.0 failed 1 times, most recent failure: Lost task 0.0 in stage 262.0 (TID 565) (10.50.15.243 executor driver): org.apache.spark.SparkException: [TASK_WRITE_FAILED] Task failed while writing rows to file:/Users/muhammadehsansiddique/Downloads/MillionPlaylistProject/data/processed/embeddings/hybrid_embeddings.parquet.
	at org.apache.spark.sql.errors.QueryExecutionErrors$.taskFailedWhileWritingRowsError(QueryExecutionErrors.scala:775)
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$.executeTask(FileFormatWriter.scala:420)
	at org.apache.spark.sql.execution.datasources.WriteFilesExec.$anonfun$doExecuteWrite$1(WriteFiles.scala:100)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:893)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:893)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1144)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:642)
	at java.base/java.lang.Thread.run(Thread.java:1589)
Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/var/folders/wj/fxlzyhvj6vbcslcq88wj0vf40000gn/T/ipykernel_9918/2405206578.py", line 3, in combine_embeddings
TypeError: 'NoneType' object is not iterable

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:572)
	at org.apache.spark.sql.execution.python.BasePythonUDFRunner$$anon$1.read(PythonUDFRunner.scala:94)
	at org.apache.spark.sql.execution.python.BasePythonUDFRunner$$anon$1.read(PythonUDFRunner.scala:75)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:525)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:491)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage7.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:43)
	at org.apache.spark.sql.execution.datasources.FileFormatDataWriter.writeWithIterator(FileFormatDataWriter.scala:91)
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$.$anonfun$executeTask$1(FileFormatWriter.scala:403)
	at org.apache.spark.util.Utils$.tryWithSafeFinallyAndFailureCallbacks(Utils.scala:1397)
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$.executeTask(FileFormatWriter.scala:410)
	... 17 more

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2856)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2792)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2791)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2791)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1247)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3060)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2994)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2983)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:989)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2393)
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$.$anonfun$executeWrite$4(FileFormatWriter.scala:307)
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$.writeAndCommit(FileFormatWriter.scala:271)
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$.executeWrite(FileFormatWriter.scala:304)
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$.write(FileFormatWriter.scala:190)
	at org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand.run(InsertIntoHadoopFsRelationCommand.scala:190)
	at org.apache.spark.sql.execution.command.DataWritingCommandExec.sideEffectResult$lzycompute(commands.scala:113)
	at org.apache.spark.sql.execution.command.DataWritingCommandExec.sideEffectResult(commands.scala:111)
	at org.apache.spark.sql.execution.command.DataWritingCommandExec.executeCollect(commands.scala:125)
	at org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec.$anonfun$executeCollect$1(AdaptiveSparkPlanExec.scala:390)
	at org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec.withFinalPlanUpdate(AdaptiveSparkPlanExec.scala:418)
	at org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec.executeCollect(AdaptiveSparkPlanExec.scala:390)
	at org.apache.spark.sql.execution.QueryExecution$$anonfun$eagerlyExecuteCommands$1.$anonfun$applyOrElse$1(QueryExecution.scala:107)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$6(SQLExecution.scala:125)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:201)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:108)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:900)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:66)
	at org.apache.spark.sql.execution.QueryExecution$$anonfun$eagerlyExecuteCommands$1.applyOrElse(QueryExecution.scala:107)
	at org.apache.spark.sql.execution.QueryExecution$$anonfun$eagerlyExecuteCommands$1.applyOrElse(QueryExecution.scala:98)
	at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$1(TreeNode.scala:461)
	at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(origin.scala:76)
	at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:461)
	at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.org$apache$spark$sql$catalyst$plans$logical$AnalysisHelper$$super$transformDownWithPruning(LogicalPlan.scala:32)
	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.transformDownWithPruning(AnalysisHelper.scala:267)
	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.transformDownWithPruning$(AnalysisHelper.scala:263)
	at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.transformDownWithPruning(LogicalPlan.scala:32)
	at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.transformDownWithPruning(LogicalPlan.scala:32)
	at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:437)
	at org.apache.spark.sql.execution.QueryExecution.eagerlyExecuteCommands(QueryExecution.scala:98)
	at org.apache.spark.sql.execution.QueryExecution.commandExecuted$lzycompute(QueryExecution.scala:85)
	at org.apache.spark.sql.execution.QueryExecution.commandExecuted(QueryExecution.scala:83)
	at org.apache.spark.sql.execution.QueryExecution.assertCommandExecuted(QueryExecution.scala:142)
	at org.apache.spark.sql.DataFrameWriter.runCommand(DataFrameWriter.scala:869)
	at org.apache.spark.sql.DataFrameWriter.saveToV1Source(DataFrameWriter.scala:391)
	at org.apache.spark.sql.DataFrameWriter.saveInternal(DataFrameWriter.scala:364)
	at org.apache.spark.sql.DataFrameWriter.save(DataFrameWriter.scala:243)
	at org.apache.spark.sql.DataFrameWriter.parquet(DataFrameWriter.scala:802)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:76)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:52)
	at java.base/java.lang.reflect.Method.invoke(Method.java:578)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Thread.java:1589)
Caused by: org.apache.spark.SparkException: [TASK_WRITE_FAILED] Task failed while writing rows to file:/Users/muhammadehsansiddique/Downloads/MillionPlaylistProject/data/processed/embeddings/hybrid_embeddings.parquet.
	at org.apache.spark.sql.errors.QueryExecutionErrors$.taskFailedWhileWritingRowsError(QueryExecutionErrors.scala:775)
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$.executeTask(FileFormatWriter.scala:420)
	at org.apache.spark.sql.execution.datasources.WriteFilesExec.$anonfun$doExecuteWrite$1(WriteFiles.scala:100)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:893)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:893)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1144)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:642)
	... 1 more
Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/var/folders/wj/fxlzyhvj6vbcslcq88wj0vf40000gn/T/ipykernel_9918/2405206578.py", line 3, in combine_embeddings
TypeError: 'NoneType' object is not iterable

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:572)
	at org.apache.spark.sql.execution.python.BasePythonUDFRunner$$anon$1.read(PythonUDFRunner.scala:94)
	at org.apache.spark.sql.execution.python.BasePythonUDFRunner$$anon$1.read(PythonUDFRunner.scala:75)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:525)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:491)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage7.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:43)
	at org.apache.spark.sql.execution.datasources.FileFormatDataWriter.writeWithIterator(FileFormatDataWriter.scala:91)
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$.$anonfun$executeTask$1(FileFormatWriter.scala:403)
	at org.apache.spark.util.Utils$.tryWithSafeFinallyAndFailureCallbacks(Utils.scala:1397)
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$.executeTask(FileFormatWriter.scala:410)
	... 17 more


In [176]:
hybrid_df_chunk = spark.read.parquet("../data/processed/embeddings/hybrid_embeddings.parquet")

In [177]:
hybrid_df_chunk.count()

0

For small datasets

In [203]:
# Collect embeddings
hybrid_data = hybrid_df.select("tid", "hybrid_embedding").collect()

24/12/14 02:27:36 WARN DAGScheduler: Broadcasting large task binary with size 1072.2 KiB
24/12/14 02:27:38 WARN DAGScheduler: Broadcasting large task binary with size 1533.1 KiB
24/12/14 02:27:39 ERROR Executor: Exception in task 0.0 in stage 254.0 (TID 547)
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/var/folders/wj/fxlzyhvj6vbcslcq88wj0vf40000gn/T/ipykernel_9918/2405206578.py", line 3, in combine_embeddings
TypeError: 'NoneType' object is not iterable

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:572)
	at org.apache.spark.sql.execution.python.BasePythonUDFRunner$$anon$1.read(PythonUDFRunner.scala:94)
	at org.apache.spark.sql.execution.python.BasePythonUDFRunner$$anon$1.read(PythonUDFRunner.scala:75)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:525)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at sc

PythonException: 
  An exception was thrown from the Python worker. Please see the stack trace below.
Traceback (most recent call last):
  File "/var/folders/wj/fxlzyhvj6vbcslcq88wj0vf40000gn/T/ipykernel_9918/2405206578.py", line 3, in combine_embeddings
TypeError: 'NoneType' object is not iterable


In [179]:
hybrid_data

[]

In [180]:
# Prepare embeddings for similarity computation
track_ids = [row["tid"] for row in hybrid_data]
hybrid_embeddings = np.array([row["hybrid_embedding"] for row in hybrid_data])

In [181]:
hybrid_embeddings

array([], dtype=float64)

In [182]:
# Compute similarity
def compute_hybrid_similarity(embed1, embed2):
    return cosine_similarity([embed1], [embed2])[0][0]

In [183]:
# Example: Compute similarity between two tracks
track1_id, track2_id = '0', '1'
similarity_score = compute_hybrid_similarity(hybrid_embeddings[track1_id], hybrid_embeddings[track2_id])
print(f"Hybrid Similarity Score between Track {track1_id} and Track {track2_id}: {similarity_score:.2f}")

IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices

In [None]:
# Compute similarity matrix
num_tracks = len(hybrid_embeddings)
similarity_matrix = cosine_similarity(hybrid_embeddings)

# Convert to Pandas DataFrame for visualization
import pandas as pd

similarity_df = pd.DataFrame(similarity_matrix, index=track_ids, columns=track_ids)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 8))
sns.heatmap(similarity_df, cmap="coolwarm", xticklabels=False, yticklabels=False)
plt.title("Hybrid Track Similarity Heatmap")
plt.xlabel("Track ID")
plt.ylabel("Track ID")
plt.show()

So, Instead of collecting the embeddings to the driver, compute similarities in Spark

In [27]:
# Define a UDF for cosine similarity
def cosine_similarity_udf(v1, v2):
    v1 = np.array(v1)
    v2 = np.array(v2)
    sim = np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
    return float(sim)

cosine_similarity = udf(cosine_similarity_udf, FloatType())


[Stage 25:>                                                         (0 + 8) / 9]

In [23]:
# Compute pairwise similarities in Spark
hybrid_df = hybrid_df.withColumn("similarity", cosine_similarity(col("hybrid_embedding"), col("hybrid_embedding")))

In [24]:
hybrid_df.write.parquet("../data/processed/embeddings/hybrid/hybrid_embeddings.parquet", mode="overwrite")

24/12/13 10:28:44 WARN DAGScheduler: Broadcasting large task binary with size 96.7 MiB
24/12/13 10:29:07 WARN DAGScheduler: Broadcasting large task binary with size 180.0 MiB
ERROR:root:KeyboardInterrupt while sending command.                 (0 + 8) / 9]
Traceback (most recent call last):
  File "/Users/muhammadehsansiddique/miniforge3/lib/python3.9/site-packages/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
  File "/Users/muhammadehsansiddique/miniforge3/lib/python3.9/site-packages/py4j/clientserver.py", line 511, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
  File "/Users/muhammadehsansiddique/miniforge3/lib/python3.9/socket.py", line 704, in readinto
    return self._sock.recv_into(b)
KeyboardInterrupt


KeyboardInterrupt: 

In [32]:
# Example: Compute similarity between two tracks
track1_id, track2_id = 0, 1
similarity_score = compute_hybrid_similarity(hybrid_df[track1_id], hybrid_df[track2_id])
print(f"Hybrid Similarity Score between Track {track1_id} and Track {track2_id}: {similarity_score:.2f}")

TypeError: unsupported format string passed to Column.__format__

In [26]:
# Cartesian product of the DataFrame with itself
track_pairs = hybrid_df.alias("df1").crossJoin(hybrid_df.alias("df2")) \
    .select(
        col("df1.tid").alias("track1_id"),
        col("df2.tid").alias("track2_id"),
        col("df1.hybrid_embedding").alias("embedding1"),
        col("df2.hybrid_embedding").alias("embedding2")
    )

[Stage 25:>                                                         (0 + 8) / 9]

In [28]:
# Add a column for similarity
similarity_df = track_pairs.withColumn(
    "similarity",
    cosine_similarity(col("embedding1"), col("embedding2"))
)

In [29]:
# Example: Similarity between specific tracks
track1_id = "0"  # Replace with your track1 ID
track2_id = "1"  # Replace with your track2 ID

specific_similarity = similarity_df.filter(
    (col("track1_id") == track1_id) & (col("track2_id") == track2_id)
)

specific_similarity.show()


24/12/13 10:47:00 WARN DAGScheduler: Broadcasting large task binary with size 96.7 MiB
24/12/13 10:47:11 WARN DAGScheduler: Broadcasting large task binary with size 96.7 MiB
ERROR:root:KeyboardInterrupt while sending command.][Stage 27:>   (0 + 0) / 8]
Traceback (most recent call last):
  File "/Users/muhammadehsansiddique/miniforge3/lib/python3.9/site-packages/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
  File "/Users/muhammadehsansiddique/miniforge3/lib/python3.9/site-packages/py4j/clientserver.py", line 511, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
  File "/Users/muhammadehsansiddique/miniforge3/lib/python3.9/socket.py", line 704, in readinto
    return self._sock.recv_into(b)
KeyboardInterrupt


KeyboardInterrupt: 

In [None]:
# Example: Similarity for a specific track
specific_track_id = "0"  # Replace with your track ID

track_similarity = similarity_df.filter(
    (col("track1_id") == specific_track_id) | (col("track2_id") == specific_track_id)
)

track_similarity.show()


In [8]:
spark.stop()

ConnectionRefusedError: [Errno 61] Connection refused