# Cloud Assignment
## Luke Winters, Jeff Bowers
### Data Pipeline for Global Influenza Surveillance Data Enriched with Population Data Using Apache Spark

# ---------------------------------------------------------------------

### Let's start by getting a Spark session going

In [1]:
from pyspark.sql import SparkSession

# Create or get a Spark session
spark = SparkSession.builder.appName("FluNetAssignment").getOrCreate()

spark


## Bronze Layer - Raw FluNet Ingestion and Initial Cleaning
### We can union the datasets

In [2]:
from pyspark.sql import functions as F

# Read each WHO region file
df_afr = (spark.read
    .option("header", True)
    .option("inferSchema", True)
    .csv("data/FluNet_AFR.csv")
)

df_amr = (spark.read
    .option("header", True)
    .option("inferSchema", True)
    .csv("data/FluNet_AMR.csv")
)

df_emr = (spark.read
    .option("header", True)
    .option("inferSchema", True)
    .csv("data/FluNet_EMR.csv")
)

df_eur = (spark.read
    .option("header", True)
    .option("inferSchema", True)
    .csv("data/FluNet_EUR.csv")
)

df_sear = (spark.read
    .option("header", True)
    .option("inferSchema", True)
    .csv("data/FluNet_SEAR.csv")
)

df_wpr = (spark.read
    .option("header", True)
    .option("inferSchema", True)
    .csv("data/FluNet_WPR.csv")
)

# Combine them into one big raw DataFrame
df_raw = (
    df_afr
    .unionByName(df_amr)
    .unionByName(df_emr)
    .unionByName(df_eur)
    .unionByName(df_sear)
    .unionByName(df_wpr)
)

print("Total rows across all WHO regions:", df_raw.count())

# Quick look at schema and some rows
df_raw.printSchema()
df_raw.show(5)


Total rows across all WHO regions: 184351
root
 |-- WHOREGION: string (nullable = true)
 |-- FLUSEASON: string (nullable = true)
 |-- HEMISPHERE: string (nullable = true)
 |-- ITZ: string (nullable = true)
 |-- COUNTRY_CODE: string (nullable = true)
 |-- COUNTRY_AREA_TERRITORY: string (nullable = true)
 |-- ISO_WEEKSTARTDATE: string (nullable = true)
 |-- ISO_YEAR: string (nullable = true)
 |-- ISO_WEEK: string (nullable = true)
 |-- MMWR_WEEKSTARTDATE: string (nullable = true)
 |-- MMWR_YEAR: integer (nullable = true)
 |-- MMWR_WEEK: string (nullable = true)
 |-- ORIGIN_SOURCE: string (nullable = true)
 |-- SPEC_PROCESSED_NB: integer (nullable = true)
 |-- SPEC_RECEIVED_NB: integer (nullable = true)
 |-- AH1N12009: integer (nullable = true)
 |-- AH1: integer (nullable = true)
 |-- AH3: integer (nullable = true)
 |-- AH5: integer (nullable = true)
 |-- AH7N9: string (nullable = true)
 |-- ANOTSUBTYPED: integer (nullable = true)
 |-- ANOTSUBTYPABLE: string (nullable = true)
 |-- AOTHER_

### Now we're going to do some additional null profiling to ensure data quality. We'll do this by getting an idea of how many nulls are in each of the most important columns, by percentage.

In [3]:
from pyspark.sql import functions as F

important_cols = [
    "WHOREGION",
    "COUNTRY_CODE",
    "COUNTRY_AREA_TERRITORY",
    "ISO_YEAR",
    "ISO_WEEK",
    "SPEC_PROCESSED_NB",
    "INF_ALL",
    "INF_NEGATIVE"
]

# Calculate percentage of NULL values for each important column in the raw dataset
null_profile_raw = (
    df_raw
    .select([
        F.round(
            F.avg(F.col(c).isNull().cast("double")) * 100.0, 2
        ).alias(f"{c}_null_pct")
        for c in important_cols
    ])
)

print("Percentage of nulls in key columns (raw df_raw):")
null_profile_raw.show(truncate=False)


Percentage of nulls in key columns (raw df_raw):
+------------------+---------------------+-------------------------------+-----------------+-----------------+--------------------------+----------------+---------------------+
|WHOREGION_null_pct|COUNTRY_CODE_null_pct|COUNTRY_AREA_TERRITORY_null_pct|ISO_YEAR_null_pct|ISO_WEEK_null_pct|SPEC_PROCESSED_NB_null_pct|INF_ALL_null_pct|INF_NEGATIVE_null_pct|
+------------------+---------------------+-------------------------------+-----------------+-----------------+--------------------------+----------------+---------------------+
|0.0               |3.37                 |3.48                           |3.82             |3.83             |10.47                     |46.99           |62.71                |
+------------------+---------------------+-------------------------------+-----------------+-----------------+--------------------------+----------------+---------------------+



### We do a quick count to make sure the union included all lines from the files individually

In [4]:
# Row counts per regional file
count_afr  = df_afr.count()
count_amr  = df_amr.count()
count_emr  = df_emr.count()
count_eur  = df_eur.count()
count_sear = df_sear.count()
count_wpr  = df_wpr.count()

print("AFR rows: ", count_afr)
print("AMR rows: ", count_amr)
print("EMR rows: ", count_emr)
print("EUR rows: ", count_eur)
print("SEAR rows:", count_sear)
print("WPR rows: ", count_wpr)


# Add all the counts together, to make sure it all equals
total_from_parts = (
    count_afr + count_amr + count_emr +
    count_eur + count_sear + count_wpr
)

print("\nTotal rows from individual parts:", total_from_parts)
print("df_raw.count():                   ", df_raw.count())


AFR rows:  22627
AMR rows:  41427
EMR rows:  13670
EUR rows:  77974
SEAR rows: 8508
WPR rows:  20145

Total rows from individual parts: 184351
df_raw.count():                    184351


### Check for duplicates: First we check for exact duplicates

In [5]:
# Exact duplicate check (whole row)
total_rows = df_raw.count()
distinct_rows = df_raw.distinct().count()

print("Total rows:   ", total_rows)
print("Distinct rows:", distinct_rows)
print("Duplicate rows (exact matches):", total_rows - distinct_rows)


Total rows:    184351
Distinct rows: 183799
Duplicate rows (exact matches): 552


### Check for duplicates: Next we check for duplicate keys

In [6]:
key_cols = ["WHOREGION", "COUNTRY_CODE", "ISO_YEAR", "ISO_WEEK"]

dup_keys = (
    df_raw
    .groupBy(key_cols)
    .count()
    .filter(F.col("count") > 1)
)

print("Number of key combinations with duplicates:", dup_keys.count())
dup_keys.show(10)


Number of key combinations with duplicates: 36635
+--------------------+------------+--------+--------+-----+
|           WHOREGION|COUNTRY_CODE|ISO_YEAR|ISO_WEEK|count|
+--------------------+------------+--------+--------+-----+
|                 AFR|         GIN|    2023|       2|    2|
|                 AFR|         GIN|    2023|       4|    2|
|                 AFR|         MDG|    2021|      44|    2|
|                 AFR|         SLE|    2022|      23|    2|
|                 AFR|         SLE|    2022|      24|    2|
|                 AFR|         SLE|    2022|      22|    2|
|                 AFR|         SLE|    2022|      15|    2|
|                 AFR|         GIN|    2023|       6|    2|
|Case 2. Female 4 ...|        NULL|    NULL|    NULL|    2|
|                 AFR|         GIN|    2023|       7|    2|
+--------------------+------------+--------+--------+-----+
only showing top 10 rows



## Removing exact duplicate rows

In [7]:
# Remove exact (row-wise) duplicates
df_raw_nodup = df_raw.dropDuplicates()

print("After dropDuplicates:")
print("Total rows:   ", df_raw_nodup.count())
print("Distinct rows:", df_raw_nodup.distinct().count())


After dropDuplicates:
Total rows:    183799
Distinct rows: 183799


### We need to filter out obviously “invalid” rows (like that WHOREGION = Case 2. Female 4 ...) and keep only real country–weeks.

### This next step will do that by starting from df_raw_nodup (with exact duplicates already removed), and keep only rows where WHOREGION is one of the 6 real WHO Regions, and where COUNTRY_CODE, ISO_YEAR, ISO_WEEK are not null.


In [8]:
from pyspark.sql import functions as F

# Keep only rows with valid WHO region + non-null country/year/week
valid_regions = ["AFR", "AMR", "EMR", "EUR", "SEAR", "WPR"]

df_valid = (
    df_raw_nodup
    .filter(
        (F.col("WHOREGION").isin(valid_regions)) &
        F.col("COUNTRY_CODE").isNotNull() &
        F.col("ISO_YEAR").isNotNull() &
        F.col("ISO_WEEK").isNotNull()
    )
)

print("Rows after filtering invalid records:", df_valid.count())

# Check remaining duplicates by key: region + country + iso year/week
key_cols = ["WHOREGION", "COUNTRY_CODE", "ISO_YEAR", "ISO_WEEK"]

dup_keys = (
    df_valid
    .groupBy(key_cols)
    .count()
    .filter(F.col("count") > 1)
)

print("Number of key combinations with duplicates (after filtering):", dup_keys.count())
dup_keys.show(10)


Rows after filtering invalid records: 177224
Number of key combinations with duplicates (after filtering): 36216
+---------+------------+--------+--------+-----+
|WHOREGION|COUNTRY_CODE|ISO_YEAR|ISO_WEEK|count|
+---------+------------+--------+--------+-----+
|      AMR|         COL|    2021|      45|    2|
|      AMR|         GTM|    2023|      39|    2|
|      AMR|         BRA|    2023|      32|    2|
|      AMR|         USA|    2024|      14|    2|
|      AMR|         USA|    2020|      21|    2|
|      EMR|         AFG|    2023|      24|    2|
|      EMR|         EGY|    2023|      40|    2|
|      EMR|         ARE|    2023|      23|    2|
|      EUR|         AZE|    2024|      28|    2|
|      EUR|         ARM|    2014|      40|    2|
+---------+------------+--------+--------+-----+
only showing top 10 rows



### Next we are going to ensure that the years and weeks make sense. The FluNet dataset has records going back to 1995, so we'll start from there and ensure make sense (Only 52 weeks in a year so there can't be anything like Week 102, for example)

In [9]:
from pyspark.sql import functions as F


# Adding sensible year/week ranges
df_valid = df_valid.filter(
    (F.col("ISO_YEAR") >= 1995) &       
    (F.col("ISO_WEEK").between(1, 53))
)

print("Rows after enforcing year/week ranges:", df_valid.count())


Rows after enforcing year/week ranges: 177224


### The raw FluNet data contains multiple records per country–week.
### This pipeline will aggregate them to a single country–week record by summing all numeric surveillance counts.

In [10]:
from pyspark.sql import functions as F

# Define the grouping (key) columns
group_cols = [
    "WHOREGION",
    "COUNTRY_CODE",
    "COUNTRY_AREA_TERRITORY",
    "HEMISPHERE",
    "ITZ",
    "ISO_YEAR",
    "ISO_WEEK",
]

# Numeric columns we want to sum
numeric_cols = [
    "SPEC_PROCESSED_NB",
    "SPEC_RECEIVED_NB",
    "AH1N12009",
    "AH1",
    "AH3",
    "AH5",
    "ANOTSUBTYPED",
    "AOTHER_SUBTYPE",
    "BVIC_NODEL",
    "BYAM",
    "BNOTDETERMINED",
    "INF_B",
    "INF_ALL",
    "INF_NEGATIVE",
    "ILI_ACTIVITY",
    "ADENO",
    "BOCA",
    "METAPNEUMO",
    "PARAINFLUENZA",
    "RHINO",
    "RSV_PROCESSED",
    "RSV",
    "OTHERRESPVIRUS",
    "PSOURCE_SUBTYPE_INF",
    "PSOURCE_PPOS_INF",
    "PSOURCE_RSV",
]

# Build aggregation expressions: sum each numeric column
agg_exprs = [F.sum(F.col(c)).alias(c) for c in numeric_cols]

# Aggregate to one row per (region, country, year, week)
df_agg = df_valid.groupBy(*group_cols).agg(*agg_exprs)

print("Rows after aggregation:", df_agg.count())

# Check that we no longer have duplicate keys
dup_keys_after = (
    df_agg
    .groupBy("WHOREGION", "COUNTRY_CODE", "ISO_YEAR", "ISO_WEEK")
    .count()
    .filter(F.col("count") > 1)
)

print("Number of key combinations with duplicates after aggregation:",
      dup_keys_after.count())

df_agg.show(5)


Rows after aggregation: 140108
Number of key combinations with duplicates after aggregation: 0
+---------+------------+----------------------+----------+-----------+--------+--------+-----------------+----------------+---------+----+---+---+------------+--------------+----------+----+--------------+-----+-------+------------+------------+-----+----+----------+-------------+-----+-------------+----+--------------+-------------------+----------------+-----------+
|WHOREGION|COUNTRY_CODE|COUNTRY_AREA_TERRITORY|HEMISPHERE|        ITZ|ISO_YEAR|ISO_WEEK|SPEC_PROCESSED_NB|SPEC_RECEIVED_NB|AH1N12009| AH1|AH3|AH5|ANOTSUBTYPED|AOTHER_SUBTYPE|BVIC_NODEL|BYAM|BNOTDETERMINED|INF_B|INF_ALL|INF_NEGATIVE|ILI_ACTIVITY|ADENO|BOCA|METAPNEUMO|PARAINFLUENZA|RHINO|RSV_PROCESSED| RSV|OTHERRESPVIRUS|PSOURCE_SUBTYPE_INF|PSOURCE_PPOS_INF|PSOURCE_RSV|
+---------+------------+----------------------+----------+-----------+--------+--------+-----------------+----------------+---------+----+---+---+------------+----

## Silver Layer - Cleaned Weekly, Country Level Dataset

### So now we have exactly one row per region, country, ISO year, ISO week

### We now make this data clean by giving nice names to columns

In [11]:
# Cleaned, renamed version of the aggregated table
df_clean = (
    df_agg
    .select(
        # Dimensions
        F.col("WHOREGION").alias("who_region"),
        F.col("HEMISPHERE").alias("hemisphere"),
        F.col("ITZ").alias("itz"),
        F.col("COUNTRY_AREA_TERRITORY").alias("country_name"),
        F.col("COUNTRY_CODE").alias("country_code"),
        F.col("ISO_YEAR").cast("int").alias("iso_year"),
        F.col("ISO_WEEK").cast("int").alias("iso_week"),

        # Lab counts
        F.col("SPEC_RECEIVED_NB").alias("spec_received_nb"),
        F.col("SPEC_PROCESSED_NB").alias("spec_processed_nb"),
        F.col("AH1N12009").alias("ah1n1_2009"),
        F.col("AH1").alias("ah1"),
        F.col("AH3").alias("ah3"),
        F.col("AH5").alias("ah5"),
        F.col("ANOTSUBTYPED").alias("a_not_subtyped"),
        F.col("AOTHER_SUBTYPE").alias("a_other_subtype"),
        F.col("BVIC_NODEL").alias("b_vic_nodel"),
        F.col("BYAM").alias("b_yam"),
        F.col("BNOTDETERMINED").alias("b_not_determined"),
        F.col("INF_B").alias("inf_b"),
        F.col("INF_ALL").alias("inf_all"),
        F.col("INF_NEGATIVE").alias("inf_negative"),
        F.col("ILI_ACTIVITY").alias("ili_activity"),

        # Other respiratory viruses
        F.col("ADENO").alias("adeno"),
        F.col("BOCA").alias("boca"),
        F.col("METAPNEUMO").alias("metapneumo"),
        F.col("PARAINFLUENZA").alias("parainfluenza"),
        F.col("RHINO").alias("rhino"),
        F.col("RSV_PROCESSED").alias("rsv_processed"),
        F.col("RSV").alias("rsv"),
        F.col("OTHERRESPVIRUS").alias("other_resp_virus"),

        # Source / meta counts
        F.col("PSOURCE_SUBTYPE_INF").alias("psource_subtype_inf"),
        F.col("PSOURCE_PPOS_INF").alias("psource_ppos_inf"),
        F.col("PSOURCE_RSV").alias("psource_rsv"),
    )
)

df_clean.printSchema()
df_clean.show(5)


root
 |-- who_region: string (nullable = true)
 |-- hemisphere: string (nullable = true)
 |-- itz: string (nullable = true)
 |-- country_name: string (nullable = true)
 |-- country_code: string (nullable = true)
 |-- iso_year: integer (nullable = true)
 |-- iso_week: integer (nullable = true)
 |-- spec_received_nb: long (nullable = true)
 |-- spec_processed_nb: long (nullable = true)
 |-- ah1n1_2009: long (nullable = true)
 |-- ah1: long (nullable = true)
 |-- ah3: long (nullable = true)
 |-- ah5: long (nullable = true)
 |-- a_not_subtyped: long (nullable = true)
 |-- a_other_subtype: long (nullable = true)
 |-- b_vic_nodel: long (nullable = true)
 |-- b_yam: long (nullable = true)
 |-- b_not_determined: long (nullable = true)
 |-- inf_b: long (nullable = true)
 |-- inf_all: long (nullable = true)
 |-- inf_negative: long (nullable = true)
 |-- ili_activity: long (nullable = true)
 |-- adeno: long (nullable = true)
 |-- boca: long (nullable = true)
 |-- metapneumo: long (nullable = true

### Next step in cleaning our dataset is handly negative counts and completely empty weeks

In [12]:
from pyspark.sql import functions as F

count_cols = [
    "spec_received_nb",
    "spec_processed_nb",
    "inf_all",
    "inf_negative",
]

# Clip any negative lab counts to 0 
for c in count_cols:
    df_clean = df_clean.withColumn(
        c,
        F.when(F.col(c) < 0, 0).otherwise(F.col(c))
    )

# Drop rows that have no lab information at all
df_clean = df_clean.filter(
    (F.col("spec_processed_nb").isNotNull()) |
    (F.col("inf_all").isNotNull()) |
    (F.col("inf_negative").isNotNull())
)

print("Rows in df_clean after cleaning:", df_clean.count())

Rows in df_clean after cleaning: 134983


### Next we can add a new feature: positivity rate

In [13]:
# Add some derived features on top of df_clean
df_features = (
    df_clean
    # Convenience alias for clarity
    .withColumn("total_specimens", F.col("spec_processed_nb"))
    # Positivity rate: fraction of processed specimens that were positive
    .withColumn(
        "positivity_rate",
        F.when(
            (F.col("spec_processed_nb") > 0) & F.col("inf_all").isNotNull(),
            F.col("inf_all") / F.col("spec_processed_nb")
        ).otherwise(None)  # avoid divide-by-zero / missing
    )
)

# Look at a few rows where positivity_rate is not null
df_features.select(
    "who_region",
    "country_name",
    "iso_year",
    "iso_week",
    "total_specimens",
    "inf_all",
    "inf_negative",
    "positivity_rate"
).where(F.col("positivity_rate").isNotNull()).show(10)


+----------+--------------------+--------+--------+---------------+-------+------------+--------------------+
|who_region|        country_name|iso_year|iso_week|total_specimens|inf_all|inf_negative|     positivity_rate|
+----------+--------------------+--------+--------+---------------+-------+------------+--------------------+
|       AFR|              Guinea|    2020|      48|             21|      1|        NULL|0.047619047619047616|
|       AFR|                Mali|    2020|      34|             17|      2|        NULL| 0.11764705882352941|
|       AFR|               Niger|    2011|       6|             24|      7|          17|  0.2916666666666667|
|       AFR|              Zambia|    2021|      28|             64|      5|          59|            0.078125|
|       AFR|        South Africa|    2016|      20|            237|     29|         208| 0.12236286919831224|
|       AFR|        Burkina Faso|    2022|      34|              3|      2|           1|  0.6666666666666666|
|       AM

## Next step in the pipeline is to create a 4-week rolling features for each country (current week + previous 3), and calculate the sum of positive flu cases and the moving average positivity rate. 

## The pipeline then removes weeks where nothing happens in the lab, eg. no specimens are processed, and no positive or negative results.

In [14]:
# Temporal features + filtering structurally empty weeks
from pyspark.sql import functions as F
from pyspark.sql.window import Window

# 4-week rolling window by country (current week + previous 3)
w_country_4wk = (
    Window
    .partitionBy("country_code")
    .orderBy("iso_year", "iso_week")
    .rowsBetween(-3, 0)
)

df_features = (
    df_features
    .withColumn(
        "inf_all_4wk_sum",
        F.sum("inf_all").over(w_country_4wk)
    )
    .withColumn(
        "pos_rate_4wk_ma",
        F.avg("positivity_rate").over(w_country_4wk)
    )
)

# Drop weeks with no specimens processed, no positives, and no negatives
df_features = df_features.filter(
    ~(
        (F.col("spec_processed_nb").isNull() | (F.col("spec_processed_nb") == 0))
        & F.col("inf_all").isNull()
        & F.col("inf_negative").isNull()
    )
)

print("Rows in df_features after rolling features & filtering:", df_features.count())
df_features.select(
    "who_region", "country_name", "iso_year", "iso_week",
    "positivity_rate", "pos_rate_4wk_ma", "inf_all", "inf_all_4wk_sum"
).show(10)

Rows in df_features after rolling features & filtering: 128025
+----------+------------+--------+--------+-------------------+-------------------+-------+---------------+
|who_region|country_name|iso_year|iso_week|    positivity_rate|    pos_rate_4wk_ma|inf_all|inf_all_4wk_sum|
+----------+------------+--------+--------+-------------------+-------------------+-------+---------------+
|       AMR|       Aruba|    2017|       1|              0.275|              0.275|     11|             11|
|       AMR|       Aruba|    2017|       2|               NULL|              0.275|   NULL|             11|
|       AMR|       Aruba|    2017|       3|0.21428571428571427|0.24464285714285716|      3|             14|
|       AMR|       Aruba|    2017|       4|0.08333333333333333| 0.1908730158730159|      1|             15|
|       AMR|       Aruba|    2017|       5|               NULL| 0.1488095238095238|   NULL|              4|
|       AMR|       Aruba|    2017|       6|0.16666666666666666|0.15476190

### Next step in the data pipeline is to save what what has been done so far as a proper “processed” dataset. The dataset will be written out as Parquet, partitioned by year + region.

In [15]:
# Path inside the container (and under your cloud_assignment/data folder on Windows)
output_path = "data/processed/flunet_features"

(
    df_features
    .write
    .mode("overwrite")                 # safe to re-run pipeline
    .partitionBy("iso_year", "who_region")  # good for time/region queries
    .parquet(output_path)
)

print("Written processed data to:", output_path)


Written processed data to: data/processed/flunet_features


In [16]:
df_gold = spark.read.parquet("data/processed/flunet_features")

print("Rows in df_gold:", df_gold.count())
df_gold.show(5)


Rows in df_gold: 128025
+----------+-------------+------------+------------+--------+----------------+-----------------+----------+----+----+----+--------------+---------------+-----------+-----+----------------+-----+-------+------------+------------+-----+----+----------+-------------+-----+-------------+---+----------------+-------------------+----------------+-----------+---------------+--------------------+---------------+--------------------+--------+----------+
|hemisphere|          itz|country_name|country_code|iso_week|spec_received_nb|spec_processed_nb|ah1n1_2009| ah1| ah3| ah5|a_not_subtyped|a_other_subtype|b_vic_nodel|b_yam|b_not_determined|inf_b|inf_all|inf_negative|ili_activity|adeno|boca|metapneumo|parainfluenza|rhino|rsv_processed|rsv|other_resp_virus|psource_subtype_inf|psource_ppos_inf|psource_rsv|total_specimens|     positivity_rate|inf_all_4wk_sum|     pos_rate_4wk_ma|iso_year|who_region|
+----------+-------------+------------+------------+--------+----------------+

### Now the pipeline will enrich the FluNet data with the external population data

In [17]:
# Read the population CSV
df_pop_raw = (
    spark.read
    .option("header", True)
    .option("inferSchema", True)
    .csv("data/external/population.csv")
)

df_pop_raw.printSchema()
df_pop_raw.show(5)


root
 |-- Country Name: string (nullable = true)
 |-- Country Code: string (nullable = true)
 |-- Year: integer (nullable = true)
 |-- Value: double (nullable = true)

+------------+------------+----+-------+
|Country Name|Country Code|Year|  Value|
+------------+------------+----+-------+
|       Aruba|         ABW|1960|54922.0|
|       Aruba|         ABW|1961|55578.0|
|       Aruba|         ABW|1962|56320.0|
|       Aruba|         ABW|1963|57002.0|
|       Aruba|         ABW|1964|57619.0|
+------------+------------+----+-------+
only showing top 5 rows



### Clean the population data by giving nice names

In [18]:
# Cleaned population table
df_pop = (
    df_pop_raw
    .select(
        F.col("Country Code").alias("country_code"),
        F.col("Year").cast("int").alias("year"),
        F.col("Value").cast("double").alias("population")
    )
)

df_pop.printSchema()
df_pop.show(5)


root
 |-- country_code: string (nullable = true)
 |-- year: integer (nullable = true)
 |-- population: double (nullable = true)

+------------+----+----------+
|country_code|year|population|
+------------+----+----------+
|         ABW|1960|   54922.0|
|         ABW|1961|   55578.0|
|         ABW|1962|   56320.0|
|         ABW|1963|   57002.0|
|         ABW|1964|   57619.0|
+------------+----+----------+
only showing top 5 rows



## Gold Layer - Enriched Analytical Dataset
### We can  now join the population data and the FluNet data.
### We want to join on: country_code (3-letter ISO3) & iso_year (from FluNet) = year (from population)
### We can do a left join so we can keep all FluNet rows, even if some don't have population data
### We'll then add a column for the number of cases per 100k.

In [19]:
# Join FluNet features with population data
df_enriched = (
    df_features
    .join(
        df_pop,
        (df_features.country_code == df_pop.country_code) &
        (df_features.iso_year == df_pop.year),
        how="left"
    )
    # Drop duplicate columns from population side 
    .drop(df_pop.country_code)
    .drop("year")
)

# Add per-100k influenza positive rate
df_enriched = df_enriched.withColumn(
    "inf_all_per_100k",
    F.when(
        (F.col("inf_all").isNotNull()) &
        (F.col("population").isNotNull()) &
        (F.col("population") > 0),
        (F.col("inf_all") / F.col("population")) * 100000.0
    ).otherwise(None)
)

# Inspect some joined rows where we have population and inf_all > 0
df_enriched.select(
    "who_region",
    "country_name",
    "country_code",
    "iso_year",
    "iso_week",
    "population",
    "inf_all",
    "inf_all_per_100k",
    "positivity_rate"
).where(
    (F.col("population").isNotNull()) & (F.col("inf_all") > 0)
).show(10)


+----------+------------+------------+--------+--------+------------+-------+--------------------+--------------------+
|who_region|country_name|country_code|iso_year|iso_week|  population|inf_all|    inf_all_per_100k|     positivity_rate|
+----------+------------+------------+--------+--------+------------+-------+--------------------+--------------------+
|       AFR|      Guinea|         GIN|    2020|      48| 1.3371183E7|      1|0.007478769829116841|0.047619047619047616|
|       AFR|        Mali|         MLI|    2020|      34| 2.1713836E7|      2|0.009210717074587834| 0.11764705882352941|
|       AFR|       Niger|         NER|    2011|       6| 1.7176283E7|      7|  0.0407538697400363|  0.2916666666666667|
|       AFR|      Zambia|         ZMB|    2021|      28| 1.9603607E7|      5|0.025505510286958924|            0.078125|
|       AFR|South Africa|         ZAF|    2016|      20| 5.7259551E7|     29| 0.05064657248185547| 0.12236286919831224|
|       AFR|Burkina Faso|         BFA|  

### Next the pipeline restricts the dataset so only rows where population is known are included. It then reports how many rows are left. 


In [20]:
# Handling missing population and extra rate-based features
from pyspark.sql import functions as F

total_rows_enriched = df_enriched.count()

# Keep only rows where we know the population for rate-based analyses
df_enriched = df_enriched.filter(F.col("population").isNotNull())

print("Rows with known population:", df_enriched.count(), "out of", total_rows_enriched)


df_enriched.select(
    "who_region", "country_name", "iso_year", "iso_week",
    "population", "inf_all_per_100k"
).show(10)

Rows with known population: 108185 out of 128025
+----------+------------+--------+--------+------------+--------------------+
|who_region|country_name|iso_year|iso_week|  population|    inf_all_per_100k|
+----------+------------+--------+--------+------------+--------------------+
|       AFR|Sierra Leone|    2018|      23|   7554563.0|                NULL|
|       AFR|      Guinea|    2020|      48| 1.3371183E7|0.007478769829116841|
|       AFR|        Mali|    2020|      34| 2.1713836E7|0.009210717074587834|
|       AFR|       Niger|    2011|       6| 1.7176283E7|  0.0407538697400363|
|       AFR|      Zambia|    2021|      28| 1.9603607E7|0.025505510286958924|
|       AFR|South Africa|    2016|      20| 5.7259551E7| 0.05064657248185547|
|       AFR|South Africa|    2010|      16| 5.2344051E7|                NULL|
|       AFR|Burkina Faso|    2022|      34| 2.2509038E7|0.008885319754669213|
|       AMR|        Peru|    2021|      11| 3.3155882E7|                NULL|
|       AMR|   

### Now we can save this enriched “Gold” dataset

In [21]:
# Final "gold" dataset path
gold_path = "data/processed/flunet_enriched"

(
    df_enriched
    .write
    .mode("overwrite")
    .partitionBy("iso_year", "who_region")
    .parquet(gold_path)
)

print("Written enriched data to:", gold_path)


Written enriched data to: data/processed/flunet_enriched


In [22]:
df_gold_enriched = spark.read.parquet("data/processed/flunet_enriched")

print("Rows in df_gold_enriched:", df_gold_enriched.count())

df_gold_enriched.select(
    "who_region",
    "country_name",
    "country_code",
    "iso_year",
    "iso_week",
    "population",
    "inf_all",
    "inf_all_per_100k",
    "positivity_rate",
    "pos_rate_4wk_ma",
    "spec_processed_nb"
).show(5)


Rows in df_gold_enriched: 108185
+----------+------------+------------+--------+--------+-----------+-------+------------------+-----------------+------------------+-----------------+
|who_region|country_name|country_code|iso_year|iso_week| population|inf_all|  inf_all_per_100k|  positivity_rate|   pos_rate_4wk_ma|spec_processed_nb|
+----------+------------+------------+--------+--------+-----------+-------+------------------+-----------------+------------------+-----------------+
|       EUR|     Belgium|         BEL|    2023|       1|1.1787423E7|   1745|14.803914307648075|         54.53125| 55.62374853529087|               32|
|       EUR|     Belgium|         BEL|    2023|       2|1.1787423E7|    989| 8.390298710752978|            39.56|58.039389560931895|               25|
|       EUR|     Belgium|         BEL|    2023|       3|1.1787423E7|    718|  6.09123809334746|37.78947368421053| 51.13953575976231|               19|
|       EUR|     Belgium|         BEL|    2023|       4|1.178

### Finally, we create high level overview table from our final "gold" dataset.

### To do this, each WHO region and year is taken, and the total number of positive flu specimens is calculated, the total number of specimens tested is calculated, the average positivity rate, and the average positivity rate per 100,000 incidence is calculated.

### This gives us a nice and compact view which is perfect for dashboards or plots.



In [23]:
from pyspark.sql import functions as F

# Start from your saved gold dataset
df_gold_enriched = spark.read.parquet("data/processed/flunet_enriched")

# Summary by WHO region and year
df_region_year = (
    df_gold_enriched
    .groupBy("who_region", "iso_year")
    .agg(
        F.sum("inf_all").alias("total_influenza_positives"),
        F.sum("total_specimens").alias("total_specimens_tested"),
        F.avg("positivity_rate").alias("avg_positivity_rate"),
        F.avg("inf_all_per_100k").alias("avg_inf_all_per_100k")
    )
    .orderBy("who_region", "iso_year")
)

df_region_year.show(20)


+----------+--------+-------------------------+----------------------+-------------------+--------------------+
|who_region|iso_year|total_influenza_positives|total_specimens_tested|avg_positivity_rate|avg_inf_all_per_100k|
+----------+--------+-------------------------+----------------------+-------------------+--------------------+
|       AFR|    1996|                        9|                   186| 0.1741593567251462|0.024815225828481127|
|       AFR|    1997|                      186|                   980| 0.4706595971722022|0.021692869884744732|
|       AFR|    1998|                      390|                   494| 0.3161726508785332| 0.07799111605760091|
|       AFR|    1999|                      551|                  3094|0.26828093738554776| 0.09683751029029794|
|       AFR|    2000|                      154|                  1488|0.17248935974509344| 0.05553801358099915|
|       AFR|    2001|                      157|                  1366|0.21671782537183099| 0.06094444512

### Additionally, we created a CSV output for downstream analysis

In [24]:
csv_output_path = "data/processed/flunet_enriched_csv"

(
    df_gold_enriched
    .coalesce(1)  # put all data into 1 partition so we get a single CSV part file
    .write
    .mode("overwrite")
    .option("header", True)
    .csv(csv_output_path)
)

print(f"Wrote CSV export of gold dataset to: {csv_output_path}")
print("Note: Spark writes a folder with one part-*.csv file inside it.")

Wrote CSV export of gold dataset to: data/processed/flunet_enriched_csv
Note: Spark writes a folder with one part-*.csv file inside it.
