forked from VowpalWabbit/vowpal_wabbit
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cb_explore_adf_bag.cc
165 lines (136 loc) · 5.85 KB
/
cb_explore_adf_bag.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
// Copyright (c) by respective owners including Yahoo!, Microsoft, and
// individual contributors. All rights reserved. Released under a BSD (revised)
// license as described in the file LICENSE.
#include "cb_explore_adf_bag.h"
#include "cb_explore_adf_common.h"
#include "reductions.h"
#include "cb_adf.h"
#include "rand48.h"
#include "bs.h"
#include "gen_cs_example.h"
#include "cb_explore.h"
#include "explore.h"
#include <vector>
#include <algorithm>
#include <cmath>
// All exploration algorithms return a vector of id, probability tuples, sorted in order of scores. The probabilities
// are the probability with which each action should be replaced to the top of the list.
namespace VW
{
namespace cb_explore_adf
{
namespace bag
{
struct cb_explore_adf_bag
{
private:
float _epsilon;
size_t _bag_size;
bool _greedify;
bool _first_only;
std::shared_ptr<rand_state> _random_state;
v_array<ACTION_SCORE::action_score> _action_probs;
std::vector<float> _scores;
std::vector<float> _top_actions;
public:
cb_explore_adf_bag(
float epsilon, size_t bag_size, bool greedify, bool first_only, std::shared_ptr<rand_state> random_state);
~cb_explore_adf_bag();
// Should be called through cb_explore_adf_base for pre/post-processing
void predict(VW::LEARNER::multi_learner& base, multi_ex& examples) { predict_or_learn_impl<false>(base, examples); }
void learn(VW::LEARNER::multi_learner& base, multi_ex& examples) { predict_or_learn_impl<true>(base, examples); }
private:
template <bool is_learn>
void predict_or_learn_impl(VW::LEARNER::multi_learner& base, multi_ex& examples);
};
cb_explore_adf_bag::cb_explore_adf_bag(
float epsilon, size_t bag_size, bool greedify, bool first_only, std::shared_ptr<rand_state> random_state)
: _epsilon(epsilon), _bag_size(bag_size), _greedify(greedify), _first_only(first_only), _random_state(random_state)
{
}
template <bool is_learn>
void cb_explore_adf_bag::predict_or_learn_impl(VW::LEARNER::multi_learner& base, multi_ex& examples)
{
// Randomize over predictions from a base set of predictors
v_array<ACTION_SCORE::action_score>& preds = examples[0]->pred.a_s;
uint32_t num_actions = (uint32_t)examples.size();
if (num_actions == 0)
{
preds.clear();
return;
}
_scores.clear();
for (uint32_t i = 0; i < num_actions; i++) _scores.push_back(0.f);
_top_actions.assign(num_actions, 0);
for (uint32_t i = 0; i < _bag_size; i++)
{
// avoid updates to the random num generator
// for greedify, always update first policy once
uint32_t count = is_learn ? ((_greedify && i == 0) ? 1 : BS::weight_gen(_random_state)) : 0;
if (is_learn && count > 0)
VW::LEARNER::multiline_learn_or_predict<true>(base, examples, examples[0]->ft_offset, i);
else
VW::LEARNER::multiline_learn_or_predict<false>(base, examples, examples[0]->ft_offset, i);
assert(preds.size() == num_actions);
for (auto e : preds) _scores[e.action] += e.score;
if (!_first_only)
{
size_t tied_actions = fill_tied(preds);
for (size_t j = 0; j < tied_actions; ++j) _top_actions[preds[j].action] += 1.f / tied_actions;
}
else
_top_actions[preds[0].action] += 1.f;
if (is_learn)
for (uint32_t j = 1; j < count; j++)
VW::LEARNER::multiline_learn_or_predict<true>(base, examples, examples[0]->ft_offset, i);
}
_action_probs.clear();
for (uint32_t i = 0; i < _scores.size(); i++) _action_probs.push_back({i, 0.});
// generate distribution over actions
exploration::generate_bag(
begin(_top_actions), end(_top_actions), begin_scores(_action_probs), end_scores(_action_probs));
exploration::enforce_minimum_probability(_epsilon, true, begin_scores(_action_probs), end_scores(_action_probs));
sort_action_probs(_action_probs, _scores);
for (size_t i = 0; i < num_actions; i++) preds[i] = _action_probs[i];
}
cb_explore_adf_bag::~cb_explore_adf_bag() { _action_probs.delete_v(); }
VW::LEARNER::base_learner* setup(VW::config::options_i& options, vw& all)
{
using config::make_option;
bool cb_explore_adf_option = false;
float epsilon = 0.;
size_t bag_size = 0;
bool greedify = false;
bool first_only = false;
config::option_group_definition new_options("Contextual Bandit Exploration with Action Dependent Features");
new_options
.add(make_option("cb_explore_adf", cb_explore_adf_option)
.keep()
.help("Online explore-exploit for a contextual bandit problem with multiline action dependent features"))
.add(make_option("epsilon", epsilon).keep().allow_override().help("epsilon-greedy exploration"))
.add(make_option("bag", bag_size).keep().help("bagging-based exploration"))
.add(make_option("greedify", greedify).keep().help("always update first policy once in bagging"))
.add(make_option("first_only", first_only).keep().help("Only explore the first action in a tie-breaking event"));
options.add_and_parse(new_options);
if (!cb_explore_adf_option || !options.was_supplied("bag"))
return nullptr;
// Ensure serialization of cb_adf in all cases.
if (!options.was_supplied("cb_adf"))
{
options.insert("cb_adf", "");
}
all.delete_prediction = ACTION_SCORE::delete_action_scores;
size_t problem_multiplier = bag_size;
VW::LEARNER::multi_learner* base = as_multiline(setup_base(options, all));
all.p->lp = CB::cb_label;
all.label_type = label_type_t::cb;
using explore_type = cb_explore_adf_base<cb_explore_adf_bag>;
auto data = scoped_calloc_or_throw<explore_type>(epsilon, bag_size, greedify, first_only, all.get_random_state());
VW::LEARNER::learner<explore_type, multi_ex>& l = VW::LEARNER::init_learner(
data, base, explore_type::learn, explore_type::predict, problem_multiplier, prediction_type_t::action_probs);
l.set_finish_example(explore_type::finish_multiline_example);
return make_base(l);
}
} // namespace bag
} // namespace cb_explore_adf
} // namespace VW