In [0]:
import yaml
from pathlib import Path

with open("../config.yaml", "r") as f:
    config = yaml.safe_load(f)

config['databricks']['catalog'], config['databricks']['schema'], config["databricks"]["volume"]

In [0]:
def running_on_databricks():
    """Detect if running in Databricks environment"""
    try:
        import pyspark.dbutils  # only available in Databricks
        return True
    except ImportError:
        return False

IS_DATABRICKS = running_on_databricks()
print(IS_DATABRICKS)

In [0]:
from pyspark.sql import functions as F
# from helper import run_forecast, aggregate_to_granularity, build_features, train_test_split

from helper import (
    aggregate_to_granularity, assert_unique_series_rows, build_features,
    train_test_split, model_factory, assemble_global_pipeline, fit_global_model, predict_global,
    compute_metrics, fit_predict_local, rolling_backtest, run_forecast, plot_forecast, plot_train_test_forecast)

from pyspark.sql import SparkSession

# Give Spark way more memory since you have 32GB RAM available
spark = SparkSession.builder \
    .appName("TimeSeriesForecast") \
    .config("spark.driver.memory", "12g") \
    .config("spark.executor.memory", "12g") \
    .config("spark.driver.maxResultSize", "4g") \
    .config("spark.sql.shuffle.partitions", "16") \
    .config("spark.default.parallelism", "8") \
    .config("spark.sql.adaptive.enabled", "true") \
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
    .getOrCreate()


In [0]:
# Load data (must include columns: date, sales, family, store_nbr)
if IS_DATABRICKS:
    df_raw = spark.read.format("delta").table('portfolio_catalog.databricks_pipeline.silver_training').withColumn("date", F.to_date(F.col("date")))

else:
    df_raw = (   
        spark.read.parquet('../notebooks/data/train.parquet')

        .withColumn("date", F.to_date(F.col("date"))))


In [0]:
cfg = {
    "data": {"date_col": "date", "target_col": "sales", "group_cols": ["family", "store_nbr"],
             "freq": "D", "min_train_periods": 56},
    "aggregation": {"target_agg": "sum", "extra_numeric_aggs": {"dcoilwtico": "mean", "onpromotion": "sum"}},
    "features": {"lags": [1,7,14,28], "mas": [7,28], "add_time_signals": True},
    "split": {"mode": "horizon", "train_end_date": "", "test_horizon": 28},
    "model": {"type": "spark_gbt", "params": {"maxDepth": 7, "maxIter": 120}},
    # "model": {"type": "spark_lgbt", "params": {"maxDepth": 7, "maxIter": 120}},
    "evaluation": {"mase_seasonality": 7, "backtest": {"enabled": True, "folds": 4, "fold_horizon": 14, "step": 14}}
}

# --- Step 1: Features ---
df_feat = build_features(df_raw, cfg["data"]["date_col"], cfg["data"]["target_col"],
                         cfg["data"]["group_cols"], cfg["features"]["lags"], cfg["features"]["mas"],
                         cfg["features"]["add_time_signals"], pre_aggregate=True,
                         target_agg=cfg["aggregation"]["target_agg"],
                         extra_numeric_aggs=cfg["aggregation"].get("extra_numeric_aggs"))
display(df_feat.limit(5))

# --- Step 2: Split ---
train, test = train_test_split(df_feat, cfg["data"]["date_col"], cfg["data"]["group_cols"],
                               cfg["split"]["mode"], cfg["split"]["train_end_date"], cfg["split"]["test_horizon"],
                               cfg["data"]["min_train_periods"])

# --- Step 3: Train (global model) ---
est = model_factory(cfg["model"]["type"], cfg["model"]["params"])
feature_cols = [c for c in train.columns if c not in cfg["data"]["group_cols"] + [cfg["data"]["date_col"], cfg["data"]["target_col"], "label"]]
model = fit_global_model(train, cfg["data"]["target_col"], cfg["data"]["group_cols"], feature_cols, est)

# --- Step 4: Predict ---
pred = predict_global(model, test, cfg["data"]["group_cols"], cfg["data"]["date_col"], cfg["data"]["target_col"])
display(pred.limit(10))

# --- Step 5: Metrics ---
by_series, portfolio = compute_metrics(pred, cfg["data"]["date_col"], "y", "prediction",
                                       cfg["data"]["group_cols"], cfg["evaluation"]["mase_seasonality"])
display(by_series.orderBy("wMAPE")); display(portfolio)

# # --- Optional: Backtest ---
# from smartforecast.forecasting import aggregate_to_granularity, rolling_backtest
# df_agg = aggregate_to_granularity(df_raw, cfg["data"]["date_col"], cfg["data"]["target_col"],
#                                   cfg["data"]["group_cols"], cfg["aggregation"]["target_agg"],
#                                   cfg["aggregation"].get("extra_numeric_aggs"))
# bt = rolling_backtest(df_agg, cfg["data"]["date_col"], cfg["data"]["target_col"], cfg["data"]["group_cols"],
#                       feature_params={"lags": cfg["features"]["lags"], "mas": cfg["features"]["mas"], "add_time_signals": cfg["features"]["add_time_signals"], "freq": cfg["data"]["freq"]},
#                       model_type=cfg["model"]["type"], model_params=cfg["model"]["params"],
#                       folds=cfg["evaluation"]["backtest"]["folds"], fold_horizon=cfg["evaluation"]["backtest"]["fold_horizon"],
#                       step=cfg["evaluation"]["backtest"]["step"], mase_seasonality=cfg["evaluation"]["mase_seasonality"])
# display(bt)


In [0]:
pred.write.mode('overwrite').option("mergeSchema", "true").saveAsTable(
    'portfolio_catalog.databricks_pipeline.silver_validation_predictions'
)