Skip to content
Browse files

Minor refactoring.

  • Loading branch information...
1 parent bbdd349 commit 0fcd0daa833e4d9f88280816b2910d9a85f7b45b @echen committed Apr 22, 2012
Showing with 37 additions and 10 deletions.
  1. +10 −1 chinese_restaurant_process.rb
  2. +6 −6 plots.R
  3. +8 −2 polya_urn_model.R
  4. +7 −1 polya_urn_model.rb
  5. +6 −0 stick_breaking_process.R
View
11 chinese_restaurant_process.rb
@@ -1,7 +1,16 @@
# Generate table assignments for `num_customers` customers, according to
# a Chinese Restaurant Process with dispersion parameter `alpha`.
#
-# returns an array of integer table assignments
+# Returns an array of integer table assignments.
+#
+# Examples
+#
+# chinese_restaurant_process(num_customers = 5, alpha = 1)
+# => [1, 2, 3, 4, 3]
+#
+# chinese_restaurant_process(num_customers = 10, alpha = 3)
+# => [1, 2, 1, 1, 3, 1, 2, 3, 4, 5]
+#
def chinese_restaurant_process(num_customers, alpha)
return [] if num_customers <= 0
View
12 plots.R
@@ -3,9 +3,9 @@ library(reshape)
# Some of the plots used in the blog post.
-##########
+#################
# POLYA URN MODEL
-##########
+#################
polya_urn_model_plots = function(num_balls, alpha) {
# Lazy man's repetition...
@@ -27,9 +27,9 @@ polya_urn_model_plots = function(num_balls, alpha) {
polya_urn_model_plots(10, 1)
-##########
+########################
# STICK-BREAKING PROCESS
-##########
+########################
stick_breaking_process_plots = function(num_weights, alpha) {
x1 = stick_breaking_process(num_weights, alpha)
@@ -50,9 +50,9 @@ stick_breaking_process_plots = function(num_weights, alpha) {
stick_breaking_process_plots(10, 5)
-##########
+##############
# ALL CLUSTERS
-##########
+##############
x = read.table("mcdonalds-data-with-clusters.tsv", header = T, sep = " ", comment.char = "", quote = "")
View
10 polya_urn_model.R
@@ -1,8 +1,14 @@
# Return a vector of `num_balls` ball colors according to a Polya Urn Model
# with dispersion `alpha`, sampling from a specified base color distribution.
+#
+# Examples
+#
+# polya_urn_model(function() rnorm(1), 5, 1)
+# => c(-0.2210029, -0.3013638, 0.8149611, 1.6879720, -0.7803525)
+#
polya_urn_model = function(base_color_distribution, num_balls, alpha) {
balls = c()
-
+
for (i in 1:num_balls) {
if (runif(1) < alpha / (alpha + length(balls))) {
# Add a new ball color.
@@ -15,7 +21,7 @@ polya_urn_model = function(base_color_distribution, num_balls, alpha) {
balls = c(balls, ball)
}
}
-
+
balls
}
View
8 polya_urn_model.rb
@@ -2,7 +2,13 @@
# with a specified base color distribution and dispersion parameter
# `alpha`.
#
-# returns an array of ball colors
+# Returns an array of ball colors.
+#
+# Examples
+#
+# polya_urn_model(lambda { rand }, num_balls = 10, alpha = 1)
+# => [0.55, 0.55, 0.55, 0.55, 0.12, 0.12, 0.46, 0.46, 0.55, 0.55]
+#
def polya_urn_model(base_color_distribution, num_balls, alpha)
return [] if num_balls <= 0
View
6 stick_breaking_process.R
@@ -5,6 +5,12 @@
# \beta_k = (1 - \beta_1) * (1 - \beta_2) * ... * (1 - \beta_{k-1}) * beta_k
# where each $\beta_i$ is drawn from a Beta distribution
# \beta_i ~ Beta(1, \alpha)
+#
+# Examples
+#
+# stick_breaking_process(num_weight = 5, alpha = 1)
+# => c(0.712148550, 0.169208000, 0.101483441, 0.014156001, 0.001498306)
+#
stick_breaking_process = function(num_weights, alpha) {
betas = rbeta(num_weights, 1, alpha)
remaining_stick_lengths = c(1, cumprod(1 - betas))[1:num_weights]

0 comments on commit 0fcd0da

Please sign in to comment.
Something went wrong with that request. Please try again.