---
title: Probabilistic Forecasting at Scale with Numpyro
format:
  poster-typst: 
    size: "36x24"
    poster-authors: "Juan Orduz"
    departments: "PyMC Labs"
    institution-logo: "./images/numpyro.png"
    footer-text: "3rd Vienna Workshop on Economic Forecasting 2025"
    footer-url: "https://juanitorduz.github.io/"
    footer-emails: "juanitorduz@gmail.com"
    footer-color: "FFFF00"
    keywords: ["Probabilistic Models", "Forecasting", "JAX", "NumPyro"]
---



# JAX and NumPyro

> "JAX is a Python library for accelerator-oriented array computation and program transformation, designed for
> high-performance numerical computing and large-scale machine learning."

> "NumPyro is a lightweight probabilistic programming library that provides a NumPy backend for Pyro. We rely on JAX
> for automatic differentiation and JIT compilation to GPU / CPU."

# Classical Time Series Models

We provide implementation for the most common statistical time series models (exponential smoothing, ARIMAX, Croston's
method, TSB and many more) and also state space models.

```{.python}
# See https://juanitorduz.github.io/exponential_smoothing_numpyro/
def level_model(y: Array, future: int = 0) -> None:
    t_max = y.shape[0]
    # --- Priors ---
    ## Level
    level_smoothing = numpyro.sample(
        "level_smoothing", dist.Beta(concentration1=1, concentration0=1)
    )
    level_init = numpyro.sample("level_init", dist.Normal(loc=0, scale=1))
    ## Noise
    noise = numpyro.sample("noise", dist.HalfNormal(scale=1))
    # --- Transition Function ---
    def transition_fn(carry, t):
        previous_level = carry
        level = jnp.where(
            t < t_max,
            level_smoothing * y[t] + (1 - level_smoothing) * previous_level,
            previous_level,
        )
        mu = previous_level
        pred = numpyro.sample("pred", dist.Normal(loc=mu, scale=noise))
        return level, pred
    # --- Run Scan ---
    with numpyro.handlers.condition(data={"pred": y}):
        _, preds = scan(transition_fn, level_init, jnp.arange(t_max + future))
    # --- Forecast ---
    if future > 0:
        numpyro.deterministic("y_forecast", preds[-future:])
```
We can use different inference methods (MCMC, SVI) and different samplers (NUTS, HMC, etc.).

# Hierarchical Models

- Vectorize the model and add hierarchies to the parameters. 
- Total flexibility to write custom models.


```{tip}
We provide a macro to vectorize the model and add hierarchies to the parameters.
```


![Hierarchical State Space Models](./images/hierarchical_forecasting.png)

# Censoring Likelihoods

![Censoring Data](./images/censored_data.png)

![Censored Data](./images/censored_forecast.png)

# Dynamic Models & Calibration

![Electricity Demand](./images/electricity_data.png)

![Calibrated Gaussian process dynamic latent variable](./images/calibration.png)