Skip to content

Commit

Permalink
Revised talk
Browse files Browse the repository at this point in the history
  • Loading branch information
johnmyleswhite committed Jul 25, 2012
1 parent b1d1f2f commit bf0ab34
Show file tree
Hide file tree
Showing 26 changed files with 2,000 additions and 139 deletions.
121 changes: 121 additions & 0 deletions code/dpmm/dpmm.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Generalized t-Distribution log PDF
t_logpdf <- function(x, mu, v, df)
{
g <- lgamma(0.5 * df + 0.5) - lgamma(0.5 * df) - log(sqrt(df * pi * v))
return(g - 0.5 * (df + 1) * log(1 + (1 / df) * ((x - mu)^2) / v))
}

# Normalize a log probability vector without numerical underflow/overflow
normalize_logprob <- function(log.prob)
{
# Compute the log of the normalizing constant.
g <- log(sum(exp(log.prob - max(log.prob)))) + max(log.prob)

# Find probabilities by exponentiating normalized log probabilities.
return(exp(log.prob - g))
}

# Calculate the log likelihood for the Gaussian DP mixture model
# Mean and variance parameters marginalized under normal-gamma prior
# This corresponds to a generalized Student's t-distribution
# NB: some constant terms are ignored
dpmm_loglik <- function(xn, x, tau0, beta0, mu0, kappa0)
{
N <- length(x)
kappaN <- kappa0 + N
tauN <- tau0 + N / 2

if (N > 0)
{
# If there previous observations, use the current posterior.
xm <- mean(x)
betaN <- beta0 + 0.5 * sum((x - xm)^2) + (kappa0 * N *(xm - mu0)^2) / (2 * kappaN)
muN <- (kappa0 * mu0 + N * xm) / kappaN
}
else
{
# If there are no previous observations, revert to the prior.
betaN <- beta0
muN <- mu0
}

return(t_logpdf(xn, muN, betaN * (kappaN + 1) / (tauN * kappaN), 2 * kappaN))
}

# Gibbs sampling for the Gaussian Dirichlet process mixture model
#
# Inputs:
# x - data (vector of length N)
# alpha - DP concentration parameter
# tau0 - normal-gamma prior shape
# beta0 - normal-gamma prior rate (inverse scale)
# kappa0 - normal-gamma prior precision scaling parameter
# nIter - number of Gibbs iterations
#
# Outputs:
# C, a N x nIter matrix of cluster assignments
dpmm_gibbs <- function(x, alpha, tau0, beta0, mu0, kappa0, nIter)
{
N <- length(x)
logpost <- rep(1, 1, nIter)
p <- rep(1, 1, N)
C <- matrix(data = NA, nrow = N, ncol = nIter)
c <-rep(1, 1, N)
m <- rep(0, 1, N)
m[1] <- N
logprior <- rep(1, 1, N)
loglik <- rep(1, 1, N)
ix <- 1:N

for (i in 1:nIter)
{
print(paste("Iteration", i))

for (n in 1:N)
{
#all customers except n
cn <- c[1:N!=n]

#count cluster assignments
for (k in 1:N)
{
m[k] <- sum(cn == k)
}

if (all(m > 0))
{
#active dishes
K.active <- ix[m > 0]
}
else
{
#active dishes + 1 new dish
K.active <- c(ix[m > 0], min(ix[m == 0]))
}
for (k in K.active)
{
if (m[k] > 0)
{
#prior for old dish
logprior[k] <- log(m[k])
}
else
{
#prior for new dish
logprior[k] <- log(alpha)
}
#calculate log likelihood
loglik[k] <- dpmm_loglik(x[n], x[c==k], tau0, beta0, mu0, kappa0)
}

#posterior
post <- normalize_logprob(logprior[K.active] + loglik[K.active])

#update cluster assignment
c[n] <- sample(K.active, 1, rep = TRUE, prob = post)
}

C[,i] = c
}
C
}
Binary file added code/dpmm/dpmm_0.01.pdf
Binary file not shown.
Binary file added code/dpmm/dpmm_0.5.pdf
Binary file not shown.
Binary file added code/dpmm/dpmm_100.0.pdf
Binary file not shown.
Binary file added code/dpmm/dpmm_12.5.pdf
Binary file not shown.
Binary file added code/dpmm/dpmm_2.5.pdf
Binary file not shown.
Binary file added code/dpmm/dpmm_ground_truth.pdf
Binary file not shown.
81 changes: 81 additions & 0 deletions code/dpmm/example.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
set.seed(1)

source("code/dpmm/dpmm.R")

x <- c(rnorm(100, 100, 8), rnorm(50, 200, 25), rnorm(50, 25, 1))
labels <- c(rep("A", 100), rep("B", 50), rep("C", 50))

df <- data.frame(X = x, Label = labels)

ggplot(df, aes(x = X)) +
geom_histogram(binwidth = 3)

ggplot(df, aes(x = X, fill = Label)) +
geom_histogram(binwidth = 3) +
opts(legend.position = "none") +
opts(title = "Ground Truth Mixture Model")
ggsave("code/dpmm/dpmm_ground_truth.pdf", height = 7, width = 7)

# Gibbs sampling for the Gaussian Dirichlet process mixture model
#
# Inputs:
# x - data (vector of length N)
# alpha - DP concentration parameter
# tau0 - normal-gamma prior shape
# beta0 - normal-gamma prior rate (inverse scale)
# kappa0 - normal-gamma prior precision scaling parameter
# nIter - number of Gibbs iterations
#
# Outputs:
# C, a N x nIter matrix of cluster assignments
#dpmm_gibbs <- function(x, alpha, tau0, beta0, mu0, kappa0, nIter)

nIter <- 100
results <- dpmm_gibbs(x, 0.01, 0.1, 0.1, 0, 0.1, nIter)
results[, nIter]

ggplot(df, aes(x = X, fill = factor(results[, nIter]))) +
geom_histogram(binwidth = 3) +
opts(legend.position = "none") +
opts(title = "dp-MM with alpha = 0.01")
ggsave("code/dpmm/dpmm_0.01.pdf", height = 7, width = 7)

nIter <- 100
results <- dpmm_gibbs(x, 0.5, 0.1, 0.1, 0, 0.1, nIter)
results[, nIter]

ggplot(df, aes(x = X, fill = factor(results[, nIter]))) +
geom_histogram(binwidth = 3) +
opts(legend.position = "none") +
opts(title = "dp-MM with alpha = 0.5")
ggsave("code/dpmm/dpmm_0.5.pdf", height = 7, width = 7)

nIter <- 100
results <- dpmm_gibbs(x, 2.5, 0.1, 0.1, 0, 0.1, nIter)
results[, nIter]

ggplot(df, aes(x = X, fill = factor(results[, nIter]))) +
geom_histogram(binwidth = 3) +
opts(legend.position = "none") +
opts(title = "dp-MM with alpha = 2.5")
ggsave("code/dpmm/dpmm_2.5.pdf", height = 7, width = 7)

nIter <- 100
results <- dpmm_gibbs(x, 12.5, 0.1, 0.1, 0, 0.1, nIter)
results[, nIter]

ggplot(df, aes(x = X, fill = factor(results[, nIter]))) +
geom_histogram(binwidth = 3) +
opts(legend.position = "none") +
opts(title = "dp-MM with alpha = 12.5")
ggsave("code/dpmm/dpmm_12.5.pdf", height = 7, width = 7)

nIter <- 100
results <- dpmm_gibbs(x, 100.0, 0.1, 0.1, 0, 0.1, nIter)
results[, nIter]

ggplot(df, aes(x = X, fill = factor(results[, nIter]))) +
geom_histogram(binwidth = 3) +
opts(legend.position = "none") +
opts(title = "dp-MM with alpha = 100.0")
ggsave("code/dpmm/dpmm_100.0.pdf", height = 7, width = 7)
Loading

0 comments on commit bf0ab34

Please sign in to comment.