/
cluster_parameter_update.R
92 lines (65 loc) · 2.35 KB
/
cluster_parameter_update.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#'Update the cluster parameters of the Dirichlet process.
#'
#' Update the parameters of each individual cluster using all the data assigned to the particular cluster.
#' A sample is taken from the posterior distribution using a direct sample if the mixing distribution is conjugate or the Metropolis Hastings algorithm for non-conjugate mixtures.
#'
#'@param dpObj Dirichlet process object
#'@return Dirichlet process object with update cluster parameters
#'
#'@examples
#' dp <- DirichletProcessGaussian(rnorm(10))
#' dp <- ClusterParameterUpdate(dp)
#'
#'@export
ClusterParameterUpdate <- function(dpObj) UseMethod("ClusterParameterUpdate", dpObj)
#'@export
ClusterParameterUpdate.conjugate <- function(dpObj) {
y <- dpObj$data
numLabels <- dpObj$numberClusters
clusterLabels <- dpObj$clusterLabels
clusterParams <- dpObj$clusterParameters
mdobj <- dpObj$mixingDistribution
for (i in 1:numLabels) {
pts <- y[which(clusterLabels == i), , drop = FALSE]
post_draw <- PosteriorDraw(mdobj, pts)
for (j in seq_along(clusterParams)) {
clusterParams[[j]][, , i] <- post_draw[[j]]
}
}
dpObj$clusterParameters <- clusterParams
return(dpObj)
}
#'@export
ClusterParameterUpdate.nonconjugate <- function(dpObj) {
y <- dpObj$data
numLabels <- dpObj$numberClusters
clusterLabels <- dpObj$clusterLabels
clusterParams <- dpObj$clusterParameters
mdobj <- dpObj$mixingDistribution
mhDraws <- dpObj$mhDraws
accept_ratio <- numeric(numLabels)
start_pos <- PriorDraw(mdobj)
for (i in 1:numLabels) {
pts <- y[which(clusterLabels == i), , drop = FALSE]
for (j in seq_along(clusterParams)) {
start_pos[[j]] <- clusterParams[[j]][, , i, drop = FALSE]
}
parameter_samples <- PosteriorDraw(mdobj, pts, mhDraws, start_pos = start_pos)
for (j in seq_along(clusterParams)) {
clusterParams[[j]][, , i] <- parameter_samples[[j]][, , mhDraws]
}
accept_ratio[i] <- length(unique(parameter_samples[[1]]))/mhDraws
}
dpObj$clusterParameters <- clusterParams
return(dpObj)
}
cluster_parameter_update <- function(mdobj, data, clusters, params){
uniqueClusters <- unique(clusters)
newParams <- lapply(uniqueClusters, function(i){
updateData <- data[clusters==i, ,drop=F]
newParam <- PosteriorDraw(mdobj, updateData)
return(newParam)
} )
#newParamsFull <- newParams[clusters]
return(newParams)
}