In [1]:
import sys
sys.path.append('../')

import numpy as np
import pandas as pd
import tempfile
import json
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style="darkgrid")

import logging
logging.getLogger().setLevel(logging.INFO)

from banditpylib import trials_to_dataframe
from banditpylib.arms import GaussianArm
from banditpylib.bandits import MultiArmedBandit
from banditpylib.protocols import SinglePlayerProtocol
from banditpylib.learners.mab_fcbai_learner import ExpGap, LilUCBHeuristic, TrackAndStop
from banditpylib.utils import argmax_or_min_tuple, argmax_or_min, argmax_or_min_tuple_second

In [2]:
#means =  [0.7, 0.4, 0.1]
#means = [0.6, 0.5, 0.5]
#means = [0.5, 0.45, 0.43, 0.4] #w∗(µ1) = [0.417 0.390 0.136 0.057]
#means =  [0.6, 0.51, 0.5, 0.49, 0.48] #w∗(µ2) = [0.336, 0.251, 0.177, 0.132, 0.104]
#means =  [0.3, 0.21, 0.2, 0.19, 0.18] #w∗(µ2) = [0.336, 0.251, 0.177, 0.132, 0.104]

In [3]:
confidence = 0.95
means = [0.7, 0.4, 0.1]
max_pulls=5000
std=1

arms = [GaussianArm(mu=mean, std=std) for mean in means]
bandit = MultiArmedBandit(arms=arms)
learners = [
            ExpGap(arm_num=len(arms), confidence=confidence, threshold=3,  name='Exponential-Gap Elimination'),
            LilUCBHeuristic(arm_num=len(arms), confidence=confidence, max_pulls=max_pulls, name='Heuristic lilUCB'),
            TrackAndStop(arm_num=len(arms), confidence=confidence, tracking_rule="C", 
                        max_pulls=max_pulls,  name='Track and stop C-Tracking'),
            TrackAndStop(arm_num=len(arms), confidence=confidence, tracking_rule="D", 
                        max_pulls=max_pulls,  name='Track and stop D-Tracking')
           
           ]

# For each setup, we run 20 trials
trials = 5
temp_file = tempfile.NamedTemporaryFile()

In [4]:
game = SinglePlayerProtocol(bandit=bandit, learners=learners)
# Start playing the game
# Add `debug=True` for debugging purpose
game.play(trials=trials, output_filename=temp_file.name)

INFO:absl:start Exponential-Gap Elimination's play with multi_armed_bandit
INFO:absl:Exponential-Gap Elimination's play with multi_armed_bandit runs 2.52 seconds.
INFO:absl:start Heuristic lilUCB's play with multi_armed_bandit
INFO:absl:Heuristic lilUCB's play with multi_armed_bandit runs 2.61 seconds.
INFO:absl:start Track and stop C-Tracking's play with multi_armed_bandit


w_star:  [0. 1. 0.]
w_star:  [1. 0. 0.]
w_star:  [0. 0. 1.]
w_star:  [0. 1. 0.]
w_star:  [1. 0. 0.]


  kl_divergence(
  return (
Process SpawnPoolWorker-29:
Traceback (most recent call last):
  File "/Users/mwai/opt/anaconda3/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/Users/mwai/opt/anaconda3/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/mwai/opt/anaconda3/lib/python3.8/multiprocessing/pool.py", line 125, in worker
    result = (True, func(*args, **kwds))
  File "/Users/mwai/Library/CloudStorage/OneDrive-Chalmers/Documents/Chalmers Research/Bandits/banditpylib/examples/../banditpylib/protocols/single_player_protocol.py", line 68, in _one_trial
    actions = current_learner.actions(self._bandit.context)
  File "/Users/mwai/Library/CloudStorage/OneDrive-Chalmers/Documents/Chalmers Research/Bandits/banditpylib/examples/../banditpylib/learners/mab_fcbai_learner/track_and_stop.py", line 232, in actions
    w_star = self.solve_wstar(self.mu_hat)
  File "/Users/mwai/Libr

  File "/Users/mwai/opt/anaconda3/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/mwai/opt/anaconda3/lib/python3.8/multiprocessing/pool.py", line 114, in worker
    task = get()
  File "/Users/mwai/opt/anaconda3/lib/python3.8/multiprocessing/queues.py", line 356, in get
    res = self._reader.recv_bytes()
  File "/Users/mwai/opt/anaconda3/lib/python3.8/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
  File "/Users/mwai/opt/anaconda3/lib/python3.8/multiprocessing/connection.py", line 414, in _recv_bytes
    buf = self._recv(4)
  File "/Users/mwai/opt/anaconda3/lib/python3.8/multiprocessing/connection.py", line 379, in _recv
    chunk = read(handle, remaining)
KeyboardInterrupt


KeyboardInterrupt: 

In [None]:
trials_df = trials_to_dataframe(temp_file.name)

In [None]:
trials_df.tail(100)

In [None]:
trials_df['confidence'] = confidence

In [None]:
fig = plt.figure()
ax = plt.subplot(111)
sns.barplot(x='confidence', y='total_actions', hue='learner', data=trials_df)
plt.ylabel('pulls')
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))