Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unlearner to lsh #190

Merged
merged 5 commits into from
Nov 26, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 111 additions & 70 deletions jubatus/core/driver/recommender_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,15 @@ using jubatus::util::text::json::json_object;
using jubatus::util::text::json::json_integer;
using jubatus::util::text::json::json_string;
using jubatus::util::text::json::json_float;
using jubatus::util::text::json::to_json;
using jubatus::util::lang::lexical_cast;
using jubatus::core::fv_converter::datum;
using jubatus::core::recommender::recommender_base;
using jubatus::core::storage::column_table;
using jubatus::core::unlearner::unlearner_base;
using jubatus::core::unlearner::lru_unlearner;
using jubatus::core::recommender::inverted_index;

namespace jubatus {
namespace core {
namespace driver {
Expand Down Expand Up @@ -165,87 +167,123 @@ INSTANTIATE_TEST_CASE_P(nn_recommender_test_instance,
nn_recommender_test,
testing::ValuesIn(create_recommender_bases()));

TEST(inverted_index_unlearner, lru_update) {
unlearner::lru_unlearner::config conf;
conf.max_size = 3;
shared_ptr<unlearner::unlearner_base> unl(new unlearner::lru_unlearner(conf));
shared_ptr<recommender_base> inv =
shared_ptr<recommender_base>(new core::recommender::inverted_index(unl));
shared_ptr<driver::recommender> recommender =
shared_ptr<driver::recommender>(
new driver::recommender(inv,
make_tf_idf_fv_converter()));
recommender->update_row("id1", create_datum_str("a", "a b c"));
recommender->update_row("id2", create_datum_str("a", "d e f"));
recommender->update_row("id3", create_datum_str("a", "e f g"));
recommender->update_row("id4", create_datum_str("a", "f g h"));
recommender->update_row("id5", create_datum_str("a", "h i j"));
recommender->update_row("id6", create_datum_str("a", "i j a"));
recommender->update_row("id7", create_datum_str("a", "j a b"));

vector<pair<string, float> > ret = recommender->similar_row_from_id("id6", 4);
ASSERT_EQ(3u, ret.size());
class recommender_with_unlearning_test
: public ::testing::TestWithParam<pair<string,
common::jsonconfig::config> > {
protected:
shared_ptr<driver::recommender> create_driver() const {
const string id("my_id");
return shared_ptr<driver::recommender>(
new driver::recommender(
core::recommender::recommender_factory::create_recommender(
GetParam().first, GetParam().second, id),
make_tf_idf_fv_converter()));
}

void SetUp() {
recommender_ = create_driver();
}

void TearDown() {
recommender_->clear();
recommender_.reset();
}

shared_ptr<driver::recommender> recommender_;
};

const size_t MAX_SIZE = 3;

vector<pair<string, common::jsonconfig::config> >
create_recommender_configs_with_unlearner() {
vector<pair<string, common::jsonconfig::config> > configs;

json js(new json_object);
js["unlearner"] = to_json(string("lru"));
js["unlearner_parameter"] = new json_object;
js["unlearner_parameter"]["max_size"] = to_json(MAX_SIZE);
js["unlearner_parameter"]["sticky_pattern"] =
to_json(string("*_sticky"));

// inverted_index
configs.push_back(make_pair("inverted_index",
common::jsonconfig::config(js)));

// lsh
json js_lsh(js.clone());
js_lsh["hash_num"] = to_json(64);
configs.push_back(make_pair("lsh", common::jsonconfig::config(js_lsh)));

// TODO(@rimms): Add NN-based algorithm

return configs;
}

TEST_P(recommender_with_unlearning_test, update_row) {
recommender_->update_row("id1", create_datum_str("a", "a b c"));
recommender_->update_row("id2", create_datum_str("a", "d e f"));
recommender_->update_row("id3", create_datum_str("a", "e f g"));
recommender_->update_row("id4", create_datum_str("a", "f g h"));
recommender_->update_row("id5", create_datum_str("a", "h i j"));
recommender_->update_row("id6", create_datum_str("a", "i j a"));
recommender_->update_row("id7", create_datum_str("a", "j a b"));

vector<pair<string, float> > ret =
recommender_->similar_row_from_id("id6", MAX_SIZE + 1);
ASSERT_EQ(MAX_SIZE, ret.size());
}

TEST(inverted_index_unlearner, lru_delete) {
unlearner::lru_unlearner::config conf;
conf.max_size = 3;
shared_ptr<unlearner::unlearner_base> unl(new unlearner::lru_unlearner(conf));
shared_ptr<recommender_base> inv =
shared_ptr<recommender_base>(new core::recommender::inverted_index(unl));
shared_ptr<driver::recommender> recommender =
shared_ptr<driver::recommender>(
new driver::recommender(inv,
make_tf_idf_fv_converter()));
recommender->update_row("id1", create_datum_str("a", "a b c"));
recommender->update_row("id2", create_datum_str("a", "d e f"));
recommender->update_row("id3", create_datum_str("a", "e f g"));
recommender->clear_row("id1");
recommender->update_row("id4", create_datum_str("a", "f g h"));
recommender->update_row("id5", create_datum_str("a", "h i j"));
recommender->update_row("id6", create_datum_str("a", "i j a"));
recommender->update_row("id7", create_datum_str("a", "j a b"));

vector<pair<string, float> > ret = recommender->similar_row_from_id("id6", 4);
ASSERT_EQ(3u, ret.size());

vector<string> all = recommender->get_all_rows();
ASSERT_EQ(3u, all.size());
TEST_P(recommender_with_unlearning_test, clear_row) {
recommender_->update_row("id1", create_datum_str("a", "a b c"));
recommender_->update_row("id2", create_datum_str("a", "d e f"));
recommender_->update_row("id3", create_datum_str("a", "e f g"));
recommender_->clear_row("id1");
recommender_->update_row("id4", create_datum_str("a", "f g h"));
recommender_->update_row("id5", create_datum_str("a", "h i j"));
recommender_->update_row("id6", create_datum_str("a", "i j a"));
recommender_->update_row("id7", create_datum_str("a", "j a b"));

vector<pair<string, float> > ret =
recommender_->similar_row_from_id("id6", MAX_SIZE + 1);
ASSERT_EQ(MAX_SIZE, ret.size());

vector<string> all = recommender_->get_all_rows();
ASSERT_EQ(MAX_SIZE, all.size());
}

class inverted_index_mix_test : public ::testing::Test {
INSTANTIATE_TEST_CASE_P(
recommender_with_unlearning_test_instance,
recommender_with_unlearning_test,
testing::ValuesIn(create_recommender_configs_with_unlearner()));

class recommender_mix_with_unlearning_test
: public ::testing::TestWithParam<pair<string,
common::jsonconfig::config> > {
protected:
shared_ptr<driver::recommender> create_driver(const string& id) const {
return shared_ptr<driver::recommender>(
new driver::recommender(
core::recommender::recommender_factory::create_recommender(
GetParam().first, GetParam().second, id),
make_fv_converter()));
}

virtual void SetUp() {
lru_unlearner::config conf;
conf.max_size = 3;
conf.sticky_pattern = "*_sticky";
unl1 = shared_ptr<unlearner_base>(new lru_unlearner(conf));
inv1 = shared_ptr<recommender_base>(new inverted_index(unl1));
recommender1 =
shared_ptr<driver::recommender>(
new driver::recommender(inv1,
make_fv_converter()));
recommender1 = create_driver("my_id1");
recommender2 = create_driver("my_id2");

mixable1 =
dynamic_cast<framework::linear_mixable*>(recommender1->get_mixable());
ASSERT_TRUE(mixable1 != NULL);

unl2 = shared_ptr<unlearner_base>(new lru_unlearner(conf));
inv2 = shared_ptr<recommender_base>(new inverted_index(unl2));
recommender2 =
shared_ptr<driver::recommender>(
new driver::recommender(inv2,
make_fv_converter()));
mixable2 =
dynamic_cast<framework::linear_mixable*>(recommender2->get_mixable());
ASSERT_TRUE(mixable2 != NULL);
}

virtual void TearDown() {
unl1.reset();
inv1.reset();
recommender1->clear();
recommender1.reset();
unl2.reset();
inv2.reset();
recommender2->clear();
recommender2.reset();
}

Expand Down Expand Up @@ -274,13 +312,11 @@ class inverted_index_mix_test : public ::testing::Test {
mixable2->mix(unpacked1.get(), diff);
return diff;
}
shared_ptr<unlearner::unlearner_base> unl1, unl2;
shared_ptr<recommender_base> inv1, inv2;
shared_ptr<driver::recommender> recommender1, recommender2;
framework::linear_mixable *mixable1, *mixable2;
};

TEST_F(inverted_index_mix_test, basic) {
TEST_P(recommender_mix_with_unlearning_test, basic) {
recommender1->update_row("id1", create_datum_str("a", "a b c"));
recommender1->update_row("id2", create_datum_str("a", "d e f"));
recommender1->update_row("id3", create_datum_str("a", "e f g"));
Expand All @@ -298,7 +334,7 @@ TEST_F(inverted_index_mix_test, basic) {
ASSERT_EQ(3u, recommender2->get_all_rows().size());
}

TEST_F(inverted_index_mix_test, mix_all) {
TEST_P(recommender_mix_with_unlearning_test, mix_all) {
recommender1->update_row("id1", create_datum_str("a", "a b c"));
recommender1->update_row("id2", create_datum_str("a", "d e f"));
recommender1->update_row("id3", create_datum_str("a", "e f g"));
Expand Down Expand Up @@ -333,7 +369,7 @@ TEST_F(inverted_index_mix_test, mix_all) {
}
}

TEST_F(inverted_index_mix_test, all_sticky) {
TEST_P(recommender_mix_with_unlearning_test, all_sticky) {
recommender1->update_row("id1_sticky", create_datum_str("a", "a b c"));
recommender1->update_row("id2_sticky", create_datum_str("a", "d e f"));
recommender1->update_row("id3_sticky", create_datum_str("a", "e f g"));
Expand Down Expand Up @@ -381,6 +417,11 @@ TEST_F(inverted_index_mix_test, all_sticky) {

// TODO(kumagi): append test if there are all sticky rows

INSTANTIATE_TEST_CASE_P(
recommender_mix_with_lru_unlearning_test_instance,
recommender_mix_with_unlearning_test,
testing::ValuesIn(create_recommender_configs_with_unlearner()));

} // namespace driver
} // namespace core
} // namespace jubatus
40 changes: 40 additions & 0 deletions jubatus/core/recommender/lsh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@
#include <string>
#include <utility>
#include <vector>
#include "jubatus/util/lang/function.h"
#include "jubatus/util/lang/bind.h"
#include "../common/exception.hpp"
#include "../common/hash.hpp"
#include "lsh_util.hpp"
#include "../storage/bit_index_storage.hpp"
#include "../unlearner/unlearner_factory.hpp"

using std::pair;
using std::string;
Expand Down Expand Up @@ -59,6 +62,25 @@ lsh::lsh(const config& config)
}

initialize_model();

if (config.unlearner) {
if (!config.unlearner_parameter) {
throw JUBATUS_EXCEPTION(
common::config_exception() << common::exception::error_message(
"unlearner is set but unlearner_parameter is not found"));
}
unlearner_ = core::unlearner::create_unlearner(*config.unlearner,
core::common::jsonconfig::config(*config.unlearner_parameter));
mixable_storage_->get_model()->set_unlearner(unlearner_);
unlearner_->set_callback(
util::lang::bind(&lsh::remove_row, this, util::lang::_1));
} else {
if (config.unlearner_parameter) {
throw JUBATUS_EXCEPTION(
common::config_exception() << common::exception::error_message(
"unlearner_parameter is set but unlearner is not found"));
}
}
}

lsh::lsh()
Expand Down Expand Up @@ -98,9 +120,19 @@ void lsh::clear() {
jubatus::util::data::unordered_map<std::string, std::vector<float> >()
.swap(column2baseval_);
mixable_storage_->get_model()->clear();
if (unlearner_) {
unlearner_->clear();
}
}

void lsh::clear_row(const string& id) {
remove_row(id);
if (unlearner_) {
unlearner_->remove(id);
}
}

void lsh::remove_row(const string& id) {
orig_.remove_row(id);
mixable_storage_->get_model()->remove_row(id);
}
Expand Down Expand Up @@ -128,13 +160,21 @@ void lsh::generate_column_base(const string& column) {
}

void lsh::update_row(const string& id, const sfv_diff_t& diff) {
if (unlearner_ && !unlearner_->can_touch(id)) {
throw JUBATUS_EXCEPTION(common::exception::runtime_error(
"cannot add new row as number of sticky rows reached "
"the maximum size of unlearner: " + id));
}
generate_column_bases(diff);
orig_.set_row(id, diff);
common::sfv_t row;
orig_.get_row(id, row);
bit_vector bv;
calc_lsh_values(row, bv);
mixable_storage_->get_model()->set_row(id, bv);
if (unlearner_) {
unlearner_->touch(id);
}
}

void lsh::get_all_row_ids(std::vector<std::string>& ids) const {
Expand Down
13 changes: 12 additions & 1 deletion jubatus/core/recommender/lsh.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@
#include <utility>
#include <vector>
#include "jubatus/util/data/serialization.h"
#include "jubatus/util/data/optional.h"
#include "jubatus/util/lang/shared_ptr.h"
#include "recommender_base.hpp"
#include "../unlearner/unlearner_base.hpp"
#include "../common/jsonconfig.hpp"

namespace jubatus {
namespace core {
Expand All @@ -44,9 +47,13 @@ class lsh : public recommender_base {

int64_t hash_num;

util::data::optional<std::string> unlearner;
util::data::optional<core::common::jsonconfig::config> unlearner_parameter;

template<typename Ar>
void serialize(Ar& ar) {
ar & JUBA_MEMBER(hash_num);
ar & JUBA_MEMBER(hash_num) &
JUBA_MEMBER(unlearner) & JUBA_MEMBER(unlearner_parameter);
}
};

Expand All @@ -65,6 +72,7 @@ class lsh : public recommender_base {
size_t ret_num) const;
void clear();
void clear_row(const std::string& id);
void remove_row(const std::string& id);
void update_row(const std::string& id, const sfv_diff_t& diff);
void get_all_row_ids(std::vector<std::string>& ids) const;
std::string type() const;
Expand All @@ -88,6 +96,9 @@ class lsh : public recommender_base {
jubatus::util::lang::shared_ptr<storage::mixable_bit_index_storage>
mixable_storage_;

jubatus::util::lang::shared_ptr<unlearner::unlearner_base>
unlearner_;

const uint64_t hash_num_;
};

Expand Down
Loading