In [None]:
import datetime

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pyspark.sql import SparkSession, Window
from pyspark.sql.functions import asc, desc, sum as ssum, udf
from pyspark.sql.types import StringType
from pyspark.sql.types import IntegerType

In [None]:
spark = (
    SparkSession
    .builder
    .appName("Wrangling Data")
    .getOrCreate())

In [None]:
path = "data/sparkify_log_small.json"
user_log = spark.read.json(path)

In [None]:
user_log.take(5)

In [None]:
user_log.printSchema()

In [None]:
user_log.describe().show()

In [None]:
user_log.describe("artist").show()

In [None]:
user_log.describe("sessionId").show()

In [None]:
user_log.count()

In [None]:
user_log.select("page").dropDuplicates().sort("page").show()

In [None]:
(user_log
 .select(["userId", "firstname", "page", "song"])
 .where(user_log.userId == "1046")
 .collect())

In [None]:
get_hour = udf(lambda x: datetime.datetime.fromtimestamp(x / 1000.).hour)

In [None]:
user_log = user_log.withColumn("hour", get_hour(user_log.ts))

In [None]:
user_log.head()

In [None]:
songs_in_hour = (
    user_log
    .filter(user_log.page == "NextSong")
    .groupby(user_log.hour)
    .count()
    .orderBy(user_log.hour.cast("float")))

In [None]:
songs_in_hour.show()

In [None]:
songs_in_hour_pd = songs_in_hour.toPandas()
songs_in_hour_pd.hour = pd.to_numeric(songs_in_hour_pd.hour)

In [None]:
plt.scatter(songs_in_hour_pd["hour"], songs_in_hour_pd["count"])
plt.xlim(-1, 24);
plt.ylim(0, 1.2 * max(songs_in_hour_pd["count"]))
plt.xlabel("Hour")
plt.ylabel("Songs played");

In [None]:
user_log_valid = user_log.dropna(how="any", subset=["userId", "sessionId"])

In [None]:
user_log_valid.count()

In [None]:
user_log.select("userId").dropDuplicates().sort("userId").show()

In [None]:
user_log_valid = user_log_valid.filter(user_log_valid["userId"] != "")

In [None]:
user_log_valid.count()

In [None]:
user_log_valid.filter("page = 'Submit Downgrade'").show()

In [None]:
(user_log
 .select(["userId", "firstname", "page", "level", "song"])
 .where(user_log.userId == "1138")
 .collect())

In [None]:
flag_downgrade_event = udf(lambda x: 1 if x == "Submit Downgrade" else 0, IntegerType())

In [None]:
user_log_valid = user_log_valid.withColumn("downgraded", flag_downgrade_event("page"))

In [None]:
user_log_valid.head()

In [None]:
windowval = (
    Window
    .partitionBy("userId")
    .orderBy(desc("ts"))
    .rangeBetween(Window.unboundedPreceding, 0))

In [None]:
user_log_valid = user_log_valid.withColumn("phase", Fsum("downgraded").over(windowval))

In [None]:
(user_log_valid
 .select(["userId", "firstname", "ts", "page", "level", "phase"])
 .where(user_log.userId == "1138")
 .sort("ts")
 .collect())