<a href="https://colab.research.google.com/github/nptan2005/spark401_colab/blob/main/notebooks/Spark_excercises_01.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!apt-get install -y openjdk-17-jdk

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
  at-spi2-core fonts-dejavu-core fonts-dejavu-extra gsettings-desktop-schemas
  libatk-bridge2.0-0 libatk-wrapper-java libatk-wrapper-java-jni libatk1.0-0
  libatk1.0-data libatspi2.0-0 libgail-common libgail18 libgtk2.0-0
  libgtk2.0-bin libgtk2.0-common librsvg2-common libxcomposite1 libxt-dev
  libxtst6 libxxf86dga1 openjdk-17-jre session-migration x11-utils
Suggested packages:
  gvfs libxt-doc openjdk-17-demo openjdk-17-source visualvm mesa-utils
The following NEW packages will be installed:
  at-spi2-core fonts-dejavu-core fonts-dejavu-extra gsettings-desktop-schemas
  libatk-bridge2.0-0 libatk-wrapper-java libatk-wrapper-java-jni libatk1.0-0
  libatk1.0-data libatspi2.0-0 libgail-common libgail18 libgtk2.0-0
  libgtk2.0-bin libgtk2.0-common librsvg2-common libxcomposite1 libxt-dev
  libxtst6 libxxf86dga1 openjdk-17-jdk openjdk-17-jr

In [3]:
!wget https://archive.apache.org/dist/spark/spark-4.0.1/spark-4.0.1-bin-hadoop3.tgz
!tar xf spark-4.0.1-bin-hadoop3.tgz

--2026-01-05 06:38:42--  https://archive.apache.org/dist/spark/spark-4.0.1/spark-4.0.1-bin-hadoop3.tgz
Resolving archive.apache.org (archive.apache.org)... 65.108.204.189, 2a01:4f9:1a:a084::2
Connecting to archive.apache.org (archive.apache.org)|65.108.204.189|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 548955321 (524M) [application/x-gzip]
Saving to: ‘spark-4.0.1-bin-hadoop3.tgz’


2026-01-05 06:39:50 (7.82 MB/s) - ‘spark-4.0.1-bin-hadoop3.tgz’ saved [548955321/548955321]



In [4]:
# ===============================
# Spark 4.0.1 Setup (REQUIRED)
# ===============================
import os

os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-17-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-4.0.1-bin-hadoop3"
os.environ["PATH"] += ":/content/spark-4.0.1-bin-hadoop3/bin"

from pyspark.sql import SparkSession

spark = SparkSession.builder \
    .appName("Spark401-Training") \
    .master("local[*]") \
    .config("spark.sql.shuffle.partitions", "4") \
    .getOrCreate()

print("Spark version:", spark.version)

Spark version: 4.0.1


# 1) Tạo dữ liệu giả lập (có skew)

## 1.1 Helper timing + imports

In [5]:
import time, random
from pyspark.sql import functions as F
from pyspark.sql import types as T

def timed(label, fn):
    t0 = time.time()
    out = fn()
    t1 = time.time()
    print(f"[{label}] took {t1 - t0:.2f}s")
    return out

## 1.2 Create customers (dim)

In [6]:
from pyspark.sql import functions as F

n_customers = 50_000
segments = ["MASS", "AFFLUENT", "SME"]
risk = ["LOW", "MED", "HIGH"]

seg_arr  = F.array(*[F.lit(x) for x in segments])
risk_arr = F.array(*[F.lit(x) for x in risk])

seg_idx = (F.pmod(F.col("id"), F.lit(len(segments))) + F.lit(1)).cast("int")
risk_idx = (F.pmod(F.col("id"), F.lit(len(risk))) + F.lit(1)).cast("int")

customers = (
    spark.range(0, n_customers)
    .select(
        (F.col("id") + F.lit(1)).cast("string").alias("customer_id"),
        F.element_at(seg_arr, seg_idx).alias("segment"),
        F.element_at(risk_arr, risk_idx).alias("risk_tier"),
        # created_date: hôm trước - (id % 365) ngày
        (F.date_sub(F.current_date(), (F.pmod(F.col("id"), F.lit(365))).cast("int"))).alias("created_date"),
    )
)

customers.cache()
print("count =", customers.count())
customers.show(5, truncate=False)
customers.printSchema()

count = 50000
+-----------+--------+---------+------------+
|customer_id|segment |risk_tier|created_date|
+-----------+--------+---------+------------+
|1          |MASS    |LOW      |2026-01-05  |
|2          |AFFLUENT|MED      |2026-01-04  |
|3          |SME     |HIGH     |2026-01-03  |
|4          |MASS    |LOW      |2026-01-02  |
|5          |AFFLUENT|MED      |2026-01-01  |
+-----------+--------+---------+------------+
only showing top 5 rows
root
 |-- customer_id: string (nullable = false)
 |-- segment: string (nullable = false)
 |-- risk_tier: string (nullable = false)
 |-- created_date: date (nullable = true)



## 1.3 Create orders (fact) with skew customer_id

In [7]:
from pyspark.sql import functions as F

# orders: 2M rows (có thể giảm nếu Colab lag)
n_orders = 2_000_000

channels = ["POS", "ECOM", "ATM"]
countries = ["VN", "SG", "TH", "ID", "MY"]
statuses = ["SUCCESS", "FAILED", "REVERSED"]

HOT_CUSTOMER = "1"   # hot key
HOT_RATIO = 0.25     # 25% orders thuộc 1 customer (skew rõ)

ch_arr = F.array(*[F.lit(x) for x in channels])
print(f'test charr = {ch_arr}')
ct_arr = F.array(*[F.lit(x) for x in countries])
st_arr = F.array(*[F.lit(x) for x in statuses])

ch_idx = (F.pmod(F.col("id"), F.lit(len(channels))) + F.lit(1)).cast("int")
ct_idx = (F.pmod(F.col("id"), F.lit(len(countries))) + F.lit(1)).cast("int")
st_idx = (F.pmod(F.col("id"), F.lit(len(statuses))) + F.lit(1)).cast("int")

orders = (
    spark.range(0, n_orders)
    .select(
        (F.col("id") + 1).cast("string").alias("order_id"),

        # skew: 25% = HOT_CUSTOMER, còn lại pseudo-random trong [2..n_customers]
        F.when(F.rand(seed=7) < F.lit(HOT_RATIO), F.lit(HOT_CUSTOMER))
         .otherwise((F.pmod(F.col("id") * 17, F.lit(n_customers - 1)) + 2).cast("string"))
         .alias("customer_id"),

        (F.rand(seed=11) * 5000).cast("double").alias("amount"),

        # timestamp: now - (id % 30) days
        (F.current_timestamp() - F.expr("INTERVAL 1 DAYS") - (F.pmod(F.col("id"), F.lit(30)).cast("int") * F.expr("INTERVAL 1 DAYS")))
          .alias("order_ts"),

        F.element_at(ch_arr, ch_idx).alias("channel"),
        F.element_at(ct_arr, ct_idx).alias("country"),
        F.element_at(st_arr, st_idx).alias("status"),
    )
)

orders.cache()
print("orders count =", orders.count())
orders.show(5, truncate=False)
orders.printSchema()

test charr = Column<'array('POS', 'ECOM', 'ATM')'>
orders count = 2000000
+--------+-----------+------------------+--------------------------+-------+-------+--------+
|order_id|customer_id|amount            |order_ts                  |channel|country|status  |
+--------+-----------+------------------+--------------------------+-------+-------+--------+
|1       |2          |171.13196569036427|2026-01-04 06:43:43.452688|POS    |VN     |SUCCESS |
|2       |19         |584.1125228224664 |2026-01-03 06:43:43.452688|ECOM   |SG     |FAILED  |
|3       |36         |1626.1139047259715|2026-01-02 06:43:43.452688|ATM    |TH     |REVERSED|
|4       |53         |2179.4255232405444|2026-01-01 06:43:43.452688|POS    |ID     |SUCCESS |
|5       |70         |1391.575355397565 |2025-12-31 06:43:43.452688|ECOM   |MY     |FAILED  |
+--------+-----------+------------------+--------------------------+-------+-------+--------+
only showing top 5 rows
root
 |-- order_id: string (nullable = false)
 |-- custo

# LEVEL 1 — Transform/Action + Cache/Partition

## Bài 1.1 — Lazy + cache đúng


In [8]:
filtered = orders.filter(F.col("status") == F.lit("SUCCESS"))

print("Explain BEFORE cache:")
filtered.explain()

timed("count #1 (no cache)", lambda: filtered.count())

filtered_cached = filtered.cache()
timed("count #2 (after cache)", lambda: filtered_cached.count())

print("Explain AFTER cache (có InMemoryTableScan/Cache):")
filtered_cached.explain()

Explain BEFORE cache:
== Physical Plan ==
*(1) Filter (status#233 = SUCCESS)
+- InMemoryTableScan [order_id#227, customer_id#228, amount#229, order_ts#230, channel#231, country#232, status#233], [(status#233 = SUCCESS)]
      +- InMemoryRelation [order_id#227, customer_id#228, amount#229, order_ts#230, channel#231, country#232, status#233], StorageLevel(disk, memory, deserialized, 1 replicas)
            +- *(1) Project [cast((id#226L + 1) as string) AS order_id#227, CASE WHEN (rand(7) < 0.25) THEN 1 ELSE cast((pmod((id#226L * 17), 49999) + 2) as string) END AS customer_id#228, (rand(11) * 5000.0) AS amount#229, 2026-01-04 06:43:43.452688 + -(INTERVAL '1' DAY * cast(pmod(id#226L, 30) as int)) AS order_ts#230, element_at([POS,ECOM,ATM], cast((pmod(id#226L, 3) + 1) as int), None, true) AS channel#231, element_at([VN,SG,TH,ID,MY], cast((pmod(id#226L, 5) + 1) as int), None, true) AS country#232, element_at([SUCCESS,FAILED,REVERSED], cast((pmod(id#226L, 3) + 1) as int), None, true) AS statu

## Bài 1.2 — repartition vs coalesce

In [9]:
orders_small = orders.filter(F.col("country") == F.lit("VN"))

print("Partitions original:", orders_small.rdd.getNumPartitions())

a = orders_small.repartition(200)
print("Partitions after repartition(200):", a.rdd.getNumPartitions())
print("Explain repartition:")
a.explain()

timed("repartition(200) count", lambda: a.count())

b = orders_small.coalesce(10)
print("Partitions after coalesce(10):", b.rdd.getNumPartitions())
print("Explain coalesce:")
b.explain()

timed("coalesce(10) count", lambda: b.count())

Partitions original: 2
Partitions after repartition(200): 200
Explain repartition:
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=true
+- == Final Plan ==
   ResultQueryStage 2
   +- ShuffleQueryStage 1
      +- Exchange RoundRobinPartitioning(200), REPARTITION_BY_NUM, [plan_id=296]
         +- *(1) Filter (country#232 = VN)
            +- TableCacheQueryStage 0
               +- InMemoryTableScan [order_id#227, customer_id#228, amount#229, order_ts#230, channel#231, country#232, status#233], [(country#232 = VN)]
                     +- InMemoryRelation [order_id#227, customer_id#228, amount#229, order_ts#230, channel#231, country#232, status#233], StorageLevel(disk, memory, deserialized, 1 replicas)
                           +- *(1) Project [cast((id#226L + 1) as string) AS order_id#227, CASE WHEN (rand(7) < 0.25) THEN 1 ELSE cast((pmod((id#226L * 17), 49999) + 2) as string) END AS customer_id#228, (rand(11) * 5000.0) AS amount#229, 2026-01-04 06:43:43.452688 + -(INTERVAL '1' DAY 

400000

# LEVEL 2 — Join Strategies

## Bài 2.1 — Join mặc định vs broadcast + tắt autobroadcast

In [10]:
# (1) Join default
print("=== DEFAULT JOIN ===")
orders.join(customers, "customer_id").explain()

# (2) Force broadcast
from pyspark.sql.functions import broadcast
print("=== FORCE BROADCAST JOIN ===")
orders.join(broadcast(customers), "customer_id").explain()

# (3) Disable auto broadcast
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)
print("=== AUTO BROADCAST DISABLED ===")
orders.join(customers, "customer_id").explain()

# reset (optional)
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", 10 * 1024 * 1024)  # 10MB default-ish

=== DEFAULT JOIN ===
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Project [customer_id#228, order_id#227, amount#229, order_ts#230, channel#231, country#232, status#233, segment#2, risk_tier#3, created_date#4]
   +- BroadcastHashJoin [customer_id#228], [customer_id#1], Inner, BuildRight, false
      :- Filter isnotnull(customer_id#228)
      :  +- InMemoryTableScan [order_id#227, customer_id#228, amount#229, order_ts#230, channel#231, country#232, status#233], [isnotnull(customer_id#228)]
      :        +- InMemoryRelation [order_id#227, customer_id#228, amount#229, order_ts#230, channel#231, country#232, status#233], StorageLevel(disk, memory, deserialized, 1 replicas)
      :              +- *(1) Project [cast((id#226L + 1) as string) AS order_id#227, CASE WHEN (rand(7) < 0.25) THEN 1 ELSE cast((pmod((id#226L * 17), 49999) + 2) as string) END AS customer_id#228, (rand(11) * 5000.0) AS amount#229, 2026-01-04 06:43:43.452688 + -(INTERVAL '1' DAY * cast(pmod(id#226L, 30) a

## Chạy thử timing join (nhẹ thôi)

In [11]:
joined = orders.join(customers, "customer_id")
timed("join count()", lambda: joined.count())

[join count()] took 2.10s


2000000

## Bài 2.2 — Join gây “đốt tiền” (customers_big)

In [12]:
# phóng to customers bằng cách cross join với range nhỏ
# chú ý: đừng phóng quá lớn kẻo Colab nổ RAM
mult = 10
customers_big = customers.crossJoin(spark.range(0, mult).select(F.col("id").alias("k"))).drop("k")

print("customers:", customers.count())
print("customers_big:", customers_big.count())

# join thử
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", 10 * 1024 * 1024)  # bật lại
print("=== JOIN orders with customers_big ===")
orders.join(customers_big, "customer_id").explain()

timed("join customers_big count()", lambda: orders.join(customers_big, "customer_id").count())

customers: 50000
customers_big: 500000
=== JOIN orders with customers_big ===
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Project [customer_id#228, order_id#227, amount#229, order_ts#230, channel#231, country#232, status#233, segment#2, risk_tier#3, created_date#4]
   +- SortMergeJoin [customer_id#228], [customer_id#1], Inner
      :- Sort [customer_id#228 ASC NULLS FIRST], false, 0
      :  +- Exchange hashpartitioning(customer_id#228, 4), ENSURE_REQUIREMENTS, [plan_id=874]
      :     +- Filter isnotnull(customer_id#228)
      :        +- InMemoryTableScan [order_id#227, customer_id#228, amount#229, order_ts#230, channel#231, country#232, status#233], [isnotnull(customer_id#228)]
      :              +- InMemoryRelation [order_id#227, customer_id#228, amount#229, order_ts#230, channel#231, country#232, status#233], StorageLevel(disk, memory, deserialized, 1 replicas)
      :                    +- *(1) Project [cast((id#226L + 1) as string) AS order_id#227, CASE WHEN (r

20000000

# LEVEL 3 — Skew + Salting

## Bài 3.1 — Detect skew nhanh

In [13]:
orders.groupBy("customer_id").count().orderBy(F.desc("count")).show(20, truncate=False)

+-----------+------+
|customer_id|count |
+-----------+------+
|1          |500755|
|32714      |40    |
|4313       |39    |
|46379      |39    |
|7413       |39    |
|8935       |38    |
|26025      |38    |
|39886      |38    |
|15245      |38    |
|3731       |38    |
|18348      |38    |
|39393      |38    |
|8281       |38    |
|5447       |38    |
|20013      |38    |
|40015      |38    |
|29221      |38    |
|41091      |38    |
|10279      |38    |
|25600      |38    |
+-----------+------+
only showing top 20 rows


## Bài 3.2 — Salting (thực chiến)

In [14]:
N_SALT = 16  # số nhánh chia hot key

# orders salted: chỉ salt với hot key
orders_salted = (
    orders.withColumn(
        "salt",
        F.when(F.col("customer_id") == F.lit(HOT_CUSTOMER), (F.pmod(F.col("order_id").cast("long"), F.lit(N_SALT))).cast("int"))
         .otherwise(F.lit(0))
    )
)

# customers salted: với hot key thì explode ra N_SALT bản ghi; còn lại salt=0
customers_hot = customers.filter(F.col("customer_id") == F.lit(HOT_CUSTOMER)) \
    .withColumn("salt", F.explode(F.sequence(F.lit(0), F.lit(N_SALT - 1))))
customers_cold = customers.filter(F.col("customer_id") != F.lit(HOT_CUSTOMER)) \
    .withColumn("salt", F.lit(0))

customers_salted = customers_hot.unionByName(customers_cold)

print("Explain salted join:")
orders_salted.join(customers_salted, ["customer_id", "salt"]).explain()

timed("salted join count()", lambda: orders_salted.join(customers_salted, ["customer_id", "salt"]).count())

Explain salted join:
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Project [customer_id#228, salt#3765, order_id#227, amount#229, order_ts#230, channel#231, country#232, status#233, segment#2, risk_tier#3, created_date#4]
   +- BroadcastHashJoin [customer_id#228, salt#3765], [customer_id#1, salt#3767], Inner, BuildRight, false
      :- Project [order_id#227, customer_id#228, amount#229, order_ts#230, channel#231, country#232, status#233, CASE WHEN (customer_id#228 = 1) THEN cast(pmod(cast(order_id#227 as bigint), 16) as int) ELSE 0 END AS salt#3765]
      :  +- Filter ((isnotnull((customer_id#228 = 1)) AND isnotnull(customer_id#228)) AND CASE WHEN (customer_id#228 = 1) THEN isnotnull(cast(pmod(cast(order_id#227 as bigint), 16) as int)) ELSE true END)
      :     +- InMemoryTableScan [amount#229, channel#231, country#232, customer_id#228, order_id#227, order_ts#230, status#233], [isnotnull((customer_id#228 = 1)), isnotnull(customer_id#228), CASE WHEN (customer_id#228 = 1) T

2000000

# LEVEL 4 — Gold-ish KPI + Window

## Bài 4.1 — KPI theo ngày + segment

In [15]:
orders_enriched = orders.join(customers, "customer_id") \
    .withColumn("order_date", F.to_date("order_ts"))

kpi_daily = (orders_enriched
    .groupBy("order_date", "segment")
    .agg(
        F.count("*").alias("txn_cnt"),
        F.sum("amount").alias("revenue"),
        F.sum(F.when(F.col("status") == "SUCCESS", 1).otherwise(0)).alias("success_cnt")
    )
    .orderBy("order_date", "segment")
)

kpi_daily.explain()
kpi_daily.show(20, truncate=False)

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Sort [order_date#4389 ASC NULLS FIRST, segment#2 ASC NULLS FIRST], true, 0
   +- Exchange rangepartitioning(order_date#4389 ASC NULLS FIRST, segment#2 ASC NULLS FIRST, 4), ENSURE_REQUIREMENTS, [plan_id=1583]
      +- HashAggregate(keys=[order_date#4389, segment#2], functions=[count(1), sum(amount#229), sum(CASE WHEN (status#233 = SUCCESS) THEN 1 ELSE 0 END)])
         +- Exchange hashpartitioning(order_date#4389, segment#2, 4), ENSURE_REQUIREMENTS, [plan_id=1580]
            +- HashAggregate(keys=[order_date#4389, segment#2], functions=[partial_count(1), partial_sum(amount#229), partial_sum(CASE WHEN (status#233 = SUCCESS) THEN 1 ELSE 0 END)])
               +- Project [amount#229, status#233, segment#2, cast(order_ts#230 as date) AS order_date#4389]
                  +- BroadcastHashJoin [customer_id#228], [customer_id#1], Inner, BuildRight, false
                     :- Filter isnotnull(customer_id#228)
                     :

## Bài 4.2 — Window: top 10 customers mỗi ngày theo revenue

In [16]:
from pyspark.sql.window import Window

daily_by_customer = (orders
    .withColumn("order_date", F.to_date("order_ts"))
    .groupBy("order_date", "customer_id")
    .agg(F.sum("amount").alias("daily_revenue"))
)

w = Window.partitionBy("order_date").orderBy(F.desc("daily_revenue"))

top10 = (daily_by_customer
    .withColumn("rk", F.dense_rank().over(w))
    .filter(F.col("rk") <= 10)
    .orderBy("order_date", "rk")
)

top10.explain()
top10.show(50, truncate=False)

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Sort [order_date#4877 ASC NULLS FIRST, rk#4888 ASC NULLS FIRST], true, 0
   +- Exchange rangepartitioning(order_date#4877 ASC NULLS FIRST, rk#4888 ASC NULLS FIRST, 4), ENSURE_REQUIREMENTS, [plan_id=1792]
      +- Filter (rk#4888 <= 10)
         +- Window [dense_rank(daily_revenue#4878) windowspecdefinition(order_date#4877, daily_revenue#4878 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rk#4888], [order_date#4877], [daily_revenue#4878 DESC NULLS LAST]
            +- WindowGroupLimit [order_date#4877], [daily_revenue#4878 DESC NULLS LAST], dense_rank(daily_revenue#4878), 10, Final
               +- Sort [order_date#4877 ASC NULLS FIRST, daily_revenue#4878 DESC NULLS LAST], false, 0
                  +- Exchange hashpartitioning(order_date#4877, 4), ENSURE_REQUIREMENTS, [plan_id=1786]
                     +- WindowGroupLimit [order_date#4877], [daily_revenue#4878 DESC NULLS LAST], dense

# LEVEL 5 — Idempotent output (mô phỏng local)

Colab không có GCS mặc định, nên mình demo output ra /content/output/... để bạn hiểu “idempotent partition overwrite”.

In [17]:
import os
base_path = "/content/output/gold/orders_kpi"
dt = "2025-01-05"  # thử đổi dt

out_path = f"{base_path}/dt={dt}"

# giả lập: mỗi lần rerun cùng dt => overwrite đúng partition dt
(
    kpi_daily.filter(F.col("order_date") == F.lit(dt))
    .write.mode("overwrite")
    .parquet(out_path)
)

print("Wrote:", out_path)
print("Files:", os.listdir(out_path)[:5])

Wrote: /content/output/gold/orders_kpi/dt=2025-01-05
Files: ['_SUCCESS', '._SUCCESS.crc', 'part-00000-0f00c31e-0aa3-4ab2-ae77-3b9c96029297-c000.snappy.parquet', '.part-00000-0f00c31e-0aa3-4ab2-ae77-3b9c96029297-c000.snappy.parquet.crc']
