-
Notifications
You must be signed in to change notification settings - Fork 0
/
sample_and_summarise.r
72 lines (72 loc) · 1.71 KB
/
sample_and_summarise.r
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
sample_and_summarise = function(generated,stan_mod_for_sampling){
csv_base = paste(as.list(generated$gen_args),collapse='_')
sampled = stan_mod_for_sampling$sample(
data = generated$data_for_stan
, chains = parallel::detectCores()/2
, parallel_chains = parallel::detectCores()/2
, show_messages = F
, refresh = 1
, output_dir = 'stan_temp'
, output_basename = csv_base
, seed = abs(digest::digest2int(digest::digest(generated$data_for_stan,algo='xxhash64')))
)
#get ranks & other summaries
(
sampled$draws()
%>% posterior::as_draws_df()
%>% as_tibble()
%>% select(-.chain,-.iteration)
%>% pivot_longer(-.draw)
%>% left_join(generated$true_pars,by='name')
%>% filter(
!is.na(true)
)
%>% rename(variable=name)
%>% group_by(variable)
%>% summarise(
rank = mean(true>value)
, true = true[1]
, .groups = 'drop'
)
%>% left_join(
(
sampled$draws()
%>% posterior::summarise_draws(
posterior::default_convergence_measures()
)
)
, by = 'variable'
)
%>% left_join(
(
sampled$draws()
%>% posterior::summarise_draws(
~posterior::quantile2(.x,probs=c(.1,.25,.5,.75,.9))
)
)
, by = 'variable'
)
) -> posterior_summary
(
sampled$sampler_diagnostics()
%>% posterior::as_draws_df()
%>% summarise(
max_treedepth = max(treedepth__)
, num_divergent = sum(divergent__)
, var_energy = var(energy__)
)
%>% mutate(
time = sampled$time()$total
, var_summary = list(posterior_summary)
, model = str_replace(
basename(stan_mod_for_sampling$stan_file())
, fixed('.stan')
, ''
)
)
%>% bind_cols(generated$gen_args)
) -> to_return
#delete the csvs
system(paste0('rm stan_temp/',csv_base,'*'))
return(to_return)
}