In [1]:
from scipy import stats
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from collections import namedtuple
import random

In [2]:
Probable_Arm = namedtuple('Probable_Arm', ['arm', 'win_probability'])

class Arm():
    def __init__(self):
        self.alpha = 1
        self.beta = 1
        
        self.rewards = 0
        self.tries = 0
        

    def update(self, was_success):
        self.tries += 1
        self.rewards += int(was_success)
        
        self.alpha += self.rewards
        self.beta += (self.tries - self.rewards)
        

In [3]:

Winner_Pair = namedtuple('Winner_Pair', ['arm', 'probability'])

def get_winner(arms):
    winner = None
    for arm in arms:
        arm_probability = stats.beta.rvs(arm.arm.alpha, arm.arm.beta)
        if winner is None:
            winner = Winner_Pair(arm=arm, probability=arm_probability)
        elif arm_probability > winner.probability:
            winner = Winner_Pair(arm=arm, probability=arm_probability)
    return winner.arm


In [4]:
def create_arms(probabilities):
    return [Probable_Arm(arm=Arm(), win_probability=i) 
            for i in probabilities]
        

In [5]:
def run_trial(arm):
    was_success = False
    r = random.random()

    if arm.win_probability > random.random():
        was_success = True
    arm.arm.update(was_success)
    return arm


In [6]:
min_value = 0.001

def graph_arms(probable_arms):
    x = np.linspace(0, 1, 1000)
    df = pd.DataFrame(index=x)
    for arm in probable_arms:
        df[arm.win_probability] = stats.beta.pdf(x, arm.arm.alpha, arm.arm.beta)

    max_row = len(x)
    for index, row in df[::-1].iterrows():
        if any([r > min_value for r in row]):
            max_row = index
            break
        
    ax = sns.lineplot(data=df[0:max_row])
    for line in ax.lines:
        ax.fill_between(line.get_xydata()[:,0],line.get_xydata()[:,1], alpha=0.3)
    plt.autoscale(tight=True)
    plt.legend(title="Probability")
    plt.show()

In [7]:
def graph_history(history):
    ax = sns.lineplot(data=history)
    plt.legend(title="Total impressions per Arm")
    plt.show()

In [8]:
def run(num_trials, probs):
    prob_list = [float(i.strip()) for i in probs.split(',')]
    arms = set(create_arms(prob_list))
    history = pd.DataFrame(index=[i for i in range(num_trials)])
    for arm in arms:
        history[arm.win_probability] = [0 for i in range(num_trials)]

    for i in range(num_trials):
        winner = get_winner(arms)
        arms.remove(winner)

        arm = run_trial(winner)
        arms.add(arm)

        for arm in arms:
            history.loc[i,arm.win_probability] = arm.arm.tries

    graph_arms(arms)
    graph_history(history)
    


In [9]:
widget = interact_manual(run, 
                     num_trials=widgets.IntSlider(min=1, max=1000, step=1, value=10, continuous_update=False), 
                     probs=widgets.Text(
                        placeholder='Comma-sep floats',
                        description='Probabilities:',
                        value='0.05,0.075,0.1',
                        continuous_update=False
                    ));
widget

interactive(children=(IntSlider(value=10, continuous_update=False, description='num_trials', max=1000, min=1),…

<function __main__.run(num_trials, probs)>