In [39]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.feature import Normalizer
from pyspark.sql.functions import rand
from pyspark.ml.linalg import Vectors, VectorUDT
from pyspark.sql.functions import udf
from pyspark.ml.feature import StringIndexer
from pyspark.sql.types import ArrayType, DoubleType
from pyspark.sql.functions import size, max as max_, array
from annoy import AnnoyIndex
from pyspark.sql.functions import col
import numpy as np
import librosa
import os
import gc
import audioread
from audioread.exceptions import NoBackendError

In [2]:
def extract_mfcc(audio_path):
    try:
        audio, sample_rate = librosa.load(audio_path, sr=None)
        mfcc_features = librosa.feature.mfcc(y=audio, sr=sample_rate, n_mfcc=100).mean(axis=1)
        return mfcc_features.tolist()
    except NoBackendError as e:
        print("No suitable audio backend found:", e)

In [3]:
def load_data_from_mongodb(spark_session, input_uri):
    df = spark_session.read.format("com.mongodb.spark.sql.DefaultSource").option("uri", input_uri).load()
    return df

In [4]:
array_to_vector_udf = udf(lambda arr: Vectors.dense(arr), VectorUDT())
def create_feature_vector(df):
    df = df.withColumn("features", array_to_vector_udf(df["mfcc_features"]))
    return df

In [5]:
def normalize_features(df):
    normalizer = Normalizer(inputCol="features", outputCol="normalized_features")
    df = normalizer.transform(df)
    return df

In [6]:
def train_recommendation_model(df, num_trees=10):
    features = df.select("features").collect()
    dim = len(features[0].features)

    print("Dimension of features:", dim)

    annoy_index = AnnoyIndex(dim, 'angular')
    for i, row in enumerate(features):
        vector = row.features
        vector_length = len(vector)
        if vector_length != dim:
            print("Vector length mismatch at index", i, "- Expected:", dim, "Got:", vector_length)
            continue
        annoy_index.add_item(i, vector)

    annoy_index.build(num_trees)
    annoy_index.save('music_recommendation.ann')

    return annoy_index

In [29]:
def load_trained_annoy_index():
    annoy_index = AnnoyIndex(100, 'angular')
    annoy_index.load('music_recommendation.ann')
    return annoy_index

In [8]:
def evaluate_model(model, test_data):
    return

In [9]:
def split_train_test_data(df, train_ratio=0.8):
    train_data = df.filter(rand() < train_ratio)
    test_data = df.subtract(train_data)
    return train_data, test_data

In [10]:
def add_user_id_column(df, user_id_column):
    indexer = StringIndexer(inputCol=user_id_column, outputCol="user_id")
    indexed_df = indexer.fit(df).transform(df)
    return indexed_df

In [11]:
def find_nearest_neighbors(annoy_index, target_vector, k=5):
    nearest_neighbors = annoy_index.get_nns_by_vector(target_vector, k)
    return nearest_neighbors

In [13]:
input_uri = "mongodb://localhost:27017/bda.audio_features"

spark = SparkSession.builder \
    .appName("MusicRecommendationModel") \
    .config("spark.mongodb.input.uri", input_uri) \
    .config("spark.jars.packages", "org.mongodb.spark:mongo-spark-connector_2.12:3.0.2") \
    .getOrCreate()

24/05/12 00:25:13 WARN Utils: Your hostname, moaz-HP-ProBook-440-G5 resolves to a loopback address: 127.0.1.1; using 192.168.10.5 instead (on interface wlp2s0)
24/05/12 00:25:13 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address


:: loading settings :: url = jar:file:/home/moaz/spark/jars/ivy-2.5.1.jar!/org/apache/ivy/core/settings/ivysettings.xml


Ivy Default Cache set to: /home/moaz/.ivy2/cache
The jars for the packages stored in: /home/moaz/.ivy2/jars
org.mongodb.spark#mongo-spark-connector_2.12 added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-8bb65ed2-02a9-4846-b149-28dd59acb3a3;1.0
	confs: [default]
	found org.mongodb.spark#mongo-spark-connector_2.12;3.0.2 in central
	found org.mongodb#mongodb-driver-sync;4.0.5 in central
	found org.mongodb#bson;4.0.5 in central
	found org.mongodb#mongodb-driver-core;4.0.5 in central
:: resolution report :: resolve 251ms :: artifacts dl 16ms
	:: modules in use:
	org.mongodb#bson;4.0.5 from central in [default]
	org.mongodb#mongodb-driver-core;4.0.5 from central in [default]
	org.mongodb#mongodb-driver-sync;4.0.5 from central in [default]
	org.mongodb.spark#mongo-spark-connector_2.12;3.0.2 from central in [default]
	---------------------------------------------------------------------
	|                  |            modules            ||   artifacts   |

In [14]:
def process_audio_file(audio_path):
    audio_features = extract_mfcc(audio_path)
    audio_features_df = spark.createDataFrame([(audio_path, audio_features)], ["audio_path", "mfcc_features"])
    audio_features_df = create_feature_vector(audio_features_df)
    audio_features_df = normalize_features(audio_features_df)
    return audio_features_df

In [15]:
df = load_data_from_mongodb(spark, input_uri)
df.show()

                                                                                

+--------------------+--------------------+--------------------+
|                 _id|          audio_path|       mfcc_features|
+--------------------+--------------------+--------------------+
|{663f94f62f15021b...|DataSet/067/06732...|[-206.81163024902...|
|{663f94f62f15021b...|DataSet/067/06705...|[-252.43721008300...|
|{663f94f62f15021b...|DataSet/067/06759...|[-218.26321411132...|
|{663f94f62f15021b...|DataSet/067/06776...|[-214.04769897460...|
|{663f94f62f15021b...|DataSet/067/06774...|[-138.56770324707...|
|{663f94f62f15021b...|DataSet/067/06716...|[7.75728654861450...|
|{663f94f62f15021b...|DataSet/096/09671...|[-164.58990478515...|
|{663f94f62f15021b...|DataSet/096/09616...|[-142.40655517578...|
|{663f94f62f15021b...|DataSet/096/09627...|[-340.62588500976...|
|{663f94f62f15021b...|DataSet/096/09602...|[-125.80657958984...|
|{663f94f62f15021b...|DataSet/096/09664...|[-214.04400634765...|
|{663f94f62f15021b...|DataSet/096/09693...|[-64.120185852050...|
|{663f94f62f15021b...|Dat

In [16]:
df = create_feature_vector(df)
df.show()

[Stage 2:>                                                          (0 + 1) / 1]

+--------------------+--------------------+--------------------+--------------------+
|                 _id|          audio_path|       mfcc_features|            features|
+--------------------+--------------------+--------------------+--------------------+
|{663f94f62f15021b...|DataSet/067/06732...|[-206.81163024902...|[-206.81163024902...|
|{663f94f62f15021b...|DataSet/067/06705...|[-252.43721008300...|[-252.43721008300...|
|{663f94f62f15021b...|DataSet/067/06759...|[-218.26321411132...|[-218.26321411132...|
|{663f94f62f15021b...|DataSet/067/06776...|[-214.04769897460...|[-214.04769897460...|
|{663f94f62f15021b...|DataSet/067/06774...|[-138.56770324707...|[-138.56770324707...|
|{663f94f62f15021b...|DataSet/067/06716...|[7.75728654861450...|[7.75728654861450...|
|{663f94f62f15021b...|DataSet/096/09671...|[-164.58990478515...|[-164.58990478515...|
|{663f94f62f15021b...|DataSet/096/09616...|[-142.40655517578...|[-142.40655517578...|
|{663f94f62f15021b...|DataSet/096/09627...|[-340.62588

                                                                                

In [17]:
df = normalize_features(df)
df.show()

[Stage 3:>                                                          (0 + 1) / 1]

+--------------------+--------------------+--------------------+--------------------+--------------------+
|                 _id|          audio_path|       mfcc_features|            features| normalized_features|
+--------------------+--------------------+--------------------+--------------------+--------------------+
|{663f94f62f15021b...|DataSet/067/06732...|[-206.81163024902...|[-206.81163024902...|[-0.7371349737352...|
|{663f94f62f15021b...|DataSet/067/06705...|[-252.43721008300...|[-252.43721008300...|[-0.8747194103813...|
|{663f94f62f15021b...|DataSet/067/06759...|[-218.26321411132...|[-218.26321411132...|[-0.6919126717382...|
|{663f94f62f15021b...|DataSet/067/06776...|[-214.04769897460...|[-214.04769897460...|[-0.8262855612396...|
|{663f94f62f15021b...|DataSet/067/06774...|[-138.56770324707...|[-138.56770324707...|[-0.6987036729530...|
|{663f94f62f15021b...|DataSet/067/06716...|[7.75728654861450...|[7.75728654861450...|[0.06079818933145...|
|{663f94f62f15021b...|DataSet/096/096

                                                                                

In [18]:
df = add_user_id_column(df, "audio_path")
df.show()

24/05/12 00:25:33 WARN DAGScheduler: Broadcasting large task binary with size 5.3 MiB
[Stage 7:>                                                          (0 + 1) / 1]

+--------------------+--------------------+--------------------+--------------------+--------------------+-------+
|                 _id|          audio_path|       mfcc_features|            features| normalized_features|user_id|
+--------------------+--------------------+--------------------+--------------------+--------------------+-------+
|{663f94f62f15021b...|DataSet/067/06732...|[-206.81163024902...|[-206.81163024902...|[-0.7371349737352...|45921.0|
|{663f94f62f15021b...|DataSet/067/06705...|[-252.43721008300...|[-252.43721008300...|[-0.8747194103813...|45705.0|
|{663f94f62f15021b...|DataSet/067/06759...|[-218.26321411132...|[-218.26321411132...|[-0.6919126717382...|46148.0|
|{663f94f62f15021b...|DataSet/067/06776...|[-214.04769897460...|[-214.04769897460...|[-0.8262855612396...|46279.0|
|{663f94f62f15021b...|DataSet/067/06774...|[-138.56770324707...|[-138.56770324707...|[-0.6987036729530...|46254.0|
|{663f94f62f15021b...|DataSet/067/06716...|[7.75728654861450...|[7.7572865486145

                                                                                

In [19]:
train_data, test_data = split_train_test_data(df)

In [20]:
train_data.show()

24/05/12 00:25:35 WARN DAGScheduler: Broadcasting large task binary with size 5.3 MiB
[Stage 8:>                                                          (0 + 1) / 1]

+--------------------+--------------------+--------------------+--------------------+--------------------+-------+
|                 _id|          audio_path|       mfcc_features|            features| normalized_features|user_id|
+--------------------+--------------------+--------------------+--------------------+--------------------+-------+
|{663f94f62f15021b...|DataSet/067/06732...|[-206.81163024902...|[-206.81163024902...|[-0.7371349737352...|45921.0|
|{663f94f62f15021b...|DataSet/067/06705...|[-252.43721008300...|[-252.43721008300...|[-0.8747194103813...|45705.0|
|{663f94f62f15021b...|DataSet/067/06759...|[-218.26321411132...|[-218.26321411132...|[-0.6919126717382...|46148.0|
|{663f94f62f15021b...|DataSet/067/06776...|[-214.04769897460...|[-214.04769897460...|[-0.8262855612396...|46279.0|
|{663f94f62f15021b...|DataSet/067/06716...|[7.75728654861450...|[7.75728654861450...|[0.06079818933145...|45763.0|
|{663f94f62f15021b...|DataSet/096/09671...|[-164.58990478515...|[-164.5899047851

                                                                                

In [21]:
test_data.show()

24/05/12 00:25:37 WARN DAGScheduler: Broadcasting large task binary with size 5.3 MiB
24/05/12 00:25:38 WARN DAGScheduler: Broadcasting large task binary with size 5.3 MiB
24/05/12 00:25:52 WARN DAGScheduler: Broadcasting large task binary with size 5.3 MiB

+--------------------+--------------------+--------------------+--------------------+--------------------+--------+
|                 _id|          audio_path|       mfcc_features|            features| normalized_features| user_id|
+--------------------+--------------------+--------------------+--------------------+--------------------+--------+
|{663f94f72f15021b...|DataSet/090/09061...|[-75.388412475585...|[-75.388412475585...|[-0.4684452248217...| 60827.0|
|{663f972c2f15021b...|DataSet/003/00359...|[-262.57876586914...|[-262.57876586914...|[-0.8240702605563...|  1918.0|
|{663f99722f15021b...|DataSet/038/03814...|[-102.44507598876...|[-102.44507598876...|[-0.6650440054836...| 24563.0|
|{663f99cf2f15021b...|DataSet/065/06574...|[-232.19203186035...|[-232.19203186035...|[-0.7309460564184...| 44552.0|
|{663f9bbd2f15021b...|DataSet/130/13072...|[-201.38792419433...|[-201.38792419433...|[-0.6949045812334...| 88460.0|
|{663f9f162f15021b...|DataSet/011/01112...|[-156.21083068847...|[-156.21

24/05/12 00:25:58 WARN DAGScheduler: Broadcasting large task binary with size 5.3 MiB
                                                                                

In [22]:
selected_train_data = train_data.select("user_id", "normalized_features", "audio_path")
selected_train_data = selected_train_data.withColumnRenamed("normalized_features", "features")
selected_train_data.show()

24/05/12 00:25:58 WARN DAGScheduler: Broadcasting large task binary with size 5.3 MiB


+-------+--------------------+--------------------+
|user_id|            features|          audio_path|
+-------+--------------------+--------------------+
|45921.0|[-0.7371349737352...|DataSet/067/06732...|
|45705.0|[-0.8747194103813...|DataSet/067/06705...|
|46148.0|[-0.6919126717382...|DataSet/067/06759...|
|46279.0|[-0.8262855612396...|DataSet/067/06776...|
|45763.0|[0.06079818933145...|DataSet/067/06716...|
|64636.0|[-0.6678599763898...|DataSet/096/09671...|
|64329.0|[-0.6462094523192...|DataSet/096/09616...|
|64351.0|[-0.8703146004582...|DataSet/096/09627...|
|64562.0|[-0.7333173542057...|DataSet/096/09664...|
|60952.0|[-0.7540288007909...|DataSet/090/09083...|
|64482.0|[-0.4635524587552...|DataSet/096/09650...|
|60986.0|[-0.7108369184895...|DataSet/090/09089...|
|60937.0|[-0.7867684544029...|DataSet/090/09080...|
|60721.0|[-0.5967708371627...|DataSet/090/09031...|
|78323.0|[-0.7470270151995...|DataSet/117/11745...|
|78228.0|[-0.7159934470662...|DataSet/117/11731...|
|78328.0|[-0

In [23]:
model = train_recommendation_model(selected_train_data)

                                                                                

Dimension of features: 100
Vector length mismatch at index 556 - Expected: 100 Got: 0
Vector length mismatch at index 631 - Expected: 100 Got: 0
Vector length mismatch at index 697 - Expected: 100 Got: 0
Vector length mismatch at index 793 - Expected: 100 Got: 0
Vector length mismatch at index 864 - Expected: 100 Got: 0
Vector length mismatch at index 1390 - Expected: 100 Got: 0
Vector length mismatch at index 1618 - Expected: 100 Got: 0
Vector length mismatch at index 1657 - Expected: 100 Got: 0
Vector length mismatch at index 2125 - Expected: 100 Got: 0
Vector length mismatch at index 2170 - Expected: 100 Got: 0
Vector length mismatch at index 2352 - Expected: 100 Got: 0
Vector length mismatch at index 3106 - Expected: 100 Got: 0
Vector length mismatch at index 3592 - Expected: 100 Got: 0
Vector length mismatch at index 5182 - Expected: 100 Got: 0
Vector length mismatch at index 5470 - Expected: 100 Got: 0
Vector length mismatch at index 5849 - Expected: 100 Got: 0
Vector length mism

In [48]:
audio_path = "Sampled_Dataset/002/002108.mp3"
audio_features_df = process_audio_file(audio_path)
annoy_index = load_trained_annoy_index()

In [49]:
target_features = audio_features_df.collect()[0]["features"]
nearest_neighbors = find_nearest_neighbors(annoy_index, target_features, 100)
neighbors_data = train_data.select("audio_path").collect()
nearest_neighbor_paths = [neighbors_data[i]["audio_path"] for i in nearest_neighbors]

print("Top 100 Nearest Neighbors:")
for i in range(100):
    print(f"Neighbor {i + 1}:")
    print(f"Index: {nearest_neighbors[i]}")
    print(f"Path: {nearest_neighbor_paths[i]}\n")

Top 100 Nearest Neighbors:
Neighbor 1:
Index: 1225
Path: DataSet/117/117994.mp3

Neighbor 2:
Index: 37609
Path: DataSet/032/032480.mp3

Neighbor 3:
Index: 57213
Path: DataSet/014/014306.mp3

Neighbor 4:
Index: 11853
Path: DataSet/082/082695.mp3

Neighbor 5:
Index: 9345
Path: DataSet/029/029227.mp3

Neighbor 6:
Index: 34406
Path: DataSet/114/114573.mp3

Neighbor 7:
Index: 51913
Path: DataSet/050/050507.mp3

Neighbor 8:
Index: 41180
Path: DataSet/138/138668.mp3

Neighbor 9:
Index: 5452
Path: DataSet/019/019271.mp3

Neighbor 10:
Index: 79220
Path: DataSet/000/000648.mp3

Neighbor 11:
Index: 36865
Path: DataSet/032/032100.mp3

Neighbor 12:
Index: 11207
Path: DataSet/001/001308.mp3

Neighbor 13:
Index: 81549
Path: DataSet/092/092882.mp3

Neighbor 14:
Index: 31718
Path: DataSet/133/133446.mp3

Neighbor 15:
Index: 19124
Path: DataSet/124/124091.mp3

Neighbor 16:
Index: 49742
Path: DataSet/060/060718.mp3

Neighbor 17:
Index: 26500
Path: DataSet/044/044029.mp3

Neighbor 18:
Index: 80804
Path: D