# **Gold Layer**

The Gold layer produces high-quality analytical datasets, climate normals, anomaly metrics, regional summaries, and machine-learning outputs derived from the cleaned Silver daily data.

**Key steps:**
- Join Silver daily data with station metadata and restrict to Norway only.
- Aggregate to:
  - **Station–monthly** climate metrics (mean TMAX/TMIN/TAVG, total PRCP, wet days, completeness flags).
  - **Station–yearly** summaries (annual means, totals, and completeness per station).
- Compute **climate normals** (using 2010–2020) per station and month, then derive:
  - **Monthly anomalies** (temperature anomaly, precipitation ratio).
  - **Yearly anomalies** per station.
- Build **regional aggregates** for Norway:
  - Monthly and yearly mean anomalies and precipitation ratios.
- Train and save ML models:
  - Linear regression for regional temperature anomaly trend + simple forecast (2026–2030).
  - Random forest for yearly station rainfall prediction.
  - GBT regressor for yearly station temperature anomaly.
  - K-means clustering to define Norwegian climate zones.
  - Logistic regression to classify “heatwave years” based on summer anomalies.

## 01. Config

In [30]:
# GOLD: GHCN-Daily – aggregates, normals and anomalies

from pyspark.sql import SparkSession, functions as F, types as T

spark = (SparkSession.builder
         .appName("GHCN-Gold")
         .getOrCreate())
spark.sparkContext.setLogLevel("WARN")

# Make shuffle partitions smaller
spark.conf.set("spark.sql.shuffle.partitions", "400")

print("Spark:", spark.version)

# ---- Paths ----
SILVER_PATH   = "/home/ubuntu/spark-notebooks/project/data/silver"
META_DIR      = "/home/ubuntu/spark-notebooks/project/data/silver_meta"
GOLD_DIR      = "/home/ubuntu/spark-notebooks/project/data/gold"

STATIONS_PQ   = f"{META_DIR}/stations.parquet"
INVENTORY_PQ  = f"{META_DIR}/inventory.parquet"
COVERAGE_PQ   = f"{META_DIR}/coverage.parquet"

# Gold outputs
OUT_STN_MONTHLY = f"{GOLD_DIR}/station_monthly"
OUT_STN_YEARLY  = f"{GOLD_DIR}/station_yearly"
OUT_NORM_9120   = f"{GOLD_DIR}/normals_1991_2020"
OUT_ANOM_MONTH  = f"{GOLD_DIR}/anomalies_monthly"
OUT_ANOM_YEAR   = f"{GOLD_DIR}/anomalies_yearly"
OUT_REG_MONTH   = f"{GOLD_DIR}/region_monthly"
OUT_REG_YEAR    = f"{GOLD_DIR}/region_yearly"

# ML outputs
OUT_LR_REGION_FORECAST = f"{GOLD_DIR}/ml_lr_region_forecast"
OUT_RF_PRCPT           = f"{GOLD_DIR}/ml_rf_station_prcp"
OUT_GBT_TANOM          = f"{GOLD_DIR}/ml_gbt_station_tanom"
OUT_KMEANS_CLUSTERS    = f"{GOLD_DIR}/ml_kmeans_clusters"
OUT_LOGR_HEATWAVE      = f"{GOLD_DIR}/ml_logr_heatwave"

YEAR_MIN = 2010
YEAR_MAX = 2025
NORMALS_START = 2010
NORMALS_END   = 2020

DAYS_PER_MONTH_TEMP_MIN = 20
DAYS_PER_MONTH_PRCP_MIN = 20
MONTHS_PER_YEAR_MIN     = 10

Spark: 3.5.0


## 02. Load Silver + meta and build silver_enriched

In [None]:
# --- Load Silver daily facts ---
silver = (spark.read.parquet(SILVER_PATH)
          .where((F.col("year") >= YEAR_MIN) & (F.col("year") <= YEAR_MAX)))

# --- Load metadata from Silver-meta ---
stations  = spark.read.parquet(STATIONS_PQ)
coverage  = spark.read.parquet(COVERAGE_PQ)
inventory = spark.read.parquet(INVENTORY_PQ)

print("Stations:", stations.count())
print("Coverage rows:", coverage.count())

country_expr = F.when(F.col("id").isNotNull(), F.substring(F.col("id"), 1, 2)).otherwise(F.lit(None))

silver_enriched = (silver
    .withColumn(
        "tavg_c_fallback",
        F.when(
            F.col("tavg_c").isNull()
            & F.col("tmax_c").isNotNull()
            & F.col("tmin_c").isNotNull(),
            (F.col("tmax_c") + F.col("tmin_c")) / 2.0,
        ).otherwise(F.col("tavg_c"))
    )
    .join(stations.select("id", "lat", "lon", "elev", "state", "name"), "id", "left")
    .withColumn("country", country_expr)
    .withColumn("has_tmax", F.when(F.col("tmax_c").isNotNull(), 1).otherwise(0))
    .withColumn("has_tmin", F.when(F.col("tmin_c").isNotNull(), 1).otherwise(0))
    .withColumn("has_tavg", F.when(F.col("tavg_c_fallback").isNotNull(), 1).otherwise(0))
    .withColumn("has_prcp", F.when(F.col("prcp_mm").isNotNull(), 1).otherwise(0))
)

# Limit to a region (adjust as needed)
TARGET_COUNTRIES = ["NO"]  # Nordic ["NO", "SE", "DK", "FI", "IS"]

silver_enriched = silver_enriched.filter(F.col("country").isin(TARGET_COUNTRIES))

print("Rows after region filter:", silver_enriched.count())

silver_enriched.select("id", "country", "date", "tmax_c", "tmin_c",
                       "tavg_c", "tavg_c_fallback", "prcp_mm").show(5, truncate=False)
print("silver_enriched is ready for monthly/annual aggregations")

Stations: 129649
Coverage rows: 129461


                                                                                

Rows after region filter: 1787937
+-----------+-------+----------+------+------+------+---------------+-------+
|id         |country|date      |tmax_c|tmin_c|tavg_c|tavg_c_fallback|prcp_mm|
+-----------+-------+----------+------+------+------+---------------+-------+
|NOE00134382|NO     |2011-10-23|NULL  |NULL  |NULL  |NULL           |0.0    |
|NOE00134694|NO     |2011-02-24|NULL  |NULL  |NULL  |NULL           |0.0    |
|NOE00110635|NO     |2011-03-28|NULL  |NULL  |NULL  |NULL           |9.8    |
|NOE00133566|NO     |2011-04-15|10.0  |7.0   |8.5   |8.5            |4.5    |
|NOE00110680|NO     |2011-06-01|NULL  |NULL  |NULL  |NULL           |30.5   |
+-----------+-------+----------+------+------+------+---------------+-------+
only showing top 5 rows

silver_enriched is ready for monthly/annual aggregations


## 03. Station-level monthly aggregates + completeness flags

In [7]:
# --- Station-level monthly aggregates ---

monthly = (silver_enriched
    .groupBy("id", "year", "month")
    .agg(
        F.first("country", ignorenulls=True).alias("country"),
        F.first("state", ignorenulls=True).alias("state"),
        F.first("name", ignorenulls=True).alias("name"),
        F.first("lat", ignorenulls=True).alias("lat"),
        F.first("lon", ignorenulls=True).alias("lon"),
        F.first("elev", ignorenulls=True).alias("elev"),

        F.sum("has_tmax").alias("days_tmax_obs"),
        F.sum("has_tmin").alias("days_tmin_obs"),
        F.sum("has_tavg").alias("days_tavg_obs"),
        F.sum("has_prcp").alias("days_prcp_obs"),

        F.avg("tmax_c").alias("tmax_mean_c"),
        F.avg("tmin_c").alias("tmin_mean_c"),
        F.avg("tavg_c_fallback").alias("tavg_mean_c"),

        F.sum(F.coalesce(F.col("prcp_mm"), F.lit(0.0))).alias("prcp_total_mm"),
        F.sum(F.when(F.col("prcp_mm") > 0, 1).otherwise(0)).alias("wet_days")
    )
    .withColumn(
        "is_complete_temp",
        (F.col("days_tmax_obs") >= DAYS_PER_MONTH_TEMP_MIN) &
        (F.col("days_tmin_obs") >= DAYS_PER_MONTH_TEMP_MIN)
    )
    .withColumn(
        "is_complete_prcp",
        (F.col("days_prcp_obs") >= DAYS_PER_MONTH_PRCP_MIN)
    )
)

(monthly
    .write
    .mode("overwrite")
    .parquet(OUT_STN_MONTHLY)
)

print("Wrote station monthly to:", OUT_STN_MONTHLY)

25/11/13 22:38:43 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
25/11/13 22:38:59 WARN MemoryManager: Total allocation exceeds 95.00% (906,992,014 bytes) of heap memory
Scaling row group sizes to 96.54% for 7 writers
25/11/13 22:38:59 WARN MemoryManager: Total allocation exceeds 95.00% (906,992,014 bytes) of heap memory
Scaling row group sizes to 84.47% for 8 writers


Wrote station monthly to: /home/ubuntu/spark-notebooks/project/data/gold/station_monthly


25/11/13 22:38:59 WARN MemoryManager: Total allocation exceeds 95.00% (906,992,014 bytes) of heap memory
Scaling row group sizes to 96.54% for 7 writers
                                                                                

## 04. Station-level yearly aggregates (built from monthly)

In [8]:
# Read the monthly station data we just wrote
monthly = spark.read.parquet(OUT_STN_MONTHLY)

print("Monthly rows:", monthly.count())

station_yearly = (monthly
    .groupBy("id", "year")
    .agg(
        # Keep meta (first non-null)
        F.first("country", ignorenulls=True).alias("country"),
        F.first("state", ignorenulls=True).alias("state"),
        F.first("name", ignorenulls=True).alias("name"),
        F.first("lat", ignorenulls=True).alias("lat"),
        F.first("lon", ignorenulls=True).alias("lon"),
        F.first("elev", ignorenulls=True).alias("elev"),

        # Count how many months are "complete" for temp / prcp
        F.sum(F.when(F.col("is_complete_temp"), 1).otherwise(0)).alias("n_complete_temp_months"),
        F.sum(F.when(F.col("is_complete_prcp"), 1).otherwise(0)).alias("n_complete_prcp_months"),

        # Annual averages / totals based on the *monthly* metrics
        F.avg("tmax_mean_c").alias("year_tmax_mean_c"),
        F.avg("tmin_mean_c").alias("year_tmin_mean_c"),
        F.avg("tavg_mean_c").alias("year_tavg_mean_c"),

        F.sum("prcp_total_mm").alias("year_prcp_total_mm"),
        F.sum("wet_days").alias("year_wet_days")
    )
    # completeness filter: keep years with enough good months
    .withColumn(
        "is_complete_year_temp",
        F.col("n_complete_temp_months") >= MONTHS_PER_YEAR_MIN
    )
    .withColumn(
        "is_complete_year_prcp",
        F.col("n_complete_prcp_months") >= MONTHS_PER_YEAR_MIN
    )
)

station_yearly.show(5, truncate=False)

(station_yearly
    .write
    .mode("overwrite")
    .parquet(OUT_STN_YEARLY)
)

print("Wrote station yearly to:", OUT_STN_YEARLY)

Monthly rows: 59100


                                                                                

+-----------+----+-------+-----+------------+-------+-------+-----+----------------------+----------------------+------------------+--------------------+------------------+------------------+-------------+---------------------+---------------------+
|id         |year|country|state|name        |lat    |lon    |elev |n_complete_temp_months|n_complete_prcp_months|year_tmax_mean_c  |year_tmin_mean_c    |year_tavg_mean_c  |year_prcp_total_mm|year_wet_days|is_complete_year_temp|is_complete_year_prcp|
+-----------+----+-------+-----+------------+-------+-------+-----+----------------------+----------------------+------------------+--------------------+------------------+------------------+-------------+---------------------+---------------------+
|NO000001026|2010|NO     |NULL |TROMSO      |69.6539|18.9281|100.0|12                    |12                    |5.450768049155147 |-0.32379544290834583|2.5634863031233994|1139.5            |215          |true                 |true                 |


## 05. Climate Normals (2010–2025) per station & month

In [9]:
normals = (monthly
    .where(
        (F.col("year") >= NORMALS_START) &
        (F.col("year") <= NORMALS_END) &
        (F.col("is_complete_temp") | F.col("is_complete_prcp"))
    )
    .groupBy("id", "month")
    .agg(
        F.first("country", ignorenulls=True).alias("country"),
        F.first("state", ignorenulls=True).alias("state"),
        F.first("name", ignorenulls=True).alias("name"),
        F.first("lat", ignorenulls=True).alias("lat"),
        F.first("lon", ignorenulls=True).alias("lon"),
        F.first("elev", ignorenulls=True).alias("elev"),

        # "normal" monthly mean temps
        F.avg("tmax_mean_c").alias("normal_tmax_c"),
        F.avg("tmin_mean_c").alias("normal_tmin_c"),
        F.avg("tavg_mean_c").alias("normal_tavg_c"),

        # "normal" monthly precipitation
        F.avg("prcp_total_mm").alias("normal_prcp_total_mm")
    )
)

normals.show(5, truncate=False)

(normals
    .write
    .mode("overwrite")
    .parquet(OUT_NORM_9120)
)

print("Wrote normals to:", OUT_NORM_9120)

+-----------+-----+-------+-----+------------+-------+-------+-----+------------------+-------------------+------------------+--------------------+
|id         |month|country|state|name        |lat    |lon    |elev |normal_tmax_c     |normal_tmin_c      |normal_tavg_c     |normal_prcp_total_mm|
+-----------+-----+-------+-----+------------+-------+-------+-----+------------------+-------------------+------------------+--------------------+
|NO000001026|9    |NO     |NULL |TROMSO      |69.6539|18.9281|100.0|12.12151515151515 |5.927272727272728  |9.02439393939394  |83.94545454545454   |
|NO000001026|10   |NO     |NULL |TROMSO      |69.6539|18.9281|100.0|5.890615835777126 |1.492375366568915  |3.691495601173021 |111.36363636363639  |
|NO000001026|11   |NO     |NULL |TROMSO      |69.6539|18.9281|100.0|2.703030303030303 |-1.3684848484848484|0.6672727272727272|101.55454545454545  |
|NO000001465|2    |NO     |NULL |TORUNGEN FYR|58.3831|8.7917 |12.0 |2.9885356023287057|-0.8603560232870577|1.064

## 06. Station monthly anomalies (T & PRCP)

In [10]:
# Reload
monthly = spark.read.parquet(OUT_STN_MONTHLY)
normals = spark.read.parquet(OUT_NORM_9120)

# Join normals onto monthly by (id, month)
monthly_with_norms = (monthly.alias("m")
    .join(
        normals.select(
            "id", "month",
            "normal_tmax_c", "normal_tmin_c", "normal_tavg_c",
            "normal_prcp_total_mm"
        ).alias("n"),
        on=["id", "month"],
        how="left"
    )
)

anomalies_monthly = (monthly_with_norms
    .withColumn("tavg_anom_c", F.col("tavg_mean_c") - F.col("normal_tavg_c"))
    .withColumn("tmax_anom_c", F.col("tmax_mean_c") - F.col("normal_tmax_c"))
    .withColumn("tmin_anom_c", F.col("tmin_mean_c") - F.col("normal_tmin_c"))
    .withColumn("prcp_ratio",
                F.when(F.col("normal_prcp_total_mm") > 0,
                       F.col("prcp_total_mm") / F.col("normal_prcp_total_mm"))
                 .otherwise(None))
)

anomalies_monthly.select(
    "id", "year", "month",
    "tavg_mean_c", "normal_tavg_c", "tavg_anom_c",
    "prcp_total_mm", "normal_prcp_total_mm", "prcp_ratio"
).show(5, truncate=False)

(anomalies_monthly
    .write
    .mode("overwrite")
    .parquet(OUT_ANOM_MONTH)
)

print("Wrote monthly anomalies to:", OUT_ANOM_MONTH)

+-----------+----+-----+--------------------+-------------------+-------------------+-----------------+--------------------+-------------------+
|id         |year|month|tavg_mean_c         |normal_tavg_c      |tavg_anom_c        |prcp_total_mm    |normal_prcp_total_mm|prcp_ratio         |
+-----------+----+-----+--------------------+-------------------+-------------------+-----------------+--------------------+-------------------+
|NO000001026|2010|2    |-7.132142857142857  |-3.170101880877743 |-3.962040976265114 |71.60000000000001|88.64545454545454   |0.8077120295354324 |
|NO000001026|2012|3    |-0.07903225806451611|-1.7576246334310852|1.678592375366569  |164.2            |130.48181818181817  |1.2584128753570683 |
|NO000001026|2012|12   |-3.9935483870967743 |-1.2998533724340176|-2.6936950146627567|10.2             |110.2909090909091   |0.09248269040553905|
|NO000001026|2013|3    |-4.261290322580645  |-1.7576246334310852|-2.5036656891495594|111.9            |130.48181818181817  |0.8575

25/11/13 22:46:06 WARN MemoryManager: Total allocation exceeds 95.00% (906,992,014 bytes) of heap memory
Scaling row group sizes to 96.54% for 7 writers
25/11/13 22:46:06 WARN MemoryManager: Total allocation exceeds 95.00% (906,992,014 bytes) of heap memory
Scaling row group sizes to 84.47% for 8 writers
25/11/13 22:46:06 WARN MemoryManager: Total allocation exceeds 95.00% (906,992,014 bytes) of heap memory
Scaling row group sizes to 96.54% for 7 writers

Wrote monthly anomalies to: /home/ubuntu/spark-notebooks/project/data/gold/anomalies_monthly


                                                                                

## 07. Station yearly anomalies + regional aggregates

In [11]:
# Station-year anomalies: average of monthly anomalies
anom_month = spark.read.parquet(OUT_ANOM_MONTH)

anom_yearly = (anom_month
    .groupBy("id", "year")
    .agg(
        F.first("country", ignorenulls=True).alias("country"),
        F.first("state", ignorenulls=True).alias("state"),
        F.first("name", ignorenulls=True).alias("name"),
        F.first("lat", ignorenulls=True).alias("lat"),
        F.first("lon", ignorenulls=True).alias("lon"),
        F.first("elev", ignorenulls=True).alias("elev"),

        F.avg("tavg_anom_c").alias("year_tavg_anom_c"),
        F.avg("tmax_anom_c").alias("year_tmax_anom_c"),
        F.avg("tmin_anom_c").alias("year_tmin_anom_c"),
        F.avg("prcp_ratio").alias("year_prcp_ratio_mean")
    )
)

anom_yearly.show(5, truncate=False)

(anom_yearly
    .write
    .mode("overwrite")
    .parquet(OUT_ANOM_YEAR)
)

print("Wrote station yearly anomalies to:", OUT_ANOM_YEAR)

+-----------+----+-------+-----+------+-------+-------+-----+--------------------+-------------------+--------------------+--------------------+
|id         |year|country|state|name  |lat    |lon    |elev |year_tavg_anom_c    |year_tmax_anom_c   |year_tmin_anom_c    |year_prcp_ratio_mean|
+-----------+----+-------+-----+------+-------+-------+-----+--------------------+-------------------+--------------------+--------------------+
|NO000001026|2010|NO     |NULL |TROMSO|69.6539|18.9281|100.0|-1.2308624893058988 |-1.1571914621344546|-1.3045335164773426 |1.0781935017499642  |
|NO000001026|2014|NO     |NULL |TROMSO|69.6539|18.9281|100.0|0.08963194233772642 |0.02260628492135693|0.15665759975409577 |0.8585165502621788  |
|NO000001026|2017|NO     |NULL |TROMSO|69.6539|18.9281|100.0|-0.18797590456447563|-0.1784581031994826|-0.19749370592946827|0.9902688574969831  |
|NO000001026|2022|NO     |NULL |TROMSO|69.6539|18.9281|100.0|0.2873661973300456  |0.2757719275224834 |0.2989604671376083  |1.25101

## 08. Regional (country-level) monthly aggregates

In [13]:
region_monthly = (anom_month
    .groupBy("country", "year", "month")
    .agg(
        F.countDistinct("id").alias("n_stations"),
        F.avg("tavg_anom_c").alias("region_tavg_anom_c"),
        F.avg("prcp_ratio").alias("region_prcp_ratio_mean")
    )
    .orderBy("country", "year", "month")
)

region_monthly.show(20, truncate=False)

(region_monthly
    .write
    .mode("overwrite")
    .parquet(OUT_REG_MONTH)
)

print("Wrote regional monthly aggregates to:", OUT_REG_MONTH)

+-------+----+-----+----------+--------------------+----------------------+
|country|year|month|n_stations|region_tavg_anom_c  |region_prcp_ratio_mean|
+-------+----+-----+----------+--------------------+----------------------+
|NO     |2010|1    |348       |-3.3252784201821326 |0.4406968178025895    |
|NO     |2010|2    |348       |-4.336489114490003  |0.518110927793653     |
|NO     |2010|3    |348       |-1.5885612159021112 |1.0557093624208183    |
|NO     |2010|4    |346       |-0.16914578541299163|0.8744069668105932    |
|NO     |2010|5    |346       |-0.6597762316217446 |0.6861378098453258    |
|NO     |2010|6    |345       |-1.1344352556890005 |0.9184086631009946    |
|NO     |2010|7    |344       |0.010082198100524757|1.1322376652448332    |
|NO     |2010|8    |344       |-0.3503631406892781 |0.9119924055043602    |
|NO     |2010|9    |345       |-0.9333221877887713 |0.7921106875852602    |
|NO     |2010|10   |345       |-0.539972512021546  |1.0281952184033654    |
|NO     |201

## 09. Regional (country-level) annual aggregates

In [14]:
# Regional = aggregate across stations, by country & year
region_yearly = (anom_yearly
    .groupBy("country", "year")
    .agg(
        F.countDistinct("id").alias("n_stations"),
        F.avg("year_tavg_anom_c").alias("region_tavg_anom_c"),
        F.avg("year_tmax_anom_c").alias("region_tmax_anom_c"),
        F.avg("year_tmin_anom_c").alias("region_tmin_anom_c"),
        F.avg("year_prcp_ratio_mean").alias("region_prcp_ratio_mean")
    )
    .orderBy("country", "year")
)

region_yearly.show(20, truncate=False)

(region_yearly
    .write
    .mode("overwrite")
    .parquet(OUT_REG_YEAR)
)

print("Wrote regional yearly aggregates to:", OUT_REG_YEAR)

+-------+----+----------+--------------------+--------------------+--------------------+----------------------+
|country|year|n_stations|region_tavg_anom_c  |region_tmax_anom_c  |region_tmin_anom_c  |region_prcp_ratio_mean|
+-------+----+----------+--------------------+--------------------+--------------------+----------------------+
|NO     |2010|351       |-1.984448717343891  |-1.8988769678497743 |-2.0665353612505957 |0.7764641719750219    |
|NO     |2011|362       |0.4323245502791186  |0.35742375872330484 |0.5102872481272246  |1.1090946366788985    |
|NO     |2012|354       |-0.7191595359834009 |-0.9111802208446663 |-0.5357508134640188 |0.9792685698777025    |
|NO     |2013|351       |-0.29206199412510403|-0.09506434188070689|-0.4983655262137265 |0.9218520926672795    |
|NO     |2014|339       |0.8207395237695982  |0.7715410810000193  |0.8671929175267532  |0.9942268925590289    |
|NO     |2015|333       |0.48755628094910386 |0.400006196346523   |0.5780451355410658  |1.08641707310576

## 10. Sanity checks & quick peeks

In [19]:
print("station_monthly sample")
spark.read.parquet(OUT_STN_MONTHLY).orderBy("year","month").show(10, truncate=False)

print("normals sample")
spark.read.parquet(OUT_NORM_9120).orderBy("id","month").show(10, truncate=False)

print("anomalies_yearly sample")
spark.read.parquet(OUT_ANOM_YEAR).orderBy("year").show(10, truncate=False)

station_monthly sample
+-----------+----+-----+-------+-----+-----------------------+-------+-------+-----+-------------+-------------+-------------+-------------+--------------------+-------------------+-------------------+------------------+--------+----------------+----------------+
|id         |year|month|country|state|name                   |lat    |lon    |elev |days_tmax_obs|days_tmin_obs|days_tavg_obs|days_prcp_obs|tmax_mean_c         |tmin_mean_c        |tavg_mean_c        |prcp_total_mm     |wet_days|is_complete_temp|is_complete_prcp|
+-----------+----+-----+-------+-----+-----------------------+-------+-------+-----+-------------+-------------+-------------+-------------+--------------------+-------------------+-------------------+------------------+--------+----------------+----------------+
|NOE00109903|2010|1    |NO     |NULL |BJORNHOLT              |60.0508|10.6864|360.0|31           |31           |31           |31           |-8.312903225806451  |-14.57741935483871 |-11.

## 11. ML Models

In [20]:
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import LinearRegression, RandomForestRegressor, GBTRegressor
from pyspark.ml.clustering import KMeans
from pyspark.ml.classification import LogisticRegression

#### 11.1 Linear Regression — predict temperature anomaly from year

In [32]:
# Linear Regression: region temp anomaly vs year (trend)

region_yearly = spark.read.parquet(OUT_REG_YEAR)

# Keep only rows with anomaly present
reg_lr_input = (region_yearly
    .where(region_yearly.region_tavg_anom_c.isNotNull())
    .withColumn("year_centered", F.col("year") - 2010)  # helps numerics a bit
)

# Features: just "year_centered" for a simple trend line
lr_assembler = VectorAssembler(
    inputCols=["year_centered"],
    outputCol="features"
)

reg_lr_data = lr_assembler.transform(reg_lr_input).select(
    "country", "year", "features", "region_tavg_anom_c"
).withColumnRenamed("region_tavg_anom_c", "label")

train_df, test_df = reg_lr_data.randomSplit([0.8, 0.2], seed=42)

lr = LinearRegression(featuresCol="features", labelCol="label")
lr_model = lr.fit(train_df)

print("Linear Regression coefficients:", lr_model.coefficients)
print("Intercept:", lr_model.intercept)
print("RMSE:", lr_model.summary.rootMeanSquaredError)
print("R²:", lr_model.summary.r2)

# Inspect some predictions on test set
pred_test = lr_model.transform(test_df)
pred_test.select("country", "year", "label", "prediction").show(10, truncate=False)

# OPTIONAL: forecast future years for NO
future_years = spark.createDataFrame(
    [( "NO", y ) for y in range(2026, 2031)],
    ["country", "year"]
).withColumn("year_centered", F.col("year") - 2010)

future_features = lr_assembler.transform(future_years)
future_pred = lr_model.transform(future_features)

print("Forecast region_tavg_anom_c for NO 2026–2030:")
future_pred.select("country", "year", "prediction").show()

# Save LR forecast for region anomalies (Norway)
(future_pred
    .select("country", "year", "prediction")
    .write
    .mode("overwrite")
    .parquet(OUT_LR_REGION_FORECAST)
)

print("Saved LR region forecast to:", OUT_LR_REGION_FORECAST)


25/11/14 09:45:58 WARN Instrumentation: [86f62532] regParam is zero, which might cause numerical instability and overfitting.


Linear Regression coefficients: [0.08860696139001684]
Intercept: -0.5073308939577166
RMSE: 0.6664193171032775
R²: 0.28974818961791715
+-------+----+-------------------+-------------------+
|country|year|label              |prediction         |
+-------+----+-------------------+-------------------+
|NO     |2012|-0.7191595359834009|-0.3301169711776829|
|NO     |2016|0.2935185081363368 |0.02431087438238444|
|NO     |2018|0.10659460260468707|0.2015247971624181 |
|NO     |2023|-0.3435268405073371|0.6445596041125023 |
+-------+----+-------------------+-------------------+

Forecast region_tavg_anom_c for NO 2026–2030:
+-------+----+------------------+
|country|year|        prediction|
+-------+----+------------------+
|     NO|2026|0.9103804882825528|
|     NO|2027|0.9989874496725697|
|     NO|2028|1.0875944110625866|
|     NO|2029|1.1762013724526033|
|     NO|2030|1.2648083338426201|
+-------+----+------------------+

Saved LR region forecast to: /home/ubuntu/spark-notebooks/project/data/g

#### 11.2 Random Forest — predict yearly rainfall (station-level)

In [33]:
# Random Forest Regressor: predict station yearly rainfall (FIXED)

stn_yearly = spark.read.parquet(OUT_STN_YEARLY)

rf_input = (stn_yearly
    .where(
        (F.col("year_prcp_total_mm").isNotNull()) &
        (F.col("is_complete_year_prcp") == True) &
        F.col("year").isNotNull() &
        F.col("lat").isNotNull() &
        F.col("lon").isNotNull() &
        F.col("elev").isNotNull() &
        F.col("year_tavg_mean_c").isNotNull() &
        F.col("year_tmax_mean_c").isNotNull() &
        F.col("year_tmin_mean_c").isNotNull()
    )
    .select(
        "id", "country", "year",
        "lat", "lon", "elev",
        "year_tavg_mean_c", "year_tmax_mean_c", "year_tmin_mean_c",
        "year_prcp_total_mm"
    )
)

rf_assembler = VectorAssembler(
    inputCols=[
        "year", "lat", "lon", "elev",
        "year_tavg_mean_c", "year_tmax_mean_c", "year_tmin_mean_c"
    ],
    outputCol="features",
    handleInvalid="skip"   # extra safety
)

rf_data = (rf_assembler
    .transform(rf_input)
    .select("id", "country", "year", "features", "year_prcp_total_mm")
    .withColumnRenamed("year_prcp_total_mm", "label")
)

print("RF training rows:", rf_data.count())

rf_train, rf_test = rf_data.randomSplit([0.8, 0.2], seed=42)

rf = RandomForestRegressor(featuresCol="features", labelCol="label", numTrees=50)
rf_model = rf.fit(rf_train)

rf_pred = rf_model.transform(rf_test)
rf_pred.select(
    "id", "country", "year", "label", "prediction"
).show(10, truncate=False)

# Save RF station rainfall predictions
(rf_pred
    .select("id", "country", "year", "label", "prediction")
    .write
    .mode("overwrite")
    .parquet(OUT_RF_PRCPT)
)

print("Saved RF station rainfall predictions to:", OUT_RF_PRCPT)



RF training rows: 1204
+-----------+-------+----+------------------+------------------+
|id         |country|year|label             |prediction        |
+-----------+-------+----+------------------+------------------+
|NO000001026|NO     |2013|1220.4            |1056.3305599362527|
|NO000001465|NO     |2010|597.0             |1191.3651437573544|
|NO000001465|NO     |2015|1080.1999999999998|1222.2440186911263|
|NO000005350|NO     |2018|607.2             |821.5371605254228 |
|NO000014030|NO     |2016|931.0             |1332.7370499693113|
|NO000050540|NO     |2011|2680.9            |2703.173544302458 |
|NO000098550|NO     |2012|559.2             |648.0629969443205 |
|NO000099710|NO     |2015|484.50000000000006|488.66293450038205|
|NOE00105467|NO     |2015|1313.0            |1152.1706160166523|
|NOE00105467|NO     |2018|1132.0            |1193.7391077212196|
+-----------+-------+----+------------------+------------------+
only showing top 10 rows

Saved RF station rainfall predictions to:

#### 11.3 GBT Regressor — predict yearly temperature anomaly (station-level)

In [34]:
# GBT Regressor: predict station yearly temp anomaly (FIXED)

anom_yearly = spark.read.parquet(OUT_ANOM_YEAR)

gbt_input = (anom_yearly
    .where(F.col("year_tavg_anom_c").isNotNull())
    .select(
        "id", "country", "year",
        "lat", "lon", "elev",
        "year_tavg_anom_c"
    )
)

gbt_assembler = VectorAssembler(
    inputCols=["year", "lat", "lon", "elev"],
    outputCol="features"
)

gbt_data = (gbt_assembler
    .transform(gbt_input)
    .select("id", "country", "year", "features", "year_tavg_anom_c")
    .withColumnRenamed("year_tavg_anom_c", "label")
)

gbt_train, gbt_test = gbt_data.randomSplit([0.8, 0.2], seed=42)

gbt = GBTRegressor(featuresCol="features", labelCol="label", maxDepth=5, maxIter=30)
gbt_model = gbt.fit(gbt_train)

gbt_pred = gbt_model.transform(gbt_test)
gbt_pred.select(
    "id", "country", "year", "label", "prediction"
).show(10, truncate=False)

# Save GBT station temp anomaly predictions
(gbt_pred
    .select("id", "country", "year", "label", "prediction")
    .write
    .mode("overwrite")
    .parquet(OUT_GBT_TANOM)
)

print("Saved GBT station temp anomaly predictions to:", OUT_GBT_TANOM)



+-----------+-------+----+-------------------+--------------------+
|id         |country|year|label              |prediction          |
+-----------+-------+----+-------------------+--------------------+
|NO000001026|NO     |2013|0.377187305880993  |0.19196534633858917 |
|NO000001026|NO     |2019|-0.7127109276059503|-0.7073096274145144 |
|NO000001026|NO     |2021|-0.4822257765562829|-0.38402873398523746|
|NO000001465|NO     |2012|-0.6900538667703565|-0.8077697685572962 |
|NO000001465|NO     |2019|0.2194459632877325 |0.23828567448715204 |
|NO000001465|NO     |2024|0.4511205863256758 |0.5697838129581089  |
|NO000005350|NO     |2019|-0.5047142679065196|-0.15577375767259502|
|NO000014030|NO     |2014|1.0850482080880008 |1.0993079101416579  |
|NO000050540|NO     |2010|-2.136346486676715 |-2.2029237543474567 |
|NO000050540|NO     |2011|0.24627544112666422|0.07945347682615679 |
+-----------+-------+----+-------------------+--------------------+
only showing top 10 rows

Saved GBT station temp

#### 11.4 K-Means — cluster stations into climate zones

In [35]:
# K-Means: cluster stations into climate zones

normals = spark.read.parquet(OUT_NORM_9120)

stn_normals = (normals
    .groupBy("id")
    .agg(
        F.first("country", ignorenulls=True).alias("country"),
        F.first("lat", ignorenulls=True).alias("lat"),
        F.first("lon", ignorenulls=True).alias("lon"),
        F.first("elev", ignorenulls=True).alias("elev"),
        F.avg("normal_tavg_c").alias("annual_normal_tavg_c"),
        F.avg("normal_prcp_total_mm").alias("annual_normal_prcp_mm")
    )
    .where(
        F.col("lat").isNotNull() &
        F.col("lon").isNotNull() &
        F.col("annual_normal_tavg_c").isNotNull() &
        F.col("annual_normal_prcp_mm").isNotNull()
    )
)

# Optionally focus on Norway only
stn_normals_no = stn_normals.where(F.col("country") == "NO")

kmeans_assembler = VectorAssembler(
    inputCols=["lat", "lon", "annual_normal_tavg_c", "annual_normal_prcp_mm"],
    outputCol="features"
)

kmeans_data = kmeans_assembler.transform(stn_normals_no)

k = 4  # number of climate zones, adjust as you like
kmeans = KMeans(k=k, seed=42, featuresCol="features", predictionCol="cluster")
k_model = kmeans.fit(kmeans_data)

k_result = k_model.transform(kmeans_data)

k_result.select(
    "id", "country", "lat", "lon", "annual_normal_tavg_c",
    "annual_normal_prcp_mm", "cluster"
).show(20, truncate=False)

# Save K-Means climate clusters
(k_result
    .select("id", "country", "lat", "lon",
            "annual_normal_tavg_c", "annual_normal_prcp_mm", "cluster")
    .write
    .mode("overwrite")
    .parquet(OUT_KMEANS_CLUSTERS)
)

print("Saved K-Means clusters to:", OUT_KMEANS_CLUSTERS)



+-----------+-------+-------+-------+--------------------+---------------------+-------+
|id         |country|lat    |lon    |annual_normal_tavg_c|annual_normal_prcp_mm|cluster|
+-----------+-------+-------+-------+--------------------+---------------------+-------+
|NO000001026|NO     |69.6539|18.9281|3.794348792429299   |89.06818181818181    |3      |
|NO000001465|NO     |58.3831|8.7917 |8.668056276855637   |80.75757575757575    |1      |
|NO000005350|NO     |60.3883|11.5603|6.165564882345844   |69.5590909090909     |1      |
|NO000014030|NO     |59.3   |4.883  |8.46171597984395    |87.08465488215488    |3      |
|NO000050540|NO     |60.3831|5.3331 |8.853986335627047   |214.28863636363636   |2      |
|NO000080700|NO     |66.8167|13.9831|6.086739570614571   |0.0                  |0      |
|NO000098550|NO     |70.367 |31.1   |3.0012944721858053  |47.48666666666667    |1      |
|NO000099710|NO     |74.5167|19.0167|0.30937277089181997 |38.326515151515146   |1      |
|NOE00100574|NO     |

#### 11.5 Logistic Regression — “heatwave year” classification

In [36]:
# Logistic Regression: classify "heatwave years" (FIXED)

anom_month = spark.read.parquet(OUT_ANOM_MONTH)

# Step 1: compute summer (JJA) mean anomaly per station-year
summer_anom = (anom_month
    .where(F.col("month").isin(6, 7, 8))  # June, July, August
    .groupBy("id", "year")
    .agg(
        F.first("country", ignorenulls=True).alias("country"),
        F.first("lat", ignorenulls=True).alias("lat"),
        F.first("lon", ignorenulls=True).alias("lon"),
        F.first("elev", ignorenulls=True).alias("elev"),
        F.avg("tavg_anom_c").alias("summer_tavg_anom_c"),
        F.avg("prcp_ratio").alias("summer_prcp_ratio")
    )
    .where(F.col("summer_tavg_anom_c").isNotNull())
)

# Step 2: create label: 1 if summer anomaly >= 2°C, else 0
heatwave_df = (summer_anom
    .withColumn(
        "label",
        F.when(F.col("summer_tavg_anom_c") >= 2.0, 1.0).otherwise(0.0)
    )
    # Drop rows where any feature is null
    .where(
        F.col("summer_prcp_ratio").isNotNull() &
        F.col("lat").isNotNull() &
        F.col("lon").isNotNull()
    )
)

# Assemble features
lr_cls_assembler = VectorAssembler(
    inputCols=["summer_tavg_anom_c", "summer_prcp_ratio", "lat", "lon"],
    outputCol="features",
    handleInvalid="skip"   # extra protection, but we already filtered
)

lr_cls_data = lr_cls_assembler.transform(heatwave_df).select(
    "id", "country", "year", "label", "features"
)

train_cls, test_cls = lr_cls_data.randomSplit([0.8, 0.2], seed=42)

logreg = LogisticRegression(featuresCol="features", labelCol="label")
logreg_model = logreg.fit(train_cls)

pred_cls = logreg_model.transform(test_cls)

pred_cls.select(
    "id", "country", "year", "label", "probability", "prediction"
).show(20, truncate=False)

# Save Logistic Regression heatwave predictions
(pred_cls
    .select("id", "country", "year", "label", "probability", "prediction")
    .write
    .mode("overwrite")
    .parquet(OUT_LOGR_HEATWAVE)
)

print("Saved heatwave classification to:", OUT_LOGR_HEATWAVE)


+-----------+-------+----+-----+-----------+----------+
|id         |country|year|label|probability|prediction|
+-----------+-------+----+-----+-----------+----------+
|NO000001026|NO     |2012|0.0  |[1.0,0.0]  |0.0       |
|NO000001026|NO     |2016|0.0  |[1.0,0.0]  |0.0       |
|NO000001026|NO     |2018|0.0  |[1.0,0.0]  |0.0       |
|NO000001026|NO     |2023|0.0  |[1.0,0.0]  |0.0       |
|NO000001465|NO     |2013|0.0  |[1.0,0.0]  |0.0       |
|NO000001465|NO     |2017|0.0  |[1.0,0.0]  |0.0       |
|NO000001465|NO     |2023|0.0  |[1.0,0.0]  |0.0       |
|NO000005350|NO     |2018|0.0  |[1.0,0.0]  |0.0       |
|NO000014030|NO     |2012|0.0  |[1.0,0.0]  |0.0       |
|NO000014030|NO     |2013|0.0  |[1.0,0.0]  |0.0       |
|NO000014030|NO     |2014|0.0  |[1.0,0.0]  |0.0       |
|NO000014030|NO     |2016|0.0  |[1.0,0.0]  |0.0       |
|NO000014030|NO     |2018|0.0  |[1.0,0.0]  |0.0       |
|NO000014030|NO     |2022|0.0  |[1.0,0.0]  |0.0       |
|NO000050540|NO     |2013|0.0  |[1.0,0.0]  |0.0 