Skip to content
This repository has been archived by the owner on May 24, 2018. It is now read-only.

Commit

Permalink
add iter_solver_test
Browse files Browse the repository at this point in the history
  • Loading branch information
mli committed Aug 9, 2015
1 parent 40aefb9 commit 16a4c2e
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 76 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Expand Up @@ -40,3 +40,5 @@ deps
bin
tracker
*.d

doc
69 changes: 35 additions & 34 deletions learn/base/progress.h
Expand Up @@ -10,40 +10,6 @@
#include "dmlc/logging.h"

namespace dmlc {
namespace wormhole {

/**
* \brief Serializable progress for objective value, accuracy, ...
*/
class Progress : public Serializable {
public:
Progress() { }
virtual ~Progress() { }

/// \brief merge from another progress
virtual void Merge(const Progress* const other) {
CHECK_EQ(other->vec_.size(), vec_.size());
for (size_t i = 0; i < vec_.size(); ++i) {
vec_[i] += other->vec_[i];
}
}

/// \brief Returns a header string for printing
virtual std::string HeadStr() = 0;

/// \brief Returns a string for printing
virtual std::string PrintStr() = 0;

virtual void Load(Stream* fi) { fi->Read(&vec_); }
virtual void Save(Stream *fo) const { fo->Write(vec_); }

double& objv() { return vec_[0]; }
double objv() const { return vec_[0]; }
protected:
std::vector<double> vec_;
};

} // namespace wormhole

/// DEPRECATED

Expand Down Expand Up @@ -120,3 +86,38 @@ class VectorProgress : public IProgress {
};

} // namespace dmlc

// namespace wormhole {

// /**
// * \brief Serializable progress for objective value, accuracy, ...
// */
// class Progress : public Serializable {
// public:
// Progress() { }
// virtual ~Progress() { }

// /// \brief merge from another progress
// virtual void Merge(const Progress* const other) {
// CHECK_EQ(other->vec_.size(), vec_.size());
// for (size_t i = 0; i < vec_.size(); ++i) {
// vec_[i] += other->vec_[i];
// }
// }

// /// \brief Returns a header string for printing
// virtual std::string HeadStr() = 0;

// /// \brief Returns a string for printing
// virtual std::string PrintStr() = 0;

// virtual void Load(Stream* fi) { fi->Read(&vec_); }
// virtual void Save(Stream *fo) const { fo->Write(vec_); }

// double& objv() { return vec_[0]; }
// double objv() const { return vec_[0]; }
// protected:
// std::vector<double> vec_;
// };

// } // namespace wormhole
28 changes: 18 additions & 10 deletions learn/base/workload_pool.h
Expand Up @@ -6,18 +6,30 @@
#include <sstream>
#include <vector>
#include <list>
#include <unordered_set>
#include <unordered_map>
#include <mutex>
namespace dmlc {

/**
* @brief A thread-safe workload pool
*/
class WorkloadPool {
public:
WorkloadPool(bool shuffle = true) : shuffle_(shuffle) {
static void Match(const std::string& pattern, Workload* wl) {
std::vector<std::string> files;
MatchFile(pattern, &files);
wl->file.resize(files.size());
for (size_t i = 0; i < files.size(); ++i) {
wl->file[i].filename = files[i];
}
}

WorkloadPool() {
straggler_killer_ = new std::thread([this]() {
while (!done_) {
RemoveStraggler();
sleep(2);
// detecter straggler for every 2 second
RemoveStraggler(); sleep(2);
}
});
}
Expand All @@ -29,13 +41,9 @@ class WorkloadPool {
}
}

static void Match(const std::string& pattern, Workload* wl) {
std::vector<std::string> files;
MatchFile(pattern, &files);
wl->file.resize(files.size());
for (size_t i = 0; i < files.size(); ++i) {
wl->file[i].filename = files[i];
}

void Init(bool shuffle, int timeout) {
// TODO
}

void Add(const std::vector<Workload::File>& files, int npart,
Expand Down
79 changes: 49 additions & 30 deletions learn/solver/iter_solver.h
Expand Up @@ -2,8 +2,10 @@
* @file iter_solver.h
* @brief Template for an iterate solver
*/
#include "base/workload.h"
#include "ps.h"
#include "base/string_stream.h"
#include "base/workload.h"
#include "base/workload_pool.h"
namespace dmlc {
namespace solver {

Expand Down Expand Up @@ -145,7 +147,7 @@ class IterScheduler : public ps::App {
/// \brief run iterations
virtual bool Run() {
printf("Connected %d servers and %d workers\n",
ps::NodeInfo()::NumServers(), ps::NodeInfo()::NumWorkers());
ps::NodeInfo::NumServers(), ps::NodeInfo::NumWorkers());

start_time_ = GetTime();

Expand Down Expand Up @@ -185,27 +187,27 @@ class IterScheduler : public ps::App {
}

SaveModel(true);
Printf("Training finished!\n");
printf("Training finished!\n");
return true;
}


virtual void ProcessResponse(ps::Message* response) {
if (response->task.cmd() == kProcess) {
IterCmd cmd(response->task.cmd());
if (cmd.process()) {
auto id = response->sender;
pool_.Finish(id);

if (response->task.msg().size()) {
CHECK(worker_local_data_);
CHECK(use_worker_local_data_);
StringStream ss(response->task.msg());
Workload wl; wl.Load(&ss);
pool_.Add(wl.file, num_part_per_file_, id);
pool_.Add(wl.file, num_parts_per_file_, id);
return;
}

pool_.Finish(id);
Workload wl; pool_.Get(id, &wl);
if (!wl.Empty()) {
CHECK_EQ(wl.file.size(), (size_t)1);
wl.type = cur_type_;
wl.type = cur_task_;
wl.data_pass = cur_data_pass_;
wl.file[0].format = data_format_;
SendWorkload(id, wl);
Expand Down Expand Up @@ -238,29 +240,31 @@ class IterScheduler : public ps::App {

pool_.Clear(); pool_.Init(shuffle_, straggler_);

if (!use_worker_local_data_) {
if (use_worker_local_data_) {
// ask the workers to match the files
Workload wl;
wl.file.resize(1);
wl.file[0].filename = data;
wl.file[0].n = 0;
Wait(SendWorkload(ps::kWorkerGroup, wl));
} else {
// i will do it
Workload wl; pool_.Match(data, &wl);
pool_.Add(wl.file, num_part_per_file_);
pool_.Add(wl.file, num_parts_per_file_);
if (is_predict) {
CHECK_EQ(wl.file.size(), (size_t)1)
<< "use single file for prediction";
}
int npart = wl.file.size() * num_part_per_file_;
if (cur_data_pass_ == 0 && (npart < ps::NodeInfo()::NumWorkers())) {
int npart = wl.file.size() * num_parts_per_file_;
if (cur_data_pass_ == 0 && (npart < ps::NodeInfo::NumWorkers())) {
fprintf(stderr, "WARNING: # of data parts (%d) < # of workers (%d)\n",
npart, ps::NodeInfo()::NumWorkers());
npart, ps::NodeInfo::NumWorkers());
fprintf(stderr, " You may want to increase \"num_parts_per_file\"\n");
}
}

// ask all workers to start
Workload wl;
if (use_worker_local_data_) {
wl.file.resize(1);
wl.file[0].filename = data;
wl.file[0].n = 0;
}
SendWorkload(ps::kWorkerGroup, wl);
// ask all workers to start by sending an empty workload
Workload wl; SendWorkload(ps::kWorkerGroup, wl);

// print every k sec for training
printf(" sec %s\n", ProgHeader().c_str());
Expand Down Expand Up @@ -295,14 +299,29 @@ class IterScheduler : public ps::App {
return Stop(prog, is_train);
}

void SendWorkload(const std::string id, const Workload& wl) {
int SendWorkload(const std::string id, const Workload& wl) {
StringStream ss; wl.Save(&ss);
ps::Task task; task.set_msg(ss.str());
IterCmd cmd; cmd.set_process();
task.set_cmd(cmd.cmd); Submit(task, id);
task.set_cmd(cmd.cmd); return Submit(task, id);
}

void SaveModel(bool force) {
if (model_out_.size() == 0) return;
if (force || (save_iter_ > 0 && (cur_data_pass_+1) % save_iter_ == 0)) {
int iter = force ? 0 : cur_data_pass_;
if (iter == 0) {
printf("Saving final model to %s\n", model_out_.c_str());
} else {
printf("Saving model to %s-iter_%d\n", model_out_.c_str(), iter);
}
IterCmd cmd; cmd.set_save_model(); cmd.set_iter(iter);
ps::Task task; task.set_cmd(cmd.cmd); task.set_msg(model_out_);
Wait(Submit(task, ps::kServerGroup));
}
}

Root<double> monitor_;
ps::Root<double> monitor_;
WorkloadPool pool_;
double start_time_;
};
Expand Down Expand Up @@ -333,7 +352,7 @@ class IterServer : public ps::App {
auto filename = ModelName(request->task.msg(), cmd.iter());
if (cmd.save_model()) {
Stream* fo = CHECK_NOTNULL(Stream::Create(filename.c_str(), "w"));
SaveModel(fi);
SaveModel(fo);
} else if (cmd.load_model()) {
Stream* fi = CHECK_NOTNULL(Stream::Create(filename.c_str(), "r"));
LoadModel(fi);
Expand All @@ -344,11 +363,11 @@ class IterServer : public ps::App {
CHECK(base.size()) << "empty model name";
std::string name = base;
if (iter > 0) name += "_iter-" + std::to_string(iter);
return name + "_part-" + std::to_string(ps::NodeInfo()::MyRank());
return name + "_part-" + std::to_string(ps::NodeInfo::MyRank());
}

public:
Slave<double> reporter_;
ps::Slave<double> reporter_;
};

/**
Expand Down Expand Up @@ -388,7 +407,7 @@ class IterWorker : public ps::App {
}
}
private:
Slave<double> reporter_;
ps::Slave<double> reporter_;
};

} // namespace solver
Expand Down
9 changes: 9 additions & 0 deletions learn/test/Makefile
@@ -0,0 +1,9 @@
include ../../make/ps_app.mk

all: build/iter_solver_test

clean:
rm -rf build

build/iter_solver_test: build/iter_solver_test.o $(DMLC_SLIB)
$(CXX) $(CFLAGS) $(filter %.o %.a, $^) $(LDFLAGS) -o $@
74 changes: 74 additions & 0 deletions learn/test/iter_solver_test.cc
@@ -0,0 +1,74 @@
/**
* @file iter_solver_test.cc
* @brief
*
* run: tracker/dmlc_local.py -s 2 -n 4 learn/test/build/iter_solver_test
*/

#include "solver/iter_solver.h"

DEFINE_string(train_data, "", "");
DEFINE_string(val_data, "", "");
DEFINE_bool(batch, false, "");

namespace dmlc {

class IterTestScheduler : public solver::IterScheduler {
public:
IterTestScheduler() {
train_data_ = FLAGS_train_data;
val_data_ = FLAGS_val_data;
batch_ = FLAGS_batch;
data_format_ = "libsvm";
}
virtual ~IterTestScheduler() { }

};

class IterTestServer : public solver::IterServer {
public:
IterTestServer() { }
virtual ~IterTestServer() { }

virtual void SaveModel(Stream* fo) const { fo->Write(model_); }
virtual void LoadModel(Stream* fi) { fi->Read(&model_); }

private:
std::vector<float> model_; // a fake model
};

class IterTestWorker : public solver::IterWorker {
public:
IterTestWorker() { }
virtual ~IterTestWorker() { }

virtual void Process(const Workload& wl) {
printf("worker %d: %s\n", ps::NodeInfo::MyRank(), wl.ShortDebugString().c_str());
srand(time(NULL));
int t = rand() % 10000;
usleep(t);
std::vector<double> p(1, t);
Report(p);
}
};
} // namespace dmlc

namespace ps {

App* App::Create(int argc, char *argv[]) {
NodeInfo info;
if (info.IsWorker()) {
return new ::dmlc::IterTestWorker();
} else if (info.IsServer()) {
return new ::dmlc::IterTestServer();
} else if (info.IsScheduler()) {
return new ::dmlc::IterTestScheduler();
}
return NULL;
}

} // namespace ps

int main(int argc, char *argv[]) {
return ps::RunSystem(&argc, &argv);
}
4 changes: 2 additions & 2 deletions make/ps_app.mk
Expand Up @@ -21,8 +21,8 @@ include $(CORE_PATH)/make/dmlc.mk

INCLUDE=-I./ -I../ -I$(PS_PATH)/src -I$(CORE_PATH)/include -I$(CORE_PATH)/src -I$(DEPS_PATH)/include

CFLAGS = -O3 -ggdb -Wall -std=c++11 $(INCLUDE) $(DMLC_CFLAGS) $(PS_CFLAGS) $(EXTRA_CFLAGS)
LDFLAGS = $(DMLC_LDFLAGS) $(PS_LDFLAGS) $(EXTRA_LDFLAGS)
CFLAGS += -O3 -ggdb -Wall -std=c++11 $(INCLUDE) $(DMLC_CFLAGS) $(PS_CFLAGS) $(EXTRA_CFLAGS)
LDFLAGS += $(DMLC_LDFLAGS) $(PS_LDFLAGS) $(EXTRA_LDFLAGS)

.DEFAULT_GOAL := all

Expand Down

0 comments on commit 16a4c2e

Please sign in to comment.