In [0]:
%pip install pymysql

Python interpreter will be restarted.
Collecting pymysql
  Downloading PyMySQL-1.1.1-py3-none-any.whl (44 kB)
Installing collected packages: pymysql
Successfully installed pymysql-1.1.1
Python interpreter will be restarted.


In [0]:
%run ./secrets_notebook

## LOAD AND CLEAN DATA


In [0]:
from pyspark.sql.functions import col, to_timestamp, year, month, dayofmonth, dayofweek, hour, date_format, expr, when, count, sum, avg, date_trunc, round, row_number, lit
from pyspark.sql.window import Window
from pyspark.sql.types import StringType, ArrayType, StructType, StructField
import requests
import base64
import json
import time
import logging
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor
import os
# import pymysql

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

df = spark.read.option("header", "true") \
    .option("quote", "\"") \
    .option("escape", "\"") \
    .csv("/FileStore/tables/spotify_history.csv")
    
null_counts = df.select([sum(col(c).isNull().cast("int")).alias(c) for c in df.columns])
null_counts.show()

df = df.na.fill("unknown", ["reason_start", "reason_end"])
df = df.filter(col("ms_played") >= 500)
df_cleaned = df.dropDuplicates()

# Convert data types
df_cleaned = df_cleaned \
    .withColumn("ts", to_timestamp("ts", "yyyy-MM-dd HH:mm:ss")) \
    .withColumn("ms_played", col("ms_played").cast("integer")) \
    .withColumn("shuffle", col("shuffle").cast("boolean")) \
    .withColumn("skipped", col("skipped").cast("boolean"))



+-----------------+---+--------+---------+----------+-----------+----------+------------+----------+-------+-------+
|spotify_track_uri| ts|platform|ms_played|track_name|artist_name|album_name|reason_start|reason_end|shuffle|skipped|
+-----------------+---+--------+---------+----------+-----------+----------+------------+----------+-------+-------+
|                0|  0|       0|        0|         0|          0|         0|         143|       117|      0|      0|
+-----------------+---+--------+---------+----------+-----------+----------+------------+----------+-------+-------+



## CREATE DIMENSION TABLES

In [0]:
# DIM_TIME
window_time = Window.orderBy("ts")
dim_time = df_cleaned.select("ts").distinct() \
    .withColumn("time_id", row_number().over(window_time)) \
    .withColumn("date_full", date_trunc("day", col("ts"))) \
    .withColumn("day_of_week", date_format("ts", "E")) \
    .withColumn("hour", hour("ts"))

# DIM_ARTIST
window_artist = Window.orderBy("artist_name")
dim_artist = df_cleaned.select("artist_name").distinct() \
    .withColumn("artist_id", row_number().over(window_artist))

# DIM_ALBUM
dim_album_raw = df_cleaned.select("album_name", "artist_name").distinct()
dim_album = dim_album_raw.join(dim_artist, on="artist_name", how="inner")
window_album = Window.orderBy("album_name", "artist_id")
dim_album = dim_album.withColumn("album_id", row_number().over(window_album)) \
    .select("album_id", "album_name", "artist_id")

# DIM_TRACK 
window_spec = Window.partitionBy("spotify_track_uri").orderBy(col("ts").desc())
ranked_tracks = df_cleaned.select("spotify_track_uri", "track_name", "ts") \
    .withColumn("rank", row_number().over(window_spec))
dim_track = ranked_tracks.filter(col("rank") == 1) \
    .drop("rank", "ts") \
    .withColumnRenamed("spotify_track_uri", "track_id") \

# DIM_PLATFORM
window_platform = Window.orderBy("platform")
dim_platform = df_cleaned.select("platform").distinct() \
    .withColumn("platform_id", row_number().over(window_platform))

# DIM_REASON
window_reason = Window.orderBy("reason_start", "reason_end")
dim_reason = df_cleaned.select("reason_start", "reason_end").distinct() \
    .withColumn("reason_id", row_number().over(window_reason))



## CREATE FACT TABLE

In [0]:
c = df_cleaned.alias("c")
t = dim_time.alias("t")
tr = dim_track.alias("tr")
a = dim_artist.alias("a")
al = dim_album.alias("al")
p = dim_platform.alias("p")
r = dim_reason.alias("r")

fact_streams = c \
    .join(t, col("c.ts") == col("t.ts"), "inner") \
    .join(tr, col("c.spotify_track_uri") == col("tr.track_id"), "inner") \
    .join(a, col("c.artist_name") == col("a.artist_name"), "inner") \
    .join(al, (col("c.album_name") == col("al.album_name")) & 
               (col("a.artist_id") == col("al.artist_id")), "inner") \
    .join(p, col("c.platform") == col("p.platform"), "inner") \
    .join(r, (col("c.reason_start") == col("r.reason_start")) & 
              (col("c.reason_end") == col("r.reason_end")), "inner") \
    .select(
        col("t.time_id"),
        col("tr.track_id"),
        col("a.artist_id"),
        col("al.album_id"),
        col("p.platform_id"),
        col("r.reason_id"),
        col("c.ms_played"),
        col("c.shuffle"),
        col("c.skipped")
    )


## ADD TRACK IMAGE URLS FROM SPOTIFY API

In [0]:
# SPOTIFY_CREDENTIALS = [
#     {
#         "CLIENT_ID": "",
#         "CLIENT_SECRET": ""
#     },
#     {
#         "CLIENT_ID": "",
#         "CLIENT_SECRET": ""
#     },
#     {
#         "CLIENT_ID": "",
#         "CLIENT_SECRET": ""
#     }
# ]

token_cache = {}

def get_spotify_token(cred_index=0):
    current_time = int(time.time())
    
    if cred_index in token_cache and current_time < token_cache[cred_index]["expires_at"]:
        return token_cache[cred_index]["token"]
    
    credentials = SPOTIFY_CREDENTIALS[cred_index % len(SPOTIFY_CREDENTIALS)]
    auth_string = f"{credentials['CLIENT_ID']}:{credentials['CLIENT_SECRET']}"
    auth_base64 = base64.b64encode(auth_string.encode()).decode()

    try:
        with requests.Session() as session:
            response = session.post(
                "https://accounts.spotify.com/api/token",
                headers={
                    "Authorization": f"Basic {auth_base64}",
                    "Content-Type": "application/x-www-form-urlencoded"
                },
                data={"grant_type": "client_credentials"},
                timeout=5
            )
            response.raise_for_status()
            token_json = response.json()
            
            token_cache[cred_index] = {
                "token": token_json["access_token"],
                "expires_at": current_time + token_json.get("expires_in", 3600) - 60
            }
            return token_cache[cred_index]["token"]
    except Exception as e:
        logger.error(f"API #{cred_index}: Token error: {e}")
        raise


2025-06-11 02:50:04,613 - INFO - Received command c on object id p0


In [0]:
def get_track_images_parallel(track_ids, batch_size=100, max_workers=6):
    batches = [
        (track_ids[i:i+batch_size], i % len(SPOTIFY_CREDENTIALS))
        for i in range(0, len(track_ids), batch_size)
    ]
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {
            executor.submit(process_track_batch, batch): batch 
            for batch in batches
        }
        results = []
        
        for future in concurrent.futures.as_completed(futures):
            try:
                results.extend(future.result())
            except Exception as e:
                batch = futures[future]
                logger.warning(f"Failed batch: {str(e)[:100]}")
                results.extend([(tid, None) for tid in batch[0]])
    
    # result_dict = {}
    # for tid, url in results:
    #     if url:
    #         result_dict[tid] = url
    # return result_dict
    
    return {tid: url for tid, url in results if url}

2025-06-11 02:50:04,816 - INFO - Received command c on object id p0


In [0]:
def process_track_batch(batch_data):
    track_ids, cred_index = batch_data
    max_retries = 2 
    retry_count = 0
    
    while retry_count < max_retries:
        try:
            token = get_spotify_token(cred_index)
            cleaned_ids = [tid.split(":")[-1] for tid in track_ids]
            
            with requests.Session() as session:
                response = session.get(
                    f"https://api.spotify.com/v1/tracks?ids={','.join(cleaned_ids)}",
                    headers={"Authorization": f"Bearer {token}"},
                    timeout=10
                )
                
                if response.status_code == 429:
                    retry_after = int(response.headers.get("Retry-After", 2))
                    time.sleep(retry_after)
                    token_cache.pop(cred_index, None)
                    retry_count += 1
                    continue
                    
                if response.status_code == 401:
                    token_cache.pop(cred_index, None)
                    retry_count += 1
                    continue
                    
                response.raise_for_status()
                tracks = response.json().get("tracks", [])
                
                return [
                    (
                        track_ids[i],
                        max(t.get("album", {}).get("images", []), 
                            key=lambda x: x.get("width", 0))["url"] if t and t.get("album", {}).get("images") else None
                    )
                    for i, t in enumerate(tracks)
                ]
                
        except Exception as e:
            logger.debug(f"API #{cred_index}: Attempt {retry_count+1} failed: {str(e)[:100]}")
            retry_count += 1
            time.sleep(1)
    
    return [(tid, None) for tid in track_ids]

2025-06-11 02:50:04,919 - INFO - Received command c on object id p0


In [0]:
def update_track_images():
    dim_track.cache()
    
    track_ids = [row.track_id for row in dim_track.select("track_id").distinct().collect()]

    image_dict = get_track_images_parallel(track_ids)
    
    if image_dict:
        image_df = spark.createDataFrame(
            list(image_dict.items()), 
            "track_id string, image_url string"
        )
        
        return dim_track.join(
            image_df.hint("broadcast"),
            "track_id",
            "left"
        ).select(
            dim_track["*"],
            image_df["image_url"]
        )
    return dim_track

In [0]:
try:
    dim_track = update_track_images()
    logger.info("Track images update completed successfully!")
    
except Exception as e:
    logger.error(f"Critical error: {str(e)}")
    raise

2025-06-11 02:50:05,128 - INFO - Received command c on object id p0
2025-06-11 02:50:48,072 - INFO - Closing down clientserver connection
2025-06-11 02:50:49,536 - INFO - Track images update completed successfully!


In [0]:
invalid_track_ids = dim_track.filter(col("image_url").isNull()).select("track_id")
dim_track = dim_track.filter(col("image_url").isNotNull())

fact_streams = fact_streams.join(
    invalid_track_ids,
    on="track_id",
    how="left_anti"
)

valid_artist_ids = fact_streams.select("artist_id").distinct()
dim_artist = dim_artist.join(
    valid_artist_ids, 
    "artist_id", 
    "inner")

valid_album_ids = fact_streams.select("album_id").distinct()
dim_album = dim_album.join(
    valid_album_ids, 
    "album_id", 
    "inner")

In [0]:
dim_time.write.format("delta").mode("overwrite").saveAsTable("dim_time")
dim_artist.write.format("delta").mode("overwrite").saveAsTable("dim_artist")
dim_album.write.format("delta").mode("overwrite").saveAsTable("dim_album")
dim_track.write.format("delta").mode("overwrite").saveAsTable("dim_track")
dim_platform.write.format("delta").mode("overwrite").saveAsTable("dim_platform")
dim_reason.write.format("delta").mode("overwrite").saveAsTable("dim_reason")
fact_streams.write.format("delta").mode("overwrite").saveAsTable("fact_streams")

logger.info("Successfully wrote all data to Delta tables")


2025-05-22 04:47:22,559 - INFO - Received command c on object id p0
2025-05-22 04:51:30,902 - INFO - Successfully wrote all data to Delta tables


## Connect and write data to AWS RDS MySQL

In [0]:
# jdbc_url = ""
# db_user = ""
# db_password = ""

def test_connection():
    try:
        conn = pymysql.connect(
            host=my_host,
            user=db_user,
            password=db_password
        )
        cursor = conn.cursor()
        cursor.execute(f"CREATE DATABASE IF NOT EXISTS {database_name}")
        conn.close()
        logger.info("Successfully connected to RDS MySQL")
        return True
    except Exception as e:
        logger.error(f"Failed to connect to RDS MySQL: {str(e)}")
        return False

# Write DataFrame to MySQL
def write_to_mysql(df, table_name, mode="overwrite", batch_size=10000):
    try:
        if df.count() == 0:
            logger.warning(f"DataFrame for {table_name} is empty. Skipping.")
            return

        logger.info(f"Starting to write {df.count()} rows to {table_name}")

        df.write \
            .format("jdbc") \
            .option("url", jdbc_url) \
            .option("dbtable", table_name) \
            .option("user", db_user) \
            .option("password", db_password) \
            .option("driver", "com.mysql.cj.jdbc.Driver") \
            .option("batchsize", batch_size) \
            .option("numPartitions", min(df.rdd.getNumPartitions(), 10)) \
            .option("rewriteBatchedStatements", "true") \
            .option("connectTimeout", 30000) \
            .option("socketTimeout", 300000) \
            .mode(mode) \
            .save()

        logger.info(f"Successfully wrote data to {table_name}")

    except Exception as e:
        logger.error(f"Error writing to {table_name}: {str(e)}")
        raise

try:
    if test_connection():
        write_to_mysql(dim_time, "dim_time")
        write_to_mysql(dim_artist, "dim_artist")
        write_to_mysql(dim_album, "dim_album")
        write_to_mysql(dim_track, "dim_track")
        write_to_mysql(dim_platform, "dim_platform")
        write_to_mysql(dim_reason, "dim_reason")

        write_to_mysql(fact_streams, "fact_streams", batch_size=50000)

        logger.info("All data successfully written to AWS RDS MySQL!")
    else:
        logger.error("Failed to establish connection to RDS. Data not written.")
except Exception as e:
    logger.error(f"An error occurred in the MySQL data writing process: {str(e)}")

2025-05-13 04:10:14,561 - INFO - Received command c on object id p0
2025-05-13 04:10:15,452 - INFO - Successfully connected to RDS MySQL
2025-05-13 04:10:21,203 - INFO - Starting to write 134819 rows to dim_time
2025-05-13 04:10:39,231 - INFO - Successfully wrote data to dim_time
2025-05-13 04:10:43,581 - INFO - Starting to write 4068 rows to dim_artist
2025-05-13 04:10:51,697 - INFO - Successfully wrote data to dim_artist
2025-05-13 04:11:00,082 - INFO - Starting to write 8337 rows to dim_album
2025-05-13 04:11:11,929 - INFO - Successfully wrote data to dim_album
2025-05-13 04:11:25,076 - INFO - Starting to write 16268 rows to dim_track
2025-05-13 04:11:37,753 - INFO - Successfully wrote data to dim_track
2025-05-13 04:11:41,390 - INFO - Starting to write 6 rows to dim_platform
2025-05-13 04:11:48,631 - INFO - Successfully wrote data to dim_platform
2025-05-13 04:11:53,170 - INFO - Starting to write 88 rows to dim_reason
2025-05-13 04:12:00,567 - INFO - Successfully wrote data to dim_