In [0]:
# --------------------------------------------
# 📌 SECTION 0 — Install 2 Python TS forecasting libraries
# --------------------------------------------

%pip install pmdarima statsmodels

# pmdarima: Automates ARIMA model selection for time series analysis.
# statsmodels: Provides statistical models, including manual ARIMA fitting.


[43mNote: you may need to restart the kernel using %restart_python or dbutils.library.restartPython() to use updated packages.[0m


In [0]:
# --------------------------------------------
# SECTION 1 — Load Data from Databricks
# --------------------------------------------

print("="*80)
print("SECTION 1 — LOAD DATA FROM DATABRICKS")
print("="*80)

from pyspark.sql import functions as F
import pandas as pd

# Load from your existing table
df_spark = spark.table("workspace.default.train_set_imputed")

# Rename index → timestamp
df_spark = df_spark.withColumnRenamed("index", "timestamp")

df_spark = df_spark.select(
    "timestamp", "country", "grid_stress_score",
    "mean_temperature_c", "Actual_Load"
)

df_spark = df_spark.orderBy("timestamp")

df = df_spark.toPandas()
df["timestamp"] = pd.to_datetime(df["timestamp"])

print(df.head())
print(df.shape)


SECTION 1 — LOAD DATA FROM DATABRICKS
   timestamp country  grid_stress_score  mean_temperature_c  Actual_Load
0 2023-01-01      AT               50.0            3.151662     5280.800
1 2023-01-01      BE               62.5           13.812451     7082.920
2 2023-01-01      DE               62.5           12.760933    37777.205
3 2023-01-01      ES               50.0           11.962337    19251.000
4 2023-01-01      FR               62.5           12.742556    45709.000
(222638, 5)


In [0]:
# --------------------------------------------
# SECTION 2 — Prepare TS & Fit ARIMA per country
# --------------------------------------------

# It loops through each country in the dataset, prepares its time series data (hourly stress scores), 
# and fits an ARIMA model using auto_arima. 
# The result is a dictionary of fitted models — one per country

print("="*80)
print("SECTION 2 — PREPARE TIME SERIES (PER COUNTRY)")
print("="*80)

from pmdarima import auto_arima

countries = sorted(df["country"].unique())
print("Countries found:", countries)

arima_models = {}   # to store fitted models
series_info = {}    

for country in countries:
    print("\n" + "-"*80)
    print(f"▶ Training ARIMA for country: {country}")
    print("-"*80)

    # Filter this country's data
    df_country = df[df["country"] == country].copy()

    # Sort & index by timestamp
    df_country = df_country.sort_values("timestamp")
    df_country = df_country.set_index("timestamp")

    # Univariate time series (grid stress)
    ts = df_country["grid_stress_score"]

    print("Head:")
    print(ts.head())
    print("Length:", len(ts))

    # Skip if too few points
    if len(ts) < 30:
        print("⚠️ Not enough observations, skipping.")
        continue

    # Fit ARIMA automatically
    model = auto_arima(
        ts,
        seasonal=False,         
        trace=False,
        suppress_warnings=True,
        stepwise=True
    )

    arima_models[country] = model
    series_info[country] = {"n_obs": len(ts)}

    print(f"✔ Fitted ARIMA for {country}, order={model.order}")


SECTION 2 — PREPARE TIME SERIES (PER COUNTRY)
Countries found: ['AT', 'BE', 'DE', 'ES', 'FR', 'HR', 'HU', 'IT', 'LT', 'NL', 'PL', 'PT', 'SK']

--------------------------------------------------------------------------------
▶ Training ARIMA for country: AT
--------------------------------------------------------------------------------
Head:
timestamp
2023-01-01 00:00:00    50.0
2023-01-01 01:00:00    50.0
2023-01-01 02:00:00    50.0
2023-01-01 03:00:00    62.5
2023-01-01 04:00:00    62.5
Name: grid_stress_score, dtype: float64
Length: 17521
✔ Fitted ARIMA for AT, order=(0, 1, 5)

--------------------------------------------------------------------------------
▶ Training ARIMA for country: BE
--------------------------------------------------------------------------------
Head:
timestamp
2023-01-01 00:00:00    62.5
2023-01-01 01:00:00    62.5
2023-01-01 02:00:00    62.5
2023-01-01 03:00:00    62.5
2023-01-01 04:00:00    62.5
Name: grid_stress_score, dtype: float64
Length: 17520
✔ Fitte

In [0]:
# --------------------------------------------
# SECTION 3 — Forecast with ARIMA models
# --------------------------------------------

# It generates 6-hour-ahead forecasts from the ARIMA models for each country 
# and stores the results in a dictionary (arima_forecasts) for later use or visualization.


print("="*80)
print("SECTION 3 — FORECAST WITH ARIMA (PER COUNTRY)")
print("="*80)

arima_forecasts = {}

n_periods = 6  # forecast horizon

for country, model in arima_models.items():
    print(f"\n▶ Forecasting for {country}")
    fc = model.predict(n_periods=n_periods)
    arima_forecasts[country] = fc
    print(fc[:5])



SECTION 3 — FORECAST WITH ARIMA (PER COUNTRY)

▶ Forecasting for AT
2024-12-31 01:00:00    18.261758
2024-12-31 02:00:00    23.615674
2024-12-31 03:00:00    30.624668
2024-12-31 04:00:00    34.440792
2024-12-31 05:00:00    34.919724
Freq: h, dtype: float64

▶ Forecasting for BE
17520    26.785454
17521    26.968349
17522    27.319779
17523    27.566727
17524    27.740257
dtype: float64

▶ Forecasting for DE
2024-12-31 01:00:00    35.697865
2024-12-31 02:00:00    34.778514
2024-12-31 03:00:00    35.015411
2024-12-31 04:00:00    35.340923
2024-12-31 05:00:00    35.296053
Freq: h, dtype: float64

▶ Forecasting for ES
17377    11.852527
17378    16.529599
17379    19.157737
17380    19.157737
17381    19.157737
dtype: float64

▶ Forecasting for FR
17486    45.162276
17487    41.514989
17488    39.361429
17489    39.745455
17490    37.919006
dtype: float64

▶ Forecasting for HR
2024-12-31 01:00:00    32.076763
2024-12-31 02:00:00    29.104411
2024-12-31 03:00:00    27.296795
2024-12-31 04:0

  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(


In [0]:
# --------------------------------------------
# SECTION 4 — SAVE ARIMA MODELS FOR STREAMLIT
# --------------------------------------------
print("="*80)
print("SECTION 4 — SAVE ARIMA MODELS")
print("="*80)

import os
import pickle

# Use the folder path where you want to save the models
output_dir = "/Workspace/Users/(Enter your account Email for Databricks here)/"
os.makedirs(output_dir, exist_ok=True)

# Create subfolder for ARIMA models
arima_dir = f"{output_dir}/arima_models"
os.makedirs(arima_dir, exist_ok=True)

# Save model for each country
for country, model in arima_models.items():
    model_path = f"{arima_dir}/arima_{country}.pkl"
    
    with open(model_path, "wb") as f:
        pickle.dump(model, f)

    print(f"✔ Saved ARIMA model for {country} → {model_path}")




SECTION 4 — SAVE ARIMA MODELS
✔ Saved ARIMA model for AT → /Workspace/Users/y.hsiao.6666@gmail.com//arima_models/arima_AT.pkl
✔ Saved ARIMA model for BE → /Workspace/Users/y.hsiao.6666@gmail.com//arima_models/arima_BE.pkl
✔ Saved ARIMA model for DE → /Workspace/Users/y.hsiao.6666@gmail.com//arima_models/arima_DE.pkl
✔ Saved ARIMA model for ES → /Workspace/Users/y.hsiao.6666@gmail.com//arima_models/arima_ES.pkl
✔ Saved ARIMA model for FR → /Workspace/Users/y.hsiao.6666@gmail.com//arima_models/arima_FR.pkl
✔ Saved ARIMA model for HR → /Workspace/Users/y.hsiao.6666@gmail.com//arima_models/arima_HR.pkl
✔ Saved ARIMA model for HU → /Workspace/Users/y.hsiao.6666@gmail.com//arima_models/arima_HU.pkl
✔ Saved ARIMA model for IT → /Workspace/Users/y.hsiao.6666@gmail.com//arima_models/arima_IT.pkl
✔ Saved ARIMA model for LT → /Workspace/Users/y.hsiao.6666@gmail.com//arima_models/arima_LT.pkl
✔ Saved ARIMA model for NL → /Workspace/Users/y.hsiao.6666@gmail.com//arima_models/arima_NL.pkl
✔ Saved AR

In [0]:
# --------------------------------------------
# SECTION 5 — ARIMA VALIDATION
# --------------------------------------------

# --------------------------------------------
# 📌 SECTION 5 — ARIMA VALIDATION (PER COUNTRY, NO PLOTS)
# --------------------------------------------

print("="*80)
print("SECTION 5 — ARIMA VALIDATION (MAE & RMSE ONLY)")
print("="*80)

# Load validation dataset
val_spark = spark.table("workspace.default.validation_set_imputed")

# Rename index → timestamp
val_spark = val_spark.withColumnRenamed("index", "timestamp")

# Select required columns
val_spark = val_spark.select(
    "timestamp", "country", "grid_stress_score"
).orderBy("timestamp")

# Convert to pandas
val_df = val_spark.toPandas()
val_df["timestamp"] = pd.to_datetime(val_df["timestamp"])

from sklearn.metrics import mean_absolute_error, mean_squared_error
import numpy as np

validation_results = {}

for country in countries:
    print("\n" + "-"*80)
    print(f"▶ VALIDATING ARIMA for {country}")
    print("-"*80)

    if country not in arima_models:
        print("⚠️ No ARIMA model trained for this country — skipping.")
        continue

    # Filter validation data for this country
    df_val_country = val_df[val_df["country"] == country].copy()

    if len(df_val_country) == 0:
        print("⚠️ No validation data available — skipping.")
        continue

    df_val_country = df_val_country.sort_values("timestamp")
    df_val_country = df_val_country.set_index("timestamp")

    val_ts = df_val_country["grid_stress_score"]

    # Forecast same length as validation set
    model = arima_models[country]
    steps = len(val_ts)
    fc = model.predict(n_periods=steps)

    # Compute MAE and RMSE
    mae = mean_absolute_error(val_ts, fc)
    rmse = np.sqrt(mean_squared_error(val_ts, fc))

    validation_results[country] = {"MAE": mae, "RMSE": rmse}

    print(f"MAE  = {mae:.3f}")
    print(f"RMSE = {rmse:.3f}")

# Summary of all countries
print("\n" + "="*80)
print("FINAL VALIDATION SUMMARY")
print("="*80)

for c, metrics in validation_results.items():
    print(f"{c}:  MAE={metrics['MAE']:.3f}   RMSE={metrics['RMSE']:.3f}")



SECTION 5 — ARIMA VALIDATION (MAE & RMSE ONLY)

--------------------------------------------------------------------------------
▶ VALIDATING ARIMA for AT
--------------------------------------------------------------------------------
MAE  = 12.865
RMSE = 15.896

--------------------------------------------------------------------------------
▶ VALIDATING ARIMA for BE
--------------------------------------------------------------------------------
MAE  = 10.001
RMSE = 12.811

--------------------------------------------------------------------------------
▶ VALIDATING ARIMA for DE
--------------------------------------------------------------------------------


  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(


MAE  = 13.184
RMSE = 16.114

--------------------------------------------------------------------------------
▶ VALIDATING ARIMA for ES
--------------------------------------------------------------------------------
MAE  = 12.048
RMSE = 15.445

--------------------------------------------------------------------------------
▶ VALIDATING ARIMA for FR
--------------------------------------------------------------------------------
MAE  = 12.773
RMSE = 14.800

--------------------------------------------------------------------------------
▶ VALIDATING ARIMA for HR
--------------------------------------------------------------------------------


  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(


MAE  = 12.647
RMSE = 16.548

--------------------------------------------------------------------------------
▶ VALIDATING ARIMA for HU
--------------------------------------------------------------------------------
MAE  = 15.440
RMSE = 18.503

--------------------------------------------------------------------------------
▶ VALIDATING ARIMA for IT
--------------------------------------------------------------------------------
MAE  = 14.629
RMSE = 18.165

--------------------------------------------------------------------------------
▶ VALIDATING ARIMA for LT
--------------------------------------------------------------------------------


  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(


MAE  = 14.275
RMSE = 18.573

--------------------------------------------------------------------------------
▶ VALIDATING ARIMA for NL
--------------------------------------------------------------------------------
MAE  = 14.609
RMSE = 18.419

--------------------------------------------------------------------------------
▶ VALIDATING ARIMA for PL
--------------------------------------------------------------------------------
MAE  = 12.623
RMSE = 15.484

--------------------------------------------------------------------------------
▶ VALIDATING ARIMA for PT
--------------------------------------------------------------------------------


  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(
  return get_prediction_index(


MAE  = 10.690
RMSE = 14.492

--------------------------------------------------------------------------------
▶ VALIDATING ARIMA for SK
--------------------------------------------------------------------------------
⚠️ No validation data available — skipping.

FINAL VALIDATION SUMMARY
AT:  MAE=12.865   RMSE=15.896
BE:  MAE=10.001   RMSE=12.811
DE:  MAE=13.184   RMSE=16.114
ES:  MAE=12.048   RMSE=15.445
FR:  MAE=12.773   RMSE=14.800
HR:  MAE=12.647   RMSE=16.548
HU:  MAE=15.440   RMSE=18.503
IT:  MAE=14.629   RMSE=18.165
LT:  MAE=14.275   RMSE=18.573
NL:  MAE=14.609   RMSE=18.419
PL:  MAE=12.623   RMSE=15.484
PT:  MAE=10.690   RMSE=14.492
