In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, split, explode, when, lit
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, BooleanType
import time

# Initialize SparkSession with optimized settings
spark = SparkSession.builder \
    .master("spark://192.168.2.205:7077") \
    .appName("LoadLargeLyricsDataToPostgres") \
    .config("spark.executor.instances", "4") \
    .config("spark.executor.cores", "4") \
    .config("spark.executor.memory", "8g") \
    .config("spark.default.parallelism", "16") \
    .config("spark.sql.shuffle.partitions", "16") \
    .getOrCreate()

In [2]:
# Verify cluster setup
print("Cluster Info:")
print(spark.sparkContext._jsc.sc().getExecutorMemoryStatus())

Cluster Info:
Map(group29:43163 -> (384093388,384093388))


In [3]:
# PostgreSQL connection properties
pg_url = "jdbc:postgresql://192.168.2.205:5432/lyricsdb"
pg_properties = {
    "user": "sparkuser",
    "password": "9567",
    "driver": "org.postgresql.Driver",
    "batchsize": "10000"  # Faster writes for large data
}

In [4]:
# Process lyrics files (train/test)
def process_lyrics_file(file_path, is_test):
    start_time = time.time()
    
    raw_df = spark.read.text(file_path).repartition(16)
    
    # Corrected the substr call to include length
    top_words = raw_df.filter(col("value").startswith("%")) \
        .select(split(col("value").substr(2, 1000), ",").alias("words")) \
        .first()["words"]
    
    words_df = spark.createDataFrame(
        [(i + 1, word) for i, word in enumerate(top_words)],
        ["word_id", "word"]
    ).repartition(4)
    
    lyrics_raw = raw_df.filter(~col("value").startswith("#") & 
                              ~col("value").startswith("%") & 
                              (col("value") != ""))
    
    lyrics_split = lyrics_raw.select(
        split(col("value"), ",").alias("parts")
    ).filter(col("parts").getItem(0).isNotNull()) \
     .select(
        col("parts").getItem(0).alias("track_id"),
        col("parts").getItem(1).alias("mxm_tid"),
        split(col("parts").getItem(2), ":").alias("word_counts")  # Assuming word_counts is in "word:count" format
    )
    
    lyrics_exploded = lyrics_split.select(
        "track_id",
        "mxm_tid",
        explode(col("word_counts")).alias("word_count")
    ).select(
        "track_id",
        "mxm_tid",
        split(col("word_count"), ":").getItem(0).cast("int").alias("word_id"),
        split(col("word_count"), ":").getItem(1).cast("int").alias("count")
    ).withColumn("is_test", when(lit(is_test), True).otherwise(False)) \
     .repartition(16)
    
    end_time = time.time()
    print(f"Processing {file_path} took {end_time - start_time:.2f} seconds")
    return words_df, lyrics_exploded

In [5]:
# Total timing start
total_start_time = time.time()

# Load train and test files
print("Loading train data...")
train_words_df, train_lyrics_df = process_lyrics_file("/home/ubuntu/DE1-G29/mxm_dataset_train.txt", False)
print("Loading test data...")
test_words_df, test_lyrics_df = process_lyrics_file("/home/ubuntu/DE1-G29/mxm_dataset_test.txt", True)

# Combine data
print("Combining data...")
start_time = time.time()
words_df = train_words_df.union(test_words_df).distinct().cache()
lyrics_df = train_lyrics_df.union(test_lyrics_df).repartition(16).cache()
end_time = time.time()
print(f"Combining took {end_time - start_time:.2f} seconds")

Loading train data...


                                                                                

Processing /home/ubuntu/DE1-G29/mxm_dataset_train.txt took 4.66 seconds
Loading test data...
Processing /home/ubuntu/DE1-G29/mxm_dataset_test.txt took 0.40 seconds
Combining data...
Combining took 0.24 seconds


In [6]:
print("Loading matches data...")
start_time = time.time()
matches_raw = spark.read.text("/home/ubuntu/DE1-G29/mxm_779k_matches.txt") \
    .filter(~col("value").startswith("#") & (col("value") != "")) \
    .repartition(16)
matches_df = matches_raw.select(
    split(col("value"), "\\|").alias("parts")  # Escape | since it’s a regex special char
).select(
    col("parts").getItem(0).alias("track_id"),
    col("parts").getItem(3).cast("int").alias("mxm_tid"),
    col("parts").getItem(1).alias("msd_artist_name"),
    col("parts").getItem(2).alias("msd_title"),
    col("parts").getItem(4).alias("mxm_artist_name"),
    col("parts").getItem(5).alias("mxm_title")
).repartition(16).cache()
end_time = time.time()
print(f"Loading matches took {end_time - start_time:.2f} seconds")

Loading matches data...
Loading matches took 0.11 seconds


In [7]:
print(spark.sparkContext._jsc.sc().getClass().getClassLoader().getResource("org/postgresql/Driver.class"))


jar:file:/home/ubuntu/spark-3.1.2-bin-hadoop3.2/jars/postgresql-42.5.0.jar!/org/postgresql/Driver.class


In [36]:
# Write to PostgreSQL
print("Writing to PostgreSQL...")
start_time = time.time()
words_df.write.jdbc(url=pg_url, table="words", mode="append", properties=pg_properties)
print(f"Words write took {time.time() - start_time:.2f} seconds")

start_time = time.time()
lyrics_df.write.jdbc(url=pg_url, table="lyrics", mode="append", properties=pg_properties)
print(f"Lyrics write took {time.time() - start_time:.2f} seconds")

start_time = time.time()
matches_df.write.jdbc(url=pg_url, table="matches", mode="append", properties=pg_properties)
print(f"Matches write took {time.time() - start_time:.2f} seconds")

# Total time
total_end_time = time.time()
print(f"Total execution time: {total_end_time - total_start_time:.2f} seconds")

Writing to PostgreSQL...


Py4JJavaError: An error occurred while calling o821.jdbc.
: java.lang.ClassNotFoundException: org.postgresql.Driver
	at java.net.URLClassLoader.findClass(URLClassLoader.java:387)
	at java.lang.ClassLoader.loadClass(ClassLoader.java:418)
	at java.lang.ClassLoader.loadClass(ClassLoader.java:351)
	at org.apache.spark.sql.execution.datasources.jdbc.DriverRegistry$.register(DriverRegistry.scala:46)
	at org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions.$anonfun$driverClass$1(JDBCOptions.scala:102)
	at org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions.$anonfun$driverClass$1$adapted(JDBCOptions.scala:102)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions.<init>(JDBCOptions.scala:102)
	at org.apache.spark.sql.execution.datasources.jdbc.JdbcOptionsInWrite.<init>(JDBCOptions.scala:217)
	at org.apache.spark.sql.execution.datasources.jdbc.JdbcOptionsInWrite.<init>(JDBCOptions.scala:221)
	at org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider.createRelation(JdbcRelationProvider.scala:45)
	at org.apache.spark.sql.execution.datasources.SaveIntoDataSourceCommand.run(SaveIntoDataSourceCommand.scala:46)
	at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult$lzycompute(commands.scala:70)
	at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult(commands.scala:68)
	at org.apache.spark.sql.execution.command.ExecutedCommandExec.doExecute(commands.scala:90)
	at org.apache.spark.sql.execution.SparkPlan.$anonfun$execute$1(SparkPlan.scala:180)
	at org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:218)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:215)
	at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:176)
	at org.apache.spark.sql.execution.QueryExecution.toRdd$lzycompute(QueryExecution.scala:132)
	at org.apache.spark.sql.execution.QueryExecution.toRdd(QueryExecution.scala:131)
	at org.apache.spark.sql.DataFrameWriter.$anonfun$runCommand$1(DataFrameWriter.scala:989)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$5(SQLExecution.scala:103)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:163)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:90)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:775)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:64)
	at org.apache.spark.sql.DataFrameWriter.runCommand(DataFrameWriter.scala:989)
	at org.apache.spark.sql.DataFrameWriter.saveToV1Source(DataFrameWriter.scala:438)
	at org.apache.spark.sql.DataFrameWriter.saveInternal(DataFrameWriter.scala:415)
	at org.apache.spark.sql.DataFrameWriter.save(DataFrameWriter.scala:301)
	at org.apache.spark.sql.DataFrameWriter.jdbc(DataFrameWriter.scala:817)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:750)
