Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stan backend #100

Open
5 tasks
lindeloev opened this issue Jan 5, 2021 · 3 comments
Open
5 tasks

stan backend #100

lindeloev opened this issue Jan 5, 2021 · 3 comments
Assignees
Labels
enhancement New feature or request

Comments

@lindeloev
Copy link
Owner

lindeloev commented Jan 5, 2021

mcp 2.0 will support stan in addition to JAGS. It is far out in the future but this issue collects working points.

  • Obviously, generate a stan model, pass data, and sample it.
  • Support bridgesampling-based Bayes Factors
  • Can jags-functions and stan-functions be dropped as dependencies, only to be installed upon first use? (call mcp(model, data, backend = "stan")). Otherwise, the dependencies would be quite heavy for non-JAGS and non-stan users.
  • Check if stan samples more effectively using a continuous step function, e.g., as in this post.
  • Option or default to no prior for non-intercept and non-changepoint parameters? Cf. Cross-validation Bayes Factor #122.
@lindeloev lindeloev added this to the 2.0 stan backend milestone Jan 5, 2021
@lindeloev lindeloev self-assigned this Jan 5, 2021
@lindeloev lindeloev added the enhancement New feature or request label Jan 5, 2021
@jpzhangvincent
Copy link

Awesome. Excited to see this on the roadmap. I'd love to contribute to this while still learning Bayesian modeling. Do you have any suggestion or contributor guide? I would be interested in implementing the python version with the PyMC3 backend as well.

@lindeloev
Copy link
Owner Author

lindeloev commented Jan 12, 2021

Thanks, @jpzhangvincent, that would be great! I think getting it to work is simply a matter of (a) re-writing a few JAGS models as stan models and learn if they work well and (b) write an R function that generate these from mcps internal representation of the model. I could really use some input on (a) here as my stan skills are limited.

mcp is under heavy internal restructuring and a few breaking changes, most of which is tracked in issue #90. I think it makes sense to wait until after that release when things hopefully settle down. But I think the JAGS-part is finished now. mcp 0.4 takes formulas like this:

model = list(
  y ~ 1 + x:group,
  ~ 0 + x,
  ~ 1 + sigma(1 + group)
)

which for data like

> head(df)
  x group         y          z
1 1     A -1.431554 -5.9042791
2 2     B 12.819796  1.6075971
3 3     C 17.218474  4.8689988
4 4     D  9.243459 -2.1581639
5 5     A  9.609940 10.1076712
6 6     B  9.544842  0.2298296

generates JAGS code like this:

model {
  # mcp helper values
  cp_0 = MINX
  cp_3 = MAXX

  # Priors for population-level effects
  cp_1 ~ dt(MINX, 1/((MAXX-MINX)/N_CP)^2, N_CP-1) T(cp_0, MAXX)
  cp_2 ~ dt(MINX, 1/((MAXX-MINX)/N_CP)^2, N_CP-1) T(cp_1, MAXX)
  Intercept_1 ~ dt(MEANLINKY, 1/(SDLINKY)^2, 3) 
  xgroupA_1 ~ dt(0, 1/(N_CP*SDLINKY/(MAXX-MINX))^2, 3) 
  xgroupB_1 ~ dt(0, 1/(N_CP*SDLINKY/(MAXX-MINX))^2, 3) 
  xgroupC_1 ~ dt(0, 1/(N_CP*SDLINKY/(MAXX-MINX))^2, 3) 
  xgroupD_1 ~ dt(0, 1/(N_CP*SDLINKY/(MAXX-MINX))^2, 3) 
  sigma_1 ~ dt(0, 1/(SDLINKY)^2, 3) T(0, )
  x_2 ~ dt(0, 1/(N_CP*SDLINKY/(MAXX-MINX))^2, 3) 
  Intercept_3 ~ dt(MEANLINKY, 1/(SDLINKY)^2, 3) 
  sigma_3 ~ dt(0, 1/(SDLINKY)^2, 3) T(0, )
  sigma_groupB_3 ~ dt(0, 1/(SDLINKY)^2, 3) 
  sigma_groupC_3 ~ dt(0, 1/(SDLINKY)^2, 3) 
  sigma_groupD_3 ~ dt(0, 1/(SDLINKY)^2, 3) 

  # Model and likelihood
  for (i_ in 1:length(x)) {
    # par_x local to each segment
    x_local_1_[i_] = min(x[i_], cp_1)
    x_local_2_[i_] = min(x[i_], cp_2) - cp_1
    x_local_3_[i_] = min(x[i_], cp_3) - cp_2
    
    # Formula for mu
    mu_[i_] =
    
      # Segment 1: y1 + x:group
      (x[i_] >= cp_0) * (x[i_] < cp_2) * inprod(rhs_data_[i_, c(1)], c(Intercept_1)) * 1 + 
      (x[i_] >= cp_0) * (x[i_] < cp_2) * inprod(rhs_data_[i_, c(2, 3, 4, 5)], c(xgroupA_1, xgroupB_1, xgroupC_1, xgroupD_1)) * x_local_1_[i_] + 
    
      # Segment 2: y ~ 10 + x
      (x[i_] >= cp_1) * (x[i_] < cp_2) * inprod(rhs_data_[i_, c(7)], c(x_2)) * x_local_2_[i_] + 
    
      # Segment 3: y ~ 11 + sigma(1 + group)
      (x[i_] >= cp_2) * inprod(rhs_data_[i_, c(8)], c(Intercept_3)) * 1
    
    # Formula for sigma
    sigma_[i_] = max(10^-9, sigma_tmp[i_])  # Count negative sigma as just-above-zero sigma
    sigma_tmp[i_] =  
      # Segment 1: y1 + x:group
      (x[i_] >= cp_0) * (x[i_] < cp_2) * inprod(rhs_data_[i_, c(6)], c(sigma_1)) * 1 + 
    
      # Segment 3: y ~ 11 + sigma(1 + group)
      (x[i_] >= cp_2) * inprod(rhs_data_[i_, c(9, 10, 11, 12)], c(sigma_3, sigma_groupB_3, sigma_groupC_3, sigma_groupD_3)) * 1

    # Likelihood and log-density for family = gaussian()
    y[i_] ~ dnorm((mu_[i_]), 1 / sigma_[i_]^2)  # SD as precision
    loglik_[i_] = logdensity.norm(y[i_], (mu_[i_]), 1 / sigma_[i_]^2)  # SD as precision
  }
}

Here, rhs_data_ is model.matrix but with x factored out of all terms. x is then "factored in" in JAGS, as you can see. inprod is simply equivalent to %*%* in base R.

Some of the work points for generating an equivalent stan model are:

  1. I think some of the priors can be dropped in stan (JAGS requires priors for everything).
  2. I think stan allows for vectorizing, so we can get rid of the for-loop.
  3. I have to learn more stan to see if some of it can be moved to a "data" chunk, etc.
  4. There are many identical ways to represent the formula-part, but JAGS samples considerably faster for this particular one. I'd like to see if stan is more robust so that we needn't have multiple lines of code for each segment.
  5. In general, how can this be made to run the most efficient in stan? Can we use some of the new primitives, can we make a model that runs on GPU, etc.?

Would love any tips, example stan models, or thoughts!

@lindeloev lindeloev mentioned this issue Jan 14, 2021
6 tasks
@mattansb
Copy link

mattansb commented Mar 17, 2021

As far as dependencies, you can:

  1. Have JAGS/Stan as suggested
  2. On startup
    • If neither is installed, give the user a message.
    • If only one is installed, set some options() to use that one.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants