Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

add ucb2 algorithm

  • Loading branch information...
commit 7e4bd109eeea7f53fd5bfbd6fc228cc16b7fb051 1 parent 5554618
Adam Laiacano authored
View
34 python/algorithms/base.py
@@ -0,0 +1,34 @@
+from abc import ABCMeta
+import abc
+
+class BanditAlgo(object):
+ """
+ This is a base class for all other bandit algorithms.
+ """
+ __metaclass__ = ABCMeta
+
+ @staticmethod
+ def ind_max(x):
+ m = max(x)
+ return x.index(m)
+
+ def initialize(self, n_arms):
+ self.counts = [0 for col in range(n_arms)]
+ self.values = [0.0 for col in range(n_arms)]
+ return
+
+ @abc.abstractmethod
+ def update(self):
+ """
+ A method to update the parameters of the specific
+ bandit algorithm.
+ """
+ pass
+
+ @abc.abstractmethod
+ def select_arm(self):
+ """
+ A method to select an arm as determined by the specific
+ bandit algorithm.
+ """
+ pass
View
23 python/algorithms/ucb/test_ucb2.py
@@ -0,0 +1,23 @@
+execfile("core.py")
+from algorithms.ucb.ucb2 import *
+import random
+
+random.seed(1)
+means = [0.1, 0.1, 0.1, 0.1, 0.9]
+n_arms = len(means)
+random.shuffle(means)
+arms = map(lambda (mu): BernoulliArm(mu), means)
+print("Best arm is " + str(ind_max(means)))
+
+for alpha in [0.1, 0.3, 0.5, 0.7, 0.9]:
+ algo = UCB2(alpha, [], [])
+ algo.initialize(n_arms)
+ results = test_algorithm(algo, arms, 5000, 250)
+
+ f = open("algorithms/ucb/ucb2_results_%s.tsv" % alpha, "w")
+
+ for i in range(len(results[0])):
+ f.write("\t".join([str(results[j][i]) for j in range(len(results))]))
+ f.write("\t%s\n" % alpha)
+
+ f.close()
View
77 python/algorithms/ucb/ucb2.py
@@ -0,0 +1,77 @@
+import math
+
+
+def ind_max(x):
+ m = max(x)
+ return x.index(m)
+
+
+class UCB2(object):
+ def __init__(self, alpha, counts, values):
+ """
+ UCB2 algorithm. Implementation of the slides at:
+ http://lane.compbio.cmu.edu/courses/slides_ucb.pdf
+ """
+ self.alpha = alpha
+ self.counts = counts
+ self.values = values
+ self.__current_arm = 0
+ self.__next_update = 0
+ return
+
+ def initialize(self, n_arms):
+ self.counts = [0 for col in range(n_arms)]
+ self.values = [0.0 for col in range(n_arms)]
+ self.r = [0 for col in range(n_arms)]
+ self.__current_arm = 0
+ self.__next_update = 0
+
+ def __bonus(self, n, r):
+ tau = self.__tau(r)
+ bonus = math.sqrt((1. + self.alpha) * math.log(math.e * float(n) / tau) / (2 * tau))
+ return bonus
+
+ def __tau(self, r):
+ return int(math.ceil((1 + self.alpha) ** r))
+
+ def __set_arm(self, arm):
+ """
+ When choosing a new arm, make sure we play that arm for
+ tau(r+1) - tau(r) episodes.
+ """
+ self.__current_arm = arm
+ self.__next_update += max(1, self.__tau(self.r[arm] + 1) - self.__tau(self.r[arm]))
+ self.r[arm] += 1
+
+ def select_arm(self):
+ n_arms = len(self.counts)
+
+ # play each arm once
+ for arm in range(n_arms):
+ if self.counts[arm] == 0:
+ self.__set_arm(arm)
+ return arm
+
+ # make sure we aren't still playing the previous arm.
+ if self.__next_update > sum(self.counts):
+ return self.__current_arm
+
+ ucb_values = [0.0 for arm in range(n_arms)]
+ total_counts = sum(self.counts)
+ for arm in xrange(n_arms):
+ bonus = self.__bonus(total_counts, self.r[arm])
+ ucb_values[arm] = self.values[arm] + bonus
+
+ chosen_arm = ind_max(ucb_values)
+ self.__set_arm(chosen_arm)
+ return chosen_arm
+
+ def update(self, chosen_arm, reward):
+ n = self.counts[chosen_arm]
+ self.counts[chosen_arm] = n + 1
+
+ value = self.values[chosen_arm]
+ if n == 0:
+ self.values[chosen_arm] = reward
+ else:
+ self.values[chosen_arm] = ((n - 1) / float(n)) * value + (1 / float(n)) * reward
View
62 r/ucb/plot_ucb2.R
@@ -0,0 +1,62 @@
+library("plyr")
+library("ggplot2")
+
+data.files <- c(
+ "python/algorithms/ucb/ucb2_results_0.1.tsv",
+ "python/algorithms/ucb/ucb2_results_0.3.tsv",
+ "python/algorithms/ucb/ucb2_results_0.5.tsv",
+ "python/algorithms/ucb/ucb2_results_0.9.tsv",
+ "python/algorithms/ucb/ucb2_results_0.9.tsv"
+)
+
+results <- llply(data.files, function(x){z <- read.csv(x, header=F, sep="\t"); names(z)=c("Sim", "T", "ChosenArm", "Reward", "CumulativeReward", "Alpha"); return(z)})
+results <- ldply(results)
+
+results$Alpha <- factor(results$Alpha)
+
+# Plot average reward as a function of time.
+stats <- ddply(idata.frame(results),
+ c("T", "Alpha"),
+ function (df) {mean(df$Reward)})
+ggplot(stats, aes(x = T, y = V1, color=Alpha)) +
+ geom_line() +
+ ylim(0, 1) +
+ xlab("Time") +
+ ylab("Average Reward") +
+ opts(title="Performance of the UCB2 Algorithm")
+ggsave("r/graphs/ucb2_average_reward.pdf")
+
+# Plot frequency of selecting correct arm as a function of time.
+# In this instance, 1 is the correct arm.
+stats <- ddply(idata.frame(results),
+ c("T", "Alpha"),
+ function (df) {mean(df$ChosenArm == 1)})
+ggplot(stats, aes(x = T, y = V1, color=Alpha)) +
+ geom_line() +
+ ylim(0, 1) +
+ xlab("Time") +
+ ylab("Probability of Selecting Best Arm") +
+ opts(title="Accuracy of the UCB2 Algorithm")
+ggsave("r/graphs/ucb2_average_accuracy.pdf")
+
+# Plot variance of chosen arms as a function of time.
+stats <- ddply(idata.frame(results),
+ c("T", "Alpha"),
+ function (df) {var(df$ChosenArm)})
+ggplot(stats, aes(x = T, y = V1, color=Alpha)) +
+ geom_line() +
+ xlab("Time") +
+ ylab("Variance of Chosen Arm") +
+ opts(title="Variability of the UCB2 Algorithm")
+ggsave("r/graphs/ucb2_variance_choices.pdf")
+
+# Plot cumulative reward as a function of time.
+stats <- ddply(idata.frame(results),
+ c("T", "Alpha"),
+ function (df) {mean(df$CumulativeReward)})
+ggplot(stats, aes(x = T, y = V1, color=Alpha)) +
+ geom_line() +
+ xlab("Time") +
+ ylab("Cumulative Reward of Chosen Arm") +
+ opts(title="Cumulative Reward of the UCB2 Algorithm")
+ggsave("r/graphs/ucb2_cumulative_reward.pdf")
Please sign in to comment.
Something went wrong with that request. Please try again.