Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Adam Laiacano
committed
Nov 15, 2012
1 parent
5554618
commit 7e4bd10
Showing
4 changed files
with
196 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from abc import ABCMeta | ||
import abc | ||
|
||
class BanditAlgo(object): | ||
""" | ||
This is a base class for all other bandit algorithms. | ||
""" | ||
__metaclass__ = ABCMeta | ||
|
||
@staticmethod | ||
def ind_max(x): | ||
m = max(x) | ||
return x.index(m) | ||
|
||
def initialize(self, n_arms): | ||
self.counts = [0 for col in range(n_arms)] | ||
self.values = [0.0 for col in range(n_arms)] | ||
return | ||
|
||
@abc.abstractmethod | ||
def update(self): | ||
""" | ||
A method to update the parameters of the specific | ||
bandit algorithm. | ||
""" | ||
pass | ||
|
||
@abc.abstractmethod | ||
def select_arm(self): | ||
""" | ||
A method to select an arm as determined by the specific | ||
bandit algorithm. | ||
""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
execfile("core.py") | ||
from algorithms.ucb.ucb2 import * | ||
import random | ||
|
||
random.seed(1) | ||
means = [0.1, 0.1, 0.1, 0.1, 0.9] | ||
n_arms = len(means) | ||
random.shuffle(means) | ||
arms = map(lambda (mu): BernoulliArm(mu), means) | ||
print("Best arm is " + str(ind_max(means))) | ||
|
||
for alpha in [0.1, 0.3, 0.5, 0.7, 0.9]: | ||
algo = UCB2(alpha, [], []) | ||
algo.initialize(n_arms) | ||
results = test_algorithm(algo, arms, 5000, 250) | ||
|
||
f = open("algorithms/ucb/ucb2_results_%s.tsv" % alpha, "w") | ||
|
||
for i in range(len(results[0])): | ||
f.write("\t".join([str(results[j][i]) for j in range(len(results))])) | ||
f.write("\t%s\n" % alpha) | ||
|
||
f.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import math | ||
|
||
|
||
def ind_max(x): | ||
m = max(x) | ||
return x.index(m) | ||
|
||
|
||
class UCB2(object): | ||
def __init__(self, alpha, counts, values): | ||
""" | ||
UCB2 algorithm. Implementation of the slides at: | ||
http://lane.compbio.cmu.edu/courses/slides_ucb.pdf | ||
""" | ||
self.alpha = alpha | ||
self.counts = counts | ||
self.values = values | ||
self.__current_arm = 0 | ||
self.__next_update = 0 | ||
return | ||
|
||
def initialize(self, n_arms): | ||
self.counts = [0 for col in range(n_arms)] | ||
self.values = [0.0 for col in range(n_arms)] | ||
self.r = [0 for col in range(n_arms)] | ||
self.__current_arm = 0 | ||
self.__next_update = 0 | ||
|
||
def __bonus(self, n, r): | ||
tau = self.__tau(r) | ||
bonus = math.sqrt((1. + self.alpha) * math.log(math.e * float(n) / tau) / (2 * tau)) | ||
return bonus | ||
|
||
def __tau(self, r): | ||
return int(math.ceil((1 + self.alpha) ** r)) | ||
|
||
def __set_arm(self, arm): | ||
""" | ||
When choosing a new arm, make sure we play that arm for | ||
tau(r+1) - tau(r) episodes. | ||
""" | ||
self.__current_arm = arm | ||
self.__next_update += max(1, self.__tau(self.r[arm] + 1) - self.__tau(self.r[arm])) | ||
self.r[arm] += 1 | ||
|
||
def select_arm(self): | ||
n_arms = len(self.counts) | ||
|
||
# play each arm once | ||
for arm in range(n_arms): | ||
if self.counts[arm] == 0: | ||
self.__set_arm(arm) | ||
return arm | ||
|
||
# make sure we aren't still playing the previous arm. | ||
if self.__next_update > sum(self.counts): | ||
return self.__current_arm | ||
|
||
ucb_values = [0.0 for arm in range(n_arms)] | ||
total_counts = sum(self.counts) | ||
for arm in xrange(n_arms): | ||
bonus = self.__bonus(total_counts, self.r[arm]) | ||
ucb_values[arm] = self.values[arm] + bonus | ||
|
||
chosen_arm = ind_max(ucb_values) | ||
self.__set_arm(chosen_arm) | ||
return chosen_arm | ||
|
||
def update(self, chosen_arm, reward): | ||
n = self.counts[chosen_arm] | ||
self.counts[chosen_arm] = n + 1 | ||
|
||
value = self.values[chosen_arm] | ||
if n == 0: | ||
self.values[chosen_arm] = reward | ||
else: | ||
self.values[chosen_arm] = ((n - 1) / float(n)) * value + (1 / float(n)) * reward |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
library("plyr") | ||
library("ggplot2") | ||
|
||
data.files <- c( | ||
"python/algorithms/ucb/ucb2_results_0.1.tsv", | ||
"python/algorithms/ucb/ucb2_results_0.3.tsv", | ||
"python/algorithms/ucb/ucb2_results_0.5.tsv", | ||
"python/algorithms/ucb/ucb2_results_0.9.tsv", | ||
"python/algorithms/ucb/ucb2_results_0.9.tsv" | ||
) | ||
|
||
results <- llply(data.files, function(x){z <- read.csv(x, header=F, sep="\t"); names(z)=c("Sim", "T", "ChosenArm", "Reward", "CumulativeReward", "Alpha"); return(z)}) | ||
results <- ldply(results) | ||
|
||
results$Alpha <- factor(results$Alpha) | ||
|
||
# Plot average reward as a function of time. | ||
stats <- ddply(idata.frame(results), | ||
c("T", "Alpha"), | ||
function (df) {mean(df$Reward)}) | ||
ggplot(stats, aes(x = T, y = V1, color=Alpha)) + | ||
geom_line() + | ||
ylim(0, 1) + | ||
xlab("Time") + | ||
ylab("Average Reward") + | ||
opts(title="Performance of the UCB2 Algorithm") | ||
ggsave("r/graphs/ucb2_average_reward.pdf") | ||
|
||
# Plot frequency of selecting correct arm as a function of time. | ||
# In this instance, 1 is the correct arm. | ||
stats <- ddply(idata.frame(results), | ||
c("T", "Alpha"), | ||
function (df) {mean(df$ChosenArm == 1)}) | ||
ggplot(stats, aes(x = T, y = V1, color=Alpha)) + | ||
geom_line() + | ||
ylim(0, 1) + | ||
xlab("Time") + | ||
ylab("Probability of Selecting Best Arm") + | ||
opts(title="Accuracy of the UCB2 Algorithm") | ||
ggsave("r/graphs/ucb2_average_accuracy.pdf") | ||
|
||
# Plot variance of chosen arms as a function of time. | ||
stats <- ddply(idata.frame(results), | ||
c("T", "Alpha"), | ||
function (df) {var(df$ChosenArm)}) | ||
ggplot(stats, aes(x = T, y = V1, color=Alpha)) + | ||
geom_line() + | ||
xlab("Time") + | ||
ylab("Variance of Chosen Arm") + | ||
opts(title="Variability of the UCB2 Algorithm") | ||
ggsave("r/graphs/ucb2_variance_choices.pdf") | ||
|
||
# Plot cumulative reward as a function of time. | ||
stats <- ddply(idata.frame(results), | ||
c("T", "Alpha"), | ||
function (df) {mean(df$CumulativeReward)}) | ||
ggplot(stats, aes(x = T, y = V1, color=Alpha)) + | ||
geom_line() + | ||
xlab("Time") + | ||
ylab("Cumulative Reward of Chosen Arm") + | ||
opts(title="Cumulative Reward of the UCB2 Algorithm") | ||
ggsave("r/graphs/ucb2_cumulative_reward.pdf") |