Skip to content

Commit

Permalink
Merge pull request #132 from jubatus/add-bandit-service
Browse files Browse the repository at this point in the history
Add bandit service
  • Loading branch information
rimms committed Jan 19, 2019
2 parents 03790c5 + c0c577a commit 2ce0138
Show file tree
Hide file tree
Showing 4 changed files with 295 additions and 1 deletion.
5 changes: 4 additions & 1 deletion README.rst
Expand Up @@ -22,7 +22,8 @@ Currently jubakit supports
`Anomaly <http://jubat.us/en/api/api_anomaly.html>`_,
`Recommender <http://jubat.us/en/api/api_recommender.html>`_,
`NearestNeighbor <http://jubat.us/en/api/api_nearest_neighbor.html>`_,
`Clustering <http://jubat.us/en/api/api_clustering.html>`_ and
`Clustering <http://jubat.us/en/api/api_clustering.html>`_,
`Bandit <http://jubat.us/en/api/api_bandit.html>`_ and
`Weight <http://jubat.us/en/api/api_weight.html>`_ engines.

Install
Expand Down Expand Up @@ -116,6 +117,8 @@ See the `example <https://github.com/jubatus/jubakit/tree/master/example>`_ dire
+-----------------------------------+-----------------------------------------------+-----------------------+
| clustering_2d.py | Clustering 2-dimensional dataset | |
+-----------------------------------+-----------------------------------------------+-----------------------+
| bandit_slot.py | Multi-armed bandit with slot machine example | |
+-----------------------------------+-----------------------------------------------+-----------------------+
| weight_shogun.py | Tracing fv_converter behavior using Weight | |
+-----------------------------------+-----------------------------------------------+-----------------------+
| weight_model_extract.py | Extract contents of Weight model file | |
Expand Down
97 changes: 97 additions & 0 deletions example/bandit_slot.py
@@ -0,0 +1,97 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from __future__ import absolute_import, division, \
print_function, unicode_literals

"""
Using Classifier and CSV file
========================================
This is an simple example of Bandit service.
The player `Jubatun` tries to maximize the cumulative reward of
a sequence of slot machine plays by multi-armed bandit algorithm.
You can try various simulation settings by modifying the slot machine setting.
Let's edit lines 67-72 and enjoy!
"""

import random

from jubakit.bandit import Bandit, Config


class Slot(object):
"""Slot machine."""

def __init__(self, probability, average, stddev):
"""
Initialize slot machine.
:param float probability: Hit probability.
:param float average: Average of a gaussian distribution.
:param float stddev: Standard deviation of a gaussian distribution.
:return: self
"""
self.probability = probability
self.average = average
self.stddev = stddev

def hit(self):
"""
This slot machine hits with the given probability.
:return bool: Whether this slot machine hits or not.
"""
if random.random() < self.probability:
return True
else:
return False

def reward(self):
"""
A reward is determined based on
the given average and standard deviation.
:return float: A reward.
"""
if self.hit():
return random.gauss(self.average, self.stddev)
else:
return 0.0


# Experimental config.
# Which slot machine should we choose?
iteration = 1000
slots = {
'bad': Slot(0.1, 50, 10), # E[R] = 5: bad arm
'normal': Slot(0.01, 600, 100), # E[R] = 6: normal arm
'good': Slot(0.001, 8000, 1000) # E[R] = 8: good arm
}

# Launch bandit service.
player = 'Jubatan'
config = Config(method='epsilon_greedy', parameter={'epsilon': 0.1})
bandit = Bandit.run(config)

# Initialize bandit settings.
bandit.reset(player)
for name, slot in slots.items():
bandit.register_arm(name)

# Select arms and get rewards.
cumulative_reward = 0
for i in range(iteration):
arm = bandit.select_arm(player)
reward = float(slots[arm].reward())
bandit.register_reward(player, arm, reward)
cumulative_reward += reward

# Show result.
arm_info = bandit.get_arm_info(player)
frequencies = {name: info.trial_count for name, info in arm_info.items()}

print('cumulative reward: {0:.2f}'.format(cumulative_reward))
print('slot frequencies: {0}'.format(frequencies))
102 changes: 102 additions & 0 deletions jubakit/bandit.py
@@ -0,0 +1,102 @@
# -*- coding: utf-8 -*-

from __future__ import absolute_import, division, print_function, unicode_literals

import jubatus
import jubatus.embedded

from .base import BaseService, GenericConfig


class Bandit(BaseService):
"""
Bandit service.
"""

@classmethod
def name(cls):
return 'bandit'

@classmethod
def _client_class(cls):
return jubatus.bandit.client.Bandit

@classmethod
def _embedded_class(cls):
return jubatus.embedded.Bandit

def register_arm(self, arm_id):
arm_id = str(arm_id)
return self._client().register_arm(arm_id)

def delete_arm(self, arm_id):
arm_id = str(arm_id)
return self._client().delete_arm(arm_id)

def select_arm(self, player_id):
player_id = str(player_id)
return self._client().select_arm(player_id)

def register_reward(self, player_id, arm_id, reward):
arm_id = str(arm_id)
player_id = str(player_id)
reward = float(reward)
return self._client().register_reward(player_id, arm_id, reward)

def get_arm_info(self, player_id):
player_id = str(player_id)
arm_info = self._client().get_arm_info(player_id)
# convert key object to string type.
return {str(name): info for name, info in arm_info.items()}

def reset(self, player_id):
player_id = str(player_id)
return self._client().reset(str(player_id))


class Config(GenericConfig):
"""
Configuration to run Bandit service.
"""

@classmethod
def methods(cls):
return [
'epsilon_greedy',
'epsilon_decreasing',
'ucb1',
'softmax',
'exp3',
'ts'
]

@classmethod
def _default_method(cls):
return 'epsilon_greedy'

@classmethod
def _default_parameter(cls, method):
params = {
'assume_unrewarded': False
}
if method in ('epsilon_greedy',):
params['epsilon'] = 0.1
elif method in ('softmax',):
params['tau'] = 0.05
elif method in ('exp3',):
params['gamma'] = 0.1
elif method not in ('epsilon_decreasing', 'ucb1', 'ts'):
raise RuntimeError('unknown method: {0}'.format(method))
return params

@classmethod
def _default(cls, cfg):
cfg.clear()

method = cls._default_method()
parameter = cls._default_parameter(method)

if method is not None:
cfg['method'] = method
if parameter is not None:
cfg['parameter'] = parameter
92 changes: 92 additions & 0 deletions jubakit/test/test_bandit.py
@@ -0,0 +1,92 @@
# -*- coding: utf-8 -*-

from __future__ import absolute_import, division, print_function, unicode_literals

from unittest import TestCase

from jubakit.bandit import Bandit, Config

from . import requireEmbedded


class BanditTest(TestCase):

def test_simple(self):
Bandit()

def test_simple_launch(self):
Bandit.run(Config())

@requireEmbedded
def test_embedded(self):
Bandit.run(Config(), embedded=True)

def test_register_arm(self):
bandit = Bandit.run(Config())
ret = bandit.register_arm(1)
self.assertIsInstance(ret, bool)

def test_delete_arm(self):
bandit = Bandit.run(Config())
bandit.register_arm(1)
ret = bandit.delete_arm(1)
self.assertIsInstance(ret, bool)

def test_select_arm(self):
bandit = Bandit.run(Config())
bandit.register_arm(1)
ret = bandit.select_arm('player')
self.assertEqual(ret, str(1))

def test_register_reward(self):
bandit = Bandit.run(Config())
bandit.register_arm(1)
bandit.select_arm('player')
ret = bandit.register_reward('player', 1, 10)
self.assertIsInstance(ret, bool)

def test_get_arm_info(self):
from jubatus.bandit.types import ArmInfo
bandit = Bandit.run(Config())
bandit.register_arm(1)
bandit.select_arm('player')
ret = bandit.get_arm_info('player')
self.assertIsInstance(ret, dict)
for name, info in ret.items():
self.assertIsInstance(name, str)
self.assertIsInstance(info, ArmInfo)

def test_reset(self):
bandit = Bandit.run(Config())
bandit.register_arm(1)
bandit.select_arm('player')
bandit.register_reward('player', 1, 10)
ret = bandit.reset('player')
self.assertIsInstance(ret, bool)


class ConfigTest(TestCase):

def test_simple(self):
config = Config()
self.assertEqual('epsilon_greedy', config['method'])

def test_methods(self):
config = Config()
self.assertIsInstance(config.methods(), list)

def test_default(self):
config = Config.default()
self.assertEqual('epsilon_greedy', config['method'])

def test_method_params(self):
for method in Config.methods():
self.assertTrue(
'assume_unrewarded' in Config(method=method)['parameter'])
self.assertTrue('epsilon' in Config(method='epsilon_greedy')['parameter'])
self.assertTrue('tau' in Config(method='softmax')['parameter'])
self.assertTrue('gamma' in Config(method='exp3')['parameter'])

def test_invalid_method(self):
self.assertRaises(
RuntimeError, Config._default_parameter, 'invalid_method')

0 comments on commit 2ce0138

Please sign in to comment.