forked from VowpalWabbit/vowpal_wabbit
-
Notifications
You must be signed in to change notification settings - Fork 3
/
search_multiclasstask.cc
55 lines (48 loc) · 1.88 KB
/
search_multiclasstask.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
/*
CoPyright (c) by respective owners including Yahoo!, Microsoft, and
individual contributors. All rights reserved. Released under a BSD (revised)
license as described in the file LICENSE.
*/
#include "search_multiclasstask.h"
namespace MulticlassTask { Search::search_task task = { "multiclasstask", run, initialize, finish, nullptr, nullptr }; }
namespace MulticlassTask {
struct task_data {
size_t max_label;
size_t num_level;
v_array<uint32_t> y_allowed;
};
void initialize(Search::search& sch, size_t& num_actions, po::variables_map& /*vm*/) {
task_data * my_task_data = new task_data();
sch.set_options( 0 );
sch.set_num_learners(num_actions);
my_task_data->max_label = num_actions;
my_task_data->num_level = (size_t)ceil(log(num_actions) /log(2));
my_task_data->y_allowed.push_back(1);
my_task_data->y_allowed.push_back(2);
sch.set_task_data(my_task_data);
}
void finish(Search::search& sch) {
task_data * my_task_data = sch.get_task_data<task_data>();
my_task_data->y_allowed.delete_v();
delete my_task_data;
}
void run(Search::search& sch, vector<example*>& ec) {
task_data * my_task_data = sch.get_task_data<task_data>();
size_t gold_label = ec[0]->l.multi.label;
size_t label = 0;
size_t learner_id = 0;
for(size_t i=0; i<my_task_data->num_level;i++){
size_t mask = 1<<(my_task_data->num_level-i-1);
size_t y_allowed_size = (label+mask +1 <= my_task_data->max_label)?2:1;
action oracle = (((gold_label-1)&mask)>0)+1;
size_t prediction = sch.predict(*ec[0], 0, &oracle, 1, nullptr, nullptr, my_task_data->y_allowed.begin, y_allowed_size, learner_id); // TODO: do we really need y_allowed?
learner_id= (learner_id << 1) + prediction;
if(prediction == 2)
label += mask;
}
label+=1;
sch.loss(!(label == gold_label));
if (sch.output().good())
sch.output() << label << ' ';
}
}