Permalink
Switch branches/tags
Nothing to show
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
247 lines (211 sloc) 7.51 KB
# Check transitions that ended with a divergence
check_div <- function(fit, quiet=FALSE) {
sampler_params <- get_sampler_params(fit, inc_warmup=FALSE)
divergent <- do.call(rbind, sampler_params)[,'divergent__']
n = sum(divergent)
N = length(divergent)
if (!quiet) print(sprintf('%s of %s iterations ended with a divergence (%s%%)',
n, N, 100 * n / N))
if (n > 0) {
if (!quiet) print(' Try running with larger adapt_delta to remove the divergences')
if (quiet) return(FALSE)
} else {
if (quiet) return(TRUE)
}
}
# Check transitions that ended prematurely due to maximum tree depth limit
check_treedepth <- function(fit, max_depth = 10, quiet=FALSE) {
sampler_params <- get_sampler_params(fit, inc_warmup=FALSE)
treedepths <- do.call(rbind, sampler_params)[,'treedepth__']
n = length(treedepths[sapply(treedepths, function(x) x == max_depth)])
N = length(treedepths)
if (!quiet)
print(sprintf('%s of %s iterations saturated the maximum tree depth of %s (%s%%)',
n, N, max_depth, 100 * n / N))
if (n > 0) {
if (!quiet) print(' Run again with max_depth set to a larger value to avoid saturation')
if (quiet) return(FALSE)
} else {
if (quiet) return(TRUE)
}
}
# Checks the energy fraction of missing information (E-FMI)
check_energy <- function(fit, quiet=FALSE) {
sampler_params <- get_sampler_params(fit, inc_warmup=FALSE)
no_warning <- TRUE
for (n in 1:length(sampler_params)) {
energies = sampler_params[n][[1]][,'energy__']
numer = sum(diff(energies)**2) / length(energies)
denom = var(energies)
if (numer / denom < 0.2) {
if (!quiet) print(sprintf('Chain %s: E-FMI = %s', n, numer / denom))
no_warning <- FALSE
}
}
if (no_warning) {
if (!quiet) print('E-FMI indicated no pathological behavior')
if (quiet) return(TRUE)
} else {
if (!quiet) print(' E-FMI below 0.2 indicates you may need to reparameterize your model')
if (quiet) return(FALSE)
}
}
# Checks the effective sample size per iteration
check_n_eff <- function(fit, quiet=FALSE) {
fit_summary <- summary(fit, probs = c(0.5))$summary
N <- dim(fit_summary)[[1]]
iter <- dim(extract(fit)[[1]])[[1]]
no_warning <- TRUE
for (n in 1:N) {
ratio <- fit_summary[,5][n] / iter
if (ratio < 0.001) {
if (!quiet) print(sprintf('n_eff / iter for parameter %s is %s!',
rownames(fit_summary)[n], ratio))
no_warning <- FALSE
}
}
if (no_warning) {
if (!quiet) print('n_eff / iter looks reasonable for all parameters')
if (quiet) return(TRUE)
}
else {
if (!quiet) print(' n_eff / iter below 0.001 indicates that the effective sample size has likely been overestimated')
if (quiet) return(FALSE)
}
}
# Checks the potential scale reduction factors
check_rhat <- function(fit, quiet=FALSE) {
fit_summary <- summary(fit, probs = c(0.5))$summary
N <- dim(fit_summary)[[1]]
no_warning <- TRUE
for (n in 1:N) {
rhat <- fit_summary[,6][n]
if (rhat > 1.1 || is.infinite(rhat) || is.nan(rhat)) {
if (!quiet) print(sprintf('Rhat for parameter %s is %s!',
rownames(fit_summary)[n], rhat))
no_warning <- FALSE
}
}
if (no_warning) {
if (!quiet) print('Rhat looks reasonable for all parameters')
if (quiet) return(TRUE)
} else {
if (!quiet) print(' Rhat above 1.1 indicates that the chains very likely have not mixed')
if (quiet) return(FALSE)
}
}
# Checks tail k_hat of each parameter
khat <- function(x) {
N <- length(x)
x <- sort(x)
if (x[1] <= 0) {
print("x must be positive!")
return (0)
}
q <- x[floor(0.25 * N + 0.5)]
M <- 20 + floor(sqrt(N))
b_hat_vec <- rep(0, M)
log_w_vec <- rep(0, M)
for (m in 1:M) {
b_hat_vec[m] <- 1 / x[N] + (1 - sqrt(M / (m - 0.5))) / (3 * q)
k_hat <- - mean( log (1 - b_hat_vec[m] * x) )
log_w_vec[m] <- length(x) * ( log(b_hat_vec[m] / k_hat) + k_hat - 1)
}
max_log_w <- max(log_w_vec)
b_hat <- sum(b_hat_vec * exp(log_w_vec - max_log_w)) / sum(exp(log_w_vec - max_log_w))
mean( log (1 - b_hat * x) )
}
tail_khats <- function(x) {
x_center <- median(x)
x_left <- abs(x[x <= x_center] - x_center)
x_right <- x[x > x_center] - x_center
if (x_center == min(x) | x_center == max(x))
return (c(0, 0))
else
return (c(khat(x_left), khat(x_right)))
}
check_tail_khat <- function(fit, quiet=FALSE) {
params <- extract(fit_nom, permuted=FALSE)
N <- dim(params)[3]
param_names <- names(as.data.frame(fit_nom))
no_warning <- TRUE
for (n in 1:N) {
x <- params[,,n]
dim(x) <- NULL
khats <- tail_khats(x)
if (khats[1] > 0.5 & khats[2] > 0.5) {
if (!quiet) print(sprintf('Tail khats for parameter %s are %s and %s!',
param_names[n], khats[1], khats[1]))
no_warning <- FALSE
} else if (khats[1] > 0.5) {
if (!quiet) print(sprintf('Left tail khat for parameter %s is %s!',
param_names[n], khats[1]))
no_warning <- FALSE
} else if (khats[2] > 0.5) {
if (!quiet) print(sprintf('Right tail khat for parameter %s is %s!',
param_names[n], khats[2]))
no_warning <- FALSE
}
}
if (no_warning) {
if (!quiet) print('Tail khats looks reasonable for all parameters')
if (quiet) return(TRUE)
}
else {
if (!quiet) print(' Tail khat above 0.5 indicates that the parameter is probably not square integrable')
if (quiet) return(FALSE)
}
}
# Check all diagnostics
check_all_diagnostics <- function(fit, max_depth = 10, quiet=FALSE) {
if (!quiet) {
check_n_eff(fit)
check_rhat(fit)
check_tail_khat(fit)
check_div(fit)
check_treedepth(fit, max_depth)
check_energy(fit)
} else {
warning_code <- 0
if (!check_n_eff(fit, quiet=TRUE))
warning_code <- bitwOr(warning_code, bitwShiftL(1, 0))
if (!check_rhat(fit, quiet=TRUE))
warning_code <- bitwOr(warning_code, bitwShiftL(1, 1))
if (!check_khat(fit, quiet=TRUE))
warning_code <- bitwOr(warning_code, bitwShiftL(1, 2))
if (!check_div(fit, quiet=TRUE))
warning_code <- bitwOr(warning_code, bitwShiftL(1, 3))
if (!check_treedepth(fit, max_depth, quiet=TRUE))
warning_code <- bitwOr(warning_code, bitwShiftL(1, 4))
if (!check_energy(fit, quiet=TRUE))
warning_code <- bitwOr(warning_code, bitwShiftL(1, 5))
return(warning_code)
}
}
# Parse warning code encoding returned by check_all_diagnostics
parse_warning_code <- function(warning_code) {
if (bitwAnd(warning_code, bitwShiftL(1, 0)))
print("n_eff / iteration warning")
if (bitwAnd(warning_code, bitwShiftL(1, 1)))
print("rhat warning")
if (bitwAnd(warning_code, bitwShiftL(1, 2)))
print("khat warning")
if (bitwAnd(warning_code, bitwShiftL(1, 3)))
print("divergence warning")
if (bitwAnd(warning_code, bitwShiftL(1, 4)))
print("treedepth warning")
if (bitwAnd(warning_code, bitwShiftL(1, 5)))
print("energy warning")
}
# Returns parameter arrays separated into divergent and non-divergent transitions
partition_div <- function(fit) {
nom_params <- extract(fit, permuted=FALSE)
n_chains <- dim(nom_params)[2]
params <- as.data.frame(do.call(rbind, lapply(1:n_chains, function(n) nom_params[,n,])))
sampler_params <- get_sampler_params(fit, inc_warmup=FALSE)
divergent <- do.call(rbind, sampler_params)[,'divergent__']
params$divergent <- divergent
div_params <- params[params$divergent == 1,]
nondiv_params <- params[params$divergent == 0,]
return(list(div_params, nondiv_params))
}