# DATA420 A1 — Analysis + Visualisations (additions only)
These cells can be appended to your existing **Analysis.ipynb** (and/or Visualizations).

In [None]:
# Setup Spark and load enriched stations
from pyspark.sql import SparkSession, functions as F

print("Starting SparkSession for Analysis additions...")
spark = SparkSession.builder.getOrCreate()
print("Spark version:", spark.version)

azure_account_name = "madsstorage002"
azure_data_container_name = "campus-data"
azure_user_container_name = "campus-user"

WASBS_DATA = f"wasbs://{azure_data_container_name}@{azure_account_name}.blob.core.windows.net/ghcnd"
WASBS_DAILY = f"{WASBS_DATA}/daily"
WASBS_USER_BASE = f"wasbs://{azure_user_container_name}@{azure_account_name}.blob.core.windows.net/dew59"

enriched_path = f"{WASBS_USER_BASE}/enriched_stations.parquet/"
print("Loading enriched stations:", enriched_path)
enriched_stations = spark.read.parquet(enriched_path)
print("Enriched stations count:", enriched_stations.count())

In [None]:
# Q1(a): Totals and network overlaps
total_stations = enriched_stations.count()
active_2025 = enriched_stations.filter(F.col("last_year") >= 2025).count()

gsn = enriched_stations.filter(F.col("gsn_flag")=="GSN").count()
hcn = enriched_stations.filter(F.col("hcn_crn_flag")=="HCN").count()
crn = enriched_stations.filter(F.col("hcn_crn_flag")=="CRN").count()

gsn_hcn = enriched_stations.filter((F.col("gsn_flag")=="GSN") & (F.col("hcn_crn_flag")=="HCN")).count()
gsn_crn = enriched_stations.filter((F.col("gsn_flag")=="GSN") & (F.col("hcn_crn_flag")=="CRN")).count()
hcn_crn = enriched_stations.filter((F.col("hcn_crn_flag")=="HCN") & (F.col("hcn_crn_flag")=="CRN")).count()

print("Total stations:", total_stations)
print("Active in 2025:", active_2025)
print("GSN:", gsn, "HCN:", hcn, "CRN:", crn)
print("GSN∩HCN:", gsn_hcn, "GSN∩CRN:", gsn_crn, "HCN∩CRN:", hcn_crn)

In [None]:
# Q1(b): Southern Hemisphere and U.S. territories outside 'US'
south_hemi = enriched_stations.filter(F.col("latitude") < 0).count()

us_territory_like = enriched_stations.filter(
    (F.col("country_name").contains("United States")) & (F.col("country_code") != "US")
).count()

print("Southern Hemisphere stations:", south_hemi)
print("United States territories (excluding US):", us_territory_like)

In [None]:
# Q1(c): Counts per country and per state
by_country = (enriched_stations.groupBy("country_code","country_name").agg(F.count("*").alias("station_count")))
by_state = (enriched_stations.filter(F.col("state").isNotNull())
            .groupBy("state","state_name").agg(F.count("*").alias("station_count")))

out_countries = f"{WASBS_USER_BASE}/enriched_countries.parquet/"
out_states = f"{WASBS_USER_BASE}/enriched_states.parquet/"
print("Writing:", out_countries)
by_country.write.mode("overwrite").parquet(out_countries)
print("Writing:", out_states)
by_state.write.mode("overwrite").parquet(out_states)

print("Sample countries:")
by_country.orderBy(F.desc("station_count")).show(10, truncate=False)

In [None]:
# Q2(a): Haversine UDF and small CROSS JOIN demo
import math
from pyspark.sql import types as T

def haversine_km(lat1, lon1, lat2, lon2):
    R = 6371.0
    phi1, phi2 = math.radians(lat1), math.radians(lat2)
    dphi = math.radians(lat2 - lat1)
    dlambda = math.radians(lon2 - lon1)
    a = math.sin(dphi/2.0)**2 + math.cos(phi1)*math.cos(phi2)*math.sin(dlambda/2.0)**2
    c = 2*math.atan2(math.sqrt(a), math.sqrt(1-a))
    return R*c

haversine_udf = F.udf(haversine_km, T.DoubleType())

demo = enriched_stations.select("id","name","latitude","longitude").limit(10)
pairs = demo.crossJoin(demo.select(
    F.col("id").alias("id2"),
    F.col("name").alias("name2"),
    F.col("latitude").alias("lat2"),
    F.col("longitude").alias("lon2")
)).filter(F.col("id") < F.col("id2"))

distances_demo = pairs.withColumn("km", haversine_udf("latitude","longitude","lat2","lon2"))
print("Demo pairwise distances (top 10):")
distances_demo.orderBy("km").show(10, truncate=False)

In [None]:
# Q2(b): Pairwise distances across all NZ stations; save and show closest pair
nz = enriched_stations.filter(F.col("country_code")=="NZ").select("id","name","latitude","longitude")
left = nz.alias("a")
right = nz.alias("b")

nz_pairs = (left.join(right, F.col("a.id") < F.col("b.id"))
                 .select(F.col("a.id").alias("id1"),
                         F.col("a.name").alias("name1"),
                         F.col("a.latitude").alias("lat1"),
                         F.col("a.longitude").alias("lon1"),
                         F.col("b.id").alias("id2"),
                         F.col("b.name").alias("name2"),
                         F.col("b.latitude").alias("lat2"),
                         F.col("b.longitude").alias("lon2")))

nz_dist = nz_pairs.withColumn("km", haversine_udf("lat1","lon1","lat2","lon2"))
closest = nz_dist.orderBy("km").limit(1)

out_nz_pairs = f"{WASBS_USER_BASE}/nz_station_pairwise_distances.parquet/"
print("Writing:", out_nz_pairs)
nz_dist.write.mode("overwrite").parquet(out_nz_pairs)

print("Closest pair in NZ:")
closest.show(truncate=False)

In [None]:
# Q3(a): Core element observation counts
from pyspark.sql import types as T

daily_schema = T.StructType([
    T.StructField("ID", T.StringType(), True),
    T.StructField("DATE", T.StringType(), True),
    T.StructField("ELEMENT", T.StringType(), True),
    T.StructField("VALUE", T.IntegerType(), True),
    T.StructField("MFLAG", T.StringType(), True),
    T.StructField("QFLAG", T.StringType(), True),
    T.StructField("SFLAG", T.StringType(), True),
    T.StructField("OBSTIME", T.StringType(), True)
])

daily_all = spark.read.csv(f"{WASBS_DAILY}/*.csv.gz", schema=daily_schema, header=False, mode="PERMISSIVE")
core = ["PRCP","SNOW","SNWD","TMAX","TMIN"]
core_counts = (daily_all.where(F.col("ELEMENT").isin(core))
               .groupBy("ELEMENT").agg(F.count("*").alias("obs_count"))
               .orderBy(F.desc("obs_count")))

print("Core element counts:")
core_counts.show(truncate=False)

out_core = f"{WASBS_USER_BASE}/core_element_counts.parquet/"
print("Writing:", out_core)
core_counts.write.mode("overwrite").parquet(out_core)

In [None]:
# Q3(b): TMAX observations without corresponding TMIN
tmax = daily_all.filter(F.col("ELEMENT")=="TMAX").select(F.col("ID").alias("id"), F.col("DATE").alias("date"))
tmin = daily_all.filter(F.col("ELEMENT")=="TMIN").select(F.col("ID").alias("id"), F.col("DATE").alias("date"))

tmax_only = tmax.join(tmin, on=["id","date"], how="left_anti")
tmax_only_count = tmax_only.count()
tmax_only_stations = tmax_only.select("id").distinct().count()

print("TMAX without TMIN observations:", tmax_only_count)
print("Distinct stations contributing:", tmax_only_stations)

In [None]:
# Visualisation 1: TMIN and TMAX subplots for NZ stations + national average
import pandas as pd
import matplotlib.pyplot as plt

nz_ids = enriched_stations.filter(F.col("country_code")=="NZ").select("id").limit(12)
nz_ids_list = [r["id"] for r in nz_ids.collect()]

nz_daily = (daily_all.filter(F.col("ID").isin(nz_ids_list) & F.col("ELEMENT").isin(["TMIN","TMAX"]))
            .select("ID","DATE","ELEMENT","VALUE"))

pdf = nz_daily.toPandas()
pdf["date"] = pd.to_datetime(pdf["DATE"], format="%Y%m%d", errors="coerce")
pdf = pdf.dropna(subset=["date"])
pdf["value_c"] = pdf["VALUE"] / 10.0

stations = sorted(pdf["ID"].unique())
n = len(stations)
cols = 3
rows = int((n + cols - 1) / cols)

fig, axes = plt.subplots(rows, cols, figsize=(16, 4*rows), squeeze=False)
for i, sid in enumerate(stations):
    ax = axes[i//cols][i%cols]
    sub = pdf[pdf["ID"]==sid]
    for elt in ["TMIN","TMAX"]:
        ss = sub[sub["ELEMENT"]==elt].sort_values("date")
        ax.plot(ss["date"], ss["value_c"], label=elt)
    ax.set_title(sid)
    ax.set_xlabel("Date")
    ax.set_ylabel("°C")
    ax.legend(loc="best")

plt.tight_layout()
out_png1 = "/mnt/data/dew59_nz_tmin_tmax_subplots.png"
plt.savefig(out_png1, dpi=160)
print("Saved figure:", out_png1)

pdf["week"] = pdf["date"].dt.to_period("W").apply(lambda r: r.start_time)
avg_nat = (pdf.groupby(["week","ELEMENT"])["value_c"].mean().reset_index())

plt.figure(figsize=(10,5))
for elt in ["TMIN","TMAX"]:
    s = avg_nat[avg_nat["ELEMENT"]==elt].sort_values("week")
    plt.plot(s["week"], s["value_c"], label=elt)
plt.legend()
plt.xlabel("Week")
plt.ylabel("°C")
plt.title("NZ average weekly TMIN/TMAX")

out_png2 = "/mnt/data/dew59_nz_tmin_tmax_country.png"
plt.tight_layout()
plt.savefig(out_png2, dpi=160)
print("Saved figure:", out_png2)

In [None]:
# Visualisation 2: Choropleth for 2024 average daily rainfall by country
import pandas as pd
import matplotlib.pyplot as plt

prcp_path = f"{WASBS_USER_BASE}/q2a_prcp_year_country.parquet/"
print("Loading:", prcp_path)
prcp = spark.read.parquet(prcp_path)

prcp_2024 = prcp.filter(F.col("year")==2024).toPandas()
countries = enriched_stations.select("country_code","country_name").distinct().toPandas()

df = prcp_2024.merge(countries, on="country_code", how="left")

import geopandas as gpd
world = gpd.read_file(gpd.datasets.get_path("naturalearth_lowres"))

df["country_name_l"] = df["country_name"].str.lower()
world["name_l"] = world["name"].str.lower()

gdf = world.merge(df, left_on="name_l", right_on="country_name_l", how="left")

ax = gdf.plot(column="avg_prcp_mm", legend=True, figsize=(14,7), missing_kwds={"color":"lightgrey"})
ax.set_title("Average daily rainfall (mm) by country, 2024")
ax.set_axis_off()

out_png3 = "/mnt/data/dew59_2024_rainfall_choropleth.png"
plt.tight_layout()
plt.savefig(out_png3, dpi=160)
print("Saved figure:", out_png3)