Permalink
Browse files

adding ldf

  • Loading branch information...
1 parent 0cbfe00 commit f8122b759c0c41c95d2795c62872883fb8960471 @hal3 hal3 committed Jan 27, 2012
Showing with 350 additions and 37 deletions.
  1. +165 −4 vowpalwabbit/csoaa.cc
  2. +29 −0 vowpalwabbit/csoaa.h
  3. +4 −0 vowpalwabbit/parse_args.cc
  4. +9 −33 vowpalwabbit/sequence.cc
  5. +131 −0 vowpalwabbit/wap.cc
  6. +12 −0 vowpalwabbit/wap.h
View
@@ -124,13 +124,12 @@ void parse_label(void* v, v_array<substring>& words)
}
}
-void print_update(example *ec)
+void print_update(bool is_test, example *ec)
{
if (global.sd->weighted_examples > global.sd->dump_interval && !global.quiet && !global.bfgs)
{
- label* ld = (label*) ec->ld;
char label_buf[32];
- if (is_test_label(ld))
+ if (is_test)
strcpy(label_buf," unknown");
else
sprintf(label_buf," known");
@@ -185,7 +184,7 @@ void output_example(example* ec)
global.sd->example_number++;
- print_update(ec);
+ print_update(is_test_label((label*)ec->ld), ec);
}
@@ -282,3 +281,165 @@ void parse_flag(size_t s)
}
}
+
+namespace CSOAA_LDF {
+
+ v_array<example*> ec_seq = v_array<example*>();
+ size_t read_example_this_loop = 0;
+
+ void do_actual_learning()
+ {
+ if (ec_seq.index() <= 0) return; // nothing to do
+
+ int K = ec_seq.index();
+ float min_cost = FLT_MAX;
+ v_array<float> predictions = v_array<float>();
+ float min_score = FLT_MAX;
+ size_t prediction = 0;
+ float prediction_cost = 0.;
+ bool isTest = example_is_test(*ec_seq.begin);
+ for (int k=0; k<K; k++) {
+ example *ec = ec_seq.begin[k];
+ label *ld = (label*)ec->ld;
+
+ label_data simple_label;
+ simple_label.initial = 0.;
+ simple_label.label = FLT_MAX;
+ simple_label.weight = 0.;
+
+ if (ld->weight < min_cost)
+ min_cost = ld->weight;
+ if (example_is_test(ec) != isTest) {
+ isTest = true;
+ cerr << "warning: got mix of train/test data; assuming test" << endl;
+ }
+
+ ec->ld = &simple_label;
+ global.learn(ec); // make a prediction
+ push(predictions, ec->partial_prediction);
+ if (ec->partial_prediction < min_score) {
+ min_score = ec->partial_prediction;
+ prediction = ld->label;
+ prediction_cost = ld->weight;
+ }
+
+ ec->ld = ld;
+ }
+ prediction_cost -= min_cost;
+ // do actual learning
+ for (int k=0; k<K; k++) {
+ example *ec = ec_seq.begin[k];
+ label *ld = (label*)ec->ld;
+
+ // learn
+ label_data simple_label;
+ simple_label.initial = 0.;
+ simple_label.label = ld->weight;
+ simple_label.weight = 1.;
+ ec->ld = &simple_label;
+ ec->partial_prediction = 0.;
+ global.learn(ec);
+
+ // fill in test predictions
+ *(OAA::prediction_t*)&(ec->final_prediction) = (prediction == ld->label) ? 1 : 0;
+ ec->partial_prediction = predictions.begin[k];
+
+ // restore label
+ ec->ld = ld;
+ }
+ }
+
+ void output_example(example* ec)
+ {
+ label* ld = (label*)ec->ld;
+ global.sd->weighted_examples += 1.;
+ global.sd->total_features += ec->num_features;
+ float loss = 0.;
+ if (!example_is_test(ec) && (ec->final_prediction == 1))
+ loss = ld->weight;
+ global.sd->sum_loss += loss;
+ global.sd->sum_loss_since_last_dump += loss;
+
+ for (size_t i = 0; i<global.final_prediction_sink.index(); i++) {
+ int f = global.final_prediction_sink[i];
+ global.print(f, *(OAA::prediction_t*)&ec->final_prediction, 0, ec->tag);
+ }
+
+ global.sd->example_number++;
+
+ CSOAA::print_update(example_is_test(ec), ec);
+ }
+
+ void clear_seq(bool output)
+ {
+ if (ec_seq.index() > 0)
+ for (example** ecc=ec_seq.begin; ecc!=ec_seq.end; ecc++) {
+ if (output)
+ output_example(*ecc);
+ free_example(*ecc);
+ }
+ ec_seq.erase();
+ }
+
+ void learn(example *ec) {
+ // TODO: break long examples
+ if (example_is_newline(ec)) {
+ do_actual_learning();
+ clear_seq(true);
+ global_print_newline();
+ } else {
+ push(ec_seq, ec);
+ }
+ }
+
+ void initialize()
+ {
+ global.initialize();
+ }
+
+ void finalize()
+ {
+ clear_seq(true);
+ if (ec_seq.begin != NULL)
+ free(ec_seq.begin);
+ global.finish();
+ }
+
+ void drive_csoaa_ldf()
+ {
+ example* ec = NULL;
+ initialize();
+ read_example_this_loop = 0;
+ while (true) {
+ if ((ec = get_example()) != NULL) { // semiblocking operation
+ learn(ec);
+ } else if (parser_done()) {
+ do_actual_learning();
+ finalize();
+ return;
+ }
+ }
+ }
+
+ void parse_flag(size_t s)
+ {
+ *(global.lp) = OAA::mc_label_parser;
+ global.driver = drive_csoaa_ldf;
+ global.cs_initialize = initialize;
+ global.cs_learn = learn;
+ global.cs_finish = finalize;
+ }
+
+ void global_print_newline()
+ {
+ char temp[1];
+ temp[0] = '\n';
+ for (size_t i=0; i<global.final_prediction_sink.index(); i++) {
+ int f = global.final_prediction_sink[i];
+ ssize_t t = write(f, temp, 1);
+ if (t != 1)
+ std::cerr << "write error" << std::endl;
+ }
+ }
+
+}
View
@@ -5,6 +5,8 @@
#include "parse_primitives.h"
#include "global_data.h"
#include "example.h"
+#include "oaa.h"
+#include "parser.h"
namespace CSOAA {
@@ -27,4 +29,31 @@ namespace CSOAA {
delete_label, weight, initial,
sizeof(label)};
}
+
+namespace CSOAA_LDF {
+ typedef OAA::mc_label label;
+
+ inline int example_is_newline(example* ec)
+ {
+ // if only index is constant namespace or no index
+ return ((ec->indices.index() == 0) ||
+ ((ec->indices.index() == 1) &&
+ (ec->indices.last() == constant_namespace)));
+ }
+
+ inline int example_is_test(example* ec)
+ {
+ return (((OAA::mc_label*)ec->ld)->label == (uint32_t)-1);
+ }
+
+ void parse_flag(size_t s);
+ void global_print_newline();
+ void output_example(example* ec);
+
+ const label_parser cs_label_parser = {OAA::default_label, OAA::parse_label,
+ OAA::cache_label, OAA::read_cached_label,
+ OAA::delete_label, OAA::weight, OAA::initial,
+ sizeof(label)};
+}
+
#endif
@@ -72,6 +72,7 @@ po::variables_map parse_args(int argc, char *argv[],
("conjugate_gradient", "use conjugate gradient based optimization")
("csoaa", po::value<size_t>(), "Use one-against-all multiclass learning with <k> costs")
("wap", po::value<size_t>(), "Use weighted all-pairs multiclass learning with <k> costs")
+ ("csoaa_ldf", "Use one-against-all multiclass learning with label dependent features")
("nonormalize", "Do not normalize online updates")
("l1", po::value<float>(&global.l1_lambda)->default_value(0.0), "l_1 lambda")
("l2", po::value<float>(&global.l2_lambda)->default_value(0.0), "l_2 lambda")
@@ -558,6 +559,9 @@ po::variables_map parse_args(int argc, char *argv[],
if(vm.count("csoaa"))
CSOAA::parse_flag(vm["csoaa"].as<size_t>());
+ if(vm.count("csoaa_ldf"))
+ CSOAA_LDF::parse_flag(0);
+
if (vm.count("sequence")) {
if (vm.count("wap")) {
// do nothing, WAP is already initialized
View
@@ -301,18 +301,6 @@ void global_print_label(example *ec, size_t label)
}
}
-void global_print_newline()
-{
- char temp[1];
- temp[0] = '\n';
- for (size_t i=0; i<global.final_prediction_sink.index(); i++) {
- int f = global.final_prediction_sink[i];
- ssize_t t = write(f, temp, 1);
- if (t != 1)
- cerr << "write error" << endl;
- }
-}
-
void print_history(history h)
{
@@ -529,18 +517,6 @@ inline void clear_history(history h)
*** EXAMPLE MANIPULATION
********************************************************************************************/
-inline int example_is_newline(example* ec)
-{
- // if only index is constant namespace or no index
- return ((ec->indices.index() == 0) ||
- ((ec->indices.index() == 1) &&
- (ec->indices.last() == constant_namespace)));
-}
-
-inline int example_is_test(example* ec)
-{
- return (((OAA::mc_label*)ec->ld)->label == (uint32_t)-1);
-}
string audit_feature_space("history");
@@ -881,15 +857,15 @@ void run_test(example* ec)
clear_history(current_history);
- while ((ec != NULL) && (! example_is_newline(ec))) {
+ while ((ec != NULL) && (! CSOAA_LDF::example_is_newline(ec))) {
int policy = random_policy(0);
old_label = (OAA::mc_label*)ec->ld;
seq_num_features += ec->num_features;
global.sd->weighted_examples += old_label->weight;
global.sd->total_features += ec->num_features;
- if (! example_is_test(ec)) {
+ if (! CSOAA_LDF::example_is_test(ec)) {
if (!warned) {
cerr << "warning: mix of train and test data in sequence prediction at " << ec->example_counter << "; assuming all test" << endl;
warned = 1;
@@ -908,7 +884,7 @@ void run_test(example* ec)
}
if (ec != NULL) {
free_example(ec);
- global_print_newline();
+ CSOAA_LDF::global_print_newline();
}
global.sd->example_number++;
@@ -925,24 +901,24 @@ void process_next_example_sequence()
return;
// skip initial newlines
- while (example_is_newline(cur_ec)) {
- global_print_newline();
+ while (CSOAA_LDF::example_is_newline(cur_ec)) {
+ CSOAA_LDF::global_print_newline();
free_example(cur_ec);
cur_ec = safe_get_example(1);
if (cur_ec == NULL)
return;
}
- if (example_is_test(cur_ec)) {
+ if (CSOAA_LDF::example_is_test(cur_ec)) {
run_test(cur_ec);
return;
}
// we know we're training
size_t n = 0;
int skip_this_one = 0;
- while ((cur_ec != NULL) && (! example_is_newline(cur_ec))) {
- if (example_is_test(cur_ec) && !skip_this_one) {
+ while ((cur_ec != NULL) && (! CSOAA_LDF::example_is_newline(cur_ec))) {
+ if (CSOAA_LDF::example_is_test(cur_ec) && !skip_this_one) {
cerr << "warning: mix of train and test data in sequence prediction at " << cur_ec->example_counter << "; skipping" << endl;
skip_this_one = 1;
}
@@ -985,7 +961,7 @@ void process_next_example_sequence()
if (random_policy(1) == -1)
policy_seq[t] = -1;
}
- global_print_newline();
+ CSOAA_LDF::global_print_newline();
global.sd->example_number++;
print_update(1, seq_num_features);
Oops, something went wrong.

0 comments on commit f8122b7

Please sign in to comment.