In [None]:
# Run this cell to import pyspark and to define start_spark() and stop_spark()

import findspark

findspark.init()

import getpass
import pandas
import pyspark
import random
import re

from IPython.display import display, HTML
from pyspark import SparkContext
from pyspark.sql import SparkSession


# Constants used to interact with Azure Blob Storage using the hdfs command or Spark

global username

username = re.sub('@.*', '', getpass.getuser())

global azure_account_name
global azure_data_container_name
global azure_user_container_name
global azure_user_token

azure_account_name = "madsstorage002"
azure_data_container_name = "campus-data"
azure_user_container_name = "campus-user"
azure_user_token = r"sp=racwdl&st=2025-08-01T09:41:33Z&se=2026-12-30T16:56:33Z&spr=https&sv=2024-11-04&sr=c&sig=GzR1hq7EJ0lRHj92oDO1MBNjkc602nrpfB5H8Cl7FFY%3D"


# Functions used below

def dict_to_html(d):
    """Convert a Python dictionary into a two column table for display.
    """

    html = []

    html.append(f'<table width="100%" style="width:100%; font-family: monospace;">')
    for k, v in d.items():
        html.append(f'<tr><td style="text-align:left;">{k}</td><td>{v}</td></tr>')
    html.append(f'</table>')

    return ''.join(html)


def start_spark(executor_memory='4g', executor_cores=4, dynamic_allocation=True,
                max_executors=20, min_executors=1, initial_executors=2):
    """Start a Spark session with the specified configuration."""
    
    # Build the configuration
    config = pyspark.SparkConf()
    config.set('spark.app.name', f'{username}-notebook')
    config.set('spark.executor.memory', executor_memory)
    config.set('spark.executor.cores', executor_cores)
    config.set('spark.sql.adaptive.enabled', 'true')
    config.set('spark.sql.adaptive.coalescePartitions.enabled', 'true')
    config.set('spark.sql.execution.arrow.pyspark.enabled', 'true')
    
    # Azure Blob Storage configuration
    config.set(f'fs.azure.sas.{azure_data_container_name}.{azure_account_name}.blob.core.windows.net', azure_user_token)
    config.set(f'fs.azure.sas.{azure_user_container_name}.{azure_account_name}.blob.core.windows.net', azure_user_token)
    
    if dynamic_allocation:
        config.set('spark.dynamicAllocation.enabled', 'true')
        config.set('spark.dynamicAllocation.maxExecutors', str(max_executors))
        config.set('spark.dynamicAllocation.minExecutors', str(min_executors))
        config.set('spark.dynamicAllocation.initialExecutors', str(initial_executors))
    
    # Create the Spark session
    spark = SparkSession.builder.config(conf=config).getOrCreate()
    sc = spark.sparkContext
    
    print(f"Spark session started for user: {username}")
    print(f"Spark version: {spark.version}")
    print(f"Spark UI: {sc.uiWebUrl}")
    
    return spark, sc


def stop_spark():
    """Stop the current Spark session."""
    try:
        spark = SparkSession.getActiveSession()
        if spark:
            spark.stop()
            print("Spark session stopped successfully")
        else:
            print("No active Spark session found")
    except Exception as e:
        print(f"Error stopping Spark session: {e}")

In [None]:
spark, sc = start_spark()

In [None]:
import time
import os
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import seaborn as sns
import geopandas as gpd
import pandas as pd
import numpy as np
from math import radians, sin, cos, sqrt, atan2

from pyspark.sql import functions as F
from pyspark.sql.types import *
from pyspark.sql.functions import udf

notebook_run_time = time.time()

def bprint(s):
    print(f"_{s:_^60}_")

def show_as_html(df, limit=20):
    """Display Spark DataFrame as HTML table"""
    pandas_df = df.limit(limit).toPandas()
    display(HTML(pandas_df.to_html()))

def show_df(df, name="DataFrame", limit=5):
    """Show DataFrame info and sample"""
    print(f"[diag] {name} schema:")
    df.printSchema()
    print(f"[diag] {name} sample:")
    df.show(limit, truncate=False)

## Q1

### a)

In [None]:
bprint("Analysis Q1(a)1 - Load enriched stations")

enriched_write_name = f"wasbs://campus-user@madsstorage002.blob.core.windows.net/{username}/enriched_stations.parquet"
enriched = spark.read.parquet(enriched_write_name).cache()

show_as_html(enriched)

In [None]:
bprint("Analysis Q1(a)2 - Total and active stations")

total_stations = enriched.count()
print(f"Total stations: {total_stations:,}")

active_stations = enriched.filter(F.col("station_last") >= 2025)
active_count = active_stations.count()
print(f"Active stations in 2025: {active_count:,}")

In [None]:
bprint("Analysis Q1(a)3 - Network counts")

gsn = enriched.filter(enriched['gsn_flag'] == "GSN").count()
print(f"GSN stations: {gsn:,}")

hcn = enriched.filter(enriched['hcn_crn_flag'] == "HCN").count()
print(f"HCN stations: {hcn:,}")

crn = enriched.filter(enriched['hcn_crn_flag'] == "CRN").count()
print(f"CRN stations: {crn:,}")

In [None]:
bprint("Analysis Q1(a)4 - Network overlaps")

gsn_hcn = enriched.filter((enriched['gsn_flag'] == "GSN") & (enriched['hcn_crn_flag'] == "HCN")).count()
print(f"GSN ∩ HCN stations: {gsn_hcn:,}")

gsn_crn = enriched.filter((enriched['gsn_flag'] == "GSN") & (enriched['hcn_crn_flag'] == "CRN")).count()
print(f"GSN ∩ CRN stations: {gsn_crn:,}")

hcn_crn = enriched.filter((enriched['hcn_crn_flag'] == "HCN") & (enriched['hcn_crn_flag'] == "CRN")).count()
print(f"HCN ∩ CRN stations: {hcn_crn:,}")

In [None]:
bprint("Analysis Q1(a)5 - Network overlap visualization")

os.makedirs("figures", exist_ok=True)

counts = {
    "GSN": gsn,
    "HCN": hcn,
    "CRN": crn,
    "GSN ∩ HCN": gsn_hcn,
    "GSN ∩ CRN": gsn_crn,
    "HCN ∩ CRN": hcn_crn,
    "All three": 0
}

labels = list(counts.keys())
values = list(counts.values())

plt.figure(figsize=(10, 6))
bars = plt.bar(labels, values, color='orange')
plt.ylabel("Number of stations")
plt.title("Station membership and overlaps (GSN, HCN, CRN)")
plt.xticks(rotation=0, ha="center")

for b in bars:
    height = b.get_height()
    plt.text(b.get_x() + b.get_width()/2, height + max(values)*0.01,
             f'{int(height):,}', ha='center', va='bottom')

plt.tight_layout()
plt.savefig("figures/network_overlaps.png", dpi=300, bbox_inches="tight")
plt.show()

### b)

In [None]:
bprint("Analysis Q1(b)1 - Southern Hemisphere stations")

southern_stations = enriched.filter(F.col("latitude") < 0).count()
print(f"Southern Hemisphere stations: {southern_stations:,}")

In [None]:
bprint("Analysis Q1(b)2 - US territories")

us_territories = enriched.filter(
    (F.col("country_name").contains("United States")) & 
    (F.col('country_code') != 'US')
)

us_territories_count = us_territories.count()
print(f"US territories stations: {us_territories_count:,}")

if us_territories_count > 0:
    us_territories.show(10)

In [None]:
bprint("Analysis Q1(b)3 - Hemispheric analysis visualization")

north_hem = enriched.filter(F.col("latitude") >= 0).count()
south_hem = southern_stations
century_global = enriched.filter((F.col("station_last") - F.col("station_first")) >= 100).count()
nz_stations = enriched.filter(F.col("country_code") == "NZ").count()

labels1 = ["Southern\nHemisphere", "US territories\n(south of equator)", "≥100 years (global)"]
us_territories_south = us_territories.filter(F.col("latitude") < 0).count()
values1 = [south_hem, us_territories_south, century_global]

plt.figure(figsize=(10, 5))
bars = plt.bar(labels1, values1, color='skyblue')
plt.ylabel("Number of stations")
plt.title("Coverage (southern focus)")
plt.xticks(rotation=0, ha="center")

for b in bars:
    height = b.get_height()
    plt.text(b.get_x() + b.get_width()/2, height + max(values1)*0.01,
             f'{int(height):,}', ha='center', va='bottom')

plt.tight_layout()
plt.savefig("figures/southern_focus.png", dpi=300, bbox_inches="tight")
plt.show()

labels2 = ["Northern\nHemisphere", "Southern\nHemisphere", "New Zealand\n(all)"]
values2 = [north_hem, south_hem, nz_stations]

plt.figure(figsize=(10, 5))
bars = plt.bar(labels2, values2, color='lightcoral')
plt.ylabel("Number of stations")
plt.title("Coverage by hemisphere with New Zealand")
plt.xticks(rotation=0, ha="center")

for b in bars:
    height = b.get_height()
    plt.text(b.get_x() + b.get_width()/2, height + max(values2)*0.01,
             f'{int(height):,}', ha='center', va='bottom')

plt.tight_layout()
plt.savefig("figures/hemispheric_comparison.png", dpi=300, bbox_inches="tight")
plt.show()

### c)

In [None]:
bprint("Analysis Q1(c)1 - Country station counts")

country_counts = enriched.groupBy("country_code", "country_name").agg(
    F.count("*").alias("station_count")
).orderBy(F.desc("station_count"))

print("Top 10 countries by station count:")
show_as_html(country_counts, 10)

country_counts_write = f"wasbs://campus-user@madsstorage002.blob.core.windows.net/{username}/country_station_counts.parquet"
country_counts.write.mode("overwrite").parquet(country_counts_write)
print(f"[info] Country counts saved to: {country_counts_write}")

In [None]:
bprint("Analysis Q1(c)2 - Core elements coverage")

core_elements = ['PRCP', 'SNOW', 'SNWD', 'TMAX', 'TMIN']
stations_with_core = enriched.filter(F.size(F.col("elements")) >= 1).count()
stations_with_all5 = enriched.filter(F.size(F.col("elements")) >= 5).count()

print(f"Stations with ≥1 core element: {stations_with_core:,}")
print(f"Stations with all 5 core elements: {stations_with_all5:,}")

labels = ['≥1 core element', 'All 5 core elements']
values = [stations_with_core, stations_with_all5]

plt.figure(figsize=(8, 5))
bars = plt.bar(labels, values, color='lightgreen')
plt.ylabel("Number of stations")
plt.title("Core element coverage")
plt.xticks(rotation=0, ha="center")

for b in bars:
    height = b.get_height()
    plt.text(b.get_x() + b.get_width()/2, height + max(values)*0.01,
             f'{int(height):,}', ha='center', va='bottom')

plt.tight_layout()
plt.savefig("figures/core_elements_coverage.png", dpi=300, bbox_inches="tight")
plt.show()

## Q2

### a)

In [None]:
bprint("Analysis Q2(a)1 - Haversine distance function")

def haversine_distance(lat1, lon1, lat2, lon2):
    """Calculate the great circle distance between two points on the earth"""
    if any(x is None for x in [lat1, lon1, lat2, lon2]):
        return None
        
    R = 6371.0  # Earth radius in kilometers
    
    lat1_rad = radians(lat1)
    lon1_rad = radians(lon1)
    lat2_rad = radians(lat2)
    lon2_rad = radians(lon2)
    
    dlat = lat2_rad - lat1_rad
    dlon = lon2_rad - lon1_rad
    
    a = sin(dlat/2)**2 + cos(lat1_rad) * cos(lat2_rad) * sin(dlon/2)**2
    c = 2 * atan2(sqrt(a), sqrt(1-a))
    
    return R * c

haversine_udf = udf(haversine_distance, DoubleType())
print("[info] Haversine distance UDF registered")

In [None]:
bprint("Analysis Q2(a)2 - Distance function test")

test_stations = enriched.limit(5)

cross_df = test_stations.alias("a").crossJoin(test_stations.alias("b"))
cross_df = cross_df.filter(F.col("a.id") < F.col("b.id"))

result_df = cross_df.withColumn(
    'distance_km', 
    haversine_udf(F.col("a.latitude"), F.col("a.longitude"),
                  F.col("b.latitude"), F.col("b.longitude"))
)

result_df = result_df.select(
    "a.country_code", "a.station_name", "a.latitude", "a.longitude",
    "b.country_code", "b.station_name", "b.latitude", "b.longitude",
    "distance_km"
)

print("Distance calculation test results:")
show_as_html(result_df)

### b)

In [None]:
bprint("Analysis Q2(b)1 - New Zealand pairwise distances")

nz_stations_df = enriched.filter(F.col("country_code") == "NZ")
nz_count = nz_stations_df.count()
print(f"New Zealand stations: {nz_count}")

if nz_count > 0:
    nz_cross = nz_stations_df.alias("a").crossJoin(nz_stations_df.alias("b"))
    nz_cross = nz_cross.filter(F.col("a.id") < F.col("b.id"))
    
    nz_distances = nz_cross.withColumn(
        'distance_km',
        haversine_udf(F.col("a.latitude"), F.col("a.longitude"),
                      F.col("b.latitude"), F.col("b.longitude"))
    )
    
    nz_result = nz_distances.select(
        F.col("a.id").alias("station_a"),
        F.col("a.station_name").alias("name_a"),
        F.col("b.id").alias("station_b"),
        F.col("b.station_name").alias("name_b"),
        "distance_km"
    ).orderBy("distance_km")
    
    pairs_count = nz_result.count()
    print(f"Total station pairs: {pairs_count}")
    
    print("\nClosest station pairs:")
    show_as_html(nz_result, 5)
    
    print("\nFarthest station pairs:")
    show_as_html(nz_result.orderBy(F.desc("distance_km")), 5)
    
    nz_distances_write = f"wasbs://campus-user@madsstorage002.blob.core.windows.net/{username}/nz_station_distances.parquet"
    nz_result.write.mode("overwrite").parquet(nz_distances_write)
    print(f"\n[info] NZ distances saved to: {nz_distances_write}")
else:
    print("[warning] No New Zealand stations found in dataset")

In [None]:
bprint("Analysis Q2(b)2 - Precipitation analysis preparation")

daily_schema = StructType([
    StructField("id", StringType(), True),
    StructField("date", StringType(), True),
    StructField("element", StringType(), True),
    StructField("value", IntegerType(), True),
    StructField("measurement_flag", StringType(), True),
    StructField("quality_flag", StringType(), True),
    StructField("source_flag", StringType(), True),
    StructField("observation_time", StringType(), True),
])

daily_df = spark.read.format("csv") \
    .option("header", "false") \
    .option("sep", ",") \
    .schema(daily_schema) \
    .load('wasbs://campus-data@madsstorage002.blob.core.windows.net/ghcnd/daily/*.csv.gz')

print("[info] Daily data loaded")
print(f"[info] Daily data sample count: {daily_df.limit(1000).count()}")

In [None]:
bprint("Analysis Q2(b)3 - Precipitation by country and year")

prcp_df = daily_df.filter(F.col("element") == "PRCP")

prcp_df = prcp_df.withColumn("country_code", F.substring("id", 1, 2))
prcp_df = prcp_df.withColumn("year", F.year(F.to_date("date", "yyyyMMdd")))

prcp_df = prcp_df.withColumn("value_mm", F.col("value") / 10.0)
prcp_df = prcp_df.filter(F.col("value_mm") >= 0)

prcp_agg = prcp_df.groupBy("year", "country_code").agg(
    F.avg("value_mm").alias("avg_daily_rainfall")
).orderBy("year", "country_code")

print("Sample of precipitation aggregation:")
show_as_html(prcp_agg, 10)

prcp_agg_write = f"wasbs://campus-user@madsstorage002.blob.core.windows.net/{username}/precipitation_by_country_year.parquet"
prcp_agg.write.mode("overwrite").parquet(prcp_agg_write)
print(f"\n[info] Precipitation data saved to: {prcp_agg_write}")

In [None]:
bprint("Analysis Q2(b)4 - Global rainfall 2024 choropleth map")

prcp_2024 = prcp_agg.filter(F.col("year") == 2024)
prcp_2024_pandas = prcp_2024.toPandas()

print(f"Countries with 2024 rainfall data: {len(prcp_2024_pandas)}")

if len(prcp_2024_pandas) > 0:
    world_url = "https://raw.githubusercontent.com/datasets/geo-countries/master/data/countries.geojson"
    world = gpd.read_file(world_url)
    
    prcp_2024_pandas['avg_daily_rainfall'] = pd.to_numeric(prcp_2024_pandas['avg_daily_rainfall'], errors='coerce')
    prcp_2024_pandas = prcp_2024_pandas.dropna(subset=['avg_daily_rainfall'])
    
    world_merged = world.merge(prcp_2024_pandas, left_on='ISO_A2', right_on='country_code', how='left')
    
    fig, ax = plt.subplots(figsize=(15, 8))
    world_merged.plot(
        column='avg_daily_rainfall', 
        ax=ax, 
        legend=True, 
        cmap='Blues',
        missing_kwds={'color': 'lightgrey'},
        legend_kwds={'label': "Average Daily Rainfall (mm)", 'orientation': "vertical"}
    )
    
    plt.title('Global Average Daily Rainfall (2024)', fontsize=16, pad=20)
    plt.axis('off')
    plt.tight_layout()
    
    plt.savefig('figures/global_rainfall_2024.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"\nRainfall statistics for 2024:")
    print(f"Countries with data: {len(prcp_2024_pandas):,}")
    print(f"Average rainfall range: {prcp_2024_pandas['avg_daily_rainfall'].min():.2f} - {prcp_2024_pandas['avg_daily_rainfall'].max():.2f} mm")
    print(f"Global mean: {prcp_2024_pandas['avg_daily_rainfall'].mean():.2f} mm")
    
    print("\nTop 10 wettest countries in 2024:")
    top_wet = prcp_2024_pandas.nlargest(10, 'avg_daily_rainfall')[['country_code', 'avg_daily_rainfall']]
    for _, row in top_wet.iterrows():
        print(f"  {row['country_code']}: {row['avg_daily_rainfall']:.2f} mm")
        
else:
    print("[warning] No 2024 precipitation data available for choropleth map")

## Q3

### a)

In [None]:
bprint("Analysis Q3(a)1 - Daily row count")

cell_time = time.time()

total_daily_rows = daily_df.count()
print(f"Total rows in daily dataset: {total_daily_rows:,}")

cell_time = time.time() - cell_time
print(f"[time] Count operation took: {cell_time:.2f} seconds")

### b)

In [None]:
bprint("Analysis Q3(b)1 - Core elements observation counts")

core_elements = ['PRCP', 'SNOW', 'SNWD', 'TMAX', 'TMIN']

daily_subset = daily_df.sample(0.001)

core_element_counts = daily_subset.filter(
    F.col("element").isin(core_elements)
).groupBy("element").agg(
    F.count("*").alias("observation_count")
).orderBy(F.desc("observation_count"))

print("Core elements observation counts (from sample):")
show_as_html(core_element_counts)

core_counts_pandas = core_element_counts.toPandas()
if len(core_counts_pandas) > 0:
    plt.figure(figsize=(10, 6))
    bars = plt.bar(core_counts_pandas['element'], core_counts_pandas['observation_count'], 
                   color=['blue', 'lightblue', 'cyan', 'red', 'orange'])
    plt.ylabel("Observation count (sample)")
    plt.title("Core Elements Observation Counts")
    plt.xticks(rotation=0)
    
    for b in bars:
        height = b.get_height()
        plt.text(b.get_x() + b.get_width()/2, height + max(core_counts_pandas['observation_count'])*0.01,
                 f'{int(height):,}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.savefig("figures/core_elements_counts.png", dpi=300, bbox_inches="tight")
    plt.show()

### c)

In [None]:
bprint("Analysis Q3(c)1 - TMAX without TMIN analysis")

daily_sample = daily_df.sample(0.01)

tmax_obs = daily_sample.filter(daily_sample.element == "TMAX").select("id", "date", "element")
tmin_obs = daily_sample.filter(daily_sample.element == "TMIN").select("id", "date", "element")

print(f"TMAX observations in sample: {tmax_obs.count():,}")
print(f"TMIN observations in sample: {tmin_obs.count():,}")

tmax_without_tmin = tmax_obs.join(tmin_obs, on=["id", "date"], how="left_anti")

tmax_no_tmin_count = tmax_without_tmin.count()
unique_stations = tmax_without_tmin.select("id").distinct().count()

print(f"\nTMAX observations without corresponding TMIN: {tmax_no_tmin_count:,}")
print(f"Unique stations contributing: {unique_stations:,}")

if tmax_no_tmin_count > 0:
    percentage = (tmax_no_tmin_count / tmax_obs.count()) * 100
    print(f"Percentage of TMAX without TMIN: {percentage:.2f}%")
    
    print("\nSample stations with TMAX but no TMIN:")
    station_sample = tmax_without_tmin.select("id").distinct().limit(10)
    show_as_html(station_sample)

In [None]:
bprint("Additional Analysis 1 - NZ temperature time series")

if nz_count > 0:
    nz_daily = daily_df.join(nz_stations_df.select("id"), on="id", how="inner")
    nz_temp = nz_daily.filter((nz_daily.element == "TMIN") | (nz_daily.element == "TMAX"))
    nz_temp = nz_temp.select("id", "date", "element", "value")
    
    nz_temp_sample = nz_temp.sample(0.1)
    nz_temp_pandas = nz_temp_sample.toPandas()
    
    if len(nz_temp_pandas) > 0:
        nz_temp_pandas['date'] = pd.to_datetime(nz_temp_pandas['date'], format='%Y%m%d')
        nz_temp_pandas['temp_c'] = nz_temp_pandas['value'] / 10.0
        
        nz_pivot = nz_temp_pandas.pivot_table(
            index=['id', 'date'], 
            columns='element', 
            values='temp_c'
        ).reset_index()
        
        nz_pivot['year_month'] = nz_pivot['date'].dt.to_period('M')
        monthly_avg = nz_pivot.groupby('year_month')[['TMIN', 'TMAX']].mean().reset_index()
        monthly_avg['year_month'] = monthly_avg['year_month'].dt.to_timestamp()
        
        plt.figure(figsize=(12, 6))
        plt.plot(monthly_avg['year_month'], monthly_avg['TMIN'], label='Avg TMIN', color='blue', alpha=0.7)
        plt.plot(monthly_avg['year_month'], monthly_avg['TMAX'], label='Avg TMAX', color='red', alpha=0.7)
        plt.title("Average TMIN and TMAX in New Zealand")
        plt.xlabel("Date")
        plt.ylabel("Temperature (°C)")
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.savefig("figures/nz_temperature_timeseries.png", dpi=300, bbox_inches="tight")
        plt.show()
        
        print(f"[info] Processed {len(nz_temp_pandas):,} NZ temperature observations")
    else:
        print("[warning] No NZ temperature data in sample")
else:
    print("[warning] No New Zealand stations available")

In [None]:
bprint("Final Summary - Analysis completion")

print("=" * 60)
print("ANALYSIS COMPLETION SUMMARY")
print("=" * 60)
print()
print("✅ Q1 Analysis:")
print(f"   • Total stations: {total_stations:,}")
print(f"   • Active in 2025: {active_count:,}")
print(f"   • Southern Hemisphere: {southern_stations:,}")
print(f"   • Network analysis complete")
print(f"   • Hemispheric visualizations complete")
print()
print("✅ Q2 Analysis:")
print(f"   • Haversine distance function implemented")
print(f"   • NZ station analysis: {nz_count} stations")
print(f"   • Precipitation analysis complete")
print(f"   • Global rainfall 2024 choropleth map created")
print()
print("✅ Q3 Analysis:")
print(f"   • Daily dataset rows: {total_daily_rows:,}")
print(f"   • Core elements analysis complete")
print(f"   • TMAX/TMIN correspondence analysis complete")
print()
print("📊 Visualizations Created:")
print("   • Network overlaps chart")
print("   • Hemispheric analysis charts")
print("   • Core elements coverage chart")
print("   • Global rainfall 2024 choropleth map")
print("   • Core elements observation counts")
print("   • NZ temperature time series")
print()

total_runtime = time.time() - notebook_run_time
print(f"📈 Total Analysis Runtime: {total_runtime/60:.2f} minutes")
print("=" * 60)

In [None]:
bprint("Stop Spark session")

stop_spark()