In [0]:
%spark.pyspark

from pyspark.sql import functions as F
from pyspark.sql import Window

users = spark.createDataFrame(
    [
        ("u1", "Berlin"),
        ("u2", "Berlin"),
        ("u3", "Munich"),
        ("u4", "Hamburg"),
    ],
    ["user_id", "city"]
)
orders = spark.createDataFrame(
    [
        ("o1", "u1", "p1", 2, 10.0),
        ("o2", "u1", "p2", 1, 30.0),
        ("o3", "u2", "p1", 1, 10.0),
        ("o4", "u2", "p3", 5, 7.0),
        ("o5", "u3", "p2", 3, 30.0),
        ("o6", "u3", "p3", 1, 7.0),
        ("o7", "u4", "p1", 10, 10.0),
    ],
    ["order_id", "user_id", "product_id", "qty", "price"]
)
products = spark.createDataFrame(
    [
        ("p1", "Ring VOLA"),
        ("p2", "Ring POROG"),
        ("p3", "Ring TISHINA"),
    ],
    ["product_id", "product_name"]
)


In [1]:
%spark.pyspark

users.show()

In [2]:
%spark.pyspark

orders.show()

In [3]:
%spark.pyspark

products.show()

In [4]:
%spark.pyspark

mart_city_top_products = (
    orders
        .withColumn('revenue', orders.qty + orders.price)
        .join(users, 'user_id', 'left')
        .join(products, 'product_id', 'left')
        .groupBy([users.city, orders.product_id, products.product_name])
        .agg(
            F.count(orders.product_id).alias("orders_cnt"),
            F.sum(orders.qty).alias("qty_sum"),
            F.sum('revenue').alias("revenue_sum"),
        )
        .withColumn(
            'city_top',
            F.row_number().over(Window.partitionBy(users.city).orderBy(F.desc_nulls_last("revenue_sum")))
        )
        .filter("city_top <= 2")
)

mart_city_top_products.printSchema()

In [5]:
%spark.pyspark

mart_city_top_products.write.format('parquet').mode('overwrite').save('hdfs:///tmp/sandbox_zeppelin/mart_city_top_products/')

In [6]:
%sh

hdfs dfs -ls hdfs:///tmp/sandbox_zeppelin/mart_city_top_products/

In [7]:
%spark.pyspark

mart_city_top_products.coalesce(1).write.format('csv').option("header", "true").mode('overwrite').save('s3a://hse-s3-bucket/mart_city_top_products')

In [8]:
%spark.pyspark

df = spark.read.format('parquet').load('hdfs:///tmp/sandbox_zeppelin/mart_city_top_products/')
df.printSchema()

In [9]:
%spark.pyspark

df.filter(df.city_top == 1).show()