In [0]:
%spark.pyspark 
from pyspark.sql import functions as sf
from pyspark.sql import Window as sw
users = spark.createDataFrame(
[ ("u1", "Berlin"),
("u2", "Berlin"),
("u3", "Munich"),
("u4", "Hamburg"), ],
["user_id", "city"] )
orders = spark.createDataFrame(
[ ("o1", "u1", "p1", 2, 10.0),
("o2", "u1", "p2", 1, 30.0),
("o3", "u2", "p1", 1, 10.0),
("o4", "u2", "p3", 5, 7.0),
("o5", "u3", "p2", 3, 30.0),
("o6", "u3", "p3", 1, 7.0),
("o7", "u4", "p1", 10, 10.0), ],
["order_id", "user_id", "product_id", "qty", "price"] )
products = spark.createDataFrame(
[ ("p1", "Ring VOLA"),
("p2", "Ring POROG"),
("p3", "Ring TISHINA"), ],
["product_id", "product_name"] )
users.show() 
orders.show() 
products.show()

In [1]:
%spark.pyspark
mart_city_top_products_joined = orders.withColumn(
    "revenue",
    sf.expr("qty * price")
).join(
    users,
    ["user_id"]
).join(
    products,
    ["product_id"]
).cache()
mart_city_top_products_joined.show()

In [2]:
%spark.pyspark
mart_city_top_products_stg = mart_city_top_products_joined.groupBy(
    "city", "product_id", "product_name"
).agg(
    sf.count("order_id").alias("orders_cnt"),
    sf.sum("qty").alias("qty_sum"),
    sf.sum("revenue").alias("revenue_sum")
).cache()
mart_city_top_products_stg.show(100)

In [3]:
%spark.pyspark
w = sw.partitionBy("city").orderBy(
    sf.col("revenue_sum").desc()
).rowsBetween(
    sw.unboundedPreceding, 
    sw.unboundedFollowing
)
def get_top_1_prod_col(col_name: str):
    return sf.first(col_name).over(w).alias(f"top_1_{col_name}")
    
def get_top_2_prod_col(col_name: str):
    return sf.first(
        sf.when(
            sf.col("product_id") != sf.col(f"top_1_product_id"),
            sf.col(col_name)
        ),
        ignorenulls=True
    ).over(w).alias(f"top_2_{col_name}")
    
mart_city_top_products = mart_city_top_products_stg.select(
    "city", "product_id", "product_name",
    "orders_cnt", "qty_sum", "revenue_sum",
    get_top_1_prod_col("product_id"),
    get_top_1_prod_col("product_name"),
    get_top_1_prod_col("revenue_sum")
).select(
    "city", "product_id", "product_name",
    "orders_cnt", "qty_sum", "revenue_sum",
    "top_1_product_id", "top_1_product_name", "top_1_revenue_sum",
    get_top_2_prod_col("product_id"),
    get_top_2_prod_col("product_name"),
    get_top_2_prod_col("revenue_sum")
)

mart_city_top_products.write.mode("overwrite").parquet("hdfs:///tmp/sandbox_zeppelin/mart_city_top_products")

In [4]:
%spark.pyspark
spark.read.parquet("hdfs:///tmp/sandbox_zeppelin/mart_city_top_products").select(
    "city", "top_1_product_name", "top_2_product_name"
).distinct().show()