Permalink
Browse files

Make rabit library thread local

  • Loading branch information...
1 parent aeb4008 commit be50e7b63224b9fb7ff94ce34df9f8752ef83043 @tqchen tqchen committed Mar 2, 2016
Showing with 166 additions and 37 deletions.
  1. +9 −9 guide/Makefile
  2. +3 −3 guide/basic.cc
  3. +1 −1 guide/broadcast.cc
  4. +6 −5 guide/lazy_allreduce.cc
  5. +9 −1 src/allreduce_base.cc
  6. +1 −1 src/allreduce_base.h
  7. +2 −2 src/allreduce_robust.cc
  8. +1 −1 src/allreduce_robust.h
  9. +47 −14 src/engine.cc
  10. +87 −0 src/thread_local.h
View
@@ -2,25 +2,25 @@ export CC = gcc
export CXX = g++
export MPICXX = mpicxx
export LDFLAGS= -pthread -lm -L../lib
-export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../include
+export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -fopenmp -I../include
.PHONY: clean all lib libmpi
BIN = basic.rabit broadcast.rabit
MOCKBIN= lazy_allreduce.mock
all: $(BIN)
-basic.rabit: basic.cc lib
-broadcast.rabit: broadcast.cc lib
-lazy_allreduce.mock: lazy_allreduce.cc lib
+basic.rabit: basic.cc lib ../lib/librabit.a
+broadcast.rabit: broadcast.cc lib ../lib/librabit.a
+lazy_allreduce.mock: lazy_allreduce.cc lib ../lib/librabit.a
-$(BIN) :
- $(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) $(LDFLAGS) -lrabit
+$(BIN) :
+ $(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc %.a, $^) $(LDFLAGS)
-$(MOCKBIN) :
+$(MOCKBIN) :
$(CXX) $(CFLAGS) -std=c++11 -o $@ $(filter %.cpp %.o %.c %.cc, $^) $(LDFLAGS) -lrabit_mock
-$(OBJ) :
+$(OBJ) :
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )
clean:
- $(RM) $(OBJ) $(BIN) $(MOCKBIN) *~ ../src/*~
+ $(RM) $(OBJ) $(BIN) $(MOCKBIN) *~ ../src/*~
View
@@ -8,7 +8,7 @@
#define _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_DEPRECATE
#include <vector>
-#include <rabit.h>
+#include <rabit/rabit.h>
using namespace rabit;
int main(int argc, char *argv[]) {
int N = 3;
@@ -19,7 +19,7 @@ int main(int argc, char *argv[]) {
rabit::Init(argc, argv);
for (int i = 0; i < N; ++i) {
a[i] = rabit::GetRank() + i;
- }
+ }
printf("@node[%d] before-allreduce: a={%d, %d, %d}\n",
rabit::GetRank(), a[0], a[1], a[2]);
// allreduce take max of each elements in all processes
@@ -29,7 +29,7 @@ int main(int argc, char *argv[]) {
// second allreduce that sums everything up
Allreduce<op::Sum>(&a[0], N);
printf("@node[%d] after-allreduce-sum: a={%d, %d, %d}\n",
- rabit::GetRank(), a[0], a[1], a[2]);
+ rabit::GetRank(), a[0], a[1], a[2]);
rabit::Finalize();
return 0;
}
View
@@ -1,4 +1,4 @@
-#include <rabit.h>
+#include <rabit/rabit.h>
using namespace rabit;
const int N = 3;
int main(int argc, char *argv[]) {
@@ -5,7 +5,8 @@
*
* \author Tianqi Chen
*/
-#include <rabit.h>
+#include <rabit/rabit.h>
+
using namespace rabit;
const int N = 3;
int main(int argc, char *argv[]) {
@@ -16,18 +17,18 @@ int main(int argc, char *argv[]) {
printf("@node[%d] run prepare function\n", rabit::GetRank());
for (int i = 0; i < N; ++i) {
a[i] = rabit::GetRank() + i;
- }
+ }
};
printf("@node[%d] before-allreduce: a={%d, %d, %d}\n",
rabit::GetRank(), a[0], a[1], a[2]);
// allreduce take max of each elements in all processes
- Allreduce<op::Max>(&a[0], N, prepare);
+ Allreduce<op::Max>(&a[0], N, prepare);
printf("@node[%d] after-allreduce-sum: a={%d, %d, %d}\n",
- rabit::GetRank(), a[0], a[1], a[2]);
+ rabit::GetRank(), a[0], a[1], a[2]);
// rum second allreduce
Allreduce<op::Sum>(&a[0], N);
printf("@node[%d] after-allreduce-max: a={%d, %d, %d}\n",
- rabit::GetRank(), a[0], a[1], a[2]);
+ rabit::GetRank(), a[0], a[1], a[2]);
rabit::Finalize();
return 0;
}
View
@@ -51,7 +51,7 @@ AllreduceBase::AllreduceBase(void) {
}
// initialization function
-void AllreduceBase::Init(void) {
+void AllreduceBase::Init(int argc, char* argv[]) {
// setup from enviroment variables
// handler to get variables from env
for (size_t i = 0; i < env_vars.size(); ++i) {
@@ -60,6 +60,14 @@ void AllreduceBase::Init(void) {
this->SetParam(env_vars[i].c_str(), value);
}
}
+ // pass in arguments override env variable.
+ for (int i = 0; i < argc; ++i) {
+ char name[256], val[256];
+ if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
+ this->SetParam(name, val);
+ }
+ }
+
{
// handling for hadoop
const char *task_id = getenv("mapred_tip_id");
@@ -38,7 +38,7 @@ class AllreduceBase : public IEngine {
AllreduceBase(void);
virtual ~AllreduceBase(void) {}
// initialize the manager
- virtual void Init(void);
+ virtual void Init(int argc, char* argv[]);
// shutdown the engine
virtual void Shutdown(void);
/*!
@@ -31,8 +31,8 @@ AllreduceRobust::AllreduceRobust(void) {
env_vars.push_back("rabit_global_replica");
env_vars.push_back("rabit_local_replica");
}
-void AllreduceRobust::Init(void) {
- AllreduceBase::Init();
+void AllreduceRobust::Init(int argc, char* argv[]) {
+ AllreduceBase::Init(argc, argv);
result_buffer_round = std::max(world_size / num_global_replica, 1);
}
/*! \brief shutdown the engine */
@@ -24,7 +24,7 @@ class AllreduceRobust : public AllreduceBase {
AllreduceRobust(void);
virtual ~AllreduceRobust(void) {}
// initialize the manager
- virtual void Init(void);
+ virtual void Init(int argc, char* argv[]);
/*! \brief shutdown the engine */
virtual void Shutdown(void);
/*!
View
@@ -10,42 +10,72 @@
#define _CRT_SECURE_NO_DEPRECATE
#define NOMINMAX
+#include <memory>
#include "../include/rabit/internal/engine.h"
#include "./allreduce_base.h"
#include "./allreduce_robust.h"
+#include "./thread_local.h"
namespace rabit {
namespace engine {
// singleton sync manager
#ifndef RABIT_USE_BASE
#ifndef RABIT_USE_MOCK
-AllreduceRobust manager;
+typedef AllreduceRobust Manager;
#else
-AllreduceMock manager;
+typedef AllreduceMock Manager;
#endif
#else
-AllreduceBase manager;
+typedef AllreduceBase Manager;
#endif
+/*! \brief entry to to easily hold returning information */
+struct ThreadLocalEntry {
+ /*! \brief stores the current engine */
+ std::unique_ptr<Manager> engine;
+ /*! \brief whether init has been called */
+ bool initialized;
+ /*! \brief constructor */
+ ThreadLocalEntry() : initialized(false) {}
+};
+
+// define the threadlocal store.
+typedef ThreadLocalStore<ThreadLocalEntry> EngineThreadLocal;
+
/*! \brief intiialize the synchronization module */
void Init(int argc, char *argv[]) {
- for (int i = 1; i < argc; ++i) {
- char name[256], val[256];
- if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
- manager.SetParam(name, val);
- }
- }
- manager.Init();
+ ThreadLocalEntry* e = EngineThreadLocal::Get();
+ utils::Check(e->engine.get() == nullptr,
+ "rabit::Init is already called in this thread");
+ e->initialized = true;
+ e->engine.reset(new Manager());
+ e->engine->Init(argc, argv);
}
/*! \brief finalize syncrhonization module */
-void Finalize(void) {
- manager.Shutdown();
+void Finalize() {
+ ThreadLocalEntry* e = EngineThreadLocal::Get();
+ utils::Check(e->engine.get() != nullptr,
+ "rabit::Finalize engine is not initialized or already been finalized.");
+ e->engine->Shutdown();
+ e->engine.reset(nullptr);
}
+
/*! \brief singleton method to get engine */
-IEngine *GetEngine(void) {
- return &manager;
+IEngine *GetEngine() {
+ // un-initialized default manager.
+ static AllreduceBase default_manager;
+ ThreadLocalEntry* e = EngineThreadLocal::Get();
+ IEngine* ptr = e->engine.get();
+ if (ptr == nullptr) {
+ utils::Check(!e->initialized,
+ "Doing rabit call after Finalize");
+ return &default_manager;
+ } else {
+ return ptr;
+ }
}
+
// perform in-place allreduce, on sendrecvbuf
void Allreduce_(void *sendrecvbuf,
size_t type_nbytes,
@@ -63,15 +93,18 @@ void Allreduce_(void *sendrecvbuf,
ReduceHandle::ReduceHandle(void)
: handle_(NULL), redfunc_(NULL), htype_(NULL) {
}
+
ReduceHandle::~ReduceHandle(void) {}
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
return static_cast<int>(dtype.type_size);
}
+
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
utils::Assert(redfunc_ == NULL, "cannot initialize reduce handle twice");
redfunc_ = redfunc;
}
+
void ReduceHandle::Allreduce(void *sendrecvbuf,
size_t type_nbytes, size_t count,
IEngine::PreprocFunction prepare_fun,
View
@@ -0,0 +1,87 @@
+/*!
+ * Copyright (c) 2015 by Contributors
+ * \file thread_local.h
+ * \brief Common utility for thread local storage.
+ */
+#ifndef RABIT_THREAD_LOCAL_H_
+#define RABIT_THREAD_LOCAL_H_
+
+#include "../include/dmlc/base.h"
+
+#if DMLC_ENABLE_STD_THREAD
+#include <mutex>
+#endif
+
+#include <memory>
+#include <vector>
+
+namespace rabit {
+
+// macro hanlding for threadlocal variables
+#ifdef __GNUC__
+ #define MX_TREAD_LOCAL __thread
+#elif __STDC_VERSION__ >= 201112L
+ #define MX_TREAD_LOCAL _Thread_local
+#elif defined(_MSC_VER)
+ #define MX_TREAD_LOCAL __declspec(thread)
+#endif
+
+#ifndef MX_TREAD_LOCAL
+#message("Warning: Threadlocal is not enabled");
+#endif
+
+/*!
+ * \brief A threadlocal store to store threadlocal variables.
+ * Will return a thread local singleton of type T
+ * \tparam T the type we like to store
+ */
+template<typename T>
+class ThreadLocalStore {
+ public:
+ /*! \return get a thread local singleton */
+ static T* Get() {
+ static MX_TREAD_LOCAL T* ptr = nullptr;
+ if (ptr == nullptr) {
+ ptr = new T();
+ Singleton()->RegisterDelete(ptr);
+ }
+ return ptr;
+ }
+
+ private:
+ /*! \brief constructor */
+ ThreadLocalStore() {}
+ /*! \brief destructor */
+ ~ThreadLocalStore() {
+ for (size_t i = 0; i < data_.size(); ++i) {
+ delete data_[i];
+ }
+ }
+ /*! \return singleton of the store */
+ static ThreadLocalStore<T> *Singleton() {
+ static ThreadLocalStore<T> inst;
+ return &inst;
+ }
+ /*!
+ * \brief register str for internal deletion
+ * \param str the string pointer
+ */
+ void RegisterDelete(T *str) {
+#if DMLC_ENABLE_STD_THREAD
+ std::unique_lock<std::mutex> lock(mutex_);
+ data_.push_back(str);
+ lock.unlock();
+#else
+ data_.push_back(str);
+#endif
+ }
+
+#if DMLC_ENABLE_STD_THREAD
+ /*! \brief internal mutex */
+ std::mutex mutex_;
+#endif
+ /*!\brief internal data */
+ std::vector<T*> data_;
+};
+} // namespace rabit
+#endif // RABIT_THREAD_LOCAL_H_

0 comments on commit be50e7b

Please sign in to comment.