In [58]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lit, when, abs

In [2]:
 # Initialize SparkSession
spark = SparkSession.builder \
        .appName("Playlist Similarity") \
        .config("spark.driver.memory", "4g") \
        .config("spark.executor.memory", "4g") \
        .config("spark.executor.cores", "4") \
        .config("spark.driver.extraJavaOptions", "-Djava.security.manager=allow") \
        .getOrCreate()

24/12/14 03:57:55 WARN Utils: Your hostname, Muhammads-MacBook-Pro.local resolves to a loopback address: 127.0.0.1; using 10.50.15.243 instead (on interface en0)
24/12/14 03:57:55 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/12/14 03:57:56 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
df = spark.read.parquet('../data/processed/df_tracks.parquet', header=True, inferSchema=True)

In [4]:
df.count()

2262292

In [7]:
df = df.limit(10000)

In [8]:
df.show()

                                                                                

+--------------------+--------------------+-----------------+--------------------+-----------+--------------------+--------------------+---+
|          album_name|           album_uri|      artist_name|          artist_uri|duration_ms|          track_name|           track_uri|tid|
+--------------------+--------------------+-----------------+--------------------+-----------+--------------------+--------------------+---+
|The Times They Ar...|spotify:album:7DZ...|        Bob Dylan|spotify:artist:74...|     277106|Boots of Spanish ...|spotify:track:6QH...|  0|
|Bringing It All B...|spotify:album:1lP...|        Bob Dylan|spotify:artist:74...|     330533|  Mr. Tambourine Man|spotify:track:3Rk...|  1|
|The Best: Loggins...|spotify:album:5BW...|Loggins & Messina|spotify:artist:7e...|     254653|        Danny's Song|spotify:track:0ju...|  2|
|The Freewheelin' ...|spotify:album:0o1...|        Bob Dylan|spotify:artist:74...|     412200|A Hard Rain's A-G...|spotify:track:7ny...|  3|
|The Freewhee

In [9]:
# Perform Cartesian join to compute all pairwise comparisons
pairwise_df = df.alias("a").crossJoin(df.alias("b")).filter(col("a.tid") < col("b.tid"))


In [10]:
pairwise_df.show()

+--------------------+--------------------+-----------+--------------------+-----------+--------------------+--------------------+---+--------------------+--------------------+-----------------+--------------------+-----------+--------------------+--------------------+---+
|          album_name|           album_uri|artist_name|          artist_uri|duration_ms|          track_name|           track_uri|tid|          album_name|           album_uri|      artist_name|          artist_uri|duration_ms|          track_name|           track_uri|tid|
+--------------------+--------------------+-----------+--------------------+-----------+--------------------+--------------------+---+--------------------+--------------------+-----------------+--------------------+-----------+--------------------+--------------------+---+
|The Times They Ar...|spotify:album:7DZ...|  Bob Dylan|spotify:artist:74...|     277106|Boots of Spanish ...|spotify:track:6QH...|  0|Bringing It All B...|spotify:album:1lP...|  

In [11]:
# Define similarity calculation
def compute_similarity(df):
    return df.withColumn(
        "artist_similarity", when(col("a.artist_name") == col("b.artist_name"), lit(1.0)).otherwise(lit(0.0))
    ).withColumn(
        "album_similarity", when(col("a.album_name") == col("b.album_name"), lit(1.0)).otherwise(lit(0.0))
    ).withColumn(
        "duration_diff", abs(col("a.duration_ms") - col("b.duration_ms"))
    ).withColumn(
        "duration_similarity", 1 - (col("duration_diff") / lit(500000))  # Scale duration diff to [0, 1]
    ).withColumn(
        "similarity_score",
        (col("artist_similarity") * 0.5) + (col("album_similarity") * 0.3) + (col("duration_similarity") * 0.2)
    )

In [12]:
result_df = compute_similarity(pairwise_df)

In [14]:
result_df.show()

+--------------------+--------------------+-----------+--------------------+-----------+--------------------+--------------------+---+--------------------+--------------------+-----------------+--------------------+-----------+--------------------+--------------------+---+-----------------+----------------+-------------+-------------------+-------------------+
|          album_name|           album_uri|artist_name|          artist_uri|duration_ms|          track_name|           track_uri|tid|          album_name|           album_uri|      artist_name|          artist_uri|duration_ms|          track_name|           track_uri|tid|artist_similarity|album_similarity|duration_diff|duration_similarity|   similarity_score|
+--------------------+--------------------+-----------+--------------------+-----------+--------------------+--------------------+---+--------------------+--------------------+-----------------+--------------------+-----------+--------------------+--------------------+---+-

In [15]:
spark.stop()

In [28]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import StringIndexer, VectorAssembler, MinMaxScaler, Normalizer
from pyspark.sql.functions import col, expr, round
from pyspark.sql.functions import udf
from pyspark.sql.types import DoubleType
from pyspark.ml.stat import MultivariateGaussian
from pyspark.ml.functions import vector_to_array
import math



# Function to set up PySpark session
def setup_spark(app_name="TrackSimilarity"):
    return SparkSession.builder \
        .appName("Track_Similarity") \
        .config("spark.driver.memory", "4g") \
        .config("spark.executor.memory", "4g") \
        .config("spark.executor.cores", "4") \
        .config("spark.driver.extraJavaOptions", "-Djava.security.manager=allow") \
        .config("spark.sql.shuffle.partitions", "200") \
        .config("spark.executor.extraJavaOptions", "-XX:+UseG1GC -XX:InitiatingHeapOccupancyPercent=35") \
        .getOrCreate()

# Function to perform feature engineering
def feature_engineering(df):
    # Encode categorical variables
    album_indexer = StringIndexer(inputCol="album_name", outputCol="album_index")
    artist_indexer = StringIndexer(inputCol="artist_name", outputCol="artist_index")
    df = album_indexer.fit(df).transform(df)
    df = artist_indexer.fit(df).transform(df)
    
    # Combine features into a vector
    feature_cols = ["album_index", "artist_index", "duration_ms"]
    assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")
    df = assembler.transform(df)
    return df

# Function to scale features
def scale_features(df):
    scaler = MinMaxScaler(inputCol="features", outputCol="scaled_features")
    scaler_model = scaler.fit(df)
    df = scaler_model.transform(df)
    
    # Normalize features
    normalizer = Normalizer(inputCol="scaled_features", outputCol="normalized_features", p=2.0)
    df = normalizer.transform(df)
    # Convert normalized vector to array
    df = df.withColumn("normalized_array", vector_to_array(col("normalized_features")))
    return df

# Define a UDF for the dot product
def dot_product(vector1, vector2):
    return float(sum(x * y for x, y in zip(vector1, vector2)))

# Define a UDF for the magnitude (norm) of a vector
def magnitude(vector):
    return float(math.sqrt(sum(x**2 for x in vector)))

# Define the final cosine similarity function
def cosine_similarity(vector1, vector2):
    dot_prod = dot_product(vector1, vector2)
    mag1 = magnitude(vector1)
    mag2 = magnitude(vector2)
    
    # To prevent division by zero, return 0 if either magnitude is zero
    if mag1 == 0 or mag2 == 0:
        return 0.0
    else:
        return dot_prod / (mag1 * mag2)

# Register the UDF
cosine_similarity_udf = udf(cosine_similarity, DoubleType())

# Update pairwise similarity computation
def compute_pairwise_similarity_with_udf(df):
    # Perform Cartesian join for pairwise comparisons
    pairwise_df = df.alias("a").crossJoin(df.alias("b")).filter(col("a.tid") < col("b.tid"))
    
    # Calculate cosine similarity using the UDF
    pairwise_df = pairwise_df.withColumn(
        "cosine_similarity",
        cosine_similarity_udf(col("a.normalized_array"), col("b.normalized_array"))
    )
    return pairwise_df

def display_results(pairwise_df):
    pairwise_df.select(
        col("a.track_name").alias("track_1"),
        col("b.track_name").alias("track_2"),
        round("cosine_similarity", 3).alias("track similarity")
    ).orderBy(col("track similarity").asc()).show(truncate=False)

In [29]:
# Main function to orchestrate the workflow
def main():
    # Step 1: Set up Spark session
    spark = setup_spark()

    # Step 2: Load data
    df = spark.read.parquet('../data/processed/df_tracks.parquet', header=True, inferSchema=True)
    df = df.limit(1000)

    # Step 3: Feature engineering
    df = feature_engineering(df)

    # Step 4: Scale and normalize features
    df = scale_features(df)

    # Step 5: Compute pairwise similarity
    pairwise_df = compute_pairwise_similarity_with_udf(df)

    # Step 6: Display results
    display_results(pairwise_df)

    # Stop Spark session
    spark.stop()

# Entry point
if __name__ == "__main__":
    main()

                                                                                

+--------------------------------+-------------+----------------+
|track_1                         |track_2      |track similarity|
+--------------------------------+-------------+----------------+
|Ego - Remix                     |Birthday Cake|0.01            |
|Halo                            |Birthday Cake|0.011           |
|If I Were a Boy                 |Birthday Cake|0.012           |
|El Equipo Codiciado             |Birthday Cake|0.012           |
|Ego                             |Birthday Cake|0.013           |
|Upgrade U                       |Birthday Cake|0.013           |
|Sobrino del Doctor Veterinario  |Birthday Cake|0.013           |
|El Pariente                     |Birthday Cake|0.013           |
|N1, El Perfil O El Chavalón     |Birthday Cake|0.013           |
|El Xof                          |Birthday Cake|0.014           |
|Video Phone                     |Birthday Cake|0.015           |
|Deja Vu                         |Birthday Cake|0.016           |
|Diva     

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import explode, col, lit, monotonically_increasing_id
from pyspark.sql.functions import broadcast

def create_dataframes(data_path, challenge_file, output_path, partitions):
    """
    Process playlist and track data using Spark, and save the output as CSV files.

    Args:
        data_path (str): Path to the directory containing JSON files.
        challenge_file (str): Path to the challenge_set.json file.
        output_path (str): Path to save the resulting CSV files.

    Returns:
        None
    """

     # Initialize SparkSession with parallelism configurations
    spark = SparkSession.builder \
        .appName("Create Spark DataFrames for Playlist Dataset") \
        .config("spark.sql.shuffle.partitions", str(partitions)) \
        .config("spark.executor.memory", "4g") \
        .config("spark.executor.cores", "4") \
        .config("spark.driver.memory", "4g") \
        .getOrCreate()
    
    
    # Define column sets
    playlist_col = ['collaborative', 'duration_ms', 'modified_at', 
                    'name', 'num_albums', 'num_artists', 'num_edits',
                    'num_followers', 'num_tracks', 'pid']
    tracks_col = ['album_name', 'album_uri', 'artist_name', 'artist_uri', 
                  'duration_ms', 'track_name', 'track_uri'] 
    playlist_test_col = ['name', 'num_holdouts', 'num_samples', 'num_tracks', 'pid']

    # Read all JSON files in the directory
    df_raw = spark.read.option("multiline", "true").json(f"{data_path}/*.json").repartition(partitions)

    # Extract playlist data
    df_playlists = df_raw.select(explode(col("playlists")).alias("playlist"))

    # Extract playlist-level information
    df_playlists_info = df_playlists.select(*[col(f"playlist.{cols}").alias(cols) for cols in playlist_col])

    # Extract track-level information
    df_tracks = df_playlists.select(
        col("playlist.pid").alias("pid"),
        explode(col("playlist.tracks")).alias("track")
    ).select(
        col("track.track_uri").alias("track_uri1"),
        *[col(f"track.{cols}").alias(cols) for cols in tracks_col]
    ).drop_duplicates()

    
    
    # Add unique track ID (tid) in parallel
    df_tracks = df_tracks.withColumn("tid", monotonically_increasing_id())

    # Join playlist and track information to create a relationship DataFrame
    df_playlists_tracks = df_playlists.select(
        col("playlist.pid").alias("pid"),
        explode(col("playlist.tracks")).alias("track")
    ).select(
        col("pid"),
        col("track.track_uri").alias("track_uri"),
        col("track.pos").alias("pos")
    )
    
    df_playlists_tracks = df_playlists_tracks.join(broadcast(df_tracks), on="track_uri", how="left")
    
    # Join with track ID (tid)
#     df_playlists_tracks = df_playlists_tracks.join(df_tracks, on="track_uri", how="left")

    # Process challenge set
    df_challenge_raw = spark.read.option("multiline", "true").json(challenge_file).repartition(partitions)
    df_playlists_test_info = df_challenge_raw.select(
        explode(col("playlists")).alias("playlist")
    ).select(*[col(f"playlist.{cols}").alias(cols) for cols in playlist_test_col])

    df_playlists_test = df_challenge_raw.select(
        explode(col("playlists")).alias("playlist")
    ).select(
        col("playlist.pid").alias("pid"),
        explode(col("playlist.tracks")).alias("track")
    ).select(
        col("pid"),
        col("track.track_uri").alias("track_uri"),
        col("track.pos").alias("pos")
    ).join(broadcast(df_tracks), on="track_uri", how="left")

    # Save DataFrames as CSV files
    df_playlists_info.write.parquet(f"{output_path}/df_playlists_info_spark", header=True, mode="overwrite")
    df_tracks.write.parquet(f"{output_path}/df_tracks_spark", header=True, mode="overwrite")
    df_playlists_tracks.write.parquet(f"{output_path}/df_playlists_spark", header=True, mode="overwrite")
    df_playlists_test_info.write.parquet(f"{output_path}/df_playlists_test_info_spark", header=True, mode="overwrite")
    df_playlists_test.write.parquet(f"{output_path}/df_playlists_test_spark", header=True, mode="overwrite")

    print("DataFrames successfully created and saved as parquet files.")

if __name__ == "__main__":
    # Define paths
    data_path = "../data/raw/data"  # Path to directory with JSON files
    challenge_file = "../data/raw/challenge_set.json"  # Challenge set file
    output_path = "../data/processed/"  # Output directory for CSV files
     # Number of partitions for parallelism
    num_partitions = 50  # Adjust based on your system's resources

    # Run the function
    create_dataframes(data_path, challenge_file, output_path, partitions=num_partitions)

In [35]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import explode, col, lit, monotonically_increasing_id
from pyspark.sql.functions import broadcast


data_path = "../data/raw/data"
# Initialize SparkSession with parallelism configurations
spark = SparkSession.builder \
        .appName("Create Spark DataFrames for Playlist Dataset") \
        .config("spark.sql.shuffle.partitions", 200) \
        .config("spark.executor.memory", "4g") \
        .config("spark.executor.cores", "4") \
        .config("spark.driver.memory", "4g") \
        .getOrCreate()
    
    
    # Define column sets
playlist_col = ['collaborative', 'duration_ms', 'modified_at', 
                    'name', 'num_albums', 'num_artists', 'num_edits',
                    'num_followers', 'num_tracks', 'pid']
tracks_col = ['album_name', 'album_uri', 'artist_name', 'artist_uri', 
                  'duration_ms', 'track_name', 'track_uri'] 
playlist_test_col = ['name', 'num_holdouts', 'num_samples', 'num_tracks', 'pid']

    # Read all JSON files in the directory
df_raw = spark.read.option("multiline", "true").json(f"{data_path}/*.json").repartition(200)

    # Extract playlist data
df_playlists = df_raw.select(explode(col("playlists")).alias("playlist"))

    # Extract playlist-level information
df_playlists_info = df_playlists.select(*[col(f"playlist.{cols}").alias(cols) for cols in playlist_col])

                                                                                

In [44]:
  # Extract track-level information
df_tracks = df_playlists.select(
        col("playlist.pid").alias("pid"),
        explode(col("playlist.tracks")).alias("track")
    ).select(
        *[col(f"track.{cols}").alias(cols) for cols in tracks_col]
    ).drop_duplicates()

df_tracks = df_tracks.withColumn("tid", monotonically_increasing_id())

In [45]:
df_playlists.printSchema()

root
 |-- playlist: struct (nullable = true)
 |    |-- collaborative: string (nullable = true)
 |    |-- description: string (nullable = true)
 |    |-- duration_ms: long (nullable = true)
 |    |-- modified_at: long (nullable = true)
 |    |-- name: string (nullable = true)
 |    |-- num_albums: long (nullable = true)
 |    |-- num_artists: long (nullable = true)
 |    |-- num_edits: long (nullable = true)
 |    |-- num_followers: long (nullable = true)
 |    |-- num_tracks: long (nullable = true)
 |    |-- pid: long (nullable = true)
 |    |-- tracks: array (nullable = true)
 |    |    |-- element: struct (containsNull = true)
 |    |    |    |-- album_name: string (nullable = true)
 |    |    |    |-- album_uri: string (nullable = true)
 |    |    |    |-- artist_name: string (nullable = true)
 |    |    |    |-- artist_uri: string (nullable = true)
 |    |    |    |-- duration_ms: long (nullable = true)
 |    |    |    |-- pos: long (nullable = true)
 |    |    |    |-- track_name:

In [46]:
df_tracks.printSchema()

root
 |-- album_name: string (nullable = true)
 |-- album_uri: string (nullable = true)
 |-- artist_name: string (nullable = true)
 |-- artist_uri: string (nullable = true)
 |-- duration_ms: long (nullable = true)
 |-- track_name: string (nullable = true)
 |-- track_uri: string (nullable = true)
 |-- tid: long (nullable = false)



In [57]:
# Add unique track ID (tid) in parallel
df_tracks = df_tracks.withColumn("tid", monotonically_increasing_id())

    # Join playlist and track information to create a relationship DataFrame
df_playlists_tracks = df_playlists.select(
        col("playlist.pid").alias("pid"),
        explode(col("playlist.tracks")).alias("track")
    ).select(
        col("pid"),
        col("track.tid").alias("tid"),
        col("track.pos").alias("pos")
    )

AnalysisException: [FIELD_NOT_FOUND] No such struct field `tid` in `album_name`, `album_uri`, `artist_name`, `artist_uri`, `duration_ms`, `pos`, `track_name`, `track_uri`.

In [49]:
df_playlists_tracks.printSchema()

root
 |-- pid: long (nullable = true)
 |-- pos: long (nullable = true)



In [5]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import explode, col, lit, monotonically_increasing_id
from pyspark.sql.functions import broadcast

def create_dataframes(data_path, challenge_file, output_path, partitions):
    """
    Process playlist and track data using Spark, and save the output as CSV files.

    Args:
        data_path (str): Path to the directory containing JSON files.
        challenge_file (str): Path to the challenge_set.json file.
        output_path (str): Path to save the resulting CSV files.

    Returns:
        None
    """

     # Initialize SparkSession with parallelism configurations
    spark = SparkSession.builder \
        .appName("Create Spark DataFrames for Playlist Dataset") \
        .config("spark.sql.shuffle.partitions", str(partitions)) \
        .config("spark.executor.memory", "4g") \
        .config("spark.executor.cores", "4") \
        .config("spark.driver.memory", "4g") \
        .getOrCreate()
    
    
    # Define column sets
    playlist_col = ['collaborative', 'duration_ms', 'modified_at', 
                    'name', 'num_albums', 'num_artists', 'num_edits',
                    'num_followers', 'num_tracks', 'pid']
    tracks_col = ['album_name', 'album_uri', 'artist_name', 'artist_uri', 
                  'duration_ms', 'track_name', 'track_uri'] 
    playlist_test_col = ['name', 'num_holdouts', 'num_samples', 'num_tracks', 'pid']

    # Read all JSON files in the directory
    df_raw = spark.read.option("multiline", "true").json(f"{data_path}/*.json").repartition(partitions)

    # Extract playlist data
    df_playlists = df_raw.select(explode(col("playlists")).alias("playlist"))

    # Extract playlist-level information
    df_playlists_info = df_playlists.select(*[col(f"playlist.{cols}").alias(cols) for cols in playlist_col])

    # Extract track-level information
    df_tracks = df_playlists.select(
        col("playlist.pid").alias("pid"),
        explode(col("playlist.tracks")).alias("track")
    ).select(
        col("track.track_uri").alias("track_uri1"),
        *[col(f"track.{cols}").alias(cols) for cols in tracks_col]
    ).drop_duplicates()

    
    
    # Add unique track ID (tid) in parallel
    df_tracks = df_tracks.withColumn("tid", monotonically_increasing_id())

    # Join playlist and track information to create a relationship DataFrame
    df_playlists_tracks = df_playlists.select(
        col("playlist.pid").alias("pid"),
        explode(col("playlist.tracks")).alias("track")
    ).select(
        col("pid"),
        col("track.track_uri").alias("track_uri"),
        col("track.pos").alias("pos")
    )
    
    df_playlists_tracks = df_playlists_tracks.join(broadcast(df_tracks), on="track_uri", how="left")
    
    # Join with track ID (tid)
#     df_playlists_tracks = df_playlists_tracks.join(df_tracks, on="track_uri", how="left")

    # Process challenge set
    df_challenge_raw = spark.read.option("multiline", "true").json(challenge_file).repartition(partitions)
    df_playlists_test_info = df_challenge_raw.select(
        explode(col("playlists")).alias("playlist")
    ).select(*[col(f"playlist.{cols}").alias(cols) for cols in playlist_test_col])

    df_playlists_test = df_challenge_raw.select(
        explode(col("playlists")).alias("playlist")
    ).select(
        col("playlist.pid").alias("pid"),
        explode(col("playlist.tracks")).alias("track")
    ).select(
        col("pid"),
        col("track.track_uri").alias("track_uri"),
        col("track.pos").alias("pos")
    ).join(broadcast(df_tracks), on="track_uri", how="left")
    
    print("track schema:{}".format(df_tracks.printSchema()))
    print("playlist schema:{}".format(df_playlists_info.printSchema()))
    print("playlist track schema:{}".format(df_playlists_tracks.printSchema()))
    print("playlist test info schema:{}".format(df_playlists_test_info.printSchema()))
    print("playlist test schema:{}".format(df_playlists_test.printSchema()))
    
    
    
    

#     # Save DataFrames as CSV files
#     df_playlists_info.write.parquet(f"{output_path}/df_playlists_info_spark", header=True, mode="overwrite")
#     df_tracks.write.parquet(f"{output_path}/df_tracks_spark", header=True, mode="overwrite")
#     df_playlists_tracks.write.parquet(f"{output_path}/df_playlists_spark", header=True, mode="overwrite")
#     df_playlists_test_info.write.parquet(f"{output_path}/df_playlists_test_info_spark", header=True, mode="overwrite")
#     df_playlists_test.write.parquet(f"{output_path}/df_playlists_test_spark", header=True, mode="overwrite")

    print("DataFrames successfully created and saved as parquet files.")
    spark.stop()

if __name__ == "__main__":
    # Define paths
    data_path = "../data/raw/data"  # Path to directory with JSON files
    challenge_file = "../data/raw/challenge_set.json"  # Challenge set file
    output_path = "../data/processed/"  # Output directory for CSV files
     # Number of partitions for parallelism
    num_partitions = 50  # Adjust based on your system's resources

    # Run the function
    create_dataframes(data_path, challenge_file, output_path, partitions=num_partitions)

                                                                                

root
 |-- track_uri1: string (nullable = true)
 |-- album_name: string (nullable = true)
 |-- album_uri: string (nullable = true)
 |-- artist_name: string (nullable = true)
 |-- artist_uri: string (nullable = true)
 |-- duration_ms: long (nullable = true)
 |-- track_name: string (nullable = true)
 |-- track_uri: string (nullable = true)
 |-- tid: long (nullable = false)

track schema:None
root
 |-- collaborative: string (nullable = true)
 |-- duration_ms: long (nullable = true)
 |-- modified_at: long (nullable = true)
 |-- name: string (nullable = true)
 |-- num_albums: long (nullable = true)
 |-- num_artists: long (nullable = true)
 |-- num_edits: long (nullable = true)
 |-- num_followers: long (nullable = true)
 |-- num_tracks: long (nullable = true)
 |-- pid: long (nullable = true)

playlist schema:None
root
 |-- track_uri: string (nullable = true)
 |-- pid: long (nullable = true)
 |-- pos: long (nullable = true)
 |-- track_uri1: string (nullable = true)
 |-- album_name: string (nul

In [9]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import explode, col, lit, monotonically_increasing_id
from pyspark.sql.functions import broadcast


def create_dataframes(data_path, challenge_file, output_path, partitions):
    """
    Process playlist and track data using Spark, and save the output as CSV files.

    Args:
        data_path (str): Path to the directory containing JSON files.
        challenge_file (str): Path to the challenge_set.json file.
        output_path (str): Path to save the resulting CSV files.

    Returns:
        None
    """

    # Initialize SparkSession with parallelism configurations, 4 cores, 4GB memory
    spark = SparkSession.builder \
        .appName("Create Spark DataFrames for Playlist Dataset") \
        .config("spark.sql.shuffle.partitions", str(partitions)) \
        .config("spark.driver.memory", "4g") \
        .config("spark.executor.memory", "4g") \
        .config("spark.executor.cores", "4") \
        .config("spark.driver.extraJavaOptions", "-Djava.security.manager=allow") \
        .getOrCreate()

    # Define column sets
    playlist_col = ['collaborative', 'duration_ms', 'modified_at',
                    'name', 'num_albums', 'num_artists', 'num_edits',
                    'num_followers', 'num_tracks', 'pid']
    tracks_col = ['album_name', 'album_uri', 'artist_name', 'artist_uri',
                  'duration_ms', 'track_name', 'track_uri']
    playlist_test_col = ['name', 'pid']

    # Read all JSON files in the directory
    df_raw = spark.read.option("multiline", "true").json(f"{data_path}/*.json").repartition(partitions)

    # Extract playlist data
    df_playlists = df_raw.select(explode(col("playlists")).alias("playlist"))

    # Extract playlist-level information
    df_playlists_info = df_playlists.select(*[col(f"playlist.{cols}").alias(cols) for cols in playlist_col])

    # Extract track-level information
    df_tracks = df_playlists.select(
        col("playlist.pid").alias("pid"),
        explode(col("playlist.tracks")).alias("track")
    ).select(
        *[col(f"track.{cols}").alias(cols) for cols in tracks_col]
    ).drop_duplicates()

    # Add unique track ID (tid) in parallel
    df_tracks = df_tracks.withColumn("tid", monotonically_increasing_id())

    # Join playlist and track information to create a relationship DataFrame
    df_playlists_tracks = df_playlists.select(
        col("playlist.pid").alias("pid"),
        explode(col("playlist.tracks")).alias("track")
    ).select(
        col("pid"),
        col("track.track_uri").alias("track_uri"),
        col("track.pos").alias("pos")
    )

    df_playlists_tracks = df_playlists_tracks.join(broadcast(df_tracks), on="track_uri", how="left")

    # Join with track ID (tid)
    #     df_playlists_tracks = df_playlists_tracks.join(df_tracks, on="track_uri", how="left")

    # Process challenge set
    df_challenge_raw = spark.read.option("multiline", "true").json(challenge_file).repartition(partitions)

    # Extract playlist data
    df_playlists_challenge = df_challenge_raw.select(explode(col("playlists")).alias("playlist"))

    # Extract playlist-level information
    df_playlists_challenge_info = df_playlists_challenge.select(*[col(f"playlist.{cols}").alias(cols) for cols in playlist_test_col])
#     df_playlists_challenge_info.printSchema()

    # Extract track-level information
    df_tracks_challenge = df_playlists_challenge.select(
        col("playlist.pid").alias("pid"),
        explode(col("playlist.tracks")).alias("track")
    ).select(
        *[col(f"track.{cols}").alias(cols) for cols in tracks_col]
    ).drop_duplicates()

    # Add unique track ID (tid) in parallel
    df_tracks_challenge = df_tracks_challenge.withColumn("tid", monotonically_increasing_id())
#     df_tracks_challenge.printSchema()

    # Join playlist and track information to create a relationship DataFrame
    df_playlists_tracks_challenge = df_playlists_challenge.select(
        col("playlist.pid").alias("pid"),
        explode(col("playlist.tracks")).alias("track")
    ).select(
        col("pid"),
        col("track.track_uri").alias("track_uri"),
        col("track.pos").alias("pos")
    )
    
    print("playlist_schema:{}".format(df_playlists_info.printSchema()))
    print("playlist_schema:{}".format(df_tracks.printSchema()))
    print("playlist_schema:{}".format(df_playlists_tracks.printSchema()))
          
    df_playlists_tracks_challenge = df_playlists_tracks_challenge.join(broadcast(df_tracks_challenge), on="track_uri", how="left")
#     df_playlists_tracks_challenge.printSchema()

    

#     # Save DataFrames as parquet files
#     df_playlists_info.write.parquet(f"{output_path}/df_playlists_info_spark", mode="overwrite")
#     df_tracks.write.parquet(f"{output_path}/df_tracks_spark", mode="overwrite")
#     df_playlists_tracks.write.parquet(f"{output_path}/df_playlists_spark", mode="overwrite")

#     # Save challenge DataFrames as parquet files
#     df_playlists_challenge_info.write.parquet(f"{output_path}/df_playlists_test_info_spark", mode="overwrite")
#     df_tracks_challenge.write.parquet(f"{output_path}/df_tracks_test_spark", mode="overwrite")
#     df_playlists_tracks_challenge.write.parquet(f"{output_path}/df_playlists_test_spark", mode="overwrite")

    print("DataFrames successfully created and saved as parquet files.")

    # Stop the SparkSession
    spark.stop()


if __name__ == "__main__":
    # Define paths
    data_path = "../data/raw/data"  # Path to directory with JSON files
    challenge_file = "../data/raw/challenge_set.json"  # Challenge set file
    output_path = "../data/processed/"  # Output directory for CSV files
    # Number of partitions for parallelism
    num_partitions = 50  # Adjust based on your system's resources

    # Run the function
    create_dataframes(data_path, challenge_file, output_path, partitions=num_partitions)


                                                                                

root
 |-- collaborative: string (nullable = true)
 |-- duration_ms: long (nullable = true)
 |-- modified_at: long (nullable = true)
 |-- name: string (nullable = true)
 |-- num_albums: long (nullable = true)
 |-- num_artists: long (nullable = true)
 |-- num_edits: long (nullable = true)
 |-- num_followers: long (nullable = true)
 |-- num_tracks: long (nullable = true)
 |-- pid: long (nullable = true)

playlist_schema:None
root
 |-- album_name: string (nullable = true)
 |-- album_uri: string (nullable = true)
 |-- artist_name: string (nullable = true)
 |-- artist_uri: string (nullable = true)
 |-- duration_ms: long (nullable = true)
 |-- track_name: string (nullable = true)
 |-- track_uri: string (nullable = true)
 |-- tid: long (nullable = false)

playlist_schema:None
root
 |-- track_uri: string (nullable = true)
 |-- pid: long (nullable = true)
 |-- pos: long (nullable = true)
 |-- album_name: string (nullable = true)
 |-- album_uri: string (nullable = true)
 |-- artist_name: string (