In [0]:
from pyspark.sql.functions import col, sum as spark_sum, count, lpad, broadcast

zipcode_data_df = spark.read.table("hive_metastore.default.zipcode_data")
fake_customers_df = spark.read.table("hive_metastore.default.fake_customers")

#Pad zip code with leading zeros and sum total population - at grain of zip_code
#Data in zipcode_data is only for year 2010
zip_agg_df = (
    zipcode_data_df
    .withColumn("zip_code", lpad(col("zipCode").cast("string"), 5, "0"))
    .groupBy("zip_code")
    .agg(spark_sum("population").alias("total_population"))
)

#In real life we would probably want to show zip codes with no customers, but we would not be able to broadcast the zip_agg_df in that case
joined_df = fake_customers_df.join(broadcast(zip_agg_df), on="zip_code", how="inner")

result_df = (
    joined_df
    .groupBy("zip_code", "total_population")
    .agg(count("customer_id").alias("total_customers"))
)

result_df.write.format("delta").mode("overwrite").saveAsTable("hive_metastore.default.fake_broadcast_join")

