Skip to content

Commit

Permalink
Add mdape performance metric to R (#1472)
Browse files Browse the repository at this point in the history
* add test and initial function for mdape in R

* Add separate rolling_median_func and tests

* Modify rolling median function

* fix syntax in rolling median function

* sort by h

* R/diagnostics.R

* update .rd docs and notebook

* Add mdape to performance metrics params docstring
  • Loading branch information
ryankarlos committed May 20, 2020
1 parent 16e632a commit f16d9df
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 34 deletions.
91 changes: 86 additions & 5 deletions R/R/diagnostics.R
Expand Up @@ -216,10 +216,11 @@ prophet_copy <- function(m, cutoff = NULL) {
#'
#' Computes a suite of performance metrics on the output of cross-validation.
#' By default the following metrics are included:
#' 'mse': mean squared error
#' 'rmse': root mean squared error
#' 'mae': mean absolute error
#' 'mape': mean percent error
#' 'mse': mean squared error,
#' 'rmse': root mean squared error,
#' 'mae': mean absolute error,
#' 'mape': mean percent error,
#' 'mdape': median percent error,
#' 'coverage': coverage of the upper and lower intervals
#'
#' A subset of these can be specified by passing a list of names as the
Expand All @@ -244,7 +245,7 @@ prophet_copy <- function(m, cutoff = NULL) {
#'
#' @param df The dataframe returned by cross_validation.
#' @param metrics An array of performance metrics to compute. If not provided,
#' will use c('mse', 'rmse', 'mae', 'mape', 'coverage').
#' will use c('mse', 'rmse', 'mae', 'mape', 'mdape', 'coverage').
#' @param rolling_window Proportion of data to use in each rolling window for
#' computing the metrics. Should be in [0, 1] to average.
#'
Expand Down Expand Up @@ -275,6 +276,10 @@ performance_metrics <- function(df, metrics = NULL, rolling_window = 0.1) {
message('Skipping MAPE because y close to 0')
metrics <- metrics[metrics != 'mape']
}
if (('mdape' %in% metrics) & (min(abs(df_m$y)) < 1e-8)) {
message('Skipping MDAPE because y close to 0')
metrics <- metrics[metrics != 'mdape']
}
if (length(metrics) == 0) {
return(NULL)
}
Expand Down Expand Up @@ -351,6 +356,64 @@ rolling_mean_by_h <- function(x, h, w, name) {
return(res)
}


#' Compute a rolling median of x, after first aggregating by h
#'
#' Right-aligned. Computes a single median for each unique value of h. Each median
#' is over at least w samples.
#'
#' For each h where there are fewer than w samples, we take samples from the previous h,
# moving backwards. (In other words, we ~ assume that the x's are shuffled within each h.)
#'
#' @param x Array.
#' @param h Array of horizon for each value in x.
#' @param w Integer window size (number of elements).
#' @param name String name for metric in result dataframe.
#'
#' @return Dataframe with columns horizon and name, the rolling median of x.
#'
#' @importFrom dplyr "%>%"
rolling_median_by_h <- function(x, h, w, name) {
# Aggregate over h
df <- data.frame(x=x, h=h)
grouped <- df %>% dplyr::group_by(h)
df2 <- grouped %>%
dplyr::summarise(size=dplyr::n()) %>%
dplyr::arrange(h) %>%
dplyr::select(h, size)

hs <- df2$h
res <- data.frame(horizon=c())
res[[name]] <- c()

# Start from the right and work backwards
i <- length(hs)
while (i > 0) {
h_i <- hs[i]
xs <- grouped %>%
dplyr::filter(h==h_i)
xs <- xs$x

next_idx_to_add = which.max(h==h_i) - 1

while ((length(xs) < w) & (next_idx_to_add > 0)) {
# Include points from the previous horizon. All of them if still less
# than w, otherwise just enough to get to w.
xs <- c(x[next_idx_to_add], xs)
next_idx_to_add = next_idx_to_add - 1
}
if (length(xs) < w) {
# Ran out of horizons before enough points.
break
}
res.i <- data.frame(horizon=hs[i])
res.i[[name]] <- median(xs)
res <- rbind(res.i, res)
i <- i - 1
}
return(res)
}

# The functions below specify performance metrics for cross-validation results.
# Each takes as input the output of cross_validation, and returns the statistic
# as a dataframe, given a window size for rolling aggregation.
Expand Down Expand Up @@ -418,6 +481,24 @@ mape <- function(df, w) {
return(rolling_mean_by_h(x = ape, h = df$horizon, w = w, name = 'mape'))
}


#' Median absolute percent error
#'
#' @param df Cross-validation results dataframe.
#' @param w Aggregation window size.
#'
#' @return Array of median absolute percent errors.
#'
#' @keywords internal
mdape <- function(df, w) {
ape <- abs((df$y - df$yhat) / df$y)
if (w < 0) {
return(data.frame(horizon = df$horizon, mdape = ape))
}
return(rolling_median_by_h(x = ape, h = df$horizon, w = w, name = 'mdape'))
}


#' Coverage
#'
#' @param df Cross-validation results dataframe.
Expand Down
20 changes: 20 additions & 0 deletions R/man/mdape.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 6 additions & 5 deletions R/man/performance_metrics.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 27 additions & 0 deletions R/man/rolling_median_by_h.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 28 additions & 1 deletion R/tests/testthat/test_diagnostics.R
Expand Up @@ -150,7 +150,7 @@ test_that("performance_metrics", {
expect_true(all(
sort(colnames(df_horizon)) == sort(c('coverage', 'mse', 'horizon'))
))
# Skip MAPE
# Skip MAPE and MDAPE
df_cv$y[1] <- 0.
df_horizon <- performance_metrics(df_cv, metrics = c('coverage', 'mape'))
expect_true(all(
Expand Down Expand Up @@ -189,6 +189,33 @@ test_that("rolling_mean", {
expect_equal(c(4.5), df$x)
})


test_that("rolling_median", {
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
x <- 0:9
h <- 0:9
df <- prophet:::rolling_median_by_h(x=x, h=h, w=1, name='x')
expect_equal(x, df$x)
expect_equal(h, df$horizon)

df <- prophet:::rolling_median_by_h(x=x, h=h, w=4, name='x')
x.true <- x[4:10] - 1.5
expect_equal(x.true, df$x)
expect_equal(3:9, df$horizon)

h <- c(1., 2., 3., 4., 4., 4., 4., 4., 7., 7.)
x.true <- c(1., 5., 8.)
h.true <- c(3., 4., 7.)
df <- prophet:::rolling_median_by_h(x=x, h=h, w=3, name='x')
expect_equal(x.true, df$x)
expect_equal(h.true, df$horizon)

df <- prophet:::rolling_median_by_h(x=x, h=h, w=10, name='x')
expect_equal(c(7.), df$horizon)
expect_equal(c(4.5), df$x)
})


test_that("copy", {
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
df <- DATA_all
Expand Down
67 changes: 44 additions & 23 deletions notebooks/diagnostics.ipynb
Expand Up @@ -166,9 +166,30 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4a9ebee9abb44b3a97eb4df74beb6346",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"text/html": [
Expand Down Expand Up @@ -202,45 +223,45 @@
" <tr>\n",
" <th>0</th>\n",
" <td>2010-02-16</td>\n",
" <td>8.956572</td>\n",
" <td>8.460049</td>\n",
" <td>9.460400</td>\n",
" <td>8.957284</td>\n",
" <td>8.480761</td>\n",
" <td>9.415366</td>\n",
" <td>8.242493</td>\n",
" <td>2010-02-15</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2010-02-17</td>\n",
" <td>8.723004</td>\n",
" <td>8.200557</td>\n",
" <td>9.236561</td>\n",
" <td>8.723736</td>\n",
" <td>8.206191</td>\n",
" <td>9.234075</td>\n",
" <td>8.008033</td>\n",
" <td>2010-02-15</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2010-02-18</td>\n",
" <td>8.606823</td>\n",
" <td>8.070835</td>\n",
" <td>9.123754</td>\n",
" <td>8.607496</td>\n",
" <td>8.112153</td>\n",
" <td>9.092314</td>\n",
" <td>8.045268</td>\n",
" <td>2010-02-15</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>2010-02-19</td>\n",
" <td>8.528688</td>\n",
" <td>8.034782</td>\n",
" <td>9.042712</td>\n",
" <td>8.529364</td>\n",
" <td>8.017767</td>\n",
" <td>9.013877</td>\n",
" <td>7.928766</td>\n",
" <td>2010-02-15</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>2010-02-20</td>\n",
" <td>8.270706</td>\n",
" <td>7.754891</td>\n",
" <td>8.739012</td>\n",
" <td>8.271329</td>\n",
" <td>7.751250</td>\n",
" <td>8.775341</td>\n",
" <td>7.745003</td>\n",
" <td>2010-02-15</td>\n",
" </tr>\n",
Expand All @@ -250,14 +271,14 @@
],
"text/plain": [
" ds yhat yhat_lower yhat_upper y cutoff\n",
"0 2010-02-16 8.956572 8.460049 9.460400 8.242493 2010-02-15\n",
"1 2010-02-17 8.723004 8.200557 9.236561 8.008033 2010-02-15\n",
"2 2010-02-18 8.606823 8.070835 9.123754 8.045268 2010-02-15\n",
"3 2010-02-19 8.528688 8.034782 9.042712 7.928766 2010-02-15\n",
"4 2010-02-20 8.270706 7.754891 8.739012 7.745003 2010-02-15"
"0 2010-02-16 8.957284 8.480761 9.415366 8.242493 2010-02-15\n",
"1 2010-02-17 8.723736 8.206191 9.234075 8.008033 2010-02-15\n",
"2 2010-02-18 8.607496 8.112153 9.092314 8.045268 2010-02-15\n",
"3 2010-02-19 8.529364 8.017767 9.013877 7.928766 2010-02-15\n",
"4 2010-02-20 8.271329 7.751250 8.775341 7.745003 2010-02-15"
]
},
"execution_count": 5,
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -313,7 +334,7 @@
" parallel=\"dask\")\n",
"```\n",
"\n",
"The `performance_metrics` utility can be used to compute some useful statistics of the prediction performance (`yhat`, `yhat_lower`, and `yhat_upper` compared to `y`), as a function of the distance from the cutoff (how far into the future the prediction was). The statistics computed are mean squared error (MSE), root mean squared error (RMSE), mean absolute error (MAE), mean absolute percent error (MAPE), and coverage of the `yhat_lower` and `yhat_upper` estimates. These are computed on a rolling window of the predictions in `df_cv` after sorting by horizon (`ds` minus `cutoff`). By default 10% of the predictions will be included in each window, but this can be changed with the `rolling_window` argument."
"The `performance_metrics` utility can be used to compute some useful statistics of the prediction performance (`yhat`, `yhat_lower`, and `yhat_upper` compared to `y`), as a function of the distance from the cutoff (how far into the future the prediction was). The statistics computed are mean squared error (MSE), root mean squared error (RMSE), mean absolute error (MAE), mean absolute percent error (MAPE), median absolute percent error (MDAPE) and coverage of the `yhat_lower` and `yhat_upper` estimates. These are computed on a rolling window of the predictions in `df_cv` after sorting by horizon (`ds` minus `cutoff`). By default 10% of the predictions will be included in each window, but this can be changed with the `rolling_window` argument."
]
},
{
Expand Down

0 comments on commit f16d9df

Please sign in to comment.