Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
476 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
from __future__ import absolute_import, division, \ | ||
print_function, unicode_literals | ||
|
||
""" | ||
Using Classifier and CSV file | ||
======================================== | ||
This is an simple example of Bandit service. | ||
The player `Jubatun` tries to maximize the cumulative reward of | ||
a sequence of slot machine plays by multi-armed bandit algorithm. | ||
You can try various simulation settings by modifying the slot machine setting. | ||
Let's edit lines 67-72 and enjoy! | ||
""" | ||
|
||
import random | ||
|
||
from jubakit.bandit import Bandit, Config | ||
|
||
|
||
class Slot(object): | ||
"""Slot machine.""" | ||
|
||
def __init__(self, probability, average, stddev): | ||
""" | ||
Initialize slot machine. | ||
:param float probability: Hit probability. | ||
:param float average: Average of a gaussian distribution. | ||
:param float stddev: Standard deviation of a gaussian distribution. | ||
:return: self | ||
""" | ||
self.probability = probability | ||
self.average = average | ||
self.stddev = stddev | ||
|
||
def hit(self): | ||
""" | ||
This slot machine hits with the given probability. | ||
:return bool: Whether this slot machine hits or not. | ||
""" | ||
if random.random() < self.probability: | ||
return True | ||
else: | ||
return False | ||
|
||
def reward(self): | ||
""" | ||
A reward is determined based on | ||
the given average and standard deviation. | ||
:return float: A reward. | ||
""" | ||
if self.hit(): | ||
return random.gauss(self.average, self.stddev) | ||
else: | ||
return 0.0 | ||
|
||
|
||
# Experimental config. | ||
# Which slot machine should we choose? | ||
iteration = 1000 | ||
slots = { | ||
'bad': Slot(0.1, 50, 10), # E[R] = 5: bad arm | ||
'normal': Slot(0.01, 600, 100), # E[R] = 6: normal arm | ||
'good': Slot(0.001, 8000, 1000) # E[R] = 8: good arm | ||
} | ||
|
||
# Launch bandit service. | ||
player = 'Jubatan' | ||
config = Config(method='epsilon_greedy', parameter={'epsilon': 0.1}) | ||
bandit = Bandit.run(config) | ||
|
||
# Initialize bandit settings. | ||
bandit.reset(player) | ||
for name, slot in slots.items(): | ||
bandit.register_arm(name) | ||
|
||
# Select arms and get rewards. | ||
cumulative_reward = 0 | ||
for i in range(iteration): | ||
arm = bandit.select_arm(player) | ||
reward = float(slots[arm].reward()) | ||
bandit.register_reward(player, arm, reward) | ||
cumulative_reward += reward | ||
|
||
# Show result. | ||
arm_info = bandit.get_arm_info(player) | ||
frequencies = {name: info.trial_count for name, info in arm_info.items()} | ||
|
||
print('cumulative reward: {0:.2f}'.format(cumulative_reward)) | ||
print('slot frequencies: {0}'.format(frequencies)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
from __future__ import absolute_import, division, print_function, unicode_literals | ||
|
||
""" | ||
Visualize training process with TensorBoard | ||
=========================================== | ||
In this example, we show the training process of Jubatus with TensorBoard. | ||
TensorBoard syntax is little complicated and in this example we use tensorboardX library. | ||
tensorboardX is a simple wrapper of TensorBoard that write events with simple function call. | ||
[How to Use] | ||
1. Install tensorboard. | ||
``` | ||
$ pip install tensorboardX | ||
``` | ||
2. Run this example. | ||
3. Check the training process using tensorboard. | ||
``` | ||
$ tensorboard --logdir runs/*** | ||
``` | ||
4. Enjoy! | ||
""" | ||
|
||
from sklearn.datasets import load_digits | ||
from sklearn.metrics import ( | ||
accuracy_score, f1_score, precision_score, recall_score, log_loss) | ||
|
||
from tensorboardX import SummaryWriter | ||
|
||
import jubakit | ||
from jubakit.classifier import Classifier, Dataset, Config | ||
from jubakit.model import JubaDump | ||
|
||
# Load the digits dataset. | ||
digits = load_digits() | ||
|
||
# Create a dataset. | ||
dataset = Dataset.from_array(digits.data, digits.target) | ||
n_samples = len(dataset) | ||
n_train_samples = int(n_samples * 0.7) | ||
train_ds = dataset[:n_train_samples] | ||
test_ds = dataset[n_train_samples:] | ||
|
||
# Create a classifier. | ||
config = Config(method='AROW', | ||
parameter={'regularization_weight': 0.1}) | ||
classifier = Classifier.run(config) | ||
|
||
model_name = 'classifier_digits' | ||
model_path = '/tmp/{}_{}_classifier_{}.jubatus'.format( | ||
classifier._host, classifier._port, model_name) | ||
|
||
# show the feature weights of the target label. | ||
target_label = 4 | ||
|
||
# Initialize summary writer. | ||
writer = SummaryWriter() | ||
|
||
# train and test the classifier. | ||
epochs = 100 | ||
for epoch in range(epochs): | ||
# train | ||
for _ in classifier.train(train_ds): pass | ||
|
||
# test | ||
y_true, y_pred = [], [] | ||
for (_, label, result) in classifier.classify(test_ds): | ||
y_true.append(label) | ||
y_pred.append(result[0][0]) | ||
|
||
# save model to check the feature weights | ||
classifier.save(model_name) | ||
|
||
model = JubaDump.dump_file(model_path) | ||
weights = model['storage']['storage']['weight'] | ||
for feature, label_values in weights.items(): | ||
for label, value in label_values.items(): | ||
if str(label) != str(target_label): | ||
continue | ||
writer.add_scalar('weights/{}'.format(feature), value['v1'], epoch) | ||
|
||
# write scores to tensorboardX summary writer. | ||
acc = accuracy_score(y_true, y_pred) | ||
prec = precision_score(y_true, y_pred, average='macro') | ||
recall = recall_score(y_true, y_pred, average='macro') | ||
f1 = f1_score(y_true, y_pred, average='macro') | ||
writer.add_scalar('metrics/accuracy', acc, epoch) | ||
writer.add_scalar('metrics/precision', prec, epoch) | ||
writer.add_scalar('metrics/recall', recall, epoch) | ||
writer.add_scalar('metrics/f1_score', f1, epoch) | ||
|
||
writer.close() | ||
classifier.stop() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
VERSION = (0, 6, 1) | ||
VERSION = (0, 6, 2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
from __future__ import absolute_import, division, print_function, unicode_literals | ||
|
||
import jubatus | ||
import jubatus.embedded | ||
|
||
from .base import BaseService, GenericConfig | ||
|
||
|
||
class Bandit(BaseService): | ||
""" | ||
Bandit service. | ||
""" | ||
|
||
@classmethod | ||
def name(cls): | ||
return 'bandit' | ||
|
||
@classmethod | ||
def _client_class(cls): | ||
return jubatus.bandit.client.Bandit | ||
|
||
@classmethod | ||
def _embedded_class(cls): | ||
return jubatus.embedded.Bandit | ||
|
||
def register_arm(self, arm_id): | ||
arm_id = str(arm_id) | ||
return self._client().register_arm(arm_id) | ||
|
||
def delete_arm(self, arm_id): | ||
arm_id = str(arm_id) | ||
return self._client().delete_arm(arm_id) | ||
|
||
def select_arm(self, player_id): | ||
player_id = str(player_id) | ||
return self._client().select_arm(player_id) | ||
|
||
def register_reward(self, player_id, arm_id, reward): | ||
arm_id = str(arm_id) | ||
player_id = str(player_id) | ||
reward = float(reward) | ||
return self._client().register_reward(player_id, arm_id, reward) | ||
|
||
def get_arm_info(self, player_id): | ||
player_id = str(player_id) | ||
arm_info = self._client().get_arm_info(player_id) | ||
# convert key object to string type. | ||
return {str(name): info for name, info in arm_info.items()} | ||
|
||
def reset(self, player_id): | ||
player_id = str(player_id) | ||
return self._client().reset(str(player_id)) | ||
|
||
|
||
class Config(GenericConfig): | ||
""" | ||
Configuration to run Bandit service. | ||
""" | ||
|
||
@classmethod | ||
def methods(cls): | ||
return [ | ||
'epsilon_greedy', | ||
'epsilon_decreasing', | ||
'ucb1', | ||
'softmax', | ||
'exp3', | ||
'ts' | ||
] | ||
|
||
@classmethod | ||
def _default_method(cls): | ||
return 'epsilon_greedy' | ||
|
||
@classmethod | ||
def _default_parameter(cls, method): | ||
params = { | ||
'assume_unrewarded': False | ||
} | ||
if method in ('epsilon_greedy',): | ||
params['epsilon'] = 0.1 | ||
elif method in ('softmax',): | ||
params['tau'] = 0.05 | ||
elif method in ('exp3',): | ||
params['gamma'] = 0.1 | ||
elif method not in ('epsilon_decreasing', 'ucb1', 'ts'): | ||
raise RuntimeError('unknown method: {0}'.format(method)) | ||
return params | ||
|
||
@classmethod | ||
def _default(cls, cfg): | ||
cfg.clear() | ||
|
||
method = cls._default_method() | ||
parameter = cls._default_parameter(method) | ||
|
||
if method is not None: | ||
cfg['method'] = method | ||
if parameter is not None: | ||
cfg['parameter'] = parameter |
Oops, something went wrong.