Skip to content

Commit

Permalink
Merge pull request #15 from Infectious-Disease-Modeling-Hubs/vignette
Browse files Browse the repository at this point in the history
Vignette
  • Loading branch information
lshandross committed Oct 19, 2023
2 parents a0c9c7d + 5a46842 commit bbe056b
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 0 deletions.
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ Description: Functions for combining model outputs (e.g. predictions or
estimates) from multiple models into an aggregated ensemble model output.
License: MIT + file LICENSE
Suggests:
knitr,
rmarkdown,
testthat (>= 3.0.0)
Config/testthat/edition: 3
Encoding: UTF-8
Expand Down
2 changes: 2 additions & 0 deletions vignettes/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.html
*.R
163 changes: 163 additions & 0 deletions vignettes/hubEnsembles.Rmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
---
title: "The `hubEnsembles` package"
output: rmarkdown::html_vignette
vignette: >
%\VignetteIndexEntry{The `hubEnsembles` package}
%\VignetteEngine{knitr::rmarkdown}
%\VignetteEncoding{UTF-8}
---

```{r, include = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>",
fig.width = 7,
fig.height = 4
)
```

# Introduction

The `hubEnsembles` package includes functionality for aggregating model outputs, such as forecasts or projections, that are submitted to a hub by multiple models and combined into ensemble model outputs. The package includes two main functions: `simple_ensemble` and `linear_pool`. We illustrate these functions in this vignette, and briefly compare them.

This vignette uses the following R packages:

```{r setup}
library(dplyr)
library(plotly)
library(hubUtils)
library(hubEnsembles)
```

# Example data: a simple forecast hub

The `example-simple-forecast-hub` has been created by the Consortium of Infectious Disease Modeling Hubs as a simple example hub to demonstrate the set up and functionality for the hubverse. The hub includes both target data and example model output data.

```{r}
hub_path <- system.file("example-data/example-simple-forecast-hub",
package = "hubEnsembles")
model_outputs <- hubUtils::connect_hub(hub_path) %>%
dplyr::collect()
head(model_outputs)
target_data_path <- file.path(hub_path, "target-data",
"covid-hospitalizations.csv")
target_data <- read.csv(target_data_path)
head(target_data)
```

# Creating ensembles with `simple_ensemble`

The `simple_ensemble` function is used to summarize across component model outputs; this function can be applied to predictions with an `output_type` of `mean`, `median`, `quantile`, `cdf`, or `pmf`.

The `simple_ensemble` function defaults to calculating an equally weighted mean across all component model outputs for each unique `output_type_id`. For our example data, which contains two output types (`median` and `quantile`), this means the resulting ensemble will be the mean of component model submitted values for each quantile.

```{r}
mean_ens <- hubEnsembles::simple_ensemble(model_outputs)
head(mean_ens)
```

## Changing the aggregation function

We can change the function used to aggregate across model outputs. For example, we may want to calculate a median of component model submitted values for each quantile. We will also use the `model_id` argument to distinguish this ensemble.

```{r}
median_ens <- hubEnsembles::simple_ensemble(model_outputs,
agg_fun = median,
model_id = "hub-ensemble-median")
head(median_ens)
```

Custom functions can also be passed into the `agg_fun` argument. For example, a geometric mean may be a more appropriate way to combine component model outputs. Any custom function to be used requires an argument `x` for the vector of numeric values to summarize, and if relevant, an argument `w` of numeric weights.

```{r}
geometric_mean <- function(x){
n <- length(x)
return(prod(x)^(1/n))
}
geometric_mean_ens <- hubEnsembles::simple_ensemble(model_outputs,
agg_fun = geometric_mean,
model_id = "hub-ensemble-geometric")
head(geometric_mean_ens)
```

## Weighting model contributions

In addition, we can weight the contributions of each model by providing a table of weights, which are provided in a `data.frame` with a `model_id` column and a `weight` column.

```{r}
model_weights <- data.frame(model_id = c("UMass-ar", "UMass-gbq", "simple_hub-baseline"),
weight = c(0.4, 0.4, 0.2))
weighted_mean_ens <- hubEnsembles::simple_ensemble(model_outputs,
weights = model_weights,
model_id = "hub-ensemble-weighted-mean")
head(weighted_mean_ens)
```

# Creating ensembles with `linear_pool`

An alternative approach to generate an ensemble is a linear pool, or a distributional mixture; this function can be applied to predictions with an `output_type` of `mean`, `quantile`, `cdf`, or `pmf`. Our example hub includes `median` output type, so we exclude it from the calculation.

For `mean`, `cdf` and `pmf` output types, the linear pool is equivalent to using a mean `simple_ensemble`. For `quantile` model outputs, the `linear_pool` function needs to approximate a full probability distribution using the value-quantile pairs from each component model. As a default, this is done with functions in the `distfromq` package, which defaults to fitting a monotonic cubic spline.

```{r}
linear_pool_ens <- hubEnsembles::linear_pool(model_outputs %>%
filter(output_type != "median"))
head(linear_pool_ens)
```

# Plots

```{r}
basic_plot_function <- function(plot_df, truth_df, plain_line = 0.5, ribbon = c(0.975, 0.025),
forecast_date) {
plain_df <- dplyr::filter(plot_df, output_type_id == plain_line)
ribbon_df <- dplyr::filter(plot_df, output_type_id %in% ribbon) %>%
dplyr::mutate(output_type_id = ifelse(output_type_id == min(ribbon),
"min", "max")) %>%
tidyr::pivot_wider(names_from = output_type_id, values_from = value)
plot_model <- plot_ly(height = 600, colors = scales::hue_pal()(50))
if (!is.null(truth_df)) {
plot_model <- plot_model %>%
add_trace(data = truth_df, x = ~time_idx, y = ~value, type = "scatter",
mode = "lines+markers", line = list(color = "#6e6e6e"),
hoverinfo = "text", name = "ground truth",
hovertext = paste("Date: ", truth_df$time_value, "<br>",
"Ground truth: ",
format(truth_df$value, big.mark = ","),
sep = ""),
marker = list(color = "#6e6e6e", size = 7))
}
plot_model <- plot_model %>%
add_lines(data = plain_df, x = ~target_date, y = ~value,
color = ~model_id) %>%
add_ribbons(data = ribbon_df, x = ~target_date, ymin = ~min,
ymax = ~max, color = ~model_id, opacity = 0.25,
line = list(width = 0), showlegend = FALSE) %>%
plotly::layout(shapes = list(type = "line", y0 = 0, y1 = 1, yref = "paper",
x0 = forecast_date, x1 = forecast_date,
line = list(color = "gray")))
}
```

```{r}
plot_df <- dplyr::bind_rows(model_outputs, mean_ens) %>%
dplyr::filter(location == "US", origin_date == "2022-12-12") %>%
dplyr::mutate(target_date = origin_date + horizon)
plot <- basic_plot_function(
plot_df,
truth_df = target_data %>%
dplyr::filter(location == "US",
time_idx >= "2022-10-01",
time_idx <= "2023-03-01"),
forecast_date = "2022-12-12")
plot
```

0 comments on commit bbe056b

Please sign in to comment.