In [1]:
from common.session import get_spark_session


spark  = get_spark_session("03_transactions_analysis")

spark

In [2]:
cust_columns = ["customer_id","name", "country"]
cust_data = [
    (1,"Alice","US")
    ,(2,"Bob"    ,"UK")
    ,(3,"Charlie","US")
    ,(4,"Diana"  ,"DE")
]

df_customers = spark.createDataFrame(cust_data, cust_columns)
df_customers.show()


#tran_columns = "transaction_id int, customer_id int, amount int, timestamp string"
tran_columns = ["transaction_id", "customer_id", "amount", "timestamp"]
tran_data = [
     (101,1,250,'2023-05-01 10:00:00')
    ,(102,1,300,'2023-05-03 12:00:00')
    ,(103,2,450,'2023-05-04 09:30:00')
    ,(104,3,100,'2023-05-05 14:00:00')
    ,(105,1,500,'2023-05-06 16:00:00')
    ,(106,4,700,'2023-05-07 18:00:00')
]

df_transactions = spark.createDataFrame(tran_data, tran_columns)
df_transactions.show()


+-----------+-------+-------+
|customer_id|   name|country|
+-----------+-------+-------+
|          1|  Alice|     US|
|          2|    Bob|     UK|
|          3|Charlie|     US|
|          4|  Diana|     DE|
+-----------+-------+-------+

+--------------+-----------+------+-------------------+
|transaction_id|customer_id|amount|          timestamp|
+--------------+-----------+------+-------------------+
|           101|          1|   250|2023-05-01 10:00:00|
|           102|          1|   300|2023-05-03 12:00:00|
|           103|          2|   450|2023-05-04 09:30:00|
|           104|          3|   100|2023-05-05 14:00:00|
|           105|          1|   500|2023-05-06 16:00:00|
|           106|          4|   700|2023-05-07 18:00:00|
+--------------+-----------+------+-------------------+



In [26]:
'''
   For each customer, find their top 2 highest transaction amounts.
   Calculate total spend per customer and per country
   If the transactions dataset is huge, how would you optimize the join?
'''
from pyspark.sql.functions import row_number, desc, sum, col
from pyspark.sql.window import Window

window = Window.partitionBy("customer_id").orderBy(df_transactions["amount"].desc())

df_transactions_ranked = df_transactions.withColumn("amount_rank", row_number().over(window)).where("amount_rank <= 2")
df_transactions_ranked.show()


df_tran_by_user_country = (
    df_transactions.alias('a').join(df_customers.alias('b'),col('a.customer_id') == col('b.customer_id')) \
    .groupBy(['country', 'a.customer_id']) \
    .agg(sum('amount').alias('amount'))
)

df_tran_by_user_country.show()

+--------------+-----------+------+-------------------+-----------+
|transaction_id|customer_id|amount|          timestamp|amount_rank|
+--------------+-----------+------+-------------------+-----------+
|           105|          1|   500|2023-05-06 16:00:00|          1|
|           102|          1|   300|2023-05-03 12:00:00|          2|
|           103|          2|   450|2023-05-04 09:30:00|          1|
|           104|          3|   100|2023-05-05 14:00:00|          1|
|           106|          4|   700|2023-05-07 18:00:00|          1|
+--------------+-----------+------+-------------------+-----------+

+-------+-----------+------+
|country|customer_id|amount|
+-------+-----------+------+
|     US|          1|  1050|
|     UK|          2|   450|
|     US|          3|   100|
|     DE|          4|   700|
+-------+-----------+------+



In [29]:
# Use broadcast join to avoid shuffle
from pyspark.sql.functions import broadcast

df_customer_transactions = df_transactions.join(broadcast(df_customers), "customer_id")
df_customer_transactions.show()

+-----------+--------------+------+-------------------+-------+-------+
|customer_id|transaction_id|amount|          timestamp|   name|country|
+-----------+--------------+------+-------------------+-------+-------+
|          1|           101|   250|2023-05-01 10:00:00|  Alice|     US|
|          1|           102|   300|2023-05-03 12:00:00|  Alice|     US|
|          2|           103|   450|2023-05-04 09:30:00|    Bob|     UK|
|          3|           104|   100|2023-05-05 14:00:00|Charlie|     US|
|          1|           105|   500|2023-05-06 16:00:00|  Alice|     US|
|          4|           106|   700|2023-05-07 18:00:00|  Diana|     DE|
+-----------+--------------+------+-------------------+-------+-------+



In [30]:
# repartition the data to optimize parralelism

df_transactions = df_transactions.repartition("customer_id")
df_customers = df_customers.repartition("customer_id")

df_transactions.join(df_customers, "customer_id").show()


+-----------+--------------+------+-------------------+-------+-------+
|customer_id|transaction_id|amount|          timestamp|   name|country|
+-----------+--------------+------+-------------------+-------+-------+
|          1|           101|   250|2023-05-01 10:00:00|  Alice|     US|
|          1|           102|   300|2023-05-03 12:00:00|  Alice|     US|
|          1|           105|   500|2023-05-06 16:00:00|  Alice|     US|
|          2|           103|   450|2023-05-04 09:30:00|    Bob|     UK|
|          3|           104|   100|2023-05-05 14:00:00|Charlie|     US|
|          4|           106|   700|2023-05-07 18:00:00|  Diana|     DE|
+-----------+--------------+------+-------------------+-------+-------+

