In [2]:
import math
import random
import copy
from numpy.random import beta
from __future__ import division

class Arms():
    def __init__(self, mus):
        self.mus = mus
        self.n_arms = len(mus)
        self.best = max(mus)
        assert all(0 <= mu <= 1 for mu in mus)

    def __str__(self):
        return str(self.mus)

    def pull(self, idx):
        # Bernoulli reward
        return 1 if random.random() < self.mus[idx] else 0


def experiment(arms, policy, T, N=1):
    ''' Run experiment N times, each with timespan T
        and return average total regret '''
    best_mu = arms.best
    n_arms = arms.n_arms
    total_regret = 0
    policy_backup = copy.deepcopy(policy)

    for n in range(N):
        policy = copy.deepcopy(policy_backup)
        history = [[0, 0] for _ in range(n_arms)]
        for t in range(T):
            picked = policy.pick(n_arms, history)
            reward = arms.pull(picked)
            history[picked][0] += reward
            history[picked][1] += 1
            total_regret += best_mu - arms.mus[picked]

    return total_regret / N


def argmax(s):
    ''' return the first index corresponding to the max element '''
    return s.index(max(s))


class Policy():
    def __init__(self):
        pass

    def pick(self, n_arms, history, to_pick=[]):
        ''' to_pick stores the future picks '''
        pass


class RandomPick(Policy):
    def pick(self, n_arms, history):
        return random.choice(range(n_arms))

    
class BatchRandomPick(Policy):
    def __init__(self, batch_size):
        self.batch_size = batch_size

    def pick(self, n_arms, history, to_pick=[]):
        if not to_pick:
            to_pick += [random.choice(range(n_arms))] * self.batch_size
        return to_pick.pop()


class EpsGreedy(Policy):
    def __init__(self, eps):
        self.eps = eps
    def pick(self, n_arms, history):
        if random.random() < self.eps:
            return random.choice(range(n_arms))
        for i, [_, n] in enumerate(history):
            if n == 0:
                return i
        return argmax([r / n for r, n in history])


class UCB(Policy):
    def pick(self, n_arms, history):
        for i, [_, n] in enumerate(history):
            if n == 0:
                return i
        t = sum(n for _, n in history)
        ucb = [r / n + math.sqrt(math.log(t) / n) for r, n in history]
        return argmax(ucb)

    
class BatchUCB(Policy):
    def __init__(self, batch_size):
        self.batch_size = batch_size
        
    def pick(self, n_arms, history, to_pick=[]):
        if to_pick:
            return to_pick.pop()
        for i, [_, n] in enumerate(history):
            if n == 0:
                return i
        t = sum(n for _, n in history)
        ucb = [r / n + math.sqrt(math.log(t) / n) for r, n in history]
        to_pick += [argmax(ucb)] * self.batch_size
        return to_pick.pop()
    

class Thompson(Policy):
    def pick(self, n_arms, history):
        # list of (# success, # failure)
        S_F = [(arm_record[0], arm_record[1] - arm_record[0]) for arm_record in history]
        probs = [beta(s + 1,f + 1) for s, f in S_F]
        return argmax(probs)

In [3]:
# http://stackoverflow.com/questions/15204070/

from scipy.stats import norm, zscore
def sample_power_probtest(p1, p2, power=0.9, sig=0.05):
    
    z = norm.isf([sig / 2]) # two-sided t test
    zp = -norm.isf([power]) 
    d = p1 - p2
    s = 2 * ((p1 + p2) / 2) * (1 - (p1 + p2) / 2)
    n = s * ((zp + z) ** 2) / (d ** 2)
    return int(round(n[0]))

In [4]:
class ABTesting(Policy):

    def __init__(self, power=0.8, sig=0.05):
        from scipy.stats import norm
        self.power = power
        self.sig = sig
        self.best = None
        self.z_need = norm.isf(sig / 2) # 2-tail test
        self.eliminated = []
        self.to_pick = None

    def test_significance(self, history1, history2):
        [r1, n1] = history1
        [r2, n2] = history2
        p1 = r1 / n1
        p2 = r2 / n2
        try:
            z = (p1 - p2) / math.sqrt(p1 * (1 - p1) / n1 + p2 * (1 - p2) / n2)
        except ZeroDivisionError:
            return 0
        if z > self.z_need:
            # first hand is better
            return 1
        if z < -self.z_need:
            # second hand is better
            return -1
        # cannot tell which one is better
        return 0

    def pick(self, n_arms, history):
        
        if self.to_pick is None:
            self.to_pick = list(range(n_arms))

        # if we have the best choice, pick it
        if self.best is not None:
            return self.best

        # pick the arm from to_pick if not eliminated
        while self.to_pick:
            pop = self.to_pick.pop()
            if pop in self.eliminated:
                continue
            else:
                return pop

        # to_pick is empty
        survived = [a for a in range(n_arms) if a not in self.eliminated]
        for a1 in survived:
            if a1 in self.eliminated:
                continue
            for a2 in survived:
                if a1 == a2 or a2 in self.eliminated:
                    continue
                test = self.test_significance(history[a1], history[a2])
                if test == 1:
                    self.eliminated.append(a2)
                elif test == -1:
                    self.eliminated.append(a1)
        survived = [a for a in range(n_arms) if a not in self.eliminated]
        if len(survived) == 1:
            self.best = survived[0]

        if self.best is not None:
            return self.best

        # one more round
        self.to_pick += survived
        return self.to_pick.pop()

In [5]:
class UCBvsLCB(Policy):

    def __init__(self):
        self.best = None
        self.eliminated = []
        self.to_pick = None

    def compare_arm(self, history1, history2, t):
        [r1, n1] = history1
        [r2, n2] = history2
        ucb1 = r1 / n1 + math.sqrt(math.log(t) / n1)
        lcb1 = r1 / n1 - math.sqrt(math.log(t) / n1)
        ucb2 = r2 / n2 + math.sqrt(math.log(t) / n2)
        lcb2 = r2 / n2 - math.sqrt(math.log(t) / n2)
                                                        
        if lcb1 > ucb2:
            return 1
        if lcb2 > ucb1:
            return -1
        return 0

    def pick(self, n_arms, history, to_pick=[]):

        if self.to_pick is None:
            self.to_pick = list(range(n_arms))

        # if we have the best choice, pick it
        if self.best is not None:
            return self.best

        # pick the arm from to_pick if not eliminated
        while self.to_pick:
            pop = self.to_pick.pop()
            if pop in self.eliminated:
                continue
            else:
                return pop
            
        # to_pick is empty
        survived = [a for a in range(n_arms) if a not in self.eliminated]
        t = sum([n for _, n in history])
        for a1 in survived:
            if a1 in self.eliminated:
                continue
            for a2 in survived:
                if a1 == a2 or a2 in self.eliminated:
                    continue
                test = self.compare_arm(history[a1], history[a2], t)
                if test == 1:
                    self.eliminated.append(a2)
                elif test == -1:
                    self.eliminated.append(a1)
        survived = [a for a in range(n_arms) if a not in self.eliminated]
        if len(survived) == 1:
            self.best = survived[0]

        if self.best is not None:
            return self.best

        # one more round
        self.to_pick += survived
        return self.to_pick.pop()

In [6]:
a = Arms([0.2, 0.25, 0.3, 0.35, 0.4])
T = 100000
N = 100

In [7]:
experiment(a, RandomPick(), T, N)

9998.307499999602

In [8]:
experiment(a, BatchRandomPick(100), T, N)

9973.800000002188

In [9]:
experiment(a, EpsGreedy(0.1), T, N)

1077.2515000021842

In [10]:
experiment(a, UCB(), T, N)

361.3409999991922

In [11]:
experiment(a, BatchUCB(100), T, N)

382.8999999996567

In [12]:
experiment(a, Thompson(), T, N)

113.56499999991361

In [13]:
experiment(a, ABTesting(sig=0.01), T, N)

625.4805000130565

In [14]:
experiment(a, UCBvsLCB(), T, N)

1630.954499995524