Skip to content

Commit

Permalink
Improve robustness and readability of code.
Browse files Browse the repository at this point in the history
This commit merges (and fixes) Github pull request #13 from bfgray3:

  #13

-------------
Created by MOE: https://github.com/google/moe
MOE_MIGRATED_REVID=185801093
  • Loading branch information
alhauser committed Mar 28, 2018
1 parent e81cdc6 commit d039d9e
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 16 deletions.
19 changes: 10 additions & 9 deletions R/impact_inference.R
Expand Up @@ -30,7 +30,7 @@ GetPosteriorStateSamples <- function(bsts.model) {
# discarding burn-in samples (=> 900 x 2 x 365)
burn <- SuggestBurn(0.1, bsts.model)
assert_that(burn > 0)
state.contributions <- bsts.model$state.contributions[-(1 : burn), , ,
state.contributions <- bsts.model$state.contributions[-seq_len(burn), , ,
drop = FALSE]

# Sum across states, call it 'state.samples' (=> 900 x 365)
Expand All @@ -57,7 +57,7 @@ ComputeResponseTrajectories <- function(bsts.model) {
# Get observation noise standard deviation samples
burn <- SuggestBurn(0.1, bsts.model)
assert_that(burn > 0)
sigma.obs <- bsts.model$sigma.obs[-(1 : burn)] # e.g., 900
sigma.obs <- bsts.model$sigma.obs[-seq_len(burn)] # e.g., 900

# Sample from the posterior predictive density over data
n.samples <- dim(state.samples)[1] # e.g., 900 x 365
Expand Down Expand Up @@ -135,9 +135,9 @@ ComputeCumulativePredictions <- function(y.samples, point.pred, y,
# post-period.

# Compute posterior mean
is.post.period <- (1 : length(y)) >= post.period.begin
cum.pred.mean.pre <- cumsum.na.rm(as.vector(y)[1 : (post.period.begin - 1)])
non.na.indices <- which(!is.na(cum.pred.mean.pre[1:(post.period.begin - 1)]))
is.post.period <- seq_along(y) >= post.period.begin
cum.pred.mean.pre <- cumsum.na.rm(as.vector(y)[!is.post.period])
non.na.indices <- which(!is.na(cum.pred.mean.pre))
assert_that(length(non.na.indices) > 0)
last.non.na.index <- max(non.na.indices)
cum.pred.mean.post <- cumsum(point.pred$point.pred[is.post.period]) +
Expand All @@ -146,7 +146,7 @@ ComputeCumulativePredictions <- function(y.samples, point.pred, y,

# Check for overflow
assert_that(identical(which(is.na(cum.pred.mean)),
which(is.na(y[1:(post.period.begin - 1)]))),
which(is.na(y[!is.post.period]))),
msg = "unexpected NA found in cum.pred.mean")

# Compute posterior interval
Expand Down Expand Up @@ -245,7 +245,8 @@ CompileSummaryTable <- function(y.post, y.samples.post,
prob.upper)),
AbsEffect.sd = c(sd(rowMeans(y.repmat.post - y.samples.post)),
sd(rowSums(y.repmat.post - y.samples.post))))
summary <- dplyr::mutate(summary, RelEffect = AbsEffect / Pred,
summary <- dplyr::mutate(summary,
RelEffect = AbsEffect / Pred,
RelEffect.lower = AbsEffect.lower / Pred,
RelEffect.upper = AbsEffect.upper / Pred,
RelEffect.sd = AbsEffect.sd / Pred)
Expand Down Expand Up @@ -418,7 +419,7 @@ AssertCumulativePredictionsAreConsistent <- function(cum.pred, post.period,
AssertCumulativePredictionIsConsistent <- function(cum.pred.col,
summary.entry,
description) {
non.na.indices <- which(!is.na(cum.pred.col[1:(post.period[1] - 1)]))
non.na.indices <- which(!is.na(cum.pred.col[seq_len(post.period[1] - 1)]))
assert_that(length(non.na.indices) > 0)
last.non.na.index <- max(non.na.indices)
assert_that(
Expand Down Expand Up @@ -611,7 +612,7 @@ CompileNaInferences <- function(y.model) {
"cum.pred", "cum.pred.lower", "cum.pred.upper",
"point.effect", "point.effect.lower", "point.effect.upper",
"cum.effect", "cum.effect.lower", "cum.effect.upper")
na.series <- matrix(as.numeric(NA), nrow = length(y.model), ncol = 12)
na.series <- matrix(NA_real_, nrow = length(y.model), ncol = 12)
na.series <- zoo(na.series, time(y.model))
names(na.series) <- vars

Expand Down
5 changes: 3 additions & 2 deletions R/impact_misc.R
Expand Up @@ -56,8 +56,9 @@ cumsum.na.rm <- function(x) {
if (is.null(x)) {
return(x)
}
s <- cumsum(ifelse(is.na(x), 0, x))
s[is.na(x)] <- NA
nas <- is.na(x)
s <- cumsum(ifelse(nas, 0, x))
s[nas] <- NA
return(s)
}

Expand Down
2 changes: 1 addition & 1 deletion R/impact_model.R
Expand Up @@ -48,7 +48,7 @@ ObservationsAreIllConditioned <- function(y) {
ill.conditioned <- TRUE

# Fewer than 3 non-NA values?
} else if (length(y[!is.na(y)]) < 3) {
} else if (sum(!is.na(y)) < 3) {
warning("Aborting inference due to fewer than 3 non-NA values in input")
ill.conditioned <- TRUE

Expand Down
8 changes: 4 additions & 4 deletions R/impact_plot.R
Expand Up @@ -36,18 +36,18 @@ CreateDataFrameForPlot <- function(impact) {

# Reshape data frame
tmp1 <- data[, c("time", "response", "point.pred", "point.pred.lower",
"point.pred.upper")]
"point.pred.upper"), drop = FALSE]
names(tmp1) <- c("time", "response", "mean", "lower", "upper")
tmp1$baseline <- NA
tmp1$metric <- "original"
tmp2 <- data[, c("time", "response", "point.effect", "point.effect.lower",
"point.effect.upper")]
"point.effect.upper"), drop = FALSE]
names(tmp2) <- c("time", "response", "mean", "lower", "upper")
tmp2$baseline <- 0
tmp2$metric <- "pointwise"
tmp2$response <- NA
tmp3 <- data[, c("time", "response", "cum.effect", "cum.effect.lower",
"cum.effect.upper")]
"cum.effect.upper"), drop = FALSE]
names(tmp3) <- c("time", "response", "mean", "lower", "upper")
tmp3$metric <- "cumulative"
tmp3$baseline <- 0
Expand Down Expand Up @@ -118,7 +118,7 @@ CreateImpactPlot <- function(impact, metrics = c("original", "pointwise",
# Select metrics to display (and their order)
assert_that(is.vector(metrics))
metrics <- match.arg(metrics, several.ok = TRUE)
data <- data[data$metric %in% metrics, ]
data <- data[data$metric %in% metrics, , drop = FALSE]
data$metric <- factor(data$metric, metrics)

# Initialize plot
Expand Down

0 comments on commit d039d9e

Please sign in to comment.