-
Notifications
You must be signed in to change notification settings - Fork 0
/
2018-02-02-stan-vs-inla.Rmd
411 lines (323 loc) · 16.5 KB
/
2018-02-02-stan-vs-inla.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
---
title: "A Gentle Stan vs. INLA Comparison"
date: 2018-02-02
output: html_notebook
tags: ["R","Stan","INLA"]
---
Not long ago, I came across a nice blogpost by Kahtryn Morrison called [A gentle INLA tutorial](https://www.precision-analytics.ca/blog-1/inla). The blog was nice and helped me better appreciate INLA. But as a fan of the Stan probabilistic language, I felt that comparing INLA to JAGS is not really that relevant, as Stan should - at least in theory - be way faster and better than JAGS. Here, I ran a comparison of INLA to Stan on the second example called "Poisson GLM with an iid random effect".
**The TLDR is:** For this model, Stan scales considerably better than JAGS, but still cannot scale to very large model. Also, for this model Stan and INLA give almost the same results. It seems that Stan becomes useful only when your model cannot be coded in INLA.
Pleas let me know (via an [issue on GitHub](https://github.com/martinmodrak/blog/issues)) should you find any error or anything else that should be included in this post. Also, if you run the experiment on a different machine and/or with different seed, let me know the results.
Here are the original numbers from Kathryn's blog:
```{r echo=FALSE,results='asis', warning = FALSE}
library(knitr)
kathryn_results = data.frame(N = c(100, 500, 5000, 25000, 100000), kathryn_rjags = c("30.394","142.532","1714.468","8610.32","got bored after 6 hours"), kathryn_rinla = c(0.383,1.243,5.768,30.077,166.819))
kable(kathryn_results)
```
*Full source of this post is available at [this blog's Github repo](https://github.com/martinmodrak/blog/blob/master/content/post/2018-02-02-stan-vs-inla.Rmd). Keep in mind that installing RStan is unfortunately not as straightforward as running install.packages. Please consult https://github.com/stan-dev/rstan/wiki/RStan-Getting-Started if you don't have RStan already installed.*
## The model
The model we are interested in is a simple GLM with partial pooling of a random effect:
```
y_i ~ poisson(mu_i)
log(mu_i) ~ alpha + beta * x_i + nu_i
nu_i ~ normal(0, tau_nu)
```
## The comparison
Let's setup our libraries.
```{r setup,message=FALSE,warning=FALSE}
library(rstan)
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())
library(INLA)
library(tidyverse)
set.seed(6619414)
```
The results are stored in files within the repository to let me rebuild the site with blogdown easily. Delete cache directory to force a complete rerun.
```{r}
cache_dir = "_stan_vs_inla_cache/"
if(!dir.exists(cache_dir)){
dir.create(cache_dir)
}
```
Let's start by simulating data
```{r}
#The sizes of datasets to work with
N_values = c(100, 500, 5000, 25000)
data = list()
for(N in N_values) {
x = rnorm(N, mean=5,sd=1)
nu = rnorm(N,0,0.1)
mu = exp(1 + 0.5*x + nu)
y = rpois(N,mu)
data[[N]] = list(
N = N,
x = x,
y = y
)
}
```
### Measuring Stan
Here is the model code in Stan (it is good practice to put it into a file, but I wanted to make this post self-contained). It is almost 1-1 rewrite of the original JAGS code, with few important changes:
* JAGS parametrizes normal distribution via precision, Stan via sd. The model recomputes precision to sd.
* I added the ability to explicitly set parameters of the prior distributions as data which is useful later in this post
* With multilevel models, Stan works waaaaaay better with so-called non-centered parametrization. This means that instead of having ```nu ~ N(0, nu_sigma), mu = alpha + beta * x + nu``` we have ```nu_normalized ~ N(0,1), mu = alpha + beta * x + nu_normalized * nu_sigma```. This gives exactly the same inferences, but results in a geometry that Stan can explore efficiently.
There are also packages to let you specify common models (including this one) without writing Stan code, using syntax similar to R-INLA - checkout [rstanarm](http://mc-stan.org/users/interfaces/rstanarm) and [brms](https://cran.r-project.org/web/packages/brms/index.html). The latter is more flexible, while the former is easier to install, as it does not depend on rstan and can be installed simply with ```install.packages```.
Note also that Stan developers would suggest against Gamma(0.01,0.01) prior on precision in favor of normal or Cauchy distribution on sd, see https://github.com/stan-dev/stan/wiki/Prior-Choice-Recommendations.
```{r model, cache=TRUE, messages=FALSE,warning=FALSE,results='hide'}
model_code = "
data {
int N;
vector[N] x;
int y[N];
//Allowing to parametrize the priors (useful later)
real alpha_prior_mean;
real beta_prior_mean;
real<lower=0> alpha_beta_prior_precision;
real<lower=0> tau_nu_prior_shape;
real<lower=0> tau_nu_prior_rate;
}
transformed data {
//Stan parametrizes normal with sd not precision
real alpha_beta_prior_sigma = sqrt(1 / alpha_beta_prior_precision);
}
parameters {
real alpha;
real beta;
vector[N] nu_normalized;
real<lower=0> tau_nu;
}
model {
real nu_sigma = sqrt(1 / tau_nu);
vector[N] nu = nu_normalized * nu_sigma;
//taking advantage of Stan's implicit vectorization here
nu_normalized ~ normal(0,1);
//The built-in poisson_log(x) === poisson(exp(x))
y ~ poisson_log(alpha + beta*x + nu);
alpha ~ normal(alpha_prior_mean, alpha_beta_prior_sigma);
beta ~ normal(beta_prior_mean, alpha_beta_prior_sigma);
tau_nu ~ gamma(tau_nu_prior_shape,tau_nu_prior_rate);
}
//Uncomment this to have the model generate mu values as well
//Currently commented out as storing the samples of mu consumes
//a lot of memory for the big models
/*
generated quantities {
vector[N] mu = exp(alpha + beta*x + nu_normalized * nu_sigma);
}
*/
"
model = stan_model(model_code = model_code)
```
Below is the code to make the actual measurements. Some caveats:
* The compilation of the Stan model is not counted (you can avoid it with rstanarm and need to do it only once otherwise)
* There is some overhead in transferring the posterior samples from Stan to R. This overhead is non-negligible for the larger models, but you can get rid of it by storing the samples in a file and reading them separately. The overhead is not measured here.
* Stan took > 16 hours to converge for the largest data size (1e5) and then I had issues fitting the posterior samples into memory on my computer. Notably, R-Inla also crashed on my computer for this size. The largest size is thus excluded here, but I have to conclude that if you get bored after 6 hours, Stan is not practical for such a big model.
* I was not able to get rjags running in a reasonable amount of time, so I did not rerun the JAGS version of the model.
```{r}
stan_times_file = paste0(cache_dir, "stan_times.csv")
stan_summary_file = paste0(cache_dir, "stan_summary.csv")
run_stan = TRUE
if(file.exists(stan_times_file) && file.exists(stan_summary_file)) {
stan_times = read.csv(stan_times_file)
stan_summary = read.csv(stan_summary_file)
if(setequal(stan_times$N, N_values) && setequal(stan_summary$N, N_values)) {
run_stan = FALSE
}
}
if(run_stan) {
stan_times_values = numeric(length(N_values))
stan_summary_list = list()
step = 1
for(N in N_values) {
data_stan = data[[N]]
data_stan$alpha_prior_mean = 0
data_stan$beta_prior_mean = 0
data_stan$alpha_beta_prior_precision = 0.001
data_stan$tau_nu_prior_shape = 0.01
data_stan$tau_nu_prior_rate = 0.01
fit = sampling(model, data = data_stan);
stan_summary_list[[step]] =
as.data.frame(
rstan::summary(fit, pars = c("alpha","beta","tau_nu"))$summary
) %>% rownames_to_column("parameter")
stan_summary_list[[step]]$N = N
all_times = get_elapsed_time(fit)
stan_times_values[step] = max(all_times[,"warmup"] + all_times[,"sample"])
step = step + 1
}
stan_times = data.frame(N = N_values, stan_time = stan_times_values)
stan_summary = do.call(rbind, stan_summary_list)
write.csv(stan_times, stan_times_file,row.names = FALSE)
write.csv(stan_summary, stan_summary_file,row.names = FALSE)
}
```
### Measuring INLA
```{r}
inla_times_file = paste0(cache_dir,"inla_times.csv")
inla_summary_file = paste0(cache_dir,"inla_summary.csv")
run_inla = TRUE
if(file.exists(inla_times_file) && file.exists(inla_summary_file)) {
inla_times = read.csv(inla_times_file)
inla_summary = read.csv(inla_summary_file)
if(setequal(inla_times$N, N_values) && setequal(inla_summary$N, N_values)) {
run_inla = FALSE
}
}
if(run_inla) {
inla_times_values = numeric(length(N_values))
inla_summary_list = list()
step = 1
for(N in N_values) {
nu = 1:N
fit_inla = inla(y ~ x + f(nu,model="iid"), family = c("poisson"),
data = data[[N]], control.predictor=list(link=1))
inla_times_values[step] = fit_inla$cpu.used["Total"]
inla_summary_list[[step]] =
rbind(fit_inla$summary.fixed %>% select(-kld),
fit_inla$summary.hyperpar) %>%
rownames_to_column("parameter")
inla_summary_list[[step]]$N = N
step = step + 1
}
inla_times = data.frame(N = N_values, inla_time = inla_times_values)
inla_summary = do.call(rbind, inla_summary_list)
write.csv(inla_times, inla_times_file,row.names = FALSE)
write.csv(inla_summary, inla_summary_file,row.names = FALSE)
}
```
### Checking inferences
Here we see side-by-side comparisons of the inferences and they seem pretty comparable between Stan and Inla:
```{r results='asis', warning=FALSE}
for(N_to_show in N_values) {
print(kable(stan_summary %>% filter(N == N_to_show) %>%
select(c("parameter","mean","sd")),
caption = paste0("Stan results for N = ", N_to_show)))
print(kable(inla_summary %>% filter(N == N_to_show) %>%
select(c("parameter","mean","sd")),
caption = paste0("INLA results for N = ", N_to_show)))
}
```
### Summary of the timing
You can see that Stan keeps reasonable runtimes for longer time than JAGS in the original blog post, but INLA is still way faster. Also Kathryn got probably very lucky with her seed for N = 25 000, as her INLA run completed very quickly. With my (few) tests, INLA always took at least several minutes for N = 25 000. It may mean that Kathryn's JAGS time is also too short.
```{r results='asis'}
my_results = merge.data.frame(inla_times, stan_times, by = "N")
kable(merge.data.frame(my_results, kathryn_results, by = "N"))
```
You could obviously do multiple runs to reduce uncertainty etc., but this post has already taken too much time of mine, so this will be left to others.
## Testing quality of the results
I also had a hunch that maybe INLA is less precise than Stan, but that turned out to be based on an error. Thus, without much commentary, I put here my code to test this. Basically, I modify the random data generator to actually draw from priors (those priors are quite constrained to provide similar values of alpha, beta nad tau_nu as in the original). I than give both algorithms the knowledge of these priors. I compute both difference between true parameters and a point estimate (mean) and quantiles of the posterior distribution where the true parameter is found. If the algorithms give the best possible estimates, the distribution of such quantiles should be uniform over (0,1). Turns out INLA and Stan give almost exactly the same results for almost all runs and the differences in quality are (for this particular model) negligible.
```{r}
test_precision = function(N) {
rejects <- 0
repeat {
#Set the priors so that they generate similar parameters as in the example above
alpha_beta_prior_precision = 5
prior_sigma = sqrt(1/alpha_beta_prior_precision)
alpha_prior_mean = 1
beta_prior_mean = 0.5
alpha = rnorm(1, alpha_prior_mean, prior_sigma)
beta = rnorm(1, beta_prior_mean, prior_sigma)
tau_nu_prior_shape = 2
tau_nu_prior_rate = 0.01
tau_nu = rgamma(1,tau_nu_prior_shape,tau_nu_prior_rate)
sigma_nu = sqrt(1 / tau_nu)
x = rnorm(N, mean=5,sd=1)
nu = rnorm(N,0,sigma_nu)
linear = alpha + beta*x + nu
#Rejection sampling to avoid NAs and ill-posed problems
if(max(linear) < 15) {
mu = exp(linear)
y = rpois(N,mu)
if(mean(y == 0) < 0.7) {
break;
}
}
rejects = rejects + 1
}
#cat(rejects, "rejects\n")
data = list(
N = N,
x = x,
y = y
)
#cat("A:",alpha,"B:", beta, "T:", tau_nu,"\n")
#print(linear)
#print(data)
#=============== Fit INLA
nu = 1:N
fit_inla = inla(y ~ x + f(nu,model="iid",
hyper=list(theta=list(prior="loggamma",
param=c(tau_nu_prior_shape,tau_nu_prior_rate)))),
family = c("poisson"),
control.fixed = list(mean = beta_prior_mean,
mean.intercept = alpha_prior_mean,
prec = alpha_beta_prior_precision,
prec.intercept = alpha_beta_prior_precision
),
data = data, control.predictor=list(link=1)
)
time_inla = fit_inla$cpu.used["Total"]
alpha_mean_diff_inla = fit_inla$summary.fixed["(Intercept)","mean"] - alpha
beta_mean_diff_inla = fit_inla$summary.fixed["x","mean"] - beta
tau_nu_mean_diff_inla = fit_inla$summary.hyperpar[,"mean"] - tau_nu
alpha_q_inla = inla.pmarginal(alpha, fit_inla$marginals.fixed$`(Intercept)`)
beta_q_inla = inla.pmarginal(beta, fit_inla$marginals.fixed$x)
tau_nu_q_inla = inla.pmarginal(tau_nu, fit_inla$marginals.hyperpar$`Precision for nu`)
#================ Fit Stan
data_stan = data
data_stan$alpha_prior_mean = alpha_prior_mean
data_stan$beta_prior_mean = beta_prior_mean
data_stan$alpha_beta_prior_precision = alpha_beta_prior_precision
data_stan$tau_nu_prior_shape = tau_nu_prior_shape
data_stan$tau_nu_prior_rate = tau_nu_prior_rate
fit = sampling(model, data = data_stan, control = list(adapt_delta = 0.95));
all_times = get_elapsed_time(fit)
max_total_time_stan = max(all_times[,"warmup"] + all_times[,"sample"])
samples = rstan::extract(fit, pars = c("alpha","beta","tau_nu"))
alpha_mean_diff_stan = mean(samples$alpha) - alpha
beta_mean_diff_stan = mean(samples$beta) - beta
tau_nu_mean_diff_stan = mean(samples$tau_nu) - tau_nu
alpha_q_stan = ecdf(samples$alpha)(alpha)
beta_q_stan = ecdf(samples$beta)(beta)
tau_nu_q_stan = ecdf(samples$tau_nu)(tau_nu)
return(data.frame(time_rstan = max_total_time_stan,
time_rinla = time_inla,
alpha_mean_diff_stan = alpha_mean_diff_stan,
beta_mean_diff_stan = beta_mean_diff_stan,
tau_nu_mean_diff_stan = tau_nu_mean_diff_stan,
alpha_q_stan = alpha_q_stan,
beta_q_stan = beta_q_stan,
tau_nu_q_stan = tau_nu_q_stan,
alpha_mean_diff_inla = alpha_mean_diff_inla,
beta_mean_diff_inla = beta_mean_diff_inla,
tau_nu_mean_diff_inla = tau_nu_mean_diff_inla,
alpha_q_inla= alpha_q_inla,
beta_q_inla = beta_q_inla,
tau_nu_q_inla = tau_nu_q_inla
))
}
```
Actually running the comparison. On some occasions, Stan does not converge, my best guess is that the data are somehow pathological, but I didn't investigate thoroughly. You see that results for Stan and Inla are very similar both as point estimates and the distribution of posterior quantiles. The accuracy of the INLA approximation is also AFAIK going to improve with more data.
```{r precision, results='asis',warning=FALSE,message=FALSE}
library(skimr) #Uses skimr to summarize results easily
precision_results_file = paste0(cache_dir,"precision_results.csv")
if(file.exists(precision_results_file)) {
results_precision_df = read.csv(precision_results_file)
} else {
results_precision = list()
for(i in 1:100) {
results_precision[[i]] = test_precision(50)
}
results_precision_df = do.call(rbind, results_precision)
write.csv(results_precision_df,precision_results_file,row.names = FALSE)
}
#Remove uninteresting skim statistics
skim_with(numeric = list(missing = NULL, complete = NULL, n = NULL))
skimmed = results_precision_df %>% select(-X) %>% skim()
#Now a hack to display skim histograms properly in the output:
skimmed_better = skimmed %>% rowwise() %>% mutate(formatted =
if_else(stat == "hist",
utf8ToInt(formatted) %>% as.character() %>% paste0("&#", . ,";", collapse = ""),
formatted))
mostattributes(skimmed_better) = attributes(skimmed)
skimmed_better %>% kable(escape = FALSE)
```