# Many Models Forecasting Demo

This notebook showcases how to run MMF on a serverless compute using local models. We will use [M4 competition](https://www.sciencedirect.com/science/article/pii/S0169207019301128#sec5) data. The descriptions here are mostly the same as the case with the [daily resolution](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/examples/daily/local_univariate_daily.ipynb), so we will skip the redundant parts and focus only on the essentials.

**Note that for a large scale forecasting use case (number of time series greater than 100), we recommend using a non-serverless compute, preferably a multi-node cluster.**

### Serverless Compute setup

Attach this notebook to a [serverless compute](https://docs.databricks.com/aws/en/compute/serverless/notebooks). Then go to the [Configuration](https://docs.databricks.com/aws/en/compute/serverless/dependencies) tab and set the [Environment version](https://docs.databricks.com/aws/en/compute/serverless/dependencies#-select-an-environment-version) to 4. In the Dependencies section, [add the path](https://docs.databricks.com/aws/en/compute/serverless/dependencies#create-common-utilities-to-share-across-your-workspace) to your `many-model-forecasting` directory: e.g., `/Workspace/Users/ryuta.yoshimatsu@databricks.com/many-model-forecasting`. This is required to use the MMF functions within the notebooks.

### Install and import packages

In [0]:
%pip install -r ../../requirements-local.txt --quiet
%pip install datasetsforecast==0.0.8 --quiet
dbutils.library.restartPython()

In [0]:
import logging
import pathlib
import pandas as pd
from datasetsforecast.m4 import M4
from mmf_sa import run_forecast

# Suppress MLflow context resolution warnings on serverless compute
logging.getLogger("mlflow.tracking.context.registry").setLevel(logging.ERROR)

# Suppress py4j verbose logging
logging.getLogger("py4j.clientserver").setLevel(logging.WARNING)
logging.getLogger("py4j.java_gateway").setLevel(logging.WARNING)

### Prepare data 
We are using [`datasetsforecast`](https://github.com/Nixtla/datasetsforecast/tree/main/) package to download M4 data.

In [0]:
# Number of time series
n = 100


def create_m4_monthly():
    y_df, _, _ = M4.load(directory=str(pathlib.Path.home()), group="Monthly")
    target_ids = {f"M{i}" for i in range(1, n)}
    y_df = y_df[y_df["unique_id"].isin(target_ids)]
    y_df = (
        y_df.groupby("unique_id", group_keys=False)
             .apply(lambda g: transform_group(g, g.name))
             .reset_index(drop=True)
    )
    return y_df


def transform_group(df, unique_id):
    if len(df) > 60:
        df = df.iloc[-60:]
    start = pd.Timestamp("2018-01-01")
    date_idx = pd.date_range(start=start, periods=len(df), freq="ME", name="ds")
    res_df = pd.DataFrame({
        "ds": date_idx,
        "unique_id": unique_id,
        "y": df["y"].to_numpy()
    })
    return res_df

We are going to save this data in a delta lake table. Provide catalog and database names where you want to store the data.

In [0]:
catalog = "mmf"  # Name of the catalog we use to manage our assets
db = "m4"  # Name of the schema we use to manage our assets (e.g. datasets)
user = spark.sql('select current_user() as user').collect()[0]['user'] # User email address

In [0]:
# Making sure that the catalog and the schema exist
_ = spark.sql(f"CREATE CATALOG IF NOT EXISTS {catalog}")
_ = spark.sql(f"CREATE SCHEMA IF NOT EXISTS {catalog}.{db}")

(
    spark.createDataFrame(create_m4_monthly())
    .write.format("delta").mode("overwrite")
    .saveAsTable(f"{catalog}.{db}.serverless_train")
)

Let's take a peak at the dataset:

In [0]:
display(
  spark.sql(f"select unique_id, count(ds) as count from {catalog}.{db}.serverless_train group by unique_id order by unique_id")
  )

In [0]:
display(
  spark.sql(f"select * from {catalog}.{db}.serverless_train where unique_id in ('M1', 'M2', 'M3', 'M4', 'M5') order by unique_id, ds")
  )

In [0]:
# Get the current value of shuffle partitions
current = spark.conf.get("spark.sql.shuffle.partitions")

# If not set to 'auto' (serverless), convert to int; otherwise, use default 200
if current != "auto":
    current_val = int(current)
else:
    current_val = 200       

# If n is greater than the current value, update the shuffle partitions setting
if n > current_val:            
    spark.conf.set("spark.sql.shuffle.partitions", str(n))

Note that monthly forecasting requires the timestamp column to represent the last day of each month.

### Models
Let's configure a list of models we are going to apply to our time series for evaluation and forecasting. A comprehensive list of all supported models is available in [mmf_sa/models/README.md](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/mmf_sa/models/README.md). Look for the models where `model_type: local`; these are the local models we import from [statsforecast](https://github.com/Nixtla/statsforecast) and [sktime](https://github.com/sktime/sktime). Check their documentations for the description of each model. 

In [0]:
active_models = [
    "StatsForecastBaselineWindowAverage",
    "StatsForecastBaselineSeasonalWindowAverage",
    "StatsForecastBaselineNaive",
    "StatsForecastBaselineSeasonalNaive",
    "StatsForecastAutoArima",
    "StatsForecastAutoETS",
    "StatsForecastAutoCES",
    "StatsForecastAutoTheta",
    "StatsForecastAutoTbats",
    "StatsForecastAutoMfles",
    "StatsForecastTSB",
    "StatsForecastADIDA",
    "StatsForecastIMAPA",
    "StatsForecastCrostonClassic",
    "StatsForecastCrostonOptimized",
    "StatsForecastCrostonSBA",
    # Prophet has compatibility issues on serverless compute
]

### Run MMF

Now, we can run the evaluation and forecasting using `run_forecast` function defined in [mmf_sa/models/__init__.py](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/mmf_sa/models/__init__.py). Make sure to set `freq="M"` in `run_forecast` function.

In [0]:
run_forecast(
    spark=spark,
    train_data=f"{catalog}.{db}.serverless_train",
    scoring_data=f"{catalog}.{db}.serverless_train",
    scoring_output=f"{catalog}.{db}.serverless_scoring_output",
    evaluation_output=f"{catalog}.{db}.serverless_evaluation_output",
    group_id="unique_id",
    date_col="ds",
    target="y",
    freq="M",
    prediction_length=3,
    backtest_length=12,
    stride=1,
    metric="smape",
    train_predict_ratio=1,
    data_quality_check=True,
    resample=False,
    active_models=active_models,
    experiment_path=f"/Users/{user}/mmf/serverless",
    use_case_name="serverless",
)

### Evaluate
In `evaluation_output` table, the we store all evaluation results for all backtesting trials from all models.

In [0]:
display(
  spark.sql(f"""
    select * from {catalog}.{db}.serverless_evaluation_output 
    where unique_id = 'M1'
    order by unique_id, model, backtest_window_start_date
    """))

### Forecast
In `scoring_output` table, forecasts for each time series from each model are stored.

In [0]:
display(spark.sql(f"""
    select * from {catalog}.{db}.serverless_scoring_output 
    where unique_id = 'M1'
    order by unique_id, model, ds
    """))

Refer to the [notebook](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/examples/post-evaluation-analysis.ipynb) for guidance on performing fine-grained model selection after running `run_forecast`.

### Delete Tables
Let's clean up the tables.

In [0]:
#display(spark.sql(f"delete from {catalog}.{db}.serverless_evaluation_output"))

In [0]:
#display(spark.sql(f"delete from {catalog}.{db}.serverless_scoring_output"))