In [0]:
dbutils.widgets.text("catalog", "")
dbutils.widgets.text("db", "")
dbutils.widgets.text("model", "")
dbutils.widgets.text("run_id", "")
dbutils.widgets.text("user", "")

In [0]:
# For Chronos2, install base requirements first, then chronos-forecasting separately
# For TimesFM 2.5, clone repo and install from source (per HF docs)
# For other models, use standard requirements
model_name = dbutils.widgets.get("model")

import subprocess
import os
import shutil

def _pip_install(args):
    get_ipython().run_line_magic('pip', f"install {' '.join(args)} --quiet")

if "Chronos2" in model_name:
    # Install base MMF requirements (without chronos)
    _pip_install(["-r", "../requirements.txt"])
    # Then install latest chronos-forecasting which has Chronos2Pipeline
    _pip_install(["chronos-forecasting>=2.0.0"])
elif "TimesFM_2_5" in model_name:
    # Base requirements plus TimesFM 2.5 from source (per https://huggingface.co/google/timesfm-2.5-200m-pytorch)
    _pip_install(["-r", "../requirements.txt"])
    # Pin numpy/pandas to avoid binary incompatibilities in the runtime
    _pip_install(["numpy==1.26.4", "pandas==2.2.0"])
    # Clone and install in editable mode to a stable path
    repo_dir = "/local_disk0/timesfm"
    if os.path.exists(repo_dir):
        shutil.rmtree(repo_dir)
    subprocess.run(["git", "clone", "https://github.com/google-research/timesfm.git", repo_dir], check=True)
    get_ipython().run_line_magic('pip', f"install -e {repo_dir} --quiet")
elif "TimesFM_" in model_name:
    # Use the latest PyPI TimesFM compatible with older models
    _pip_install(["-r", "../requirements.txt"])
    _pip_install(["timesfm[torch]==1.3.0"])
else:
    model_class = "global" if "NeuralForecast" in model_name else "foundation"
    get_ipython().run_line_magic('pip', f'install -r ../requirements-{model_class}.txt --quiet')

dbutils.library.restartPython()

In [0]:
# Add MMF to Python path - use bundle deployment path
import sys
sys.path.insert(0, "/Workspace/Users/rohan.parikh@databricks.com/.bundle/mmf_demo/dev/files")

catalog = dbutils.widgets.get("catalog")
db = dbutils.widgets.get("db")
model = dbutils.widgets.get("model")
run_id = dbutils.widgets.get("run_id")
user = dbutils.widgets.get("user")

In [0]:
from mmf_sa import run_forecast
import logging
logging.getLogger("py4j.java_gateway").setLevel(logging.ERROR)
logging.getLogger("py4j.clientserver").setLevel(logging.ERROR)


run_forecast(
    spark=spark,
    train_data=f"{catalog}.{db}.m4_daily_train",
    scoring_data=f"{catalog}.{db}.m4_daily_train",
    scoring_output=f"{catalog}.{db}.daily_scoring_output",
    evaluation_output=f"{catalog}.{db}.daily_evaluation_output",
    model_output=f"{catalog}.{db}",
    group_id="unique_id",
    date_col="ds",
    target="y",
    freq="D",
    prediction_length=10,
    backtest_length=30,
    stride=10,
    metric="smape",
    train_predict_ratio=1,
    data_quality_check=True,
    resample=False,
    active_models=[model],
    experiment_path=f"/Users/{user}/mmf/m4_daily",
    use_case_name="m4_daily",
    run_id=run_id,
    accelerator="gpu",
)