# Retail Sales Forecasting Example Notebook

This notebook mirrors the video walkthrough: configure the data pipeline, run the JAX/Flax models, and generate the plots bundled inside `artifacts/run_20251215_212247`. The synthetic fallback keeps the runtime < 2 minutes on CPU.

In [1]:
from pathlib import Path
import sys

PROJECT_DIR = Path.cwd() / 'class_project/MSML610/Fall2025/Projects/UmdTask77_Retail_Sales_Forecasting_with_LSTMs'
if str(PROJECT_DIR) not in sys.path:
    sys.path.append(str(PROJECT_DIR))


In [None]:
from pathlib import Path

from retail_sales_forecasting_with_lstms.API import (
    TrainingConfig,
    run_training_pipeline,
    load_run_metrics,
    plot_training_curves,
    plot_final_metrics_comparison,
    plot_sales_metrics,
    plot_breakdowns,
    run_inference,
    plot_predictions_sample,
)

## 1. Configure data + hyper-parameters

In [None]:
cfg_example = TrainingConfig(
    data_dir=Path("data/store-sales-time-series-forecasting"),
    families=("GROCERY I", "BEVERAGES", "PRODUCE", "CLEANING", "DAIRY"),
    max_stores=5,
    context_length=30,
    horizon=7,
    epochs=3,
    batch_size=256,
    synthetic_if_missing=True,
)
cfg_example

## 2. Train LSTM + GRU and persist the run

In [None]:
run_dir, results, dataset = run_training_pipeline(
    cfg_example,
    output_dir=Path("artifacts/example_demo"),
    run_name="example_notebook_demo",
    model_names=("lstm", "gru"),
)
[(r.name, r.normalized_metrics) for r in results]

## 3. Inspect metrics

In [None]:
metrics = load_run_metrics(run_dir)
metrics["lstm"]["normalized_metrics"], metrics["gru"]["normalized_metrics"]

## 4. Generate plots

In [None]:
plot_training_curves(metrics, run_dir)
plot_final_metrics_comparison(metrics, run_dir)
plot_sales_metrics(metrics, run_dir)
plot_breakdowns(metrics, run_dir, model_name="lstm")
plot_breakdowns(metrics, run_dir, model_name="gru")

## 5. Run inference and visualize predictions

In [None]:
inference = run_inference(run_dir, "lstm", dataset=dataset)
plot_predictions_sample(
    inference["predictions"],
    inference["targets"],
    dataset,
    run_dir / "lstm_test_predictions.png",
)