In [0]:
from pyspark.sql import functions as F
from delta.tables import DeltaTable
from datetime import datetime
import json

In [0]:
dbutils.widgets.text("pipeline_config_json", "", "Pipeline Config JSON (from 01_config)")
dbutils.widgets.text("run_mode", "full", "Run Mode")  # fallback

pipeline_config_json = dbutils.widgets.get("pipeline_config_json").strip()

if pipeline_config_json:
    pipeline_config = json.loads(pipeline_config_json)

    SOURCE_TABLE = pipeline_config["silver_table"]
    TARGET_TABLE = pipeline_config["gold_dim_countries"]
    run_mode = pipeline_config.get("run_mode", "full")
    process_year = pipeline_config.get("process_year", None)
else:
    CATALOG = "ironman"
    SOURCE_TABLE = f"{CATALOG}.silver.ironman_results"
    TARGET_TABLE = f"{CATALOG}.gold.dim_countries"
    run_mode = dbutils.widgets.get("run_mode")
    process_year = None

print(f"Source: {SOURCE_TABLE}")
print(f"Target: {TARGET_TABLE}")
print(f"Run Mode: {run_mode}")
print(f"Process Year: {process_year if process_year else 'ALL'}")

In [0]:
silver_df = spark.table(SOURCE_TABLE)

In [0]:
if process_year:
    silver_df = silver_df.filter(F.col("year") == int(process_year))
    print(f"Filtered to year: {process_year}")

print(f"Silver rows: {silver_df.count():,}")

In [0]:
countries_df = (
    silver_df
    .select("country")
    .filter(F.col("country").isNotNull())
    .distinct()
    .orderBy("country")
)

print(f"Unique countries: {countries_df.count()}")

country_mapping = [
    ("AD", "Andorra", "Europe"),
    ("AE", "United Arab Emirates", "Asia"),
    ("AR", "Argentina", "South America"),
    ("AT", "Austria", "Europe"),
    ("AU", "Australia", "Oceania"),
    ("BE", "Belgium", "Europe"),
    ("BG", "Bulgaria", "Europe"),
    ("BR", "Brazil", "South America"),
    ("CA", "Canada", "North America"),
    ("CH", "Switzerland", "Europe"),
    ("CL", "Chile", "South America"),
    ("CN", "China", "Asia"),
    ("CO", "Colombia", "South America"),
    ("CZ", "Czech Republic", "Europe"),
    ("DE", "Germany", "Europe"),
    ("DK", "Denmark", "Europe"),
    ("EC", "Ecuador", "South America"),
    ("EE", "Estonia", "Europe"),
    ("ES", "Spain", "Europe"),
    ("FI", "Finland", "Europe"),
    ("FR", "France", "Europe"),
    ("GB", "Great Britain", "Europe"),
    ("GR", "Greece", "Europe"),
    ("HK", "Hong Kong", "Asia"),
    ("HR", "Croatia", "Europe"),
    ("HU", "Hungary", "Europe"),
    ("ID", "Indonesia", "Asia"),
    ("IE", "Ireland", "Europe"),
    ("IL", "Israel", "Asia"),
    ("IN", "India", "Asia"),
    ("IS", "Iceland", "Europe"),
    ("IT", "Italy", "Europe"),
    ("JP", "Japan", "Asia"),
    ("KR", "South Korea", "Asia"),
    ("LT", "Lithuania", "Europe"),
    ("LU", "Luxembourg", "Europe"),
    ("LV", "Latvia", "Europe"),
    ("MX", "Mexico", "North America"),
    ("MY", "Malaysia", "Asia"),
    ("NL", "Netherlands", "Europe"),
    ("NO", "Norway", "Europe"),
    ("NZ", "New Zealand", "Oceania"),
    ("PE", "Peru", "South America"),
    ("PH", "Philippines", "Asia"),
    ("PL", "Poland", "Europe"),
    ("PT", "Portugal", "Europe"),
    ("RO", "Romania", "Europe"),
    ("RS", "Serbia", "Europe"),
    ("RU", "Russia", "Europe"),
    ("SA", "Saudi Arabia", "Asia"),
    ("SE", "Sweden", "Europe"),
    ("SG", "Singapore", "Asia"),
    ("SI", "Slovenia", "Europe"),
    ("SK", "Slovakia", "Europe"),
    ("TH", "Thailand", "Asia"),
    ("TR", "Turkey", "Asia"),
    ("TW", "Taiwan", "Asia"),
    ("UA", "Ukraine", "Europe"),
    ("US", "United States", "North America"),
    ("UY", "Uruguay", "South America"),
    ("VE", "Venezuela", "South America"),
    ("ZA", "South Africa", "Africa"),
    ("AM", "Armenia", "Asia"),
    ("AW", "Aruba", "North America"),
    ("AZ", "Azerbaijan", "Asia"),
    ("BA", "Bosnia and Herzegovina", "Europe"),
    ("BM", "Bermuda", "North America"),
    ("CR", "Costa Rica", "North America"),
    ("CY", "Cyprus", "Europe"),
    ("DO", "Dominican Republic", "North America"),
    ("EG", "Egypt", "Africa"),
    ("GG", "Guernsey", "Europe"),
    ("HN", "Honduras", "North America"),
    ("JE", "Jersey", "Europe"),
    ("KG", "Kyrgyzstan", "Asia"),
    ("KZ", "Kazakhstan", "Asia"),
    ("ME", "Montenegro", "Europe"),
    ("MK", "North Macedonia", "Europe"),
    ("MO", "Macau", "Asia"),
    ("MT", "Malta", "Europe"),
    ("NA", "Namibia", "Africa"),
    ("NG", "Nigeria", "Africa"),
    ("NP", "Nepal", "Asia"),
    ("PA", "Panama", "North America"),
    ("PR", "Puerto Rico", "North America"),
    ("PY", "Paraguay", "South America"),
    ("RE", "Reunion", "Africa"),
    ("UZ", "Uzbekistan", "Asia"),
    ("VI", "U.S. Virgin Islands", "North America"),
    ("VN", "Vietnam", "Asia"),
]

In [0]:
mapping_schema = ["country_code", "country_name", "continent"]
mapping_df = spark.createDataFrame(country_mapping, mapping_schema)

countries_with_mapping = countries_df.alias("c").join(
    mapping_df.alias("m"),
    F.col("c.country") == F.col("m.country_code"),
    "left"
).select(
    F.col("c.country"),
    F.col("m.country_name"),
    F.col("m.continent")
)

In [0]:
unmapped = countries_with_mapping.filter(F.col("country_name").isNull())
unmapped_count = unmapped.count()

if unmapped_count > 0:
    print(f"Warning: {unmapped_count} unmapped countries:")
    display(unmapped.select("country"))
else:
    print("All countries mapped successfully")

In [0]:
countries_df = countries_with_mapping.withColumn(
    "country_name",
    F.coalesce(F.col("country_name"), F.col("country"))
).withColumn(
    "continent",
    F.coalesce(F.col("continent"), F.lit("Unknown"))
)

countries_df = countries_df.withColumn("country_key", F.abs(F.hash(F.col("country"))))

In [0]:
athlete_counts = (
    silver_df
    .filter(F.col("country").isNotNull())
    .groupBy("country")
    .agg(F.countDistinct("athlete_name").alias("athlete_count"))
)

In [0]:
countries_df = countries_df.join(athlete_counts, on="country", how="left").withColumn(
    "athlete_count",
    F.coalesce(F.col("athlete_count"), F.lit(0))
)

countries_df = (
    countries_df
    .withColumn("created_at", F.current_timestamp())
    .withColumn("updated_at", F.current_timestamp())
)

dim_countries = countries_df.select(
    "country_key",
    "country",
    "country_name",
    "continent",
    "athlete_count",
    "created_at",
    "updated_at"
)

print(f"Final row count: {dim_countries.count()}")

In [0]:
table_exists = spark.catalog.tableExists(TARGET_TABLE)

if (not table_exists) or (run_mode == "full"):
    print(f"Full load to {TARGET_TABLE}")
    (
        dim_countries.write
        .format("delta")
        .mode("overwrite")
        .option("overwriteSchema", "true")
        .saveAsTable(TARGET_TABLE)
    )
else:
    print(f"Incremental merge to {TARGET_TABLE}")
    delta_table = DeltaTable.forName(spark, TARGET_TABLE)
    (
        delta_table.alias("target")
        .merge(
            dim_countries.alias("source"),
            "target.country = source.country"
        )
        .whenMatchedUpdate(set={
            "country_name": "source.country_name",
            "continent": "source.continent",
            "athlete_count": "source.athlete_count",
            "updated_at": "source.updated_at"
        })
        .whenNotMatchedInsertAll()
        .execute()
    )

print("Write complete")

In [0]:
result_df = spark.table(TARGET_TABLE)

print(f"Table: {TARGET_TABLE}")
print(f"Total countries: {result_df.count()}")

print("\nCountries by continent:")
display(
    result_df
    .groupBy("continent")
    .agg(
        F.count("*").alias("country_count"),
        F.sum("athlete_count").alias("total_athletes")
    )
    .orderBy(F.col("total_athletes").desc())
)

print("\nTop 10 countries by athletes:")
display(
    result_df
    .orderBy(F.col("athlete_count").desc())
    .limit(10)
)

In [0]:
print("\n" + "=" * 50)
print("DIMENSION TABLE COMPLETE: dim_countries")
print("=" * 50)
print(f"Table: {TARGET_TABLE}")
print(f"Rows: {spark.table(TARGET_TABLE).count()}")
print(f"Timestamp: {datetime.now()}")
print("=" * 50)