Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #132 from jubatus/add-bandit-service
Add bandit service
- Loading branch information
Showing
4 changed files
with
295 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
from __future__ import absolute_import, division, \ | ||
print_function, unicode_literals | ||
|
||
""" | ||
Using Classifier and CSV file | ||
======================================== | ||
This is an simple example of Bandit service. | ||
The player `Jubatun` tries to maximize the cumulative reward of | ||
a sequence of slot machine plays by multi-armed bandit algorithm. | ||
You can try various simulation settings by modifying the slot machine setting. | ||
Let's edit lines 67-72 and enjoy! | ||
""" | ||
|
||
import random | ||
|
||
from jubakit.bandit import Bandit, Config | ||
|
||
|
||
class Slot(object): | ||
"""Slot machine.""" | ||
|
||
def __init__(self, probability, average, stddev): | ||
""" | ||
Initialize slot machine. | ||
:param float probability: Hit probability. | ||
:param float average: Average of a gaussian distribution. | ||
:param float stddev: Standard deviation of a gaussian distribution. | ||
:return: self | ||
""" | ||
self.probability = probability | ||
self.average = average | ||
self.stddev = stddev | ||
|
||
def hit(self): | ||
""" | ||
This slot machine hits with the given probability. | ||
:return bool: Whether this slot machine hits or not. | ||
""" | ||
if random.random() < self.probability: | ||
return True | ||
else: | ||
return False | ||
|
||
def reward(self): | ||
""" | ||
A reward is determined based on | ||
the given average and standard deviation. | ||
:return float: A reward. | ||
""" | ||
if self.hit(): | ||
return random.gauss(self.average, self.stddev) | ||
else: | ||
return 0.0 | ||
|
||
|
||
# Experimental config. | ||
# Which slot machine should we choose? | ||
iteration = 1000 | ||
slots = { | ||
'bad': Slot(0.1, 50, 10), # E[R] = 5: bad arm | ||
'normal': Slot(0.01, 600, 100), # E[R] = 6: normal arm | ||
'good': Slot(0.001, 8000, 1000) # E[R] = 8: good arm | ||
} | ||
|
||
# Launch bandit service. | ||
player = 'Jubatan' | ||
config = Config(method='epsilon_greedy', parameter={'epsilon': 0.1}) | ||
bandit = Bandit.run(config) | ||
|
||
# Initialize bandit settings. | ||
bandit.reset(player) | ||
for name, slot in slots.items(): | ||
bandit.register_arm(name) | ||
|
||
# Select arms and get rewards. | ||
cumulative_reward = 0 | ||
for i in range(iteration): | ||
arm = bandit.select_arm(player) | ||
reward = float(slots[arm].reward()) | ||
bandit.register_reward(player, arm, reward) | ||
cumulative_reward += reward | ||
|
||
# Show result. | ||
arm_info = bandit.get_arm_info(player) | ||
frequencies = {name: info.trial_count for name, info in arm_info.items()} | ||
|
||
print('cumulative reward: {0:.2f}'.format(cumulative_reward)) | ||
print('slot frequencies: {0}'.format(frequencies)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
from __future__ import absolute_import, division, print_function, unicode_literals | ||
|
||
import jubatus | ||
import jubatus.embedded | ||
|
||
from .base import BaseService, GenericConfig | ||
|
||
|
||
class Bandit(BaseService): | ||
""" | ||
Bandit service. | ||
""" | ||
|
||
@classmethod | ||
def name(cls): | ||
return 'bandit' | ||
|
||
@classmethod | ||
def _client_class(cls): | ||
return jubatus.bandit.client.Bandit | ||
|
||
@classmethod | ||
def _embedded_class(cls): | ||
return jubatus.embedded.Bandit | ||
|
||
def register_arm(self, arm_id): | ||
arm_id = str(arm_id) | ||
return self._client().register_arm(arm_id) | ||
|
||
def delete_arm(self, arm_id): | ||
arm_id = str(arm_id) | ||
return self._client().delete_arm(arm_id) | ||
|
||
def select_arm(self, player_id): | ||
player_id = str(player_id) | ||
return self._client().select_arm(player_id) | ||
|
||
def register_reward(self, player_id, arm_id, reward): | ||
arm_id = str(arm_id) | ||
player_id = str(player_id) | ||
reward = float(reward) | ||
return self._client().register_reward(player_id, arm_id, reward) | ||
|
||
def get_arm_info(self, player_id): | ||
player_id = str(player_id) | ||
arm_info = self._client().get_arm_info(player_id) | ||
# convert key object to string type. | ||
return {str(name): info for name, info in arm_info.items()} | ||
|
||
def reset(self, player_id): | ||
player_id = str(player_id) | ||
return self._client().reset(str(player_id)) | ||
|
||
|
||
class Config(GenericConfig): | ||
""" | ||
Configuration to run Bandit service. | ||
""" | ||
|
||
@classmethod | ||
def methods(cls): | ||
return [ | ||
'epsilon_greedy', | ||
'epsilon_decreasing', | ||
'ucb1', | ||
'softmax', | ||
'exp3', | ||
'ts' | ||
] | ||
|
||
@classmethod | ||
def _default_method(cls): | ||
return 'epsilon_greedy' | ||
|
||
@classmethod | ||
def _default_parameter(cls, method): | ||
params = { | ||
'assume_unrewarded': False | ||
} | ||
if method in ('epsilon_greedy',): | ||
params['epsilon'] = 0.1 | ||
elif method in ('softmax',): | ||
params['tau'] = 0.05 | ||
elif method in ('exp3',): | ||
params['gamma'] = 0.1 | ||
elif method not in ('epsilon_decreasing', 'ucb1', 'ts'): | ||
raise RuntimeError('unknown method: {0}'.format(method)) | ||
return params | ||
|
||
@classmethod | ||
def _default(cls, cfg): | ||
cfg.clear() | ||
|
||
method = cls._default_method() | ||
parameter = cls._default_parameter(method) | ||
|
||
if method is not None: | ||
cfg['method'] = method | ||
if parameter is not None: | ||
cfg['parameter'] = parameter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
from __future__ import absolute_import, division, print_function, unicode_literals | ||
|
||
from unittest import TestCase | ||
|
||
from jubakit.bandit import Bandit, Config | ||
|
||
from . import requireEmbedded | ||
|
||
|
||
class BanditTest(TestCase): | ||
|
||
def test_simple(self): | ||
Bandit() | ||
|
||
def test_simple_launch(self): | ||
Bandit.run(Config()) | ||
|
||
@requireEmbedded | ||
def test_embedded(self): | ||
Bandit.run(Config(), embedded=True) | ||
|
||
def test_register_arm(self): | ||
bandit = Bandit.run(Config()) | ||
ret = bandit.register_arm(1) | ||
self.assertIsInstance(ret, bool) | ||
|
||
def test_delete_arm(self): | ||
bandit = Bandit.run(Config()) | ||
bandit.register_arm(1) | ||
ret = bandit.delete_arm(1) | ||
self.assertIsInstance(ret, bool) | ||
|
||
def test_select_arm(self): | ||
bandit = Bandit.run(Config()) | ||
bandit.register_arm(1) | ||
ret = bandit.select_arm('player') | ||
self.assertEqual(ret, str(1)) | ||
|
||
def test_register_reward(self): | ||
bandit = Bandit.run(Config()) | ||
bandit.register_arm(1) | ||
bandit.select_arm('player') | ||
ret = bandit.register_reward('player', 1, 10) | ||
self.assertIsInstance(ret, bool) | ||
|
||
def test_get_arm_info(self): | ||
from jubatus.bandit.types import ArmInfo | ||
bandit = Bandit.run(Config()) | ||
bandit.register_arm(1) | ||
bandit.select_arm('player') | ||
ret = bandit.get_arm_info('player') | ||
self.assertIsInstance(ret, dict) | ||
for name, info in ret.items(): | ||
self.assertIsInstance(name, str) | ||
self.assertIsInstance(info, ArmInfo) | ||
|
||
def test_reset(self): | ||
bandit = Bandit.run(Config()) | ||
bandit.register_arm(1) | ||
bandit.select_arm('player') | ||
bandit.register_reward('player', 1, 10) | ||
ret = bandit.reset('player') | ||
self.assertIsInstance(ret, bool) | ||
|
||
|
||
class ConfigTest(TestCase): | ||
|
||
def test_simple(self): | ||
config = Config() | ||
self.assertEqual('epsilon_greedy', config['method']) | ||
|
||
def test_methods(self): | ||
config = Config() | ||
self.assertIsInstance(config.methods(), list) | ||
|
||
def test_default(self): | ||
config = Config.default() | ||
self.assertEqual('epsilon_greedy', config['method']) | ||
|
||
def test_method_params(self): | ||
for method in Config.methods(): | ||
self.assertTrue( | ||
'assume_unrewarded' in Config(method=method)['parameter']) | ||
self.assertTrue('epsilon' in Config(method='epsilon_greedy')['parameter']) | ||
self.assertTrue('tau' in Config(method='softmax')['parameter']) | ||
self.assertTrue('gamma' in Config(method='exp3')['parameter']) | ||
|
||
def test_invalid_method(self): | ||
self.assertRaises( | ||
RuntimeError, Config._default_parameter, 'invalid_method') |