Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Merge branch 'release/0.2.2'

  • Loading branch information...
commit e47de9701ee8d7f5a7f71c6b61ccdb98d7c40cb7 2 parents 0a420b9 + da21b73
UENISHI Kota kuenishi authored
Showing with 3,449 additions and 2,278 deletions.
  1. +1 −0  .gitignore
  2. +72 −0 README.rst
  3. +2 −1  src/classifier/classifier_factory.cpp
  4. +10 −0 src/classifier/classifier_test.cpp
  5. +0 −271 src/classifier/hs_classifier.cpp
  6. +11 −6 src/common/exception.hpp
  7. +136 −0 src/common/mprpc/async_client.cpp
  8. +97 −0 src/common/mprpc/async_client.hpp
  9. +194 −0 src/common/mprpc/rpc_client.cpp
  10. +169 −0 src/common/mprpc/rpc_client.hpp
  11. +118 −0 src/common/mprpc/rpc_client_test.cpp
  12. +23 −0 src/common/mprpc/wscript
  13. +53 −0 src/common/unordered_map.hpp
  14. +6 −4 src/common/wscript
  15. +61 −0 src/framework/aggregators.hpp
  16. +27 −14 src/framework/jubatus_serv.cpp
  17. +6 −11 src/framework/jubatus_serv.hpp
  18. +0 −2  src/framework/keeper.cpp
  19. +169 −59 src/framework/keeper.hpp
  20. +20 −0 src/framework/server_util.cpp
  21. +3 −2 src/framework/wscript
  22. +20 −1 src/fv_converter/counter.hpp
  23. +15 −0 src/fv_converter/counter_test.cpp
  24. +31 −1 src/fv_converter/datum_to_fv_converter_test.cpp
  25. +74 −0 src/fv_converter/keyword_weights.cpp
  26. +72 −0 src/fv_converter/keyword_weights.hpp
  27. +57 −0 src/fv_converter/keyword_weights_test.cpp
  28. +11 −18 src/fv_converter/weight_manager.cpp
  29. +47 −5 src/fv_converter/weight_manager.hpp
  30. +46 −13 src/fv_converter/weight_manager_test.cpp
  31. +3 −1 src/fv_converter/wscript
  32. +0 −158 src/recommender/recommender_builder.cpp
  33. +0 −66 src/recommender/recommender_builder.hpp
  34. +1 −1  src/recommender/recommender_factory.cpp
  35. +2 −1  src/regression/regression_factory.cpp
  36. +2 −1  src/regression/regression_factory_test.cpp
  37. +0 −91 src/server/classifier.hpp
  38. +42 −12 src/server/classifier.idl
  39. +14 −14 src/server/classifier_client.hpp
  40. +26 −32 src/server/classifier_impl.cpp
  41. +11 −14 src/server/classifier_keeper.cpp
  42. +29 −9 src/server/classifier_serv.cpp
  43. +7 −3 src/server/classifier_serv.hpp
  44. +5 −5 src/server/classifier_server.hpp
  45. +54 −15 src/server/classifier_test.cpp
  46. +6 −6 src/server/classifier_types.hpp
  47. +57 −0 src/server/mixable_weight_manager.hpp
  48. +0 −102 src/server/recommender.hpp
  49. +69 −19 src/server/recommender.idl
  50. +36 −28 src/server/recommender_client.hpp
  51. +46 −46 src/server/recommender_impl.cpp
  52. +21 −21 src/server/recommender_keeper.cpp
  53. +56 −12 src/server/recommender_serv.cpp
  54. +12 −8 src/server/recommender_serv.hpp
  55. +15 −13 src/server/recommender_server.hpp
  56. +3 −3 src/server/recommender_test.cpp
  57. +6 −6 src/server/recommender_types.hpp
  58. +0 −78 src/server/regression.hpp
  59. +40 −10 src/server/regression.idl
  60. +10 −10 src/server/regression_client.hpp
  61. +26 −32 src/server/regression_impl.cpp
  62. +12 −14 src/server/regression_keeper.cpp
  63. +14 −8 src/server/regression_serv.cpp
  64. +3 −2 src/server/regression_serv.hpp
  65. +5 −5 src/server/regression_server.hpp
  66. +2 −2 src/server/regression_test.cpp
  67. +6 −6 src/server/regression_types.hpp
  68. +0 −25 src/server/stat.hpp
  69. +49 −13 src/server/stat.idl
  70. +26 −22 src/server/stat_client.hpp
  71. +37 −40 src/server/stat_impl.cpp
  72. +17 −18 src/server/stat_keeper.cpp
  73. +11 −8 src/server/stat_serv.cpp
  74. +8 −7 src/server/stat_serv.hpp
  75. +12 −11 src/server/stat_server.hpp
  76. +7 −7 src/server/stat_test.cpp
  77. +4 −3 src/server/wscript
  78. +79 −0 tools/generate_clients.py
  79. +4 −6 tools/generator/.gitignore
  80. +17 −12 tools/generator/OMakefile
  81. +41 −14 tools/generator/README
  82. +294 −74 tools/generator/generator.ml
  83. +0 −108 tools/generator/idl_template.ml
  84. +71 −0 tools/generator/jdl_lexer.mll
  85. +193 −0 tools/generator/jdl_parser.mly
  86. +84 −0 tools/generator/jubatus_idl.ml
  87. +0 −91 tools/generator/keeper_template.ml
  88. +0 −71 tools/generator/lexer.mll
  89. +0 −133 tools/generator/main.ml
  90. +0 −145 tools/generator/parser.mly
  91. +0 −94 tools/generator/server_template.ml
  92. +59 −0 tools/generator/small.idl
  93. +28 −110 tools/generator/stree.ml
  94. +82 −1 tools/generator/util.ml
  95. +127 −0 tools/generator/validator.ml
  96. +7 −3 wscript
1  .gitignore
View
@@ -12,3 +12,4 @@ Makefile
cscope.*
callgrind.*
.unittest-gtest
+*.tar.gz
72 README.rst
View
@@ -9,3 +9,75 @@ LICENSE
=======
LGPL 2.1
+
+Update history
+==============
+
+Release 0.2.2 2012/4/6
+======================
+
+Improvements
+
+- Simpler interfaces at classifier, regression and recommender
+
+ - Clients are *NOT COMPATIBLE* with previous releases
+
+- Now mix works concurrently in multiple threads (except tf-idf counting)
+- Asynchronous RPC to multiple servers at once
+- Add --version option
+- Interface description language changed from C++-like to Annotated MessagePack-IDL
+- Minor error handling
+- A bit more tested than previous releases
+
+Bugfix
+
+ - #30, #29, #22
+
+Release 0.2.1 2012/3/13
+-----------------------
+
+Bugfix release: #28
+
+Release 0.2.0 2012/2/16
+-----------------------
+
+New Features
+
+- recommender
+
+ - support fast similar item search, real-time update, distributed data management
+ - inverted index : exact result, fast search
+ - locality sensitive hash : approximate result, fast search, small working space
+
+- regression
+
+ - online SVR using passive agressive algorithm
+ - as fast as current classifier
+
+- stat
+
+ - a Key(string)-Value(queue<double>)
+ - O(1) cost of getting sum, standard deviation, max, min, statistic moments for each queue
+
+- server framework
+
+ - less-tightly coupled distributed processing framework with each ML implementation
+ - idl & code generator - make it easy to write own jubatus system
+ - removed public release of client libraries (so easy to generate!)
+ - multiple mix - mutiple data objects can be mixed in one jubatus system
+
+Bugfix
+
+ - duplicate key entry in fv_converter breaks the parameter
+
+Release 0.1.1 2011/11/15
+------------------------
+
+Bugfix release
+
+Release 0.1.0 2011/10/26
+------------------------
+
+Hello Jubatus!
+
+First release: including classifier, and mix operation
3  src/classifier/classifier_factory.cpp
View
@@ -17,6 +17,7 @@
#include "classifier.hpp"
#include "classifier_factory.hpp"
+#include "../common/exception.hpp"
using namespace std;
@@ -38,7 +39,7 @@ classifier_base* classifier_factory::create_classifier(const std::string& name,
} else if (name == "NHERD"){
return static_cast<classifier_base*>(new NHERD(storage));
} else {
- return NULL;
+ throw unsupported_method(name);
}
}
10 src/classifier/classifier_test.cpp
View
@@ -24,6 +24,7 @@
#include "classifier_factory.hpp"
#include "classifier.hpp"
#include "../storage/local_storage.hpp"
+#include "../common/exception.hpp"
#include "classifier_test_util.hpp"
using namespace std;
@@ -139,4 +140,13 @@ void InitClassifiers(vector<classifier_base*>& classifiers){
}
}
+
+TEST(classifier_factory, exception){
+ local_storage * p = new local_storage;
+ ASSERT_THROW(classifier_factory::create_classifier("pa", p), unsupported_method);
+ ASSERT_THROW(classifier_factory::create_classifier("", p), unsupported_method);
+ ASSERT_THROW(classifier_factory::create_classifier("saitama", p), unsupported_method);
+ delete p;
+}
+
}
271 src/classifier/hs_classifier.cpp
View
@@ -1,271 +0,0 @@
-// Jubatus: Online machine learning framework for distributed environment
-// Copyright (C) 2011 Preferred Infrastracture and Nippon Telegraph and Telephone Corporation.
-//
-// This library is free software; you can redistribute it and/or
-// modify it under the terms of the GNU Lesser General Public
-// License as published by the Free Software Foundation; either
-// version 2.1 of the License, or (at your option) any later version.
-//
-// This library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
-// Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public
-// License along with this library; if not, write to the Free Software
-// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
-
-#include <fstream>
-#include "hs_classifier.hpp"
-
-using namespace std;
-
-namespace hs{
-
-HSClassifier::HSClassifier(){
-}
-
-HSClassifier::~HSClassifier(){
-}
-
-void HSClassifier::Clear(){
- feature2id_.clear();
- class2id_.clear();
- vector<string>().swap(id2class_);
- vector<vector<float> >().swap(weights_);
-}
-
-void HSClassifier::SaveKey2ID(const key2id_t& key2id,
- ofstream& ofs) const{
- uint64_t key_num = key2id.size();
- ofs.write((const char*)& key_num, sizeof(key_num) * 1);
-
- for (key2id_t::const_iterator it = key2id.begin();
- it != key2id.end(); ++it){
- const string& key = it->first;
- uint64_t len = key.size();
- ofs.write((const char*) &len, sizeof(len) * 1);
- ofs.write((const char*) &key[0], sizeof(key[0]) * len);
- uint64_t val = it->second;
- ofs.write((const char*) &val, sizeof(val));
- }
-}
-
-void HSClassifier::LoadKey2ID(key2id_t& key2id,
- ifstream& ifs) {
- uint64_t key_num = 0;
- ifs.read((char*)&key_num, sizeof(key_num) * 1);
-
- for (uint64_t i = 0; i < key_num; ++i){
- uint64_t len = 0;
- ifs.read((char*) &len, sizeof(len) * 1);
- string key;
- key.resize(len);
- ifs.read((char*) &key[0], sizeof(key[0]) * len);
- uint64_t val = 0;
- ifs.read((char*) &val, sizeof(val));
- key2id[key] = val;
- }
-}
-
-
-int HSClassifier::Save(const string& filename) const{
- ofstream ofs(filename.c_str());
- if (!ofs){
- return -1;
- }
- SaveKey2ID(feature2id_, ofs);
- SaveKey2ID(class2id_, ofs);
-
- for (size_t i = 0; i < weights_.size(); ++i){
- const vector<float>& v = weights_[i];
- ofs.write((const char*)&v[0], sizeof(v[0]) * v.size());
- }
-
- if (!ofs){
- return -1;
- }
-
- return 0;
-}
-
-int HSClassifier::Load(const string& filename) {
- Clear();
- ifstream ifs(filename.c_str());
- if (!ifs){
- return -1;
- }
- LoadKey2ID(feature2id_, ifs);
- LoadKey2ID(class2id_, ifs);
-
- id2class_.resize(class2id_.size());
- for (key2id_t::const_iterator it = class2id_.begin(); it != class2id_.end(); ++it){
- id2class_[it->second] = it->first;
- }
- weights_.resize(class2id_.size(), vector<float>(feature2id_.size()));
-
- if (!ifs) return -1;
-
- return 0;
-}
-
-void HSClassifier::Train(const key_str_t& input,
- const string& output){
- vector<uint64_t> features;
- for (key_str_t::const_iterator it = input.begin(); it != input.end(); ++it){
- ExtractFeatureConst(it->first, it->second, features);
- }
- uint64_t target_class = GetClassID(output);
- vector<float> scores;
- CalcScores(features, scores);
- uint64_t max_score_class = GetMaxScoreID(scores);
- if (target_class == max_score_class) {
- return; // no update
- }
-
- features.clear();
- for (key_str_t::const_iterator it = input.begin(); it != input.end(); ++it){
- ExtractFeature(it->first, it->second, features);
- }
- Update(features, target_class, max_score_class);
-}
-
-void HSClassifier::Update(const std::vector<uint64_t>& features,
- const uint64_t target_class, const uint64_t max_score_class) {
- for (size_t i = 0; i < features.size(); ++i){
- uint64_t id = features[i];
- weights_[target_class][id] += 1;
- weights_[max_score_class][id] -= 1;
- }
-}
-
-uint64_t HSClassifier::GetMaxScoreID(const vector<float>& scores) const {
- if (scores.size() == 0) return 0;
- float max_score = scores[0];
- uint64_t max_id = 0;
- for (size_t i = 1; i < scores.size(); ++i){
- if (scores[i] > max_score) {
- max_score = scores[i];
- max_id = i;
- }
- }
- return max_id;
-}
-
-void HSClassifier::CalcScores(const vector<uint64_t>& features, vector<float>& scores) const{
- size_t class_num = id2class_.size();
- scores.resize(class_num);
- fill(scores.begin(), scores.end(), 0.f);
- for (size_t i = 0; i < features.size(); ++i){
- uint64_t id = features[i];
- for (size_t j = 0; j < class_num; ++j){
- scores[j] += weights_[j][id];
- }
- }
-}
-
-void HSClassifier::ExtractFeatureConst(const std::string& field, const std::string& value,
- vector<uint64_t>& features) const{
- uint64_t total_id = GetIDConst(field + "/" + value);
- if (total_id != NOTFOUND){
- features.push_back(total_id);
- }
-
- // UTF-8 bigram feature
- string cur;
- string prev;
- bool first = true;
- for (size_t i = 0; ; ++i){
- if (first ||
- (i != value.size() && (value[i] & 0xC0) == 0x80)){
- cur += value[i];
- first = false;
- continue;
- }
- uint64_t term_id = GetIDConst(field + "/" + prev + cur);
- if (term_id != NOTFOUND){
- features.push_back(term_id);
- }
- if (i == value.size()) break;
- prev = cur;
- cur = value[i];
- }
-}
-
-void HSClassifier::ExtractFeature(const std::string& field, const std::string& value,
- vector<uint64_t>& features) {
- features.push_back(GetID(field + "/" + value));
-
- // UTF-8 bigram feature
- string cur;
- string prev;
- bool first = true;
- for (size_t i = 0; ; ++i){
- if (first ||
- (i != value.size() && (value[i] & 0xC0) == 0x80)){
- cur += value[i];
- first = false;
- continue;
- }
- features.push_back(GetID(field + "/" + prev + cur));
-
- if (i == value.size()) break;
- prev = cur;
- cur = value[i];
- }
-}
-
-
-
-key_double_t HSClassifier::Classify(const key_str_t& input) const {
- vector<uint64_t> features;
- for (key_str_t::const_iterator it = input.begin(); it != input.end(); ++it){
- ExtractFeatureConst(it->first, it->second, features);
- }
- vector<float> scores;
- CalcScores(features, scores);
-
- key_double_t class2score;
- for (size_t i = 0; i < scores.size(); ++i){
- class2score[id2class_[i]] = scores[i];
- }
- return class2score;
-}
-
-uint64_t HSClassifier::GetID(const string& key){
- key2id_t::const_iterator it = feature2id_.find(key);
- if (it != feature2id_.end()){
- return it->second;
- }
- uint64_t new_id = static_cast<uint64_t>(feature2id_.size());
- feature2id_[key] = new_id;
- for (size_t i = 0; i < weights_.size(); ++i){
- weights_[i].resize(new_id+1);
- }
- return new_id;
-}
-
-uint64_t HSClassifier::GetIDConst(const string& key) const{
- key2id_t::const_iterator it = feature2id_.find(key);
- if (it != feature2id_.end()){
- return it->second;
- } else {
- return NOTFOUND;
- }
-}
-
-uint64_t HSClassifier::GetClassID(const string& output){
- key2id_t::const_iterator it = class2id_.find(output);
- if (it != class2id_.end()){
- return it->second;
- }
- uint64_t new_id = static_cast<uint64_t>(class2id_.size());
- class2id_[output] = new_id;
- id2class_.push_back(output);
-
- weights_.resize(new_id+1);
- weights_[new_id].resize(feature2id_.size());
- return new_id;
-}
-
-}
17 src/common/exception.hpp
View
@@ -23,24 +23,29 @@
namespace jubatus{
- class storage_not_set : std::exception {};
- class config_not_set : std::exception {};
- class unsupported_method : std::runtime_error {
+ class storage_not_set : public std::exception {};
+
+ class config_not_set : public std::runtime_error {
+ public:
+ config_not_set(): runtime_error("config_not_set") {}
+ };
+
+ class unsupported_method : public std::runtime_error {
public:
unsupported_method(const std::string& n): runtime_error(n) {}
};
- class bad_storage_type : std::runtime_error {
+ class bad_storage_type : public std::runtime_error {
public:
bad_storage_type(const std::string& n):runtime_error(n){};
};
- class membership_error : std::runtime_error {
+ class membership_error : public std::runtime_error {
public:
membership_error(const std::string& n):runtime_error(n){};
};
- class argv_error : std::runtime_error {
+ class argv_error : public std::runtime_error {
public:
argv_error(const std::string& n):runtime_error(n){};
};
136 src/common/mprpc/async_client.cpp
View
@@ -0,0 +1,136 @@
+// Jubatus: Online machine learning framework for distributed environment
+// Copyright (C) 2012 Preferred Infrastracture and Nippon Telegraph and Telephone Corporation.
+//
+// This library is free software; you can redistribute it and/or
+// modify it under the terms of the GNU Lesser General Public
+// License as published by the Free Software Foundation; either
+// version 2.1 of the License, or (at your option) any later version.
+//
+// This library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+// Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public
+// License along with this library; if not, write to the Free Software
+// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+#include "async_client.hpp"
+#include <pficommon/network/mprpc/socket.h>
+#include <pficommon/network/mprpc/exception.h>
+#include <pficommon/network/ipv4.h>
+
+#include <glog/logging.h>
+
+#include "../exception.hpp"
+
+#include <fcntl.h>
+#include <string.h>
+#include <errno.h>
+
+using pfi::lang::shared_ptr;
+using pfi::network::mprpc::socket;
+using pfi::system::time::clock_time;
+using pfi::system::time::get_clock_time;
+
+namespace jubatus { namespace common { namespace mprpc {
+
+bool set_socket_nonblock(int sock, bool on){
+ int res;
+ if(on){
+ res = fcntl(sock, F_SETFL, O_NONBLOCK);
+ return res==0;
+ }else{
+ //FIXME
+ return false;
+ }
+}
+
+async_sock::async_sock():
+ //pfi::network::mprpc::socket(::socket(AF_INET, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0)),
+ pfi::network::mprpc::socket(::socket(AF_INET, SOCK_STREAM, 0)),
+ state(CLOSED),
+ progress(0)
+{
+ // FIXME: SOCK_NONBLOCK is linux only
+ int fd = get();
+ set_socket_nonblock(fd, true);
+ unpacker.reserve_buffer(4096);
+}
+
+async_sock::~async_sock(){
+}
+
+bool async_sock::set_async(bool on){ return on; }
+
+bool async_sock::send_async(const char* buf, size_t size){
+ int fd = this->get();
+ int r = ::write(fd, buf+progress, size-progress);
+ if(r > 0){
+ progress += r;
+ }
+ if(progress == size){
+ progress = 0;
+ return true;
+ }
+ return size == progress;
+};
+
+int async_sock::recv_async()
+{
+ int fd = this->get();
+ if(unpacker.message_size() == 0){
+ unpacker.reset();
+ unpacker.reserve_buffer(4096);
+ }else if(unpacker.buffer_capacity() == 0){
+ //unpacker.expand_buffer(4096);
+ }
+ int r = ::read(fd, unpacker.buffer(), unpacker.buffer_capacity());
+ // if(r < 0){
+ // char msg[1024];
+ // strerror_r(errno, msg, 1024);
+ // cout << "errno:"<< errno << msg << endl;
+ // }
+ if(r > 0){
+ unpacker.buffer_consumed(r);
+ }
+ return r;
+};
+
+int async_sock::connect_async(const std::string& host, uint16_t port){
+ int res;
+ int sock = this->get();
+
+ std::vector<pfi::network::ipv4_address> ips = resolve(host, port);
+ for (int i=0; i < (int)ips.size(); i++){
+ sockaddr_in addrin={};
+ addrin.sin_family = PF_INET;
+ addrin.sin_addr.s_addr = inet_addr(ips[i].to_string().c_str());
+ addrin.sin_port = htons(port);
+
+ res = ::connect(sock,(sockaddr*)&addrin,sizeof(addrin));
+ if (res == -1){
+ if (errno==EINPROGRESS){
+ state = CONNECTING;
+ return 0;
+ }else{
+ DLOG(ERROR) << errno;
+ }
+ }
+ else if(res == 0){
+ state = SENDING;
+ return 0;
+ }
+ }
+ ::close(sock);
+ return -1;
+}
+
+int async_sock::close(){
+ return ::close(this->get());
+}
+
+
+}
+}
+}
97 src/common/mprpc/async_client.hpp
View
@@ -0,0 +1,97 @@
+// Jubatus: Online machine learning framework for distributed environment
+// Copyright (C) 2012 Preferred Infrastracture and Nippon Telegraph and Telephone Corporation.
+//
+// This library is free software; you can redistribute it and/or
+// modify it under the terms of the GNU Lesser General Public
+// License as published by the Free Software Foundation; either
+// version 2.1 of the License, or (at your option) any later version.
+//
+// This library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+// Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public
+// License along with this library; if not, write to the Free Software
+// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+#pragma once
+
+#include <string>
+#include <vector>
+#include <map>
+
+#include <msgpack.hpp>
+
+#include <pficommon/lang/shared_ptr.h>
+#include <pficommon/system/time_util.h>
+#include <pficommon/network/mprpc/socket.h>
+
+namespace jubatus { namespace common { namespace mprpc {
+
+class async_sock : public pfi::network::mprpc::socket {
+public:
+ async_sock();
+ ~async_sock();
+ bool set_async(bool on);
+ bool send_async(const char* buf, size_t size);
+
+ int recv_async();
+
+ template <typename T> bool salvage(T&);
+
+ int connect_async(const std::string& host, uint16_t port);
+ int close();
+
+ void set_sending(){ state = SENDING; };
+ void set_recving(){ state = RECVING; };
+ void disconnected(){ state = CLOSED; };
+ bool is_connecting()const{ return state == CONNECTING; };
+ bool is_sending()const{ return state == SENDING; };
+ bool is_recving()const{ return state == RECVING; };
+private:
+ enum { CLOSED, CONNECTING, SENDING, RECVING } state;
+ size_t progress;
+ msgpack::unpacker unpacker;
+};
+
+template <typename T> bool async_sock::salvage(T& t)
+{
+ msgpack::unpacked msg;
+ if(unpacker.next(&msg)){
+ msgpack::object o = msg.get();
+ std::auto_ptr<msgpack::zone> z = msg.zone();
+ o.convert(&t);
+ return true;
+ }
+ return false;
+};
+
+
+class async_client {
+public:
+ async_client(const std::string& host, uint16_t port, int timeout_sec);
+ ~async_client();
+
+ void send_async(const std::string& method, const msgpack::sbuffer& argv);
+ void join(msgpack::object& o);
+
+ void connect_async();
+
+ void l();
+
+private:
+ bool wait();
+
+ std::string host_;
+ uint16_t port_;
+ int timeout_sec_;
+ pfi::lang::shared_ptr<async_sock> sock_;
+ pfi::system::time::clock_time start_;
+
+ int epfd_;
+};
+
+}
+}
+}
194 src/common/mprpc/rpc_client.cpp
View
@@ -0,0 +1,194 @@
+// Jubatus: Online machine learning framework for distributed environment
+// Copyright (C) 2012 Preferred Infrastracture and Nippon Telegraph and Telephone Corporation.
+//
+// This library is free software; you can redistribute it and/or
+// modify it under the terms of the GNU Lesser General Public
+// License as published by the Free Software Foundation; either
+// version 2.1 of the License, or (at your option) any later version.
+//
+// This library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+// Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public
+// License along with this library; if not, write to the Free Software
+// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+#include "rpc_client.hpp"
+#include <glog/logging.h>
+
+#include <event.h>
+
+using pfi::lang::shared_ptr;
+using pfi::system::time::clock_time;
+using pfi::system::time::get_clock_time;
+
+namespace jubatus { namespace common { namespace mprpc {
+
+rpc_mclient::rpc_mclient(const std::vector<std::pair<std::string, uint16_t> >& hosts,
+ int timeout_sec):
+ hosts_(hosts),
+ timeout_sec_(timeout_sec),
+ start_(get_clock_time()),
+ evbase_(::event_base_new())
+{
+ connect_async_();
+}
+rpc_mclient::rpc_mclient(const std::vector<std::pair<std::string, int> >& hosts,
+ int timeout_sec):
+ timeout_sec_(timeout_sec),
+ start_(get_clock_time()),
+ evbase_(::event_base_new())
+{
+ for(size_t i=0; i<hosts.size(); ++i){
+ hosts_.push_back(hosts[i]);
+ }
+ connect_async_();
+}
+
+rpc_mclient::~rpc_mclient(){
+ ::event_base_free(evbase_);
+}
+
+void rpc_mclient::call_async(const std::string& m)
+{
+ call_async_(m, std::vector<int>());
+}
+
+void rpc_mclient::connect_async_()
+{
+ clients_.clear();
+ for(size_t i=0; i<hosts_.size(); ++i){
+ shared_ptr<async_sock> p(new async_sock);
+ p->connect_async(hosts_[i].first, hosts_[i].second);
+ clients_[p->get()] = p;
+ }
+}
+
+static void readable_callback(int fd, short int events, void* arg){
+ async_context* ctx = reinterpret_cast<async_context*>(arg);
+ ctx->rest -= ctx->c->readable_callback(fd, events, ctx);
+}
+int rpc_mclient::readable_callback(int fd, int events, async_context* ctx){
+
+ int done = 0;
+ if(events & EV_READ){
+
+ int r = clients_[fd]->recv_async();
+ if(r < 0 ){
+ clients_[fd]->disconnected();
+ clients_[fd]->close();
+ done++;
+ return done;
+ }
+
+ typedef msgpack::type::tuple<uint8_t,uint32_t,msgpack::object,msgpack::object> response_t;
+ response_t res;
+
+ if(clients_[fd]->salvage<response_t>(res)){
+ // cout << __FILE__ << " " << __LINE__ << ":"<< endl;
+ // cout << "\ta0: "<< int(res.a0) << endl;
+ // cout << "\ta2: "<< res.a2.type << " " << res.a2.is_nil() << " " << res.a2 << endl;
+ // cout << "\ta3: "<< res.a3.type << " " << res.a3.is_nil() << " " << res.a3 << endl;;
+ ctx->rest--;
+ done++;
+ if(res.a0 == 1){
+ if(res.a2.is_nil()){
+ ctx->ret.push_back(res.a3);
+
+ // }else{
+ // if(res.a3.is_nil()){
+ // std::string msg;
+ // res.a2.convert(&msg);
+ }
+ return done;
+ }
+ }
+ else{ //more to recieve
+ register_fd_readable_(ctx);
+ return done;
+ }
+
+ }else if(events & EV_TIMEOUT){
+ clients_[fd]->disconnected();
+ clients_[fd]->close();
+ ctx->rest--;
+ done++;
+ }
+ return done;
+}
+
+static void writable_callback(int fd, short int events, void* arg){
+ async_context* ctx = static_cast<async_context*>(arg);
+ ctx->rest -= ctx->c->writable_callback(fd, events, ctx);
+}
+int rpc_mclient::writable_callback(int fd, int events, async_context* ctx){
+ int done = 0;
+ if(events & EV_WRITE){
+
+ if(clients_[fd]->is_connecting()){
+ clients_[fd]->set_sending();
+ }
+ if(clients_[fd]->send_async(ctx->buf->data(), ctx->buf->size())){
+ ctx->rest--;
+ done++;
+ clients_[fd]->set_recving();
+
+ }else{
+ register_fd_writable_(ctx);
+ }
+
+ }else if(events & EV_TIMEOUT){
+ clients_[fd]->disconnected();
+ clients_[fd]->close();
+ ctx->rest--;
+ done++;
+ }
+ return done;
+}
+
+void rpc_mclient::register_fd_readable_(async_context* ctx){
+ register_all_fd_(EV_READ, &mprpc::readable_callback, ctx);
+}
+void rpc_mclient::register_fd_writable_(async_context* ctx){
+ register_all_fd_(EV_WRITE, &mprpc::writable_callback, ctx);
+}
+void rpc_mclient::register_all_fd_(int choice, void(*cb)(int, short, void*), async_context* ctx ) // choice = EV_READ or EV_WRITE
+{
+ struct timeval timeout;
+ timeout.tv_sec = timeout_sec_;
+ timeout.tv_usec = 0;
+ pfi::data::unordered_map<int,pfi::lang::shared_ptr<async_sock> >::iterator it;
+ for(it=clients_.begin(); it!=clients_.end(); ++it){
+ event_base_once(evbase_, it->second->get(), choice, cb, ctx, &timeout);
+ }
+}
+void rpc_mclient::send_async(const msgpack::sbuffer& buf)
+{
+ async_context ctx;
+ ctx.c = this;
+ ctx.rest = clients_.size();
+ ctx.buf = &buf;
+
+ register_fd_writable_(&ctx);
+
+ do{
+ int r = event_base_loop(evbase_, EVLOOP_ONCE);
+ if( r != 0 ){
+ break;
+ }
+ }while(ctx.rest>0);
+}
+
+
+void rpc_mclient::join_some_(async_context& ctx)
+{
+ ctx.ret.clear();
+ event_base_loop(evbase_, EVLOOP_ONCE);
+}
+
+}
+}
+}
+
169 src/common/mprpc/rpc_client.hpp
View
@@ -0,0 +1,169 @@
+// Jubatus: Online machine learning framework for distributed environment
+// Copyright (C) 2012 Preferred Infrastracture and Nippon Telegraph and Telephone Corporation.
+//
+// This library is free software; you can redistribute it and/or
+// modify it under the terms of the GNU Lesser General Public
+// License as published by the Free Software Foundation; either
+// version 2.1 of the License, or (at your option) any later version.
+//
+// This library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+// Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public
+// License along with this library; if not, write to the Free Software
+// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+#pragma once
+
+#include <string>
+#include <vector>
+#include <map>
+
+#include <msgpack.hpp>
+
+#include <pficommon/lang/shared_ptr.h>
+#include <pficommon/lang/function.h>
+#include <pficommon/system/time_util.h>
+#include "async_client.hpp"
+#include <glog/logging.h>
+#include <pficommon/data/unordered_map.h>
+#include <pficommon/lang/noncopyable.h>
+
+extern "C"{
+struct event_base;
+}
+
+namespace jubatus { namespace common { namespace mprpc {
+
+class rpc_mclient;
+struct async_context {
+ rpc_mclient* c;
+ const msgpack::sbuffer* buf;
+ size_t rest;
+ std::vector<msgpack::object> ret;
+};
+
+class rpc_mclient : pfi::lang::noncopyable
+{
+public:
+ rpc_mclient(const std::vector<std::pair<std::string, uint16_t> >& hosts,
+ int timeout_sec);
+ rpc_mclient(const std::vector<std::pair<std::string, int> >& hosts,
+ int timeout_sec);
+ ~rpc_mclient();
+
+ template <typename Res, typename Argv>
+ Res call(const std::string& m, const Argv& a,
+ const pfi::lang::function<Res(Res,Res)>& reducer){
+ call_async(m, a);
+ return join_all(reducer);
+ };
+
+ void send_async(const msgpack::sbuffer& buf);
+
+ void call_async(const std::string&);
+
+ template <typename A0>
+ void call_async(const std::string&, const A0& a0);
+ template <typename A0, typename A1>
+ void call_async(const std::string&, const A0& a0, const A1& a1);
+ template <typename A0, typename A1, typename A2>
+ void call_async(const std::string&, const A0& a0, const A1& a1, const A2& a2);
+ template <typename A0, typename A1, typename A2, typename A3>
+ void call_async(const std::string&, const A0&, const A1&, const A2&, const A3&);
+
+ template <typename Res>
+ Res join_all(const pfi::lang::function<Res(Res,Res)>& reducer);
+
+ int readable_callback(int, int, async_context*);
+ int writable_callback(int, int, async_context*);
+
+
+private:
+ void register_fd_readable_(async_context*);
+ void register_fd_writable_(async_context*);
+ void register_all_fd_(int, void(*)(int,short,void*), async_context*);
+
+ template <typename Arr>
+ void call_async_(const std::string&, const Arr& a);
+
+ void connect_async_();
+ void join_some_(async_context&);
+
+ std::vector<std::pair<std::string, uint16_t> > hosts_;
+ int timeout_sec_;
+
+ pfi::data::unordered_map<int,pfi::lang::shared_ptr<async_sock> > clients_;
+ pfi::system::time::clock_time start_;
+
+ event_base* evbase_;
+};
+
+template <typename Arr>
+void rpc_mclient::call_async_(const std::string& m, const Arr& argv)
+{
+ msgpack::sbuffer sbuf;
+ msgpack::type::tuple<uint8_t,uint32_t,std::string,Arr> rpc_request(0, 0xDEADBEEF, m, argv);
+ msgpack::pack(&sbuf, rpc_request);
+ send_async(sbuf);
+}
+
+template <typename A0>
+void rpc_mclient::call_async(const std::string& m, const A0& a0)
+{
+ call_async_(m, msgpack::type::tuple<A0>(a0));
+}
+
+template <typename A0, typename A1>
+void rpc_mclient::call_async(const std::string& m, const A0& a0, const A1& a1)
+{
+ call_async_(m, msgpack::type::tuple<A0, A1>(a0, a1));
+}
+template <typename A0, typename A1, typename A2>
+void rpc_mclient::call_async(const std::string& m, const A0& a0, const A1& a1, const A2& a2)
+{
+ call_async_(m, msgpack::type::tuple<A0, A1, A2>(a0, a1, a2));
+}
+template <typename A0, typename A1, typename A2, typename A3>
+void rpc_mclient::call_async(const std::string& m, const A0& a0, const A1& a1, const A2& a2, const A3& a3)
+{
+ call_async_(m, msgpack::type::tuple<A0, A1, A2, A3>(a0, a1, a2, a3));
+}
+
+
+template <typename Res>
+Res rpc_mclient::join_all(const pfi::lang::function<Res(Res,Res)>& reducer)
+{
+ async_context ctx;
+ ctx.c = this;
+ ctx.rest = clients_.size();
+ ctx.buf = NULL;
+ ctx.ret = std::vector<msgpack::object>();
+
+ register_fd_readable_(&ctx);
+ join_some_(ctx);
+
+ if(ctx.ret.empty()){
+ throw std::runtime_error("no clients.");
+ }
+
+ Res result = ctx.ret[0].as<Res>();
+ for(size_t i=1;i<ctx.ret.size();++i){
+ result = reducer(result, ctx.ret[i].as<Res>());
+ }
+
+ do{
+ join_some_(ctx);
+ for(size_t i=0;i<ctx.ret.size();++i){
+ result = reducer(result, ctx.ret[i].as<Res>());
+ }
+ }while(ctx.rest>0);
+
+ return result;
+}
+
+}
+}
+}
118 src/common/mprpc/rpc_client_test.cpp
View
@@ -0,0 +1,118 @@
+// Jubatus: Online machine learning framework for distributed environment
+// Copyright (C) 2011,2012 Preferred Infrastracture and Nippon Telegraph and Telephone Corporation.
+//
+// This library is free software; you can redistribute it and/or
+// modify it under the terms of the GNU Lesser General Public
+// License as published by the Free Software Foundation; either
+// version 2.1 of the License, or (at your option) any later version.
+//
+// This library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+// Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public
+// License along with this library; if not, write to the Free Software
+// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+#include "rpc_client.hpp"
+#include "../../framework/aggregators.hpp"
+#include "gtest/gtest.h"
+#include <pficommon/concurrent/thread.h>
+#include <pficommon/network/mprpc.h>
+#include <pficommon/lang/bind.h>
+#include <pficommon/lang/cast.h>
+
+#include <iostream>
+
+#include <vector>
+#include <string>
+#include <map>
+using namespace std;
+
+using pfi::lang::function;
+
+struct strw{
+ string key;
+ string value;
+ MSGPACK_DEFINE(key,value);
+};
+
+MPRPC_PROC(test_bool, bool(int));
+MPRPC_PROC(test_twice, int(int));
+MPRPC_PROC(add_all, int(int,int,int));
+MPRPC_PROC(various, string(int,float,double, strw));
+
+static bool test_bool(int i){ return i%2; };
+static int test_twice(int i){ return i*2; };
+static int add_all(int i, int j, int k){ return (i+j+k); };
+static string various(int i, float f, double d, strw s){
+ string ret = pfi::lang::lexical_cast<string>(i)
+ + pfi::lang::lexical_cast<string>(f)
+ + pfi::lang::lexical_cast<string>(d)
+ + s.key + s.value;
+ return ret;
+}
+static string concat(string l,string r){ return (l+r); };
+
+MPRPC_GEN(1, test_mrpc, test_bool, test_twice, add_all, various);
+
+static void server_thread(unsigned u){
+ test_mrpc_server srv(3.0);
+ srv.set_test_bool(&test_bool);
+ srv.set_test_twice(&test_twice);
+ srv.set_add_all(&add_all);
+ srv.set_various(&various);
+ srv.serv(u, 10);
+}
+
+static void fork_server(unsigned u){
+ pfi::concurrent::thread th(pfi::lang::bind(&server_thread, u));
+ th.start();
+ th.detach();
+}
+
+static const uint16_t PORT0 = 60023;
+static const uint16_t PORT1 = 60024;
+
+TEST(rpc_mclient, small)
+{
+ fork_server(PORT0);
+ fork_server(PORT1);
+ usleep(500000);
+ {
+ test_mrpc_client cli0("localhost", PORT0, 3.0);
+ test_mrpc_client cli1("localhost", PORT1, 3.0);
+ EXPECT_EQ(true, cli0.call_test_bool(23));
+ EXPECT_EQ(24, cli1.call_test_twice(12));
+ }
+ vector<pair<string,uint16_t> > clients;
+ clients.push_back(make_pair(string("localhost"), PORT0));
+ clients.push_back(make_pair("localhost", PORT1));
+ jubatus::common::mprpc::rpc_mclient cli(clients, 3.0);
+ {
+ cli.call_async("test_bool", 73684);
+ EXPECT_FALSE(cli.join_all(function<bool(bool,bool)>(&jubatus::framework::all_and)));
+ }
+ {
+ cli.call_async("test_twice", 73684);
+ EXPECT_EQ(73684*4,
+ cli.join_all(function<int(int,int)>(&jubatus::framework::add<int>)));
+ }
+ {
+ cli.call_async("add_all", 23,21,-234);
+ EXPECT_EQ(2*(23+21-234),
+ cli.join_all(function<int(int,int)>(&jubatus::framework::add<int>)));
+ }
+ {
+ int i = 234;
+ float f = 234.0;
+ double d = 23e-234;
+ strw s;
+ s.key = "keykeykey";
+ s.value = "vvvvvddd";
+ string ans = concat(various(i,f,d,s) , various(i,f,d,s));
+ cli.call_async("various", i,f,d,s);
+ EXPECT_EQ(ans, cli.join_all(function<string(string,string)>(&concat)));
+ }
+}
23 src/common/mprpc/wscript
View
@@ -0,0 +1,23 @@
+
+def options(opt): pass
+
+def configure(conf): pass
+
+def build(bld):
+ src = 'rpc_client.cpp async_client.cpp'
+
+ bld.shlib(
+ source = src,
+ target = 'jubacommon_mprpc',
+ use = 'PFICOMMON GLOG ZOOKEEPER_MT EVENT'
+ )
+
+ bld.program(
+ features = 'gtest',
+ source = 'rpc_client_test.cpp',
+ target = 'rpc_client_test',
+ includes = '. ../framework',
+ use = 'PFICOMMON MSGPACK jubacommon_mprpc',
+ )
+
+ bld.install_files('${PREFIX}/include/jubatus/common/mprpc', bld.path.ant_glob('*.hpp'))
53 src/common/unordered_map.hpp
View
@@ -0,0 +1,53 @@
+// Jubatus: Online machine learning framework for distributed environment
+// Copyright (C) 2011,2012 Preferred Infrastracture and Nippon Telegraph and Telephone Corporation.
+//
+// This library is free software; you can redistribute it and/or
+// modify it under the terms of the GNU Lesser General Public
+// License as published by the Free Software Foundation; either
+// version 2.1 of the License, or (at your option) any later version.
+//
+// This library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+// Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public
+// License along with this library; if not, write to the Free Software
+// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+#pragma once
+#include <msgpack.hpp>
+#include <pficommon/data/unordered_map.h>
+
+// to make pfi::data::unordered_map serializable
+
+namespace msgpack {
+
+template <typename K, typename V>
+inline pfi::data::unordered_map<K, V> operator>> (object o, pfi::data::unordered_map<K, V>& v)
+{
+ if(o.type != type::MAP){
+ throw type_error();
+ }
+ object_kv* const p_end = o.via.map.ptr + o.via.map.size;
+ for(object_kv* p = o.via.map.ptr; p != p_end; ++p) {
+ K key;
+ p->key.convert(&key);
+ p->val.convert(&v[key]);
+ }
+ return v;
+}
+
+template <typename Stream, typename K, typename V>
+inline packer<Stream>& operator<< (packer<Stream>& o, const pfi::data::unordered_map<K,V>& v)
+{
+ o.pack_map(v.size());
+ for(typename std::tr1::unordered_map<K,V>::const_iterator it = v.begin();
+ it != v.end(); ++it) {
+ o.pack(it->first);
+ o.pack(it->second);
+ }
+ return o;
+}
+
+}
10 src/common/wscript
View
@@ -1,12 +1,13 @@
+subdirs = 'mprpc'
+
def options(opt):
- pass
+ opt.recurse(subdirs)
def configure(conf):
-# conf.check_cxx(lib = 'crypt', mandatory = True)
-# conf.check_cxx(function_name = 'crypt', header_name = 'unistd.h', mandatory = True)
conf.check_cxx(header_name = 'sys/socket.h net/if.h sys/ioctl.h', mandatory = True)
conf.check_cxx(header_name = 'netinet/in.h arpa/inet.h', mandatory = True)
+ conf.recurse(subdirs)
def build(bld):
import Options
@@ -18,7 +19,7 @@ def build(bld):
source = src,
target = 'jubacommon',
includes = '.',
- use = 'PFICOMMON GLOG ZOOKEEPER_MT CRYPT'
+ use = 'PFICOMMON GLOG ZOOKEEPER_MT CRYPT jubacommon_mprpc'
)
test_src = [
@@ -43,3 +44,4 @@ def build(bld):
map(make_test, test_src)
bld.install_files('${PREFIX}/include/jubatus/common/', bld.path.ant_glob('*.hpp'))
+ bld.recurse(subdirs)
61 src/framework/aggregators.hpp
View
@@ -0,0 +1,61 @@
+// Jubatus: Online machine learning framework for distributed environment
+// Copyright (C) 2011,2012 Preferred Infrastracture and Nippon Telegraph and Telephone Corporation.
+//
+// This library is free software; you can redistribute it and/or
+// modify it under the terms of the GNU Lesser General Public
+// License as published by the Free Software Foundation; either
+// version 2.1 of the License, or (at your option) any later version.
+//
+// This library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+// Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public
+// License along with this library; if not, write to the Free Software
+// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+#pragma once
+#include <map>
+#include <vector>
+
+namespace jubatus { namespace framework {
+
+
+template <typename K, typename V>
+std::map<K,V> merge(std::map<K,V> lhs, std::map<K,V> rhs)
+{
+ std::map<K,V> ret;
+ typename std::map<K,V>::const_iterator it;
+ for(it = lhs.begin(); it!=lhs.end(); ++it){
+ ret[it->first] = it->second;
+ }
+ for(it = rhs.begin(); it!=rhs.end(); ++it){
+ ret[it->first] = it->second;
+ }
+ return ret;
+}
+
+template <typename T>
+std::vector<T> concat(std::vector<T> lhs, std::vector<T> rhs)
+{
+ std::vector<T> ret = lhs;
+ ret.insert(ret.end(), rhs.begin(), rhs.end());
+ return ret;
+}
+
+template <typename T>
+T random(T lhs, T rhs){
+ return lhs; //TODO: make random? or left(change fun name)?
+}
+
+template <typename T>
+T add(T lhs, T rhs){
+ return lhs+rhs;
+}
+
+bool all_and(bool l, bool r){
+ return l&&r;
+};
+
+}}
41 src/framework/jubatus_serv.cpp
View
@@ -61,7 +61,6 @@ namespace jubatus { namespace framework {
}
if( use_cht_ ){
-
jubatus::common::cht::setup_cht_dir(*zk_, a_.name);
jubatus::common::cht ht(zk_, a_.name);
ht.register_node(a_.eth, a_.port);
@@ -72,10 +71,14 @@ namespace jubatus { namespace framework {
mixer_->start();
}
#endif
-
- { LOG(INFO) << "running in port=" << a_.port; }
- return serv.serv(a_.port, a_.threadnum);
-
+
+ if( serv.serv(a_.port, a_.threadnum) ){
+ LOG(INFO) << "running in port=" << a_.port;
+ return 0;
+ }else{
+ LOG(ERROR) << "failed starting server: any process using port " << a_.port << "?";
+ return -1;
+ }
}
void jubatus_serv::register_mixable(mixable0* m){
@@ -95,7 +98,7 @@ namespace jubatus { namespace framework {
#endif
};
- std::map<std::string, std::map<std::string,std::string> > jubatus_serv::get_status(int) const {
+ std::map<std::string, std::map<std::string,std::string> > jubatus_serv::get_status() const {
std::map<std::string, std::string> data;
util::get_machine_status(data);
@@ -232,7 +235,15 @@ namespace jubatus { namespace framework {
}
#endif
- int jubatus_serv::save(std::string id) {
+ void jubatus_serv::updated(){
+#ifdef HAVE_ZOOKEEPER_H
+ update_count_ = mixer_->updated();
+#else
+ update_count_++;
+#endif
+ }
+
+ bool jubatus_serv::save(std::string id) {
std::string ofile;
build_local_path_(ofile, "jubatus", id);
@@ -246,13 +257,14 @@ namespace jubatus { namespace framework {
}
ofs.close();
LOG(INFO) << "saved to " << ofile;
- return 0;
- }catch(const std::exception& e){
- return -1;
+ return true;
+ }catch(const std::runtime_error& e){
+ LOG(ERROR) << e.what();
+ throw e;
}
}
- int jubatus_serv::load(std::string id) {
+ bool jubatus_serv::load(std::string id) {
std::string ifile;
build_local_path_(ifile, "jubatus", id);
@@ -265,11 +277,12 @@ namespace jubatus { namespace framework {
}
ifs.close();
this->after_load();
- return 0;
- }catch(const std::exception& e){
+ return true;
+ }catch(const std::runtime_error& e){
ifs.close();
+ LOG(ERROR) << e.what();
+ throw e;
}
- return -1; //expected never reaching here.
}
}}
17 src/framework/jubatus_serv.hpp
View
@@ -41,7 +41,7 @@ class jubatus_serv : pfi::lang::noncopyable {
pfi::concurrent::rw_mutex& get_rw_mutex(){ return m_; };
void use_cht();
- std::map<std::string, std::map<std::string,std::string> > get_status(int) const;
+ std::map<std::string, std::map<std::string,std::string> > get_status() const;
std::string get_server_identifier()const;
#ifdef HAVE_ZOOKEEPER_H
@@ -54,16 +54,11 @@ class jubatus_serv : pfi::lang::noncopyable {
void do_mix(const std::vector<std::pair<std::string,int> >& v);
#endif
- void updated(){
-#ifdef HAVE_ZOOKEEPER_H
- update_count_ = mixer_->updated();
-#else
- update_count_++;
-#endif
- };
+public:
+ void updated();
- int save(std::string id);
- int load(std::string id);
+ bool save(std::string id);
+ bool load(std::string id);
// after load( model_ was loaded from file ) called, users reset their own data
// I'm afraid this function is useless
@@ -105,4 +100,4 @@ class jubatus_serv : pfi::lang::noncopyable {
#define JWLOCK__(p) \
pfi::concurrent::scoped_lock lk(wlock((p)->get_rw_mutex())); \
- p_->updated();
+ (p)->updated();
2  src/framework/keeper.cpp
View
@@ -37,8 +37,6 @@ keeper::keeper(const keeper_argv& a)
if(!register_keeper(*zk_, a_.eth, a_.port) ){
throw membership_error("can't register to zookeeper.");
}
- register_broadcast_analysis<int, std::string>("save");
- register_broadcast_update<std::string>("load");
}
keeper::~keeper(){
228 src/framework/keeper.hpp
View
@@ -31,9 +31,12 @@
#include "../common/lock_service.hpp"
#include "../common/cht.hpp"
+#include "../common/mprpc/rpc_client.hpp"
#include "../common/shared_ptr.hpp"
#include "server_util.hpp"
+#include <glog/logging.h>
+#include <iostream>
namespace jubatus {
namespace framework {
@@ -44,107 +47,214 @@ class keeper : public pfi::network::mprpc::rpc_server {
virtual ~keeper();
int run();
- template <typename Q>
- void register_random_update(std::string method_name) {
- pfi::lang::function<int(std::string, Q)> f =
- pfi::lang::bind(&keeper::template random_proxy<int, Q>, this, method_name, pfi::lang::_1, pfi::lang::_2);
+ template <typename R>
+ void register_random(std::string method_name) {
+ pfi::lang::function<R(std::string)> f = pfi::lang::bind(&keeper::template random_proxy<R>, this, method_name, pfi::lang::_1);
add(method_name, f);
}
-
- template <typename R, typename Q>
- void register_random_analysis(std::string method_name) {
- pfi::lang::function<R(std::string, Q)> f =
- pfi::lang::bind(&keeper::template random_proxy<R, Q>, this, method_name, pfi::lang::_1, pfi::lang::_2);
+ template <typename R, typename A0> //, typename A1, typename A2>
+ void register_random(std::string method_name) {
+ pfi::lang::function<R(std::string,A0)> f = pfi::lang::bind(&keeper::template random_proxy<R,A0>, this, method_name, pfi::lang::_1, pfi::lang::_2);
add(method_name, f);
}
-
- template <typename Q>
- void register_broadcast_update(std::string method_name) {
- pfi::lang::function<int(std::string, Q)> f =
- pfi::lang::bind(&keeper::template broadcast_proxy<int,Q>, this, method_name, pfi::lang::_1, pfi::lang::_2);
+ template <typename R, typename A0, typename A1>//, typename A2>
+ void register_random(std::string method_name) {
+ pfi::lang::function<R(std::string,A0,A1)> f = pfi::lang::bind(&keeper::template random_proxy<R,A0,A1>, this, method_name, pfi::lang::_1, pfi::lang::_2, pfi::lang::_3);
+ add(method_name,f);
+ }
+
+ template <typename R, typename A0>
+ void register_broadcast(std::string method_name,
+ pfi::lang::function<R(R,R)> agg){//pfi::lang::function<R(R,R)>& agg) {
+ pfi::lang::function<R(std::string, A0)> f =
+ pfi::lang::bind(&keeper::template broadcast_proxy<R, A0>, this, method_name, pfi::lang::_1, pfi::lang::_2,
+ agg);
add(method_name, f);
}
-
- template <typename R, typename Q>
- void register_broadcast_analysis(std::string method_name) {
- pfi::lang::function<R(std::string, Q)> f =
- pfi::lang::bind(&keeper::template broadcast_proxy<R, Q>, this, method_name, pfi::lang::_1, pfi::lang::_2);
+ template <typename R>
+ void register_broadcast(std::string method_name,
+ pfi::lang::function<R(R,R)> agg){//pfi::lang::function<R(R,R)>& agg) {
+ pfi::lang::function<R(std::string)> f =
+ pfi::lang::bind(&keeper::template broadcast_proxy<R>, this, method_name, pfi::lang::_1,
+ agg);
add(method_name, f);
}
- template <typename Q>
- void register_cht_update(std::string method_name) {
- pfi::lang::function<int(std::string, std::string, Q)> f =
- pfi::lang::bind(&keeper::template cht_proxy<int, Q>, this, method_name, pfi::lang::_1, pfi::lang::_2, pfi::lang::_3);
+
+ template <typename R>
+ void register_cht(std::string method_name, pfi::lang::function<R(R,R)> agg) {
+ pfi::lang::function<R(std::string, std::string)> f =
+ pfi::lang::bind(&keeper::template cht_proxy<R>, this, method_name, pfi::lang::_1, pfi::lang::_2, agg);
+ add(method_name, f);
+ }
+ template <typename R, typename A0>
+ void register_cht(std::string method_name, pfi::lang::function<R(R,R)> agg) {
+ pfi::lang::function<R(std::string, std::string, A0)> f =
+ pfi::lang::bind(&keeper::template cht_proxy<R,A0>, this, method_name, pfi::lang::_1, pfi::lang::_2, pfi::lang::_3, agg);
add(method_name, f);
}
- template <typename R, typename Q>
- void register_cht_analysis(std::string method_name) {
- pfi::lang::function<R(std::string, std::string, Q)> f =
- pfi::lang::bind(&keeper::template cht_proxy<R, Q>, this, method_name, pfi::lang::_1, pfi::lang::_2, pfi::lang::_3);
+ template <typename R, typename A0, typename A1>
+ void register_cht(std::string method_name, pfi::lang::function<R(R,R)> agg) {
+ pfi::lang::function<R(std::string, std::string, A0, A1)> f =
+ pfi::lang::bind(&keeper::template cht_proxy<R, A0, A1>, this, method_name, pfi::lang::_1, pfi::lang::_2, pfi::lang::_3, pfi::lang::_4, agg);
add(method_name, f);
}
private:
+ template <typename R>
+ R random_proxy(const std::string& method_name, const std::string& name){
+ // {DLOG(INFO)<< __func__ << " " << method_name << " " << name;}
+ std::vector<std::pair<std::string, int> > list;
+ get_members_(name, list);
+
+ if(list.empty())throw std::runtime_error(method_name + ": no worker serving");
+ const std::pair<std::string, int>& c = list[rng_(list.size())];
+
+ try{
+ pfi::network::mprpc::rpc_client cli(c.first, c.second, a_.timeout);
+ return cli.call<R(std::string)>(method_name)(name);
+ }catch(const std::exception& e){
+ LOG(ERROR) << e.what() << " from " << c.first << ":" << c.second;
+ throw e;
+ }
+ }
template <typename R, typename A>
- R random_proxy(const std::string& method_name, const std::string& name, const A& arg) {
+ R random_proxy(const std::string& method_name, const std::string& name, const A& arg){
// {DLOG(INFO)<< __func__ << " " << method_name << " " << name;}
std::vector<std::pair<std::string, int> > list;
get_members_(name, list);
- const std::pair<std::string, int>& c = list[rng_(list.size())];
if(list.empty())throw std::runtime_error(method_name + ": no worker serving");
+ const std::pair<std::string, int>& c = list[rng_(list.size())];
+
+ try{
+ pfi::network::mprpc::rpc_client cli(c.first, c.second, a_.timeout);
+ return cli.call<R(std::string,A)>(method_name)(name, arg);
+ }catch(const std::exception& e){
+ LOG(ERROR) << e.what() << " from " << c.first << ":" << c.second;
+ throw e;
+ }
+ }
+ template <typename R, typename A0, typename A1>
+ R random_proxy(const std::string& method_name, const std::string& name, const A0& a0, const A1& a1){
+ std::vector<std::pair<std::string, int> > list;
+
+ get_members_(name, list);
- // this code didn't work: rpc_client instance cannot be desctucted so too much live connections generated
- // and accept/read thread in pficommon server exhausted.
- //return pfi::network::mprpc::rpc_client(c.first, c.second, a_.timeout).call<R(std::string,A)>(method_name)(name, arg);
+ if(list.empty())throw std::runtime_error(method_name + ": no worker serving");
+ const std::pair<std::string, int>& c = list[rng_(list.size())];
- pfi::network::mprpc::rpc_client cli(c.first, c.second, a_.timeout);
- // {DLOG(INFO)<< "accssing to " << c.first << " " << c.second;}
- return cli.call<R(std::string,A)>(method_name)(name, arg);
+ try{
+ pfi::network::mprpc::rpc_client cli(c.first, c.second, a_.timeout);
+ return cli.call<R(std::string,A0,A1)>(method_name)(name, a0, a1);
+ }catch(const std::exception& e){
+ LOG(ERROR) << e.what() << " from " << c.first << ":" << c.second;
+ throw e;
+ }
}
+ template <typename R>
+ R broadcast_proxy(const std::string& method_name, const std::string& name,
+ pfi::lang::function<R(R,R)>& agg) {
+ // {DLOG(INFO)<< __func__ << " " << method_name << " " << name;}
+ std::vector<std::pair<std::string, int> > list;
+
+ get_members_(name, list);
+ if(list.empty())throw std::runtime_error(method_name + ": no worker serving");
+
+ try{
+ jubatus::common::mprpc::rpc_mclient c(list, a_.timeout);
+ c.call_async(method_name, name);
+ return c.join_all<R>(agg);
+ }catch(const std::exception& e){
+ LOG(ERROR) << e.what(); // << " from " << c.first << ":" << c.second;
+ throw e;
+ }
+ }
template <typename R, typename A>
- // FIXME: modify return type
- R broadcast_proxy(const std::string& method_name, const std::string& name, const A& arg) {
- // {LOG(INFO)<< __func__ << " " << method_name << " " << name;}
+ R broadcast_proxy(const std::string& method_name, const std::string& name, const A& arg,
+ pfi::lang::function<R(R,R)>& agg) {
+ // {DLOG(INFO)<< __func__ << " " << method_name << " " << name;}
std::vector<std::pair<std::string, int> > list;
get_members_(name, list);
- // std::vector<R> results;
- // FIXME: needs global lock here
- R res;
- for (size_t i = 0; i < list.size(); ++i) {
- const std::pair<std::string, int>& c = list[i];
- pfi::network::mprpc::rpc_client cli(c.first, c.second, a_.timeout);
- res = cli.call<R(std::string,A)>(method_name)(name,arg);
- // results.push_back(res);
+ if(list.empty())throw std::runtime_error(method_name + ": no worker serving");
+
+ try{
+ jubatus::common::mprpc::rpc_mclient c(list, a_.timeout);
+ c.call_async(method_name, name, arg);
+ std::cout << __LINE__ << " name:" << name << " method:" << method_name << std::endl;
+ return c.join_all<R>(agg);
+ }catch(const std::runtime_error& e){
+ std::cout << __LINE__ << e.what() << std::endl;
+ // LOG(ERROR) << e.what(); // << " from " << c.first << ":" << c.second;
+ throw e;
}
- {LOG(INFO)<< __func__;}
- // return results;
- return res;
}
- template <typename R, typename A>
- R cht_proxy(const std::string& method_name, const std::string& name, const std::string& key, const A& arg) {
- // {LOG(INFO)<< __func__ << " " << method_name << " " << name;}
+
+ template <typename R>
+ R cht_proxy(const std::string& method_name, const std::string& name, const std::string& id,
+ pfi::lang::function<R(R,R)>& agg) {
std::vector<std::pair<std::string, int> > list;
{
pfi::concurrent::scoped_lock lk(mutex_);
jubatus::common::cht ht(zk_, name);
- ht.find(key, list);
+ ht.find(id, list);
}
if(list.empty())throw std::runtime_error(method_name + ": no worker serving");
- R result;
- for(size_t i=0; i<list.size(); ++i){
- const std::pair<std::string, int>& c = list[i];
- pfi::network::mprpc::rpc_client cli(c.first, c.second, a_.timeout);
- result = cli.call<R(std::string,std::string,A)>(method_name)(name, key, arg);
+ try{
+ jubatus::common::mprpc::rpc_mclient c(list, a_.timeout);
+ c.call_async(method_name, name, id);
+ return c.join_all<R>(agg);
+ }catch(const std::exception& e){
+ LOG(ERROR) << e.what(); // << " from " << c.first << ":" << c.second;
+ throw e;
}
- return result;
}
+ template <typename R, typename A0>
+ R cht_proxy(const std::string& method_name, const std::string& name, const std::string& id, const A0& arg,
+ pfi::lang::function<R(R,R)>& agg) {
+ std::vector<std::pair<std::string, int> > list;
+ {
+ pfi::concurrent::scoped_lock lk(mutex_);
+ jubatus::common::cht ht(zk_, name);
+ ht.find(id, list);
+ }
+ if(list.empty())throw std::runtime_error(method_name + ": no worker serving");
+
+ try{
+ jubatus::common::mprpc::rpc_mclient c(list, a_.timeout);
+ c.call_async(method_name, name, id, arg);
+ return c.join_all<R>(agg);
+ }catch(const std::exception& e){
+ LOG(ERROR) << e.what(); // << " from " << c.first << ":" << c.second;
+ throw e;
+ }
+ }
+ template <typename R, typename A0, typename A1>
+ R cht_proxy(const std::string& method_name, const std::string& name, const std::string& id, const A0& a0, const A1& a1,
+ pfi::lang::function<R(R,R)>& agg) {
+ std::vector<std::pair<std::string, int> > list;
+ {
+ pfi::concurrent::scoped_lock lk(mutex_);
+ jubatus::common::cht ht(zk_, name);
+ ht.find(id, list);
+ }
+ if(list.empty())throw std::runtime_error(method_name + ": no worker serving");
+
+ try{
+ jubatus::common::mprpc::rpc_mclient c(list, a_.timeout);
+ c.call_async(method_name, name, id, a0, a1);
+ return c.join_all<R>(agg);
+ }catch(const std::exception& e){
+ LOG(ERROR) << e.what(); // << " from " << c.first << ":" << c.second;
+ throw e;
+ }
+ }
+
void get_members_(const std::string& name, std::vector<std::pair<std::string, int> >& ret);
keeper_argv a_;
20 src/framework/server_util.cpp
View
@@ -18,11 +18,14 @@
#include "server_util.hpp"
#include <glog/logging.h>
+#include <iostream>
+
#include "../common/util.hpp"
#include "../common/cmdline.h"
#include "../common/exception.hpp"
#include "../common/membership.hpp"
+
#define SET_PROGNAME(s) \
static const std::string PROGNAME(JUBATUS_APPNAME "_" s);
@@ -31,6 +34,10 @@ namespace jubatus { namespace framework {
static const std::string VERSION(JUBATUS_VERSION);
+ void print_version(const std::string& progname){
+ std::cout << "jubatus-" << VERSION << " (" << progname << ")" << std::endl;
+ }
+
server_argv::server_argv(int args, char** argv){
google::InitGoogleLogging(argv[0]);
google::LogToStderr(); // only when debug
@@ -48,8 +55,15 @@ namespace jubatus { namespace framework {
p.add<int>("interval_sec", 's', "mix interval by seconds", false, 16);
p.add<int>("interval_count", 'i', "mix interval by update count", false, 512);
+ p.add("version", 'v', "version");
+
p.parse_check(args, argv);
+ if( p.exist("version") ){
+ print_version(argv[0]);
+ exit(0);
+ }
+
port = p.get<int>("rpc-port");
threadnum = p.get<int>("thread");
timeout = p.get<int>("timeout");
@@ -93,9 +107,15 @@ namespace jubatus { namespace framework {
p.add<int>("timeout", 't', "time out (sec)", false, 10);
p.add<std::string>("zookeeper", 'z', "zookeeper location", false, "localhost:2181");
+ p.add("version", 'v', "version");
p.parse_check(args, argv);
+ if( p.exist("version") ){
+ print_version(argv[0]);
+ exit(0);
+ }
+
port = p.get<int>("rpc-port");
threadnum = p.get<int>("thread");
timeout = p.get<int>("timeout");
5 src/framework/wscript
View
@@ -14,7 +14,7 @@ def build(bld):
source = framework_source,
target = 'jubatus_framework',
includes = '.',
- use = 'PFICOMMON jubacommon MSGPACK GLOG'
+ use = 'PFICOMMON jubacommon MSGPACK GLOG jubacommon_mprpc'
)
bld.install_files('${PREFIX}/include/jubatus/framework', [
@@ -22,5 +22,6 @@ def build(bld):
'keeper.hpp',
'server_util.hpp',
'mixable.hpp',
- 'mixer.hpp'
+ 'mixer.hpp',
+ 'aggregators.hpp'
])
21 src/fv_converter/counter.hpp
View
@@ -17,7 +17,10 @@
#pragma once
+#include <pficommon/data/serialization.h>
+#include <pficommon/data/serialization/unordered_map.h>
#include <pficommon/data/unordered_map.h>
+#include "../common/unordered_map.hpp"
namespace jubatus {
namespace fv_converter {
@@ -64,7 +67,23 @@ class counter {
iterator end() {
return data_.end();
}
-
+
+ void clear() {
+ data_.clear();
+ }
+
+ void add(const counter<T>& counts) {
+ for (const_iterator it = counts.begin(); it != counts.end(); ++it) {
+ (*this)[it->first] += it->second;
+ }
+ }
+
+ MSGPACK_DEFINE(data_);
+ template <class Archiver>
+ void serialize(Archiver &ar) {
+ ar
+ & MEMBER(data_);
+ }
private:
pfi::data::unordered_map<T, unsigned> data_;
};
15 src/fv_converter/counter_test.cpp
View
@@ -38,4 +38,19 @@ TEST(counter, trivial) {
EXPECT_EQ(2u, c["fuga"]);
}
+TEST(counter, add) {
+ counter<string> x, y;
+ x["hoge"] = 1;
+ x["fuga"] = 2;
+
+ y["foo"] = 5;
+ y["hoge"] = 3;
+
+ x.add(y);
+
+ EXPECT_EQ(4u, x["hoge"]);
+ EXPECT_EQ(2u, x["fuga"]);
+ EXPECT_EQ(5u, x["foo"]);
+}
+
}
32 src/fv_converter/datum_to_fv_converter_test.cpp
View
@@ -37,6 +37,7 @@
#include "num_filter_impl.hpp"
+#include "weight_manager.hpp"
#include "converter_config.hpp"
#include "exception.hpp"
@@ -47,6 +48,7 @@ using namespace pfi::lang;
TEST(datum_to_fv_converter, trivial) {
datum_to_fv_converter conv;
+ weight_manager wm;
}
TEST(datum_to_fv_converter, num_feature) {
@@ -55,6 +57,7 @@ TEST(datum_to_fv_converter, num_feature) {
datum.num_values_.push_back(make_pair("/val2", 0.));
datum_to_fv_converter conv;
+ weight_manager wm;
typedef shared_ptr<num_feature> num_feature_t;
shared_ptr<key_matcher> a(new match_all());
@@ -62,6 +65,8 @@ TEST(datum_to_fv_converter, num_feature) {
conv.register_num_rule("log", a, num_feature_t(new num_log_feature()));
vector<pair<string, float> > feature;
conv.convert(datum, feature);
+ wm.update_weight(feature);
+ wm.get_weight(feature);
vector<pair<string, float> > expected;
expected.push_back(make_pair("/val1@num", 1.1));
@@ -78,6 +83,7 @@ TEST(datum_to_fv_converter, string_feature) {
typedef shared_ptr<word_splitter> splitter_t;
datum_to_fv_converter conv;
+ weight_manager wm;
{
shared_ptr<word_splitter> s(new space_splitter());
vector<splitter_weight_type> p;
@@ -112,12 +118,16 @@ TEST(datum_to_fv_converter, string_feature) {
datum.string_values_.push_back(make_pair("/name", "doc0"));
datum.string_values_.push_back(make_pair("/title", " this is "));
conv.convert(datum, feature);
+ wm.update_weight(feature);
+ wm.get_weight(feature);
}
{
datum datum;
datum.string_values_.push_back(make_pair("/name", "doc1"));
datum.string_values_.push_back(make_pair("/title", " this is it . it is it ."));
conv.convert(datum, feature);
+ wm.update_weight(feature);
+ wm.get_weight(feature);
}
vector<pair<string, float> > expected;
@@ -154,6 +164,7 @@ TEST(datum_to_fv_converter, string_feature) {
TEST(datum_to_fv_converter, weight) {
datum_to_fv_converter conv;
+ weight_manager wm;
{
shared_ptr<key_matcher> match(new match_all());
shared_ptr<word_splitter> s(new space_splitter());
@@ -161,13 +172,16 @@ TEST(datum_to_fv_converter, weight) {
p.push_back(splitter_weight_type(FREQ_BINARY, WITH_WEIGHT_FILE));
conv.register_string_rule("space", match, s, p);
}
- conv.add_weight("/id$a@space", 3.f);
+ wm.add_weight("/id$a@space", 3.f); // <- new
+ conv.add_weight("/id$a@space", 3.f); // <-deprecated
datum datum;
datum.string_values_.push_back(make_pair("/id", "a b"));
vector<pair<string, float> > feature;
conv.convert(datum, feature);
+ wm.update_weight(feature);
+ wm.get_weight(feature);
ASSERT_EQ(1u, feature.size());
ASSERT_EQ("/id$a@space#bin/weight", feature[0].first);
@@ -176,6 +190,7 @@ TEST(datum_to_fv_converter, weight) {
TEST(datum_to_fv_converter, register_string_rule) {
datum_to_fv_converter conv;
+ weight_manager wm;
initialize_converter(converter_config(), conv);
vector<splitter_weight_type> p;
@@ -189,6 +204,8 @@ TEST(datum_to_fv_converter, register_string_rule) {
vector<pair<string, float> > feature;
conv.convert(datum, feature);
+ wm.update_weight(feature);
+ wm.get_weight(feature);
vector<pair<string, float> > exp;
exp.push_back(make_pair("/id$a@1gram#bin/bin", 1.));
@@ -202,6 +219,7 @@ TEST(datum_to_fv_converter, register_string_rule) {
TEST(datum_to_fv_converter, register_num_rule) {
datum_to_fv_converter conv;
+ weight_manager wm;
datum datum;
datum.num_values_.push_back(make_pair("/age", 20));
@@ -209,6 +227,8 @@ TEST(datum_to_fv_converter, register_num_rule) {
{
vector<pair<string, float> > feature;
conv.convert(datum, feature);
+ wm.update_weight(feature);
+ wm.get_weight(feature);
EXPECT_EQ(0u, feature.size());
}
@@ -219,6 +239,8 @@ TEST(datum_to_fv_converter, register_num_rule) {
{
vector<pair<string, float> > feature;
conv.convert(datum, feature);
+ wm.update_weight(feature);
+ wm.get_weight(feature);
EXPECT_EQ(1u, feature.size());
vector<pair<string, float> > exp;
@@ -232,6 +254,7 @@ TEST(datum_to_fv_converter, register_num_rule) {
TEST(datum_to_fv_converter, register_string_filter) {
datum_to_fv_converter conv;
+ weight_manager wm;
datum datum;
datum.string_values_.push_back(make_pair("/text", "<tag>aaa</tag>"));
@@ -245,6 +268,8 @@ TEST(datum_to_fv_converter, register_string_filter) {
{
vector<pair<string, float> > feature;
conv.convert(datum, feature);
+ wm.update_weight(feature);
+ wm.get_weight(feature);
EXPECT_EQ(1u, feature.size());
}
@@ -256,6 +281,8 @@ TEST(datum_to_fv_converter, register_string_filter) {
{
vector<pair<string, float> > feature;
conv.convert(datum, feature);
+ wm.update_weight(feature);
+ wm.get_weight(feature);
EXPECT_EQ(2u, feature.size());
EXPECT_EQ("/text_filtered$aaa@str#bin/bin", feature[1].first);
}
@@ -264,6 +291,7 @@ TEST(datum_to_fv_converter, register_string_filter) {
TEST(datum_to_fv_converter, register_num_filter) {
datum_to_fv_converter conv;
+ weight_manager wm;
datum datum;
datum.num_values_.push_back(make_pair("/age", 20));
@@ -279,6 +307,8 @@ TEST(datum_to_fv_converter, register_num_filter) {
vector<pair<string, float> > feature;
conv.convert(datum, feature);
+ wm.update_weight(feature);
+ wm.get_weight(feature);
EXPECT_EQ(2u, feature.size());
EXPECT_EQ("/age+5@str$25", feature[1].first);
74 src/fv_converter/keyword_weights.cpp
View
@@ -0,0 +1,74 @@
+// Jubatus: Online machine learning framework for distributed environment
+// Copyright (C) 2012 Preferred Infrastracture and Nippon Telegraph and Telephone Corporation.
+//
+// This library is free software; you can redistribute it and/or
+// modify it under the terms of the GNU Lesser General Public
+// License as published by the Free Software Foundation; either
+// version 2.1 of the License, or (at your option) any later version.
+//
+// This library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+// Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public
+// License along with this library; if not, write to the Free Software
+// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+#include <cmath>
+#include "../common/type.hpp"
+#include "keyword_weights.hpp"
+#include "datum_to_fv_converter.hpp"
+
+namespace jubatus {
+namespace fv_converter {
+
+using namespace std;
+
+keyword_weights::keyword_weights()
+ : document_count_(),
+ document_frequencies_(),
+ weights_() {}
+
+struct is_zero {
+ bool operator()(const pair<string, float>& p) {
+ return p.second == 0;
+ }
+};
+
+void keyword_weights::update_document_frequency(const sfv_t& fv) {
+ ++document_count_;
+ for (sfv_t::const_iterator it = fv.begin(); it != fv.end(); ++it) {
+ ++document_frequencies_[it->first];
+ }
+}
+
+void keyword_weights::add_weight(const std::string& key, float weight) {
+ weights_[key] = weight;
+}
+
+float keyword_weights::get_user_weight(const std::string& key) const {
+ weight_t::const_iterator wit = weights_.find(key);
+ if (wit != weights_.end()) {
+ return wit->second;
+ } else {
+ return 0;
+ }
+}
+
+void keyword_weights::merge(const keyword_weights& w) {
+ document_count_ += w.document_count_;
+ document_frequencies_.add(w.document_frequencies_);
+ weight_t weights(w.weights_);
+ weights.insert(weights_.begin(), weights_.end());
+ weights_.swap(weights);
+}
+
+void keyword_weights::clear() {
+ document_count_ = 0;
+ document_frequencies_.clear();
+ weights_.clear();
+}
+
+}
+}
72 src/fv_converter/keyword_weights.hpp
View
@@ -0,0 +1,72 @@
+// Jubatus: Online machine learning framework for distributed environment
+// Copyright (C) 2012 Preferred Infrastracture and Nippon Telegraph and Telephone Corporation.
+//
+// This library is free software; you can redistribute it and/or
+// modify it under the terms of the GNU Lesser General Public
+// License as published by the Free Software Foundation; either
+// version 2.1 of the License, or (at your option) any later version.
+//
+// This library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+// Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public
+// License along with this library; if not, write to the Free Software
+// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+#pragma once
+
+#include "datum.hpp"
+#include "counter.hpp"
+#include <pficommon/data/unordered_map.h>
+#include "../common/type.hpp"
+#include <msgpack.hpp>
+
+namespace jubatus {
+namespace fv_converter {
+
+class keyword_weights {
+ public:
+ keyword_weights();
+
+ void update_document_frequency(const sfv_t& fv);
+
+ size_t get_document_frequency(const std::string& key) const {
+ return document_frequencies_[key];
+ }
+
+ size_t get_document_count() const {
+ return document_count_;
+ }
+
+ void add_weight(const std::string& key, float weight);
+
+ float get_user_weight(const std::string& key) const;
+
+ void merge(const keyword_weights& w);
+
+ void clear();
+
+ MSGPACK_DEFINE(document_count_, document_frequencies_, weights_);
+ template <class Archiver>
+ void serialize(Archiver &ar) {
+ ar
+ & MEMBER(document_count_)
+ & MEMBER(document_frequencies_)
+ & MEMBER(weights_);
+ }
+
+ private:
+ double get_global_weight(const std::string& key) const;
+
+ size_t document_count_;
+ counter<std::string> document_frequencies_;
+ typedef pfi::data::unordered_map<std::string, float> weight_t;
+ weight_t weights_;
+
+
+};
+
+}
+}
57 src/fv_converter/keyword_weights_test.cpp
View
@@ -0,0 +1,57 @@
+#include <gtest/gtest.h>
+#include <cmath>
+
+#include "keyword_weights.hpp"
+#include "../common/type.hpp"
+
+namespace jubatus {
+namespace fv_converter {
+
+using namespace std;
+
+TEST(keyword_weights, trivial) {
+ keyword_weights m, m2;
+ {
+ sfv_t fv;
+
+ m.update_document_frequency(fv);
+
+ fv.push_back(make_pair("key1", 1.0));
+ fv.push_back(make_pair("key2", 1.0));
+ m.update_document_frequency(fv);
+
+ m.add_weight("key3", 2.0);
+
+ EXPECT_EQ(2u, m.get_document_count());
+ EXPECT_EQ(1u, m.get_document_frequency("key1"));
+ EXPECT_EQ(0u, m.get_document_frequency("unknown"));
+ }
+
+ {
+ sfv_t fv;
+ m2.update_document_frequency(fv);
+
+ fv.push_back(make_pair("key1", 1.0));
+ fv.push_back(make_pair("key2", 1.0));
+ m2.update_document_frequency(fv);
+
+ m2.add_weight("key3", 3.0);
+
+ m.merge(m2);
+
+ EXPECT_EQ(4u, m.get_document_count());
+ EXPECT_EQ(2u, m.get_document_frequency("key1"));
+ EXPECT_EQ(3.0, m.get_user_weight("key3"));
+ EXPECT_EQ(0u, m.get_document_frequency("unknown"));
+ }
+
+ {
+ m.clear();
+ EXPECT_EQ(0u, m.get_document_count());
+ EXPECT_EQ(0u, m.get_document_frequency("key1"));
+ EXPECT_EQ(0.0, m.get_user_weight("key3"));
+ }
+}
+
+}
+}
29 src/fv_converter/weight_manager.cpp
View
@@ -26,9 +26,7 @@ namespace fv_converter {
using namespace std;
weight_manager::weight_manager()
- : document_count_(),
- document_frequencies_(),
- weights_() {}
+ : diff_weights_(), master_weights_() {}
struct is_zero {
bool operator()(const pair<string, float>& p) {
@@ -36,12 +34,11 @@ struct is_zero {
}
};
-void weight_manager::update_weight(sfv_t& fv) {
- ++document_count_;
- for (sfv_t::const_iterator it = fv.begin(); it != fv.end(); ++it) {
- ++document_frequencies_[it->first];
- }
+void weight_manager::update_weight(const sfv_t& fv) {
+ diff_weights_.update_document_frequency(fv);
+}
+void weight_manager::get_weight(sfv_t& fv) const {
for (sfv_t::iterator it = fv.begin(); it != fv.end(); ++it) {
double global_weight = get_global_weight(it->first);
it->second *= global_weight;
@@ -49,7 +46,6 @@ void weight_manager::update_weight(sfv_t& fv) {
fv.erase(remove_if(fv.begin(), fv.end(), is_zero()), fv.end());
}
-
double weight_manager::get_global_weight(const string& key) const {