In [0]:
# --------------------------------------------
# SECTION 0 — Install 2 Python TS forecasting libraries
# --------------------------------------------

%pip install pmdarima statsmodels


In [0]:
# --------------------------------------------
# SECTION 1 — Load Data from Databricks
# --------------------------------------------

print("="*80)
print("SECTION 1 — LOAD DATA FROM DATABRICKS")
print("="*80)

from pyspark.sql import functions as F
import pandas as pd

# Load from your existing table
df_spark = spark.table("workspace.default.train_set_imputed")

# Rename index → timestamp
df_spark = df_spark.withColumnRenamed("index", "timestamp")

df_spark = df_spark.select(
    "timestamp", "country", "grid_stress_score",
    "mean_temperature_c", "Actual_Load"
)

df_spark = df_spark.orderBy("timestamp")

df = df_spark.toPandas()
df["timestamp"] = pd.to_datetime(df["timestamp"])

print(df.head())
print(df.shape)


In [0]:
# --------------------------------------------
# SECTION 2 — Prepare TS & Fit ARIMA per country
# --------------------------------------------

# It loops through each country in the dataset, prepares its time series data (hourly stress scores), 
# and fits an ARIMA model using auto_arima. 
# The result is a dictionary of fitted models — one per country

print("="*80)
print("SECTION 2 — PREPARE TIME SERIES (PER COUNTRY)")
print("="*80)

from pmdarima import auto_arima

countries = sorted(df["country"].unique())
print("Countries found:", countries)

arima_models = {}   # to store fitted models
series_info = {}    

for country in countries:
    print("\n" + "-"*80)
    print(f"▶ Training ARIMA for country: {country}")
    print("-"*80)

    # Filter this country's data
    df_country = df[df["country"] == country].copy()

    # Sort & index by timestamp
    df_country = df_country.sort_values("timestamp")
    df_country = df_country.set_index("timestamp")

    # Univariate time series (grid stress)
    ts = df_country["grid_stress_score"]

    print("Head:")
    print(ts.head())
    print("Length:", len(ts))

    # Skip if too few points
    if len(ts) < 30:
        print("⚠️ Not enough observations, skipping.")
        continue

    # Fit ARIMA automatically
    model = auto_arima(
        ts,
        seasonal=False,         
        trace=False,
        suppress_warnings=True,
        stepwise=True
    )

    arima_models[country] = model
    series_info[country] = {"n_obs": len(ts)}

    print(f"✔ Fitted ARIMA for {country}, order={model.order}")


In [0]:
# --------------------------------------------
# SECTION 3 — Forecast with ARIMA models
# --------------------------------------------

# It generates 6-hour-ahead forecasts from the ARIMA models for each country 
# and stores the results in a dictionary (arima_forecasts) for later use or visualization.


print("="*80)
print("SECTION 3 — FORECAST WITH ARIMA (PER COUNTRY)")
print("="*80)

arima_forecasts = {}

n_periods = 6  # forecast horizon

for country, model in arima_models.items():
    print(f"\n▶ Forecasting for {country}")
    fc = model.predict(n_periods=n_periods)
    arima_forecasts[country] = fc
    print(fc[:5])



In [0]:
# --------------------------------------------
# SECTION 4 — SAVE ARIMA MODELS FOR STREAMLIT - Enter your email account for databricks!
# --------------------------------------------

print("="*80)
print("SECTION 4 — SAVE ARIMA MODELS")
print("="*80)

import os
import pickle

# Use the folder path where you want to save the models
output_dir = "/Workspace/Users/(Enter your email account for Databricks here)/"
os.makedirs(output_dir, exist_ok=True)

# Create subfolder for ARIMA models
arima_dir = f"{output_dir}/arima_models"
os.makedirs(arima_dir, exist_ok=True)

# Save model for each country
for country, model in arima_models.items():
    model_path = f"{arima_dir}/arima_{country}.pkl"
    
    with open(model_path, "wb") as f:
        pickle.dump(model, f)

    print(f"✔ Saved ARIMA model for {country} → {model_path}")




In [0]:
# --------------------------------------------
# SECTION 5 — ARIMA VALIDATION
# --------------------------------------------

print("="*80)
print("SECTION 5 — ARIMA VALIDATION (MAE & RMSE ONLY)")
print("="*80)

# Load validation dataset
val_spark = spark.table("workspace.default.validation_set_imputed")

# Rename index → timestamp
val_spark = val_spark.withColumnRenamed("index", "timestamp")

# Select required columns
val_spark = val_spark.select(
    "timestamp", "country", "grid_stress_score"
).orderBy("timestamp")

# Convert to pandas
val_df = val_spark.toPandas()
val_df["timestamp"] = pd.to_datetime(val_df["timestamp"])

from sklearn.metrics import mean_absolute_error, mean_squared_error
import numpy as np

validation_results = {}

for country in countries:
    print("\n" + "-"*80)
    print(f"▶ VALIDATING ARIMA for {country}")
    print("-"*80)

    if country not in arima_models:
        print(" No ARIMA model trained for this country — skipping.")
        continue

    # Filter validation data for this country
    df_val_country = val_df[val_df["country"] == country].copy()

    if len(df_val_country) == 0:
        print(" No validation data available — skipping.")
        continue

    df_val_country = df_val_country.sort_values("timestamp")
    df_val_country = df_val_country.set_index("timestamp")

    val_ts = df_val_country["grid_stress_score"]

    # Forecast same length as validation set
    model = arima_models[country]
    steps = len(val_ts)
    fc = model.predict(n_periods=steps)

    # Compute MAE and RMSE
    mae = mean_absolute_error(val_ts, fc)
    rmse = np.sqrt(mean_squared_error(val_ts, fc))

    validation_results[country] = {"MAE": mae, "RMSE": rmse}

    print(f"MAE  = {mae:.3f}")
    print(f"RMSE = {rmse:.3f}")

# Summary of all countries
print("\n" + "="*80)
print("FINAL VALIDATION SUMMARY")
print("="*80)

for c, metrics in validation_results.items():
    print(f"{c}:  MAE={metrics['MAE']:.3f}   RMSE={metrics['RMSE']:.3f}")

