Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
1403cee
change makeParsedVarList to avoid protect stack overflow
perrydv Feb 7, 2019
9be73b1
adding wrapper to avoid sampling empty clusters
paciorek Feb 9, 2019
27f42e0
full draft of wrapping empty cluster samplers
paciorek Feb 9, 2019
023659e
fix minor issue
paciorek Feb 9, 2019
9132b1c
add in wrapping for nosample empty clusters
paciorek Feb 9, 2019
7669900
fix typo
paciorek Feb 9, 2019
3a3a2c6
cleaned up sampler that wraps
paciorek Feb 9, 2019
e76322d
fix various bnp tests under
paciorek Feb 13, 2019
4ed6fcc
fix more test issues for
paciorek Feb 14, 2019
caad2ae
monkeying with testing
paciorek Feb 20, 2019
3c9e3a9
include cluster nodes in calcNodes to
paciorek Feb 22, 2019
88bf6bb
added additional testing of bnp hotfix
paciorek Feb 23, 2019
5f99904
fix test-bnp conflict
paciorek Feb 23, 2019
b6343fc
update testing of cluster param
paciorek Feb 23, 2019
30735d7
minor edit to comment
paciorek Feb 23, 2019
53d1ead
set up nosample empty clusters
paciorek Feb 26, 2019
8bcf7e1
fix a few test errors
paciorek Feb 27, 2019
fcb52a1
simplify adding of CRP sampler wrappers to use buildSampler
paciorek Feb 27, 2019
53fedb2
fix a couple minor est-bnp issues
paciorek Mar 2, 2019
c83fc0e
more robustification of conj checks in test-bnp
paciorek Mar 2, 2019
519ccda
more fixup of test-bnp minor issues
paciorek Mar 4, 2019
20ff17e
monkey with bnp test syntax
paciorek Mar 5, 2019
1eca3d0
more cleanup of test-bnp
paciorek Mar 6, 2019
0d2308c
more test-bnp minor fixup
paciorek Mar 9, 2019
34066b8
add note for nosample empty clusters setup
paciorek May 16, 2019
28df51d
added comment for handling more general BNP models
paciorek May 17, 2019
51205c1
remove stray browser()
paciorek May 17, 2019
b56d074
clean up merge of
paciorek May 23, 2019
da1a645
fix up bnp tests for nosample_empty_clusters2
paciorek May 24, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 44 additions & 8 deletions packages/nimble/R/BNP_samplers.R
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,10 @@ CRP_nonconjugate <- nimbleFunction(
name = "CRP_nonconjugate",
contains = CRP_helper,
setup = function(model, marginalizedNodes, dataNodes, p, nTilde) {
savedIdx <- 1
saved <- nimNumeric(2) # treated as scalar if length 1
saved2 <- saved
saved3 <- saved
},
methods = list(
storeParams = function() {}, ## nothing needed for non-conjugate
Expand All @@ -446,14 +450,26 @@ CRP_nonconjugate <- nimbleFunction(
return(model$getLogProb(dataNodes[i]))
},
sample = function(i = integer(), j = integer() ) {
## sample from prior
if( p == 1 ) {
model$simulate(marginalizedNodes[j])
} else {
for(l in 1:p) { ## marginalized nodes should be in correct order based on findClusterNodes.
model$simulate(marginalizedNodes[(l-1)*nTilde + j])
if(j == 0) { ## reset to stored values (for case of new cluster not opened)
values(model, marginalizedNodes[savedIdx]) <<- saved
if(p > 1) {
values(model, marginalizedNodes[nTilde + savedIdx]) <<- saved2
if(p > 2)
values(model, marginalizedNodes[2*nTilde + savedIdx]) <<- saved3
}
} else {
savedIdx <<- j
saved <<- values(model, marginalizedNodes[j])
model$simulate(marginalizedNodes[j])
if(p > 1) {
saved2 <<- values(model, marginalizedNodes[nTilde + j])
if(p > 2)
saved3 <<- values(model, marginalizedNodes[2*nTilde + j])
for(l in 2:p) { ## marginalized nodes should be in correct order based on findClusterNodes.
model$simulate(marginalizedNodes[(l-1)*nTilde + j])
}
}
}
}
}
)
)
Expand Down Expand Up @@ -984,6 +1000,8 @@ sampler_CRP <- nimbleFunction(
## p and nTilde only needed for non-conjugate currently.
## Note that the elements of tildeNodes will be in order such that the first element corresponds to the cluster
## obtained when xi[i] = 1, the second when xi[i] = 2, etc.
if(sampler == 'CRP_nonconjugate' && p > 3)
stop("sampler_CRP: CRP_nonconjugate sampler not yet set up to handle clustering of more than three variables.") ## This is because of how we put old values back into model when proposing a new cluster that is not accepted.
marginalizedNodes <- unlist(clusterVarInfo$clusterNodes)
helperFunctions[[1]] <- eval(as.name(sampler))(model, marginalizedNodes, dataNodes, p, min_nTilde)
calcNodes <- model$getDependencies(c(target, marginalizedNodes))
Expand Down Expand Up @@ -1160,6 +1178,8 @@ sampler_CRP <- nimbleFunction(
}
xiCounts[model[[target]][i]] <- 1
} else { # an existing label is sampled
if(sampler == 'CRP_nonconjugate') # reset to previous marginalized node value
helperFunctions[[1]]$sample(i, 0)
if( xiCounts[xi[i]] == 0 ) { # xi_i is a singleton, a component was deleted
k <- k - 1
xiUniques <- reorderXiUniques
Expand All @@ -1175,7 +1195,7 @@ sampler_CRP <- nimbleFunction(
}
}

## We have updated cluster variables but not all logProb values are up-to-date.
## We have updated cluster variables but not all logProb values are up-to-date.
model$calculate(calcNodes)
copy(from = model, to = mvSaved, row = 1, nodes = calcNodes, logProb = TRUE)
},
Expand Down Expand Up @@ -1560,3 +1580,19 @@ checkNormalInvGammaConjugacy <- function(model, clusterVarInfo) {
}
return(conjugate)
}

sampler_CRP_cluster_wrapper <- nimbleFunction(
name = "CRP_cluster_wrapper",
contains = sampler_BASE,
setup = function(model, mvSaved, target, control) {
regular_sampler <- nimbleFunctionList(sampler_BASE)
regular_sampler[[1]] <- control$wrapped_conf$buildSampler(model, mvSaved)
dcrpNode <- control$dcrpNode
clusterID <- control$clusterID
},
run = function() {
if(any(model[[dcrpNode]] == clusterID)) regular_sampler[[1]]$run()
},
methods = list(
reset = function() {regular_sampler[[1]]$reset()}
))
31 changes: 29 additions & 2 deletions packages/nimble/R/MCMC_configuration.R
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ print: A logical argument specifying whether to print the ordered list of defaul
isEndNode <- model$isEndNode(nodes)
if(useConjugacy) conjugacyResultsAll <- model$checkConjugacy(nodes)

clusterNodeInfo <- NULL; dcrpNode <- NULL; numCRPnodes <- 0
for(i in seq_along(nodes)) {
node <- nodes[i]
discrete <- model$isDiscrete(node)
Expand All @@ -222,7 +223,13 @@ print: A logical argument specifying whether to print the ordered list of defaul
if(nodeDist == 'dinvwish') { addSampler(target = node, type = 'RW_wishart'); next }
if(nodeDist == 'dcar_normal') { addSampler(target = node, type = 'CAR_normal'); next }
if(nodeDist == 'dcar_proper') { addSampler(target = node, type = 'CAR_proper'); next }
if(nodeDist == 'dCRP') { addSampler(target = node, type = 'CRP', control = list(useConjugacy = useConjugacy)); next }
if(nodeDist == 'dCRP') {
addSampler(target = node, type = 'CRP', control = list(useConjugacy = useConjugacy))
numCRPnodes <- numCRPnodes + 1
clusterNodeInfo[[numCRPnodes]] <- findClusterNodes(model, node)
dcrpNode[numCRPnodes] <- node
next
}
if(multivariateNodesAsScalars) {
for(scalarNode in nodeScalarComponents) {
if(onlySlice) addSampler(target = scalarNode, type = 'slice')
Expand Down Expand Up @@ -262,12 +269,32 @@ print: A logical argument specifying whether to print the ordered list of defaul
## default: 'RW' sampler
addSampler(target = node, type = 'RW'); next
}

## For CRP-based models, wrap samplers for cluster parameters so not sampled if cluster is unoccupied.
if(!is.null(clusterNodeInfo)) {
for(k in seq_along(clusterNodeInfo)) {
for(clusterNodes in clusterNodeInfo[[k]]$clusterNodes) {
samplers <- getSamplers(clusterNodes)
removeSamplers(clusterNodes)
for(i in seq_along(samplers)) {
node <- samplers[[i]]$target
addSampler(target = node, type = 'CRP_cluster_wrapper',
control = list(wrapped_type = samplers[[i]]$name, wrapped_conf = samplers[[i]],
dcrpNode = dcrpNode[[k]], clusterID = i))
## Note for more general clustering: will probably change to
## 'clusterID=clusterNodeInfo[[k]]$clusterIDs[[??]][i]'
## which means we probably need to change to for(clusterNodesIdx in seq_along(clusterNodeInfo[[k]]$clusterNodes))
}
}
}
}

}

if(print) printSamplers()
},

addConjugateSampler = function(conjugacyResult, dynamicallyIndexed = FALSE, print = FALSE) {
addConjugateSampler = function(conjugacyResult, dynamicallyIndexed = FALSE, dcrpNode = NULL, clusterID = NULL, print = FALSE) {
## update May 2016: old (non-dynamic) system is no longer supported -DT
##if(!getNimbleOption('useDynamicConjugacy')) {
## addSampler(target = conjugacyResult$target, type = conjugacyResult$type, control = conjugacyResult$control)
Expand Down
5 changes: 4 additions & 1 deletion packages/nimble/R/MCMC_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,11 @@ mcmc_generateControlListArgument <- function(control, controlDefaults) {

mcmc_listContentsToStr <- function(ls, displayControlDefaults=FALSE, displayNonScalars=FALSE, displayConjugateDependencies=FALSE) {
##if(any(unlist(lapply(ls, is.function)))) warning('probably provided wrong type of function argument')
if(!displayConjugateDependencies)
if(!displayConjugateDependencies) {
if(grepl('^conjugate_d', names(ls)[1])) ls <- ls[1] ## for conjugate samplers, remove all 'dep_dnorm', etc, control elements (don't print them!)
if(grepl('^CRP_cluster_wrapper', names(ls)[1]) && 'wrapped_type' %in% names(ls) &&
grepl('conjugate_d', ls$wrapped_type)[1]) ls <- ls[!names(ls) %in% c('wrapped_conf')]
}
ls <- lapply(ls, function(el) if(is.nf(el) || is.function(el)) 'function' else el) ## functions -> 'function'
ls2 <- list()
## to make displayControlDefaults argument work again, would need to code process
Expand Down
Loading