Skip to content

Commit

Permalink
Introduced dependency from assertthat >= 0.2.0 in order to use the ne…
Browse files Browse the repository at this point in the history
…w `msg` argument of the `assert_that` function.
  • Loading branch information
alhauser committed Sep 14, 2017
1 parent a7fd0a8 commit 4dfedbc
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 165 deletions.
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: CausalImpact
Title: Inferring Causal Effects using Bayesian Structural Time-Series Models
Date: 2017-05-31
Date: 2017-08-16
Author: Kay H. Brodersen <kbrodersen@google.com>,
Alain Hauser <alhauser@google.com>
Maintainer: Alain Hauser <alhauser@google.com>
Expand All @@ -10,9 +10,9 @@ Description: Implements a Bayesian approach to causal impact estimation in time
See the package documentation on GitHub
<https://google.github.io/CausalImpact/> to get started.
Copyright: Copyright (C) 2014-2017 Google, Inc.
Version: 1.2.1
Version: 1.2.2
VignetteBuilder: knitr
License: Apache License 2.0 | file LICENSE
Imports: assertthat, Boom, dplyr, ggplot2, zoo
Imports: assertthat (>= 0.2.0), Boom, dplyr, ggplot2, zoo
Depends: bsts (>= 0.7.0)
Suggests: knitr, testthat
41 changes: 21 additions & 20 deletions R/impact_analysis.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,16 +105,16 @@ FormatInputPrePostPeriod <- function(pre.period, post.period, data) {
assert_that(!is.null(post.period))
assert_that(length(pre.period) == 2, length(post.period) == 2)
assert_that(!anyNA(pre.period), !anyNA(post.period))
assert(isTRUE(all.equal(class(time(data)), class(pre.period))) ||
(is.numeric(time(data)) && is.numeric(pre.period)),
error = paste0("pre.period (", class(pre.period)[1], ") ",
"must have the same class as the time points in the ",
"data (", class(time(data))[1], ")"))
assert(isTRUE(all.equal(class(time(data)), class(post.period))) ||
(is.numeric(time(data)) && is.numeric(post.period)),
error = paste0("post.period (", class(post.period)[1], ") ",
"must have the same class as the time points in the ",
"data (", class(time(data))[1], ")"))
assert_that(isTRUE(all.equal(class(time(data)), class(pre.period))) ||
(is.numeric(time(data)) && is.numeric(pre.period)),
msg = paste0("pre.period (", class(pre.period)[1], ") ",
"must have the same class as the time points in ",
"the data (", class(time(data))[1], ")"))
assert_that(isTRUE(all.equal(class(time(data)), class(post.period))) ||
(is.numeric(time(data)) && is.numeric(post.period)),
msg = paste0("post.period (", class(post.period)[1], ") ",
"must have the same class as the time points in ",
"the data (", class(time(data))[1], ")"))
if (pre.period[1] < start(data)) {
warning(paste0("Setting pre.period[1] to start of data: ", start(data)))
}
Expand All @@ -128,8 +128,8 @@ FormatInputPrePostPeriod <- function(pre.period, post.period, data) {
period.indices <- list(
pre.period = GetPeriodIndices(pre.period, time(data)),
post.period = GetPeriodIndices(post.period, time(data)))
assert(diff(period.indices$pre.period) >= 2,
"pre.period must span at least 3 time points")
assert_that(diff(period.indices$pre.period) >= 2,
msg = "pre.period must span at least 3 time points")
assert_that(period.indices$post.period[1] > period.indices$pre.period[2])

return(period.indices)
Expand All @@ -154,12 +154,13 @@ FormatInputForCausalImpact <- function(data, pre.period, post.period,
# list of checked (and possibly reformatted) input arguments

# Check that a consistent set of variables has been provided
assert(xor(!is.null(data) && !is.null(pre.period) && !is.null(post.period) &&
is.null(bsts.model) && is.null(post.period.response),
is.null(data) && is.null(pre.period) && is.null(post.period) &&
!is.null(bsts.model) && !is.null(post.period.response)),
paste0("must either provide data, pre.period, post.period, model.args",
"; or bsts.model and post.period.response"))
assert_that(
xor(!is.null(data) && !is.null(pre.period) && !is.null(post.period) &&
is.null(bsts.model) && is.null(post.period.response),
is.null(data) && is.null(pre.period) && is.null(post.period) &&
!is.null(bsts.model) && !is.null(post.period.response)),
msg = paste0("must either provide data, pre.period, post.period, ",
"model.args; or bsts.model and post.period.response"))

# Check <data> and convert to zoo, with rows representing time points
if (!is.null(data)) {
Expand Down Expand Up @@ -505,8 +506,8 @@ PrintSummary <- function(impact, digits = 2L) {
assert_that(class(impact) == "CausalImpact")
summary <- impact$summary
alpha <- impact$model$alpha
assert(!is.null(alpha) && alpha > 0,
"invalid <alpha>; <impact> must be a CausalImpact object")
assert_that(!is.null(alpha) && alpha > 0,
msg = "invalid <alpha>; <impact> must be a CausalImpact object")

# Print title
cat("Posterior inference {CausalImpact}\n")
Expand Down
46 changes: 25 additions & 21 deletions R/impact_inference.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ ComputePointPredictions <- function(y.samples, state.samples, alpha = 0.05) {
# point.pred.upper: upper limit

# Expectation of data = expectation of state (because noise is centered)
assert(identical(dim(y.samples), dim(state.samples)),
"inconsistent y.samples, state.samples")
assert_that(identical(dim(y.samples), dim(state.samples)),
msg = "inconsistent y.samples, state.samples")
point.pred.mean <- colMeans(state.samples) # e.g., 365

# Quantiles of the data = Quantiles of (state + observation noise)
Expand Down Expand Up @@ -145,9 +145,9 @@ ComputeCumulativePredictions <- function(y.samples, point.pred, y,
cum.pred.mean <- c(cum.pred.mean.pre, cum.pred.mean.post)

# Check for overflow
assert(identical(which(is.na(cum.pred.mean)),
which(is.na(y[1:(post.period.begin - 1)]))),
"unexpected NA found in cum.pred.mean")
assert_that(identical(which(is.na(cum.pred.mean)),
which(is.na(y[1:(post.period.begin - 1)]))),
msg = "unexpected NA found in cum.pred.mean")

# Compute posterior interval
cum.pred.lower.pre <- cum.pred.mean.pre
Expand Down Expand Up @@ -207,13 +207,15 @@ CompileSummaryTable <- function(y.post, y.samples.post,
# data frame of post-period summary statistics

# Check input
assert(ncol(y.samples.post) == length(y.post), "inconsistent y.post")
assert(length(point.pred.mean.post) == length(y.post), "inconsistent y.post")
assert_that(ncol(y.samples.post) == length(y.post),
msg = "inconsistent y.post")
assert_that(length(point.pred.mean.post) == length(y.post),
msg = "inconsistent y.post")

# We will compare the matrix of predicted trajectories (e.g., 900 x 201)
# with a matrix of replicated observations (e.g., 900 x 201)
n.samples <- nrow(y.samples.post)
y.repmat.post <- repmat(y.post, n.samples, 1)
y.repmat.post <- matrix(y.post, nrow = nrow(y.samples.post),
ncol = length(y.post), byrow = TRUE)
assert_that(all(dim(y.repmat.post) == dim(y.samples.post)))

# Define quantiles
Expand Down Expand Up @@ -419,13 +421,14 @@ AssertCumulativePredictionsAreConsistent <- function(cum.pred, post.period,
non.na.indices <- which(!is.na(cum.pred.col[1:(post.period[1] - 1)]))
assert_that(length(non.na.indices) > 0)
last.non.na.index <- max(non.na.indices)
assert(is.numerically.equal(cum.pred.col[post.period[2]] -
cum.pred.col[last.non.na.index],
summary.entry[2]),
paste0("The calculated ", description, " of the cumulative effect ",
"is inconsistent with the previously calculated one. You ",
"might try to run CausalImpact on a shorter time series ",
"to avoid this problem."))
assert_that(
is.numerically.equal(cum.pred.col[post.period[2]] -
cum.pred.col[last.non.na.index],
summary.entry[2]),
msg = paste0("The calculated ", description, " of the cumulative ",
"effect is inconsistent with the previously calculated ",
"one. You might try to run CausalImpact on a shorter ",
"time series to avoid this problem."))
}

AssertCumulativePredictionIsConsistent(cum.pred$cum.pred, summary$Pred,
Expand Down Expand Up @@ -474,11 +477,12 @@ CheckInputForCompilePosteriorInferences <- function(bsts.model, y.cf,
y.cf <- as.vector(y.cf)
assert_that(is.numeric(y.cf))
assert_that(length(y.cf) >= 1)
assert(!anyNA(y.cf[(post.period[1] : post.period[2]) - cf.period.start + 1]),
"NA values in the post-period not currently supported")
assert(all(is.na(tail(bsts.model$original.series, length(y.cf)))),
paste0("bsts.model$original.series must end on a stretch of NA ",
"at least as long as y.cf"))
assert_that(!anyNA(y.cf[(post.period[1] : post.period[2]) -
cf.period.start + 1]),
msg = "NA values in the post-period not currently supported")
assert_that(all(is.na(tail(bsts.model$original.series, length(y.cf)))),
msg = paste0("bsts.model$original.series must end on a stretch ",
"of NA at least as long as y.cf"))

# Check <alpha>
assert_that(is.numeric(alpha))
Expand Down
65 changes: 6 additions & 59 deletions R/impact_misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,32 +18,6 @@
# gallusser@google.com (Fabian Gallusser)
# alhauser@google.com (Alain Hauser)

repmat <- function(X, m, n) {
# R equivalent of repmat (MATLAB). Replicates a given vector or matrix.
#
# Args:
# X: vector or matrix
# m: number of row replications
# n: number of column replications
#
# Returns:
# Matrix
#
# Examples:
# CausalImpact:::repmat(c(10, 20), 1, 2)
# # [,1] [,2] [,3] [,4]
# # [1,] 10 20 10 20

assert_that(is.vector(X) || is.matrix(X))
assert_that(is.count(m), is.count(n))
if (is.vector(X)) {
X <- t(as.matrix(X))
}
mx = dim(X)[1]
nx = dim(X)[2]
matrix(t(matrix(X, mx, nx * n)), mx * m, nx * n, byrow = TRUE)
}

is.wholenumber <- function(x, tolerance = .Machine$double.eps ^ 0.5) {
# Checks whether a number is a whole number. This is not the same as
# \code{is.integer()}, which tests the data type.
Expand Down Expand Up @@ -87,32 +61,6 @@ cumsum.na.rm <- function(x) {
return(s)
}

assert <- function(expr = TRUE, error = "") {
# Throws a custom error message if a condition is not fulfilled. This function
# is similar to `assertthat::assert_that()`. The main difference is that
# `assert()` allows for a custom error message, while the current CRAN
# version of the `assertthat` package (0.1) does not.
#
# Args:
# expr: expression that evaluates to a logical
# error: error message if expression does not evaluate to \code{TRUE}
#
# Returns:
# Returns quietly or fails with an error.
#
# Examples:
# x <- 1
# CausalImpact:::assert(x > 0)
# CausalImpact:::assert(x > 0, "input argument must be positive")
#
# Documentation:
# seealso: assert_that

if (!isTRUE(expr)) {
stop(error, call. = (error == ""))
}
}

is.numerically.equal <- function(x, y, tolerance = .Machine$double.eps ^ 0.5) {
# Tests whether two numbers are 'numerically equal' by checking whether their
# relative difference is smaller than a given tolerance. 'Relative difference'
Expand Down Expand Up @@ -205,9 +153,9 @@ ParseArguments <- function(args, defaults, allow.extra.args = FALSE) {
# Are extra args allowed?
if (!allow.extra.args) {
illegal.args <- setdiff(names(args), names(defaults))
assert(length(illegal.args) == 0,
paste0("illegal extra args: '",
paste(illegal.args, collapse = "', '"), "'"))
assert_that(length(illegal.args) == 0,
msg = paste0("illegal extra args: '",
paste(illegal.args, collapse = "', '"), "'"))
}

# Return
Expand Down Expand Up @@ -330,7 +278,8 @@ GetPeriodIndices <- function(period, times) {
indices <- seq_along(times)
is.period <- (period[1] <= times) & (times <= period[2])
# Make sure the period does match any time points.
assert(any(is.period), "The period must cover at least one data point")
assert_that(any(is.period),
msg = "The period must cover at least one data point")
period.indices <- range(indices[is.period])
return(period.indices)
}
Expand Down Expand Up @@ -435,10 +384,8 @@ PrettifyNumber <- function(x, letter = "", round.digits = 1L) {
return(sprintf("%0.*fM", round.digits, x / 1e6))
} else if ((letter == "" && abs(x) >= 1e3) || letter == "K") {
return(sprintf("%0.*fK", round.digits, x / 1e3))
} else if (abs(x) >= 1) {
} else if (abs(x) >= 1 || x == 0) {
return(sprintf("%0.*f", round.digits, x))
} else if (x == 0) {
return(sprintf("%0.*f", round.digits, x))
} else {
# Calculate position of first non-zero digit after the decimal point
first.nonzero <- - floor(log10(abs(x)))
Expand Down
12 changes: 7 additions & 5 deletions R/impact_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ FormatInputForConstructModel <- function(data, model.args) {

# Check covariates
if (ncol(data) >= 2) {
assert(all(!is.na(data[, -1])), "covariates must not be NA")
assert_that(all(!is.na(data[, -1])), msg = "covariates must not be NA")
}

# (Re-)parse <model.args>, fill gaps using <.defaults>
Expand All @@ -107,8 +107,9 @@ FormatInputForConstructModel <- function(data, model.args) {
assert_that(!is.na(model.args$niter))
assert_that(is.wholenumber(model.args$niter))
model.args$niter <- round(model.args$niter)
assert(model.args$niter >= 10,
"must draw, at the very least, 10 MCMC samples; recommending 1000")
assert_that(model.args$niter >= 10,
msg = paste0("must draw, at the very least, 10 MCMC samples; ",
"recommending 1000"))
if (model.args$niter < 1000) {
warning("Results potentially inaccurate. Consider using more MCMC samples.")
}
Expand All @@ -124,8 +125,9 @@ FormatInputForConstructModel <- function(data, model.args) {
assert_that(is.numeric(model.args$nseasons))
assert_that(!is.na(model.args$nseasons))
assert_that(is.wholenumber(model.args$nseasons))
assert(model.args$nseasons >= 1,
"nseasons cannot be 0; use 1 in order not to have seaonsal components")
assert_that(model.args$nseasons >= 1,
msg = paste0("nseasons cannot be 0; use 1 in order not to have ",
"seaonsal components"))

# Check <season.duration>
assert_that(is.scalar(model.args$season.duration))
Expand Down
4 changes: 2 additions & 2 deletions R/impact_plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ CreateDataFrameForPlot <- function(impact) {

# Check input
assert_that((class(impact) == "CausalImpact"))
assert(!isTRUE(all(is.na(impact$series[, -c(1, 2)]))),
"inference was aborted; cannot create plot")
assert_that(!isTRUE(all(is.na(impact$series[, -c(1, 2)]))),
msg = "inference was aborted; cannot create plot")

# Create data frame from zoo series
data <- as.data.frame(impact$series)
Expand Down
10 changes: 5 additions & 5 deletions man/CausalImpact.Rd
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,11 @@
\item \code{niter}. Number of MCMC samples to draw. Higher numbers
yield more accurate inferences. Defaults to 1000.

\item \code{standardize.data}. Whether to standardize all columns of
the data over the pre-intervention period before fitting the model. This is
equivalent to an empirical Bayes approach to setting the priors. It ensures
that results are invariant to linear transformations of the data. Defaults
to \code{TRUE}.
\item \code{standardize.data}. Whether to standardize all columns of the
data using moments estimated from the pre-intervention period before fitting
the model. This is equivalent to an empirical Bayes approach to setting the
priors. It ensures that results are invariant to linear transformations of
the data. Defaults to \code{TRUE}.

\item \code{prior.level.sd}. Prior standard deviation of the Gaussian random
walk of the local level, expressed in terms of data standard deviations.
Expand Down
Loading

0 comments on commit 4dfedbc

Please sign in to comment.