In [None]:
from pyspark.sql import SparkSession, types as T, functions as F, window as W
from pyspark.context import SparkConf

conf = (
    SparkConf()
    .setAppName("exploration")
    .setMaster("local[*]")
    .set("spark.executor.memory", "8g")
    .set("spark.driver.memory", "8g")
    .set("spark.driver.maxResultSize", "4g")
    .set("spark.sql.execution.arrow.pyspark.enabled", "true")
)
spark = SparkSession.builder.config(conf=conf).getOrCreate()

# Read the data

In [None]:
# Dataset needs to be downloaded first

schema = T.StructType([
    T.StructField("user_id", T.StringType(), True),
    T.StructField("timestamp", T.TimestampType(), True),
    T.StructField("artist_id", T.StringType(), True),
    T.StructField("artist_name", T.StringType(), True),
    T.StructField("track_id", T.StringType(), True),
    T.StructField("track_name", T.StringType(), True),
])

df = spark.read.csv(
    "lastfm-dataset-1K/userid-timestamp-artid-artname-traid-traname.tsv",
    sep="\t",
    header=False,
    schema=schema,
)
df.createOrReplaceTempView("plays_raw")
df.show(vertical=False, truncate=True, n=5)


# Pyspark DataFrames

In [None]:
# What are the top 10 songs played in the top 50 longest sessions by tracks count?
w1 = W.Window.partitionBy("user_id").orderBy("timestamp")

plays_df = (
    df
    .select(
        '*',
        (F.col('timestamp').cast('long') - F.lag('timestamp').over(w1).cast('long')).alias('inactive_time'),
    )
    .select(
        *df.columns,
        F.sum(F.when(F.col('inactive_time') > 1200, 1).otherwise(0)).over(w1).alias('session_id'),
    )
)

sessions_df = (
    plays_df
    .groupBy('user_id', 'session_id')
    .agg(
        (F.max('timestamp') - F.min('timestamp')).cast('long').alias('session_duration'),
    )
    .orderBy(F.col('session_duration').desc())
    .limit(50)
)

top10 = (
    sessions_df.join(plays_df, on=['user_id', 'session_id'], how="inner")
    .groupBy('track_id')
    .agg(F.count('track_id').alias('count'))
    .orderBy(F.col('count').desc())
    .limit(10)
)
top10.show(vertical=False, truncate=True, n=10)

# Spark SQL

In [None]:
# What are the top 10 songs played in the top 50 longest sessions by tracks count?
top10 = spark.sql("""
    with plays_raw_extended as (
        select
            *,
            cast(timestamp as long) - cast(lag(timestamp) over (partition by user_id order by timestamp) as long) as inactive_time
        from plays_raw
    ),
    plays as (
        select
            *,
            sum(case when inactive_time > 20*60 then 1 else 0 end) over (partition by user_id order by timestamp rows between unbounded preceding and current row) session_id
        from plays_raw_extended
    ),
    sessions as (  
        select
            user_id,
            session_id,
            cast(max(timestamp) - min(timestamp) as long) as session_duration
        from plays
        group by user_id, session_id
    ),
    top_50_sessions as (
        select
            *
        from sessions
        order by session_duration desc
        limit 50
    ),
    top_10_tracks as (
        select
            track_id,
            count(track_id) as plays
        from plays p join top_50_sessions s on p.user_id=s.user_id and p.session_id=s.session_id
        group by track_id
        order by plays desc
        limit 10
    )
    
    select * from top_10_tracks
    order by plays desc

""")
top10.show(n=10)