In [0]:
# --------------------------------------------
# SECTION 0 — Install Dependencies
# --------------------------------------------

%pip install pmdarima statsmodels prophet
dbutils.library.restartPython()


In [0]:
# --------------------------------------------
# SECTION 1 — Load Training + Validation Data
# --------------------------------------------

print("="*80)
print("SECTION 1 — LOAD TRAIN + VALIDATION DATA")
print("="*80)

from pyspark.sql import functions as F
import pandas as pd

train_spark = spark.table("workspace.default.train_imputed_timebins_lags")
val_spark   = spark.table("workspace.default.validation_imputed_timebins_lags")

cols = ["timestamp", "country", "grid_stress_score", "mean_temperature_c", "Actual_Load"]

df_train = train_spark.select(cols).toPandas()
df_val   = val_spark.select(cols).toPandas()

df_train["timestamp"] = pd.to_datetime(df_train["timestamp"])
df_val["timestamp"]   = pd.to_datetime(df_val["timestamp"])

print(df_train.head())
print("Train rows:", len(df_train))
print("Validation rows:", len(df_val))



In [0]:
# --------------------------------------------
# SECTION 2 — Select Country & Build Time Series
# --------------------------------------------

country = "DE"   # <--- CHANGE THIS

print("="*80)
print("SECTION 2 — SELECT COUNTRY & BUILD TIME SERIES")
print("="*80)

print("Available countries:", df_train["country"].unique())


# TRAIN
df_train_c = df_train[df_train["country"] == country].sort_values("timestamp")
ts_train = df_train_c.set_index("timestamp")["grid_stress_score"]

# VALIDATION
df_val_c = df_val[df_val["country"] == country].sort_values("timestamp")
ts_val = df_val_c.set_index("timestamp")["grid_stress_score"]

print("Train length:", len(ts_train))
print("Validation length:", len(ts_val))
print(ts_train.head())


In [0]:
# --------------------------------------------
# SECTION 3 — ADF Test (Stationarity)
# --------------------------------------------

print("="*80)
print("SECTION 3 — ADF STATIONARITY TEST")
print("="*80)

from statsmodels.tsa.stattools import adfuller

adf = adfuller(ts_train.dropna())
print("ADF Statistic:", adf[0])
print("p-value:", adf[1])

if adf[1] < 0.05:
    print("✓ Stationary → d = 0")
else:
    print("✗ Not stationary → differencing needed (d = 1)")


In [0]:
# --------------------------------------------
# SECTION 4 — Auto ARIMA Parameter Search
# --------------------------------------------

print("="*80)
print("SECTION 4 — AUTO ARIMA PARAMETER SEARCH")
print("="*80)

from pmdarima import auto_arima

auto_arima_model = auto_arima(
    ts_train,
    seasonal=False,
    stepwise=True,
    trace=True,
    suppress_warnings=True
)

p, d, q = auto_arima_model.order
print(f"Selected ARIMA(p,d,q) = ({p},{d},{q})")


In [0]:
# --------------------------------------------
# SECTION 5 — Fit ARIMA Model
# --------------------------------------------

print("="*80)
print("SECTION 5 — FIT ARIMA")
print("="*80)

from statsmodels.tsa.arima.model import ARIMA

arima = ARIMA(ts_train, order=(p,d,q))
arima_fit = arima.fit()

print(arima_fit.summary())


In [0]:
# --------------------------------------------
# SECTION 6 — Forecast ARIMA for Validation Horizon
# --------------------------------------------

print("="*80)
print("SECTION 6 — ARIMA FORECAST")
print("="*80)

steps = len(ts_val)
arima_fc = arima_fit.forecast(steps=steps)

print(arima_fc.head())


In [0]:
# --------------------------------------------
# SECTION 7 — PLOT: ARIMA Forecast vs Actual
# --------------------------------------------

import matplotlib.pyplot as plt

plt.figure(figsize=(14,5))
plt.plot(ts_train[-200:], label="Train (last 200)")
plt.plot(ts_val, label="Validation Actual")
plt.plot(ts_val.index, arima_fc, label="ARIMA Forecast")
plt.title(f"ARIMA Forecast vs Actual — {country}")
plt.legend()
plt.show()


In [0]:
# --------------------------------------------
# SECTION 8 — Fit SARIMA (Seasonal ARIMA)
# --------------------------------------------

#print("="*80)
#print("SECTION 8 — AUTO SARIMA")
#print("="*80)

#sarima_auto = auto_arima(
#    ts_train,
#    seasonal=True,
#    m=24,          # daily seasonality
#    trace=True,
#    suppress_warnings=True
#)

#p, d, q = sarima_auto.order
#P, D, Q, m = sarima_auto.seasonal_order
#print("SARIMA order:", (p,d,q))
#print("Seasonal order:", (P,D,Q,m))

#from statsmodels.tsa.statespace.sarimax import SARIMAX

#sarima = SARIMAX(ts_train, order=(p,d,q), seasonal_order=(P,D,Q,m))
#sarima_fit = sarima.fit()

#sarima_fc = sarima_fit.forecast(steps=steps)


In [0]:
# --------------------------------------------
# SECTION 9 — PLOT: SARIMA Forecast
# --------------------------------------------

#plt.figure(figsize=(14,5))
#plt.plot(ts_train[-200:], label="Train (last 200)")
#plt.plot(ts_val, label="Validation Actual")
#plt.plot(ts_val.index, sarima_fc, label="SARIMA Forecast")
#plt.title(f"SARIMA Forecast vs Actual — {country}")
#plt.legend()
#plt.show()


In [0]:
# --------------------------------------------
# SECTION 10 — Prophet Model
# --------------------------------------------

print("="*80)
print("SECTION 10 — PROPHET")
print("="*80)

from prophet import Prophet

df_prophet = df_train_c.rename(columns={
    "timestamp": "ds",
    "grid_stress_score": "y"
})

m = Prophet(daily_seasonality=True)
m.fit(df_prophet)

# Forecast into validation horizon
future = m.make_future_dataframe(periods=steps, freq="H")
forecast = m.predict(future)

prophet_fc = forecast.set_index("ds")["yhat"].iloc[-steps:]


In [0]:
# --------------------------------------------
# SECTION 11 — PLOT: Prophet Forecast
# --------------------------------------------

plt.figure(figsize=(14,5))
plt.plot(ts_train[-200:], label="Train (last 200)")
plt.plot(ts_val, label="Validation Actual")
plt.plot(ts_val.index, prophet_fc, label="Prophet Forecast")
plt.title(f"PROPHET Forecast vs Actual — {country}")
plt.legend()
plt.show()

# Prophet built-in plot
m.plot(forecast);


In [0]:
# --------------------------------------------
# SECTION 12 — Error Metrics (ARIMA, SARIMA, Prophet)
# --------------------------------------------

from sklearn.metrics import mean_squared_error, mean_absolute_error
import numpy as np

def mape(y, yhat):
    return np.mean(np.abs((y - yhat) / y)) * 100

# Convert to numpy arrays to be safe
y_true = ts_val.values
y_arima = np.array(arima_fc)
# y_sarima = np.array(sarima_fc)
y_prophet = np.array(prophet_fc)

# Compute MSE, then RMSE manually
arima_mse   = mean_squared_error(y_true, y_arima)
# sarima_mse  = mean_squared_error(y_true, y_sarima)
prophet_mse = mean_squared_error(y_true, y_prophet)

results = pd.DataFrame({
    "Model": ["ARIMA", 
              # "SARIMA", 
              "Prophet"],
    "RMSE": [
        arima_mse ** 0.5,
        # sarima_mse ** 0.5,
        prophet_mse ** 0.5
    ],
    "MAE": [
        mean_absolute_error(y_true, y_arima),
        # mean_absolute_error(y_true, y_sarima),
        mean_absolute_error(y_true, y_prophet)
    ],
    "MAPE (%)": [
        mape(y_true, y_arima),
        # mape(y_true, y_sarima),
        mape(y_true, y_prophet)
    ]
})

print("Forecast accuracy:")
display(results)


In [0]:
# --------------------------------------------
# SECTION 13 — Combined Comparison Plot
# --------------------------------------------

plt.figure(figsize=(14,5))
plt.plot(ts_val, label="Actual")
plt.plot(ts_val.index, arima_fc, label="ARIMA")
# plt.plot(ts_val.index, sarima_fc, label="SARIMA")
plt.plot(ts_val.index, prophet_fc, label="Prophet")
plt.title(f"Forecast Comparison — {country}")
plt.legend()
plt.show()
