# Question 1 
Scenario: "You join a users dataframe (1M rows) with a transactions dataframe (10M rows) on user_id. Your resulting dataframe has 50M rows. Why might this be happening, and how would you debug it?"

# Question 1 
Scenario: "You join a users dataframe (1M rows) with a transactions dataframe (10M rows) on user_id. Your resulting dataframe has 50M rows. Why might this be happening, and how would you debug it?"

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lit
import random

# Initialize Spark
spark = SparkSession.builder.appName("JoinExplosionDebug").getOrCreate()

# Create Users DataFrame (with intentional duplicates)
users_data = [
    (1, "Alice"),
    (2, "Bob"),
    (3, "Charlie"),
    (4, "Diana"),
    (5, "Evan"),
]

# Simulate duplicates — some user_ids appear twice
users_data *= 20  # total 100 rows
users_df = spark.createDataFrame(users_data, ["user_id", "user_name"])

# Create Transactions DataFrame (multiple transactions per user)
transactions_data = []
for i in range(100):
    user_id = random.choice([1, 2, 3, 4, 5])
    transactions_data.append((i, user_id, random.randint(10, 500)))

transactions_df = spark.createDataFrame(transactions_data, ["transaction_id", "user_id", "amount"])

# Perform join
joined_df = users_df.join(transactions_df, on="user_id", how="inner")

print("Users count:", users_df.count())
print("Transactions count:", transactions_df.count())
print("Joined count:", joined_df.count())  # Expect this to 'explode'



Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
25/10/26 11:32:38 WARN Utils: Your hostname, Jiaxins-MacBook-Pro.local, resolves to a loopback address: 127.0.0.1; using 192.168.0.23 instead (on interface en0)
25/10/26 11:32:38 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/10/26 11:32:38 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/10/26 11:32:39 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
                                                                                

Users count: 100
Transactions count: 100
Joined count: 2000


In [22]:
# you can also do this 
users_df.printSchema()
users_df.describe().show()

transactions_df.printSchema()
transactions_df.describe().show()

root
 |-- user_id: long (nullable = true)
 |-- user_name: string (nullable = true)

+-------+------------------+---------+
|summary|           user_id|user_name|
+-------+------------------+---------+
|  count|               100|      100|
|   mean|               3.0|     NULL|
| stddev|1.4213381090374029|     NULL|
|    min|                 1|    Alice|
|    max|                 5|     Evan|
+-------+------------------+---------+

root
 |-- transaction_id: long (nullable = true)
 |-- user_id: long (nullable = true)
 |-- amount: long (nullable = true)

+-------+------------------+----------------+------------------+
|summary|    transaction_id|         user_id|            amount|
+-------+------------------+----------------+------------------+
|  count|               100|             100|               100|
|   mean|              49.5|            3.21|            257.69|
| stddev|29.011491975882016|1.35806606405706|131.18255419548944|
|    min|                 0|               1|      

cartisina join.....

In [2]:
joined_df.show()


+-------+---------+--------------+------+
|user_id|user_name|transaction_id|amount|
+-------+---------+--------------+------+
|      1|    Alice|             5|   306|
|      1|    Alice|             6|   469|
|      1|    Alice|            13|   444|
|      1|    Alice|            16|   227|
|      1|    Alice|            19|    22|
|      1|    Alice|            21|   447|
|      1|    Alice|            25|   400|
|      1|    Alice|            52|   286|
|      1|    Alice|            62|    88|
|      1|    Alice|            66|   221|
|      1|    Alice|            68|    31|
|      1|    Alice|            69|   215|
|      1|    Alice|            74|   434|
|      1|    Alice|            82|    54|
|      1|    Alice|            94|   349|
|      1|    Alice|            97|   308|
|      1|    Alice|             5|   306|
|      1|    Alice|             6|   469|
|      1|    Alice|            13|   444|
|      1|    Alice|            16|   227|
+-------+---------+--------------+

In [None]:
# Step 1: Check for duplicate keys in users
dedup_user = users_df.groupBy('user_id').count().filter(col('count') > 1).show()


+-------+-----+
|user_id|count|
+-------+-----+
|      5|   20|
|      1|   20|
|      3|   20|
|      2|   20|
|      4|   20|
+-------+-----+



In [14]:
# Step 2: Check for duplicate keys in transactions
dedup_tran = transactions_df.groupBy('user_id').count().filter(col('user_id') > 1 ).show()



+-------+-----+
|user_id|count|
+-------+-----+
|      5|   21|
|      3|   24|
|      4|   25|
|      2|   14|
+-------+-----+



In [16]:
# Step 3: Clean users by deduplicating
users_clean_df = users_df.dropDuplicates(["user_id"])

# Re-run join
clean_join_df = users_clean_df.join(transactions_df, on="user_id", how="inner")
clean_join_df.show()
print("Cleaned join count:", clean_join_df.count())

+-------+---------+--------------+------+
|user_id|user_name|transaction_id|amount|
+-------+---------+--------------+------+
|      5|     Evan|             1|   394|
|      5|     Evan|             4|   372|
|      1|    Alice|             5|   306|
|      1|    Alice|             6|   469|
|      3|  Charlie|             0|   147|
|      3|  Charlie|             3|   131|
|      4|    Diana|             2|   162|
|      4|    Diana|             7|   134|
|      5|     Evan|             8|    49|
|      5|     Evan|            10|   209|
|      5|     Evan|            11|   225|
|      1|    Alice|            13|   444|
|      3|  Charlie|            14|   170|
|      2|      Bob|             9|   186|
|      2|      Bob|            12|    23|
|      2|      Bob|            15|   286|
|      1|    Alice|            16|   227|
|      1|    Alice|            19|    22|
|      1|    Alice|            21|   447|
|      3|  Charlie|            18|   316|
+-------+---------+--------------+

# Question 2
You join a customers DataFrame (1M rows) with an orders DataFrame (5M rows) on customer_id.
After the join, your resulting DataFrame has only 300K rows — far fewer than expected.

You expected at least 1M (every customer should appear), but most are missing.

In [26]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, FloatType
from pyspark.sql import functions as F

spark = SparkSession.builder.appName("JoinDebugScenario2").getOrCreate()

# ----- Customers DataFrame -----
customers_data = [
    (1, "Alice", "UK"),
    (2, "Bob", "US"),
    (3, "Charlie", "CA"),
    (4, "Diana", "UK"),
    (5, "Evan", "US"),
]

# Simulate data quality issues:
# - duplicate customer_ids
# - inconsistent formatting (string IDs, trailing spaces)
customers_data_extended = []
for i in range(100):  # 100 rows total
    entry = customers_data[i % 5]
    if i % 10 == 0:
        # introduce trailing spaces or string type ID
        customers_data_extended.append((str(entry[0]) + " ", entry[1], entry[2]))
    else:
        customers_data_extended.append(entry)

customers_df = spark.createDataFrame(customers_data_extended, ["customer_id", "customer_name", "country"])


# ----- Purchases DataFrame -----
purchases_data = []
for i in range(100):
    customer_id = str((i % 5) + 1)  # also string, to simulate schema mismatch
    purchases_data.append((i, customer_id, round(20 + i * 0.75, 2)))

purchases_df = spark.createDataFrame(purchases_data, ["purchase_id", "customer_id", "amount"])

# ----- Perform join -----
joined_df = customers_df.join(purchases_df, on="customer_id", how="inner")

print("Customers count:", customers_df.count())
print("Purchases count:", purchases_df.count())
print("Joined count:", joined_df.count())



Customers count: 100
Purchases count: 100
Joined count: 1800


In [28]:
# first check the schema 
customers_df.printSchema()
purchases_df.printSchema()

root
 |-- customer_id: string (nullable = true)
 |-- customer_name: string (nullable = true)
 |-- country: string (nullable = true)

root
 |-- purchase_id: long (nullable = true)
 |-- customer_id: string (nullable = true)
 |-- amount: double (nullable = true)



In [29]:
# Check duplicate in customers_df

dedup_cust = customers_df.groupBy('customer_id').count().filter(col('count') > 1).show()

+-----------+-----+
|customer_id|count|
+-----------+-----+
|          3|   20|
|          5|   20|
|         1 |   10|
|          1|   10|
|          4|   20|
|          2|   20|
+-----------+-----+



since both customer_id is string and it has spaces....

In [30]:
# Describe customer_id distribution in purchases
purchases_df.describe(["customer_id"]).show()

+-------+------------------+
|summary|       customer_id|
+-------+------------------+
|  count|               100|
|   mean|               3.0|
| stddev|1.4213381090374029|
|    min|                 1|
|    max|                 5|
+-------+------------------+



In [31]:
# Fixing data: trim strings, cast IDs to int, deduplicate
customers_clean = (
    customers_df.withColumn("customer_id", F.trim(F.col("customer_id")).cast(IntegerType()))
    .dropDuplicates(["customer_id"])
)
purchases_clean = (
    purchases_df.withColumn("customer_id", F.trim(F.col("customer_id")).cast(IntegerType()))
)


In [35]:
clean_df_joined = customers_clean.join(purchases_clean , 'customer_id', 'inner')
print(clean_df_joined.count())

100
