You are given a dataset containing information about various cities around the world. Your task is to calculate the total population of all cities in Japan. The COUNTRYCODE for Japan is "JPN".

In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, expr

# Create Spark session
spark = SparkSession.builder \
    .appName("DeltaWriteExample") \
    .getOrCreate()

# Base seed data (used for expansion)
base_data = [
    (1, "Tokyo", "JPN", "Kanto", 13929286),
    (2, "Osaka", "JPN", "Kansai", 2691167),
    (3, "Los Angeles", "USA", "West", 3847480),
    (4, "New York", "USA", "East", 8336817),
    (5, "London", "GBR", "England", 8982000),
    (6, "Mumbai", "IND", "Maharashtra", 12442373),
    (7, "Delhi", "IND", "Delhi", 11007835),
    (8, "Paris", "FRA", "Ile-de-France", 2148327)
]

columns = ["id", "name", "country_code", "district", "population"]

# Create base DataFrame
df = spark.createDataFrame(base_data, columns)

# Expand to ~150 rows by duplicating with variations
expanded_df = (
    df.withColumn("replica", expr("explode(sequence(1, 20))"))
      .withColumn("id", col("id") * 100 + col("replica"))
      .withColumn("population", col("population") + expr("replica * 1000"))
      .drop("replica")
)

# Final row count check
expanded_df.count()   # ~160 rows


# # Write to Delta table (managed table)
# expanded_df.write \
#     .format("delta") \
#     .mode("overwrite") \
#     .saveAsTable("silver.city_population")


In [0]:
display(expanded_df)

In [0]:
def etl(expanded_df):
    population = expanded_df.filter("country_code" =="JPN").sum("population")
    return population

In [0]:
from pyspark.sql.functions import sum
population = expanded_df.filter(expanded_df["country_code" ]=="JPN").select(sum("population"))
display(population)